bug 236: add optional support for tcp and udp checksum.
authorSebastien Vincent <vincent@clarinet.u-strasbg.fr>
Tue, 01 Jul 2008 10:52:11 -0700
changeset 3363 33d1ca2e4ba4
parent 3362 9a6f1b3c6e0b
child 3364 8e6ac6061680
bug 236: add optional support for tcp and udp checksum.
src/common/buffer.cc
src/common/buffer.h
src/internet-stack/tcp-header.cc
src/internet-stack/tcp-header.h
src/internet-stack/tcp-l4-protocol.cc
src/internet-stack/tcp-l4-protocol.h
src/internet-stack/udp-header.cc
src/internet-stack/udp-header.h
src/internet-stack/udp-l4-protocol.cc
src/internet-stack/udp-l4-protocol.h
src/node/ipv4-header.cc
src/node/ipv4-header.h
--- a/src/common/buffer.cc	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/common/buffer.cc	Tue Jul 01 10:52:11 2008 -0700
@@ -1105,6 +1105,29 @@
 
 #endif /* BUFFER_USE_INLINE */
 
+uint16_t
+Buffer::Iterator::CalculateIpChecksum(uint16_t size)
+{
+  return CalculateIpChecksum(size, 0);
+}
+
+uint16_t
+Buffer::Iterator::CalculateIpChecksum(uint16_t size, uint32_t initialChecksum)
+{
+  /* see RFC 1071 to understand this code. */
+  uint32_t sum = initialChecksum;
+
+  for (int j = 0; j < size/2; j++)
+    sum += ReadU16 ();
+
+  if (size & 1)
+     sum += ReadU8 ();
+
+  while (sum >> 16)
+    sum = (sum & 0xffff) + (sum >> 16);
+  return ~sum;
+}
+
 } // namespace ns3
 
 
--- a/src/common/buffer.h	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/common/buffer.h	Tue Jul 01 10:52:11 2008 -0700
@@ -342,6 +342,22 @@
        * bytes read.
        */
       void Read (uint8_t *buffer, uint32_t size);
+
+      /**
+       * \brief Calculate the checksum.
+       * \param size size of the buffer.
+       * \return checksum
+       */
+      uint16_t CalculateIpChecksum(uint16_t size);
+
+      /**
+       * \brief Calculate the checksum.
+       * \param size size of the buffer.
+       * \param initialChecksum initial value
+       * \return checksum
+       */
+      uint16_t CalculateIpChecksum(uint16_t size, uint32_t initialChecksum);
+
   private:
       friend class Buffer;
       Iterator (Buffer const*buffer);
--- a/src/internet-stack/tcp-header.cc	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-header.cc	Tue Jul 01 10:52:11 2008 -0700
@@ -28,8 +28,6 @@
 
 NS_OBJECT_ENSURE_REGISTERED (TcpHeader);
 
-bool TcpHeader::m_calcChecksum = false;
-
 TcpHeader::TcpHeader ()
   : m_sourcePort (0),
     m_destinationPort (0),
@@ -38,8 +36,11 @@
     m_length (5),
     m_flags (0),
     m_windowSize (0xffff),
+    m_urgentPointer (0),
+    m_initialChecksum(0),
     m_checksum (0),
-    m_urgentPointer (0)
+    m_calcChecksum(false),
+    m_goodChecksum(true)
 {}
 
 TcpHeader::~TcpHeader ()
@@ -51,6 +52,11 @@
   m_calcChecksum = true;
 }
 
+void TcpHeader::SetPayloadSize(uint16_t payloadSize)
+{
+  m_payloadSize = payloadSize;
+}
+
 void TcpHeader::SetSourcePort (uint16_t port)
 {
   m_sourcePort = port;
@@ -130,8 +136,32 @@
                                    Ipv4Address destination,
                                    uint8_t protocol)
 {
-  m_checksum = 0;
-//XXX requires peeking into IP to get length of the TCP segment
+  Buffer buf = Buffer(12);
+  uint8_t tmp[4];
+  Buffer::Iterator it;
+  uint16_t tcpLength = m_payloadSize + GetSerializedSize();
+
+  buf.AddAtStart(12);
+  it = buf.Begin();
+
+  source.Serialize(tmp);
+  it.Write(tmp, 4); /* source IP address */
+  destination.Serialize(tmp);
+  it.Write(tmp, 4); /* destination IP address */
+  it.WriteU8(0); /* protocol */
+  it.WriteU8(protocol); /* protocol */
+  it.WriteU8(tcpLength >> 8); /* length */
+  it.WriteU8(tcpLength & 0xff); /* length */
+
+  it = buf.Begin();
+  /* we don't CompleteChecksum ( ~ ) now */
+  m_initialChecksum = ~(it.CalculateIpChecksum(12));
+}
+
+bool
+TcpHeader::IsChecksumOk (void) const
+{
+  return m_goodChecksum;
 }
 
 TypeId 
@@ -188,28 +218,49 @@
 }
 void TcpHeader::Serialize (Buffer::Iterator start)  const
 {
-  start.WriteHtonU16 (m_sourcePort);
-  start.WriteHtonU16 (m_destinationPort);
-  start.WriteHtonU32 (m_sequenceNumber);
-  start.WriteHtonU32 (m_ackNumber);
-  start.WriteHtonU16 (m_length << 12 | m_flags); //reserved bits are all zero
-  start.WriteHtonU16 (m_windowSize);
-  //XXX calculate checksum here
-  start.WriteHtonU16 (m_checksum);
-  start.WriteHtonU16 (m_urgentPointer);
+  Buffer::Iterator i = start;
+  uint16_t tcpLength = m_payloadSize + GetSerializedSize();
+  i.WriteHtonU16 (m_sourcePort);
+  i.WriteHtonU16 (m_destinationPort);
+  i.WriteHtonU32 (m_sequenceNumber);
+  i.WriteHtonU32 (m_ackNumber);
+  i.WriteHtonU16 (m_length << 12 | m_flags); //reserved bits are all zero
+  i.WriteHtonU16 (m_windowSize);
+  i.WriteHtonU16 (0);
+  i.WriteHtonU16 (m_urgentPointer);
+
+  if(m_calcChecksum)
+  {
+    i = start;
+    uint16_t checksum = i.CalculateIpChecksum(tcpLength, m_initialChecksum);
+    
+    i = start;
+    i.Next(16);
+    i.WriteU16(checksum);
+  }
 }
 uint32_t TcpHeader::Deserialize (Buffer::Iterator start)
 {
-  m_sourcePort = start.ReadNtohU16 ();
-  m_destinationPort = start.ReadNtohU16 ();
-  m_sequenceNumber = start.ReadNtohU32 ();
-  m_ackNumber = start.ReadNtohU32 ();
-  uint16_t field = start.ReadNtohU16 ();
+  Buffer::Iterator i = start;
+  m_sourcePort = i.ReadNtohU16 ();
+  m_destinationPort = i.ReadNtohU16 ();
+  m_sequenceNumber = i.ReadNtohU32 ();
+  m_ackNumber = i.ReadNtohU32 ();
+  uint16_t field = i.ReadNtohU16 ();
   m_flags = field & 0x3F;
   m_length = field>>12;
-  m_windowSize = start.ReadNtohU16 ();
-  m_checksum = start.ReadNtohU16 ();
-  m_urgentPointer = start.ReadNtohU16 ();
+  m_windowSize = i.ReadNtohU16 ();
+  m_checksum = i.ReadU16 ();
+  m_urgentPointer = i.ReadNtohU16 ();
+
+  if(m_calcChecksum)
+  {
+      i = start;
+      uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
+
+      m_goodChecksum = (checksum == 0);
+  }
+
   return GetSerializedSize ();
 }
 
--- a/src/internet-stack/tcp-header.h	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-header.h	Tue Jul 01 10:52:11 2008 -0700
@@ -47,7 +47,7 @@
   /**
    * \brief Enable checksum calculation for TCP (XXX currently has no effect)
    */
-  static void EnableChecksums (void);
+  void EnableChecksums (void);
 //Setters
   /**
    * \param port The source port for this TcpHeader
@@ -150,6 +150,17 @@
   virtual void Serialize (Buffer::Iterator start) const;
   virtual uint32_t Deserialize (Buffer::Iterator start);
 
+  /**
+   * \param size The payload size in bytes
+   */
+  void SetPayloadSize (uint16_t size);
+
+  /**
+   * \brief Is the TCP checksum correct ?
+   * \returns true if the checksum is correct, false otherwise.
+   */
+  bool IsChecksumOk (void) const;
+
 private:
   uint16_t m_sourcePort;
   uint16_t m_destinationPort;
@@ -158,10 +169,14 @@
   uint8_t m_length; // really a uint4_t
   uint8_t m_flags;      // really a uint6_t
   uint16_t m_windowSize;
+  uint16_t m_urgentPointer;
+  uint16_t m_payloadSize;
+  uint16_t m_initialChecksum;
   uint16_t m_checksum;
-  uint16_t m_urgentPointer;
 
-  static bool m_calcChecksum;
+  bool m_calcChecksum;
+  bool m_goodChecksum;
+
 };
 
 }; // namespace ns3
--- a/src/internet-stack/tcp-l4-protocol.cc	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-l4-protocol.cc	Tue Jul 01 10:52:11 2008 -0700
@@ -21,6 +21,7 @@
 #include "ns3/assert.h"
 #include "ns3/log.h"
 #include "ns3/nstime.h"
+#include "ns3/boolean.h"
 
 #include "ns3/packet.h"
 #include "ns3/node.h"
@@ -328,6 +329,11 @@
                    ObjectFactoryValue (GetDefaultRttEstimatorFactory ()),
                    MakeObjectFactoryAccessor (&TcpL4Protocol::m_rttFactory),
                    MakeObjectFactoryChecker ())
+    .AddAttribute ("CalcChecksum", "If true, we calculate the checksum of outgoing packets"
+                   " and verify the checksum of incoming packets.",
+                   BooleanValue (false),
+                   MakeBooleanAccessor (&TcpL4Protocol::m_calcChecksum),
+                   MakeBooleanChecker ())
     ;
   return tid;
 }
@@ -439,14 +445,31 @@
   NS_LOG_FUNCTION (this << packet << source << destination << incomingInterface);
 
   TcpHeader tcpHeader;
+  if(m_calcChecksum)
+  {
+    tcpHeader.EnableChecksums();
+  }
+  /* XXX very dirty but needs this to AddHeader again because of checksum */
+  tcpHeader.SetLength(5); /* XXX TCP without options */
+  tcpHeader.SetPayloadSize(packet->GetSize() - tcpHeader.GetSerializedSize());
+  tcpHeader.InitializeChecksum(source, destination, PROT_NUMBER);
+
   //these two do a peek, so that the packet can be forwarded up
   packet->RemoveHeader (tcpHeader);
+
   NS_LOG_LOGIC("TcpL4Protocol " << this
                << " receiving seq " << tcpHeader.GetSequenceNumber()
                << " ack " << tcpHeader.GetAckNumber()
                << " flags "<< std::hex << (int)tcpHeader.GetFlags() << std::dec
                << " data size " << packet->GetSize());
-  packet->AddHeader (tcpHeader); 
+
+  if(!tcpHeader.IsChecksumOk ())
+  {
+    NS_LOG_INFO("Bad checksum, dropping packet!");
+    return;
+  }
+
+  packet->AddHeader (tcpHeader);
   NS_LOG_LOGIC ("TcpL4Protocol "<<this<<" received a packet");
   Ipv4EndPointDemux::EndPoints endPoints =
     m_endPoints->Lookup (destination, tcpHeader.GetDestinationPort (),
@@ -478,6 +501,11 @@
   TcpHeader tcpHeader;
   tcpHeader.SetDestinationPort (dport);
   tcpHeader.SetSourcePort (sport);
+  tcpHeader.SetPayloadSize(packet->GetSize());
+  if(m_calcChecksum)
+  {
+    tcpHeader.EnableChecksums();
+  }
   tcpHeader.InitializeChecksum (saddr,
                                daddr,
                                PROT_NUMBER);
@@ -507,8 +535,13 @@
   // XXX outgoingHeader cannot be logged
 
   outgoingHeader.SetLength (5); //header length in units of 32bit words
-  outgoingHeader.SetChecksum (0);  //XXX
-  outgoingHeader.SetUrgentPointer (0); //XXX
+  outgoingHeader.SetPayloadSize(packet->GetSize());
+  /* outgoingHeader.SetUrgentPointer (0); //XXX */
+  if(m_calcChecksum)
+  {
+    outgoingHeader.EnableChecksums();
+  }
+  outgoingHeader.InitializeChecksum(saddr, daddr, PROT_NUMBER);
 
   packet->AddHeader (outgoingHeader);
 
--- a/src/internet-stack/tcp-l4-protocol.h	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-l4-protocol.h	Tue Jul 01 10:52:11 2008 -0700
@@ -117,6 +117,9 @@
   void SendPacket (Ptr<Packet>, TcpHeader,
                   Ipv4Address, Ipv4Address);
   static ObjectFactory GetDefaultRttEstimatorFactory (void);
+
+  bool m_goodChecksum;
+  bool m_calcChecksum;
 };
 
 }; // namespace ns3
--- a/src/internet-stack/udp-header.cc	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/udp-header.cc	Tue Jul 01 10:52:11 2008 -0700
@@ -19,14 +19,11 @@
  */
 
 #include "udp-header.h"
-#include "ipv4-checksum.h"
 
 namespace ns3 {
 
 NS_OBJECT_ENSURE_REGISTERED (UdpHeader);
 
-bool UdpHeader::m_calcChecksum = false;
-
 /* The magic values below are used only for debugging.
  * They can be used to easily detect memory corruption
  * problems so you can see the patterns in memory.
@@ -35,7 +32,10 @@
   : m_sourcePort (0xfffd),
     m_destinationPort (0xfffd),
     m_payloadSize (0xfffd),
-    m_initialChecksum (0)
+    m_initialChecksum (0),
+    m_checksum(0),
+    m_calcChecksum(false),
+    m_goodChecksum(true)
 {}
 UdpHeader::~UdpHeader ()
 {
@@ -80,18 +80,35 @@
                               Ipv4Address destination,
                               uint8_t protocol)
 {
-  uint8_t buf[12];
-  source.Serialize (buf);
-  destination.Serialize (buf+4);
-  buf[8] = 0;
-  buf[9] = protocol;
-  uint16_t udpLength = m_payloadSize + GetSerializedSize ();
-  buf[10] = udpLength >> 8;
-  buf[11] = udpLength & 0xff;
+  Buffer buf = Buffer(12);
+  uint8_t tmp[4];
+  Buffer::Iterator it;
+  uint16_t udpLength = m_payloadSize + GetSerializedSize();
+
+  buf.AddAtStart(12);
+  it = buf.Begin();
 
-  m_initialChecksum = Ipv4ChecksumCalculate (0, buf, 12);
+  source.Serialize(tmp);
+  it.Write(tmp, 4); /* source IP address */
+  destination.Serialize(tmp);
+  it.Write(tmp, 4); /* destination IP address */
+  it.WriteU8(0); /* protocol */
+  it.WriteU8(protocol); /* protocol */
+  it.WriteU8(udpLength >> 8); /* length */
+  it.WriteU8(udpLength & 0xff); /* length */
+
+  it = buf.Begin();
+  /* we don't CompleteChecksum ( ~ ) now */
+  m_initialChecksum = ~(it.CalculateIpChecksum(12));
 }
 
+bool
+UdpHeader::IsChecksumOk (void) const
+{
+  return m_goodChecksum; 
+}
+
+
 TypeId 
 UdpHeader::GetTypeId (void)
 {
@@ -125,23 +142,21 @@
 UdpHeader::Serialize (Buffer::Iterator start) const
 {
   Buffer::Iterator i = start;
+  uint16_t udpLength = m_payloadSize + GetSerializedSize();
+
   i.WriteHtonU16 (m_sourcePort);
   i.WriteHtonU16 (m_destinationPort);
-  i.WriteHtonU16 (m_payloadSize + GetSerializedSize ());
+  i.WriteHtonU16 (udpLength);
   i.WriteU16 (0);
 
-  if (m_calcChecksum) 
+  if (m_calcChecksum)
     {
-#if 0
-      //XXXX
-      uint16_t checksum = Ipv4ChecksumCalculate (m_initialChecksum, 
-                                                  buffer->PeekData (), 
-                                                  GetSerializedSize () + m_payloadSize);
-      checksum = Ipv4ChecksumComplete (checksum);
-      i = buffer->Begin ();
-      i.Next (6);
-      i.WriteU16 (checksum);
-#endif
+      i = start;
+      uint16_t checksum = i.CalculateIpChecksum(udpLength, m_initialChecksum);
+
+      i = start;
+      i.Next(6);
+      i.WriteU16(checksum);
     }
 }
 uint32_t
@@ -151,10 +166,16 @@
   m_sourcePort = i.ReadNtohU16 ();
   m_destinationPort = i.ReadNtohU16 ();
   m_payloadSize = i.ReadNtohU16 () - GetSerializedSize ();
-  if (m_calcChecksum) 
-    {
-      // XXX verify checksum.
-    }
+  m_checksum = i.ReadU16();
+
+  if(m_calcChecksum)
+  {
+      i = start;
+      uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
+
+      m_goodChecksum = (checksum == 0);
+  }
+
   return GetSerializedSize ();
 }
 
--- a/src/internet-stack/udp-header.h	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/udp-header.h	Tue Jul 01 10:52:11 2008 -0700
@@ -49,7 +49,7 @@
   /**
    * \brief Enable checksum calculation for UDP (XXX currently has no effect)
    */
-  static void EnableChecksums (void);
+  void EnableChecksums (void);
   /**
    * \param port the destination port for this UdpHeader
    */
@@ -93,13 +93,21 @@
   virtual void Serialize (Buffer::Iterator start) const;
   virtual uint32_t Deserialize (Buffer::Iterator start);
 
+  /**
+   * \brief Is the UDP checksum correct ?
+   * \returns true if the checksum is correct, false otherwise.
+   */
+  bool IsChecksumOk (void) const;
+
 private:
   uint16_t m_sourcePort;
   uint16_t m_destinationPort;
   uint16_t m_payloadSize;
   uint16_t m_initialChecksum;
+  uint16_t m_checksum;
 
-  static bool m_calcChecksum;
+  bool m_calcChecksum;
+  bool m_goodChecksum;
 };
 
 } // namespace ns3
--- a/src/internet-stack/udp-l4-protocol.cc	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/udp-l4-protocol.cc	Tue Jul 01 10:52:11 2008 -0700
@@ -22,6 +22,7 @@
 #include "ns3/assert.h"
 #include "ns3/packet.h"
 #include "ns3/node.h"
+#include "ns3/boolean.h"
 
 #include "udp-l4-protocol.h"
 #include "udp-header.h"
@@ -45,6 +46,11 @@
   static TypeId tid = TypeId ("ns3::UdpL4Protocol")
     .SetParent<Ipv4L4Protocol> ()
     .AddConstructor<UdpL4Protocol> ()
+    .AddAttribute ("CalcChecksum", "If true, we calculate the checksum of outgoing packets"
+                   " and verify the checksum of incoming packets.",
+                   BooleanValue (false),
+                   MakeBooleanAccessor (&UdpL4Protocol::m_calcChecksum),
+                   MakeBooleanChecker ())
     ;
   return tid;
 }
@@ -151,9 +157,23 @@
                        Ptr<Ipv4Interface> interface)
 {
   NS_LOG_FUNCTION (this << packet << source << destination);
+  UdpHeader udpHeader;
+  if(m_calcChecksum)
+  {
+    udpHeader.EnableChecksums();
+  }
 
-  UdpHeader udpHeader;
+  udpHeader.SetPayloadSize (packet->GetSize () - udpHeader.GetSerializedSize ());
+  udpHeader.InitializeChecksum (source, destination, PROT_NUMBER);
+
   packet->RemoveHeader (udpHeader);
+
+  if(!udpHeader.IsChecksumOk ())
+  {
+    NS_LOG_INFO("Bad checksum : dropping packet!");
+    return;
+  }
+
   Ipv4EndPointDemux::EndPoints endPoints =
     m_endPoints->Lookup (destination, udpHeader.GetDestinationPort (),
                          source, udpHeader.GetSourcePort (), interface);
@@ -172,6 +192,10 @@
   NS_LOG_FUNCTION (this << packet << saddr << daddr << sport << dport);
 
   UdpHeader udpHeader;
+  if(m_calcChecksum)
+  {
+    udpHeader.EnableChecksums();
+  }
   udpHeader.SetDestinationPort (dport);
   udpHeader.SetSourcePort (sport);
   udpHeader.SetPayloadSize (packet->GetSize ());
--- a/src/internet-stack/udp-l4-protocol.h	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/udp-l4-protocol.h	Tue Jul 01 10:52:11 2008 -0700
@@ -93,6 +93,7 @@
 private:
   Ptr<Node> m_node;
   Ipv4EndPointDemux *m_endPoints;
+  bool m_calcChecksum;
 };
 
 }; // namespace ns3
--- a/src/node/ipv4-header.cc	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/node/ipv4-header.cc	Tue Jul 01 10:52:11 2008 -0700
@@ -38,6 +38,7 @@
     m_protocol (0),
     m_flags (0),
     m_fragmentOffset (0),
+    m_checksum(0),
     m_goodChecksum (true)
 {}
 
@@ -177,23 +178,6 @@
   return m_goodChecksum;
 }
 
-uint16_t
-Ipv4Header::ChecksumCalculate(Buffer::Iterator &i, uint16_t size)
-{
-  /* see RFC 1071 to understand this code. */
-  uint32_t sum = 0;
-
-  for (int j = 0; j < size/2; j++)
-    sum += i.ReadU16 ();
-
-  if (size & 1)
-     sum += i.ReadU8 ();
-
-  while (sum >> 16)
-    sum = (sum & 0xffff) + (sum >> 16);
-  return ~sum;
-}
-
 TypeId 
 Ipv4Header::GetTypeId (void)
 {
@@ -282,7 +266,7 @@
   if (m_calcChecksum) 
     {
       i = start;
-      uint16_t checksum = ChecksumCalculate(i, 20);
+      uint16_t checksum = i.CalculateIpChecksum(20);
       NS_LOG_LOGIC ("checksum=" <<checksum);
       i = start;
       i.Next (10);
@@ -318,14 +302,15 @@
   m_fragmentOffset <<= 3;
   m_ttl = i.ReadU8 ();
   m_protocol = i.ReadU8 ();
-  i.Next (2); // checksum
+  m_checksum = i.ReadU16();
+  /* i.Next (2); // checksum */
   m_source.Set (i.ReadNtohU32 ());
   m_destination.Set (i.ReadNtohU32 ());
 
   if (m_calcChecksum) 
     {
       i = start;
-      uint16_t checksum = ChecksumCalculate(i, headerSize);
+      uint16_t checksum = i.CalculateIpChecksum(headerSize);
       NS_LOG_LOGIC ("checksum=" <<checksum);
 
       m_goodChecksum = (checksum == 0);
--- a/src/node/ipv4-header.h	Mon Jun 30 22:41:22 2008 -0700
+++ b/src/node/ipv4-header.h	Tue Jul 01 10:52:11 2008 -0700
@@ -146,7 +146,6 @@
   virtual uint32_t Deserialize (Buffer::Iterator start);
 private:
 
-  static uint16_t ChecksumCalculate(Buffer::Iterator &i, uint16_t len);
   enum FlagsE {
     DONT_FRAGMENT = (1<<0),
     MORE_FRAGMENTS = (1<<1)
@@ -163,6 +162,7 @@
   uint16_t m_fragmentOffset : 13;
   Ipv4Address m_source;
   Ipv4Address m_destination;
+  uint16_t m_checksum;
   bool m_goodChecksum;
 };