bug 247: tcp checksum crashes when enabled.
authorMathieu Lacage <mathieu.lacage@sophia.inria.fr>
Thu, 10 Jul 2008 15:58:24 -0700
changeset 3404 b5d4a04c7b68
parent 3403 ac82ff1f6736
child 3405 7e943b537495
bug 247: tcp checksum crashes when enabled.
src/common/buffer.cc
src/common/buffer.h
src/internet-stack/tcp-header.cc
src/internet-stack/tcp-header.h
src/internet-stack/tcp-l4-protocol.cc
src/internet-stack/udp-header.cc
src/internet-stack/udp-header.h
src/internet-stack/udp-l4-protocol.cc
--- a/src/common/buffer.cc	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/common/buffer.cc	Thu Jul 10 15:58:24 2008 -0700
@@ -1128,6 +1128,12 @@
   return ~sum;
 }
 
+uint32_t 
+Buffer::Iterator::GetSize (void) const
+{
+  return m_dataEnd - m_dataStart;
+}
+
 } // namespace ns3
 
 
--- a/src/common/buffer.h	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/common/buffer.h	Thu Jul 10 15:58:24 2008 -0700
@@ -358,6 +358,11 @@
        */
       uint16_t CalculateIpChecksum(uint16_t size, uint32_t initialChecksum);
 
+      /**
+       * \returns the size of the underlying buffer we are iterating
+       */
+      uint32_t GetSize (void) const;
+
   private:
       friend class Buffer;
       Iterator (Buffer const*buffer);
--- a/src/internet-stack/tcp-header.cc	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/internet-stack/tcp-header.cc	Thu Jul 10 15:58:24 2008 -0700
@@ -23,6 +23,7 @@
 #include "tcp-socket-impl.h"
 #include "tcp-header.h"
 #include "ns3/buffer.h"
+#include "ns3/address-utils.h"
 
 namespace ns3 {
 
@@ -37,8 +38,6 @@
     m_flags (0),
     m_windowSize (0xffff),
     m_urgentPointer (0),
-    m_initialChecksum(0),
-    m_checksum (0),
     m_calcChecksum(false),
     m_goodChecksum(true)
 {}
@@ -52,11 +51,6 @@
   m_calcChecksum = true;
 }
 
-void TcpHeader::SetPayloadSize(uint16_t payloadSize)
-{
-  m_payloadSize = payloadSize;
-}
-
 void TcpHeader::SetSourcePort (uint16_t port)
 {
   m_sourcePort = port;
@@ -85,10 +79,6 @@
 {
   m_windowSize = windowSize;
 }
-void TcpHeader::SetChecksum (uint16_t checksum)
-{
-  m_checksum = checksum;
-}
 void TcpHeader::SetUrgentPointer (uint16_t urgentPointer)
 {
   m_urgentPointer = urgentPointer;
@@ -122,10 +112,6 @@
 {
   return m_windowSize;
 }
-uint16_t TcpHeader::GetChecksum () const
-{
-  return m_checksum;
-}
 uint16_t TcpHeader::GetUrgentPointer () const
 {
   return m_urgentPointer;
@@ -133,29 +119,31 @@
 
 void 
 TcpHeader::InitializeChecksum (Ipv4Address source, 
-                                   Ipv4Address destination,
-                                   uint8_t protocol)
+                               Ipv4Address destination,
+                               uint8_t protocol)
 {
-  Buffer buf = Buffer(12);
-  uint8_t tmp[4];
-  Buffer::Iterator it;
-  uint16_t tcpLength = m_payloadSize + GetSerializedSize();
-
-  buf.AddAtStart(12);
-  it = buf.Begin();
+  m_source = source;
+  m_destination = destination;
+  m_protocol = protocol;
+}
 
-  source.Serialize(tmp);
-  it.Write(tmp, 4); /* source IP address */
-  destination.Serialize(tmp);
-  it.Write(tmp, 4); /* destination IP address */
-  it.WriteU8(0); /* protocol */
-  it.WriteU8(protocol); /* protocol */
-  it.WriteU8(tcpLength >> 8); /* length */
-  it.WriteU8(tcpLength & 0xff); /* length */
+uint16_t
+TcpHeader::CalculateHeaderChecksum (uint16_t size) const
+{
+  Buffer buf = Buffer (12);
+  buf.AddAtStart (12);
+  Buffer::Iterator it = buf.Begin ();
 
-  it = buf.Begin();
+  WriteTo (it, m_source);
+  WriteTo (it, m_destination);
+  it.WriteU8 (0); /* protocol */
+  it.WriteU8 (m_protocol); /* protocol */
+  it.WriteU8 (size >> 8); /* length */
+  it.WriteU8 (size & 0xff); /* length */
+
+  it = buf.Begin ();
   /* we don't CompleteChecksum ( ~ ) now */
-  m_initialChecksum = ~(it.CalculateIpChecksum(12));
+  return ~(it.CalculateIpChecksum (12));
 }
 
 bool
@@ -219,7 +207,6 @@
 void TcpHeader::Serialize (Buffer::Iterator start)  const
 {
   Buffer::Iterator i = start;
-  uint16_t tcpLength = m_payloadSize + GetSerializedSize();
   i.WriteHtonU16 (m_sourcePort);
   i.WriteHtonU16 (m_destinationPort);
   i.WriteHtonU32 (m_sequenceNumber);
@@ -231,8 +218,9 @@
 
   if(m_calcChecksum)
   {
+    uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ());
     i = start;
-    uint16_t checksum = i.CalculateIpChecksum(tcpLength, m_initialChecksum);
+    uint16_t checksum = i.CalculateIpChecksum(start.GetSize (), headerChecksum);
     
     i = start;
     i.Next(16);
@@ -250,16 +238,16 @@
   m_flags = field & 0x3F;
   m_length = field>>12;
   m_windowSize = i.ReadNtohU16 ();
-  m_checksum = i.ReadU16 ();
+  i.Next (2);
   m_urgentPointer = i.ReadNtohU16 ();
 
   if(m_calcChecksum)
-  {
+    {
+      uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ());
       i = start;
-      uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
-
+      uint16_t checksum = i.CalculateIpChecksum(start.GetSize (), headerChecksum);
       m_goodChecksum = (checksum == 0);
-  }
+    }
 
   return GetSerializedSize ();
 }
--- a/src/internet-stack/tcp-header.h	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/internet-stack/tcp-header.h	Thu Jul 10 15:58:24 2008 -0700
@@ -78,10 +78,6 @@
    */
   void SetWindowSize (uint16_t windowSize);
   /**
-   * \param checksum the checksum for this TcpHeader
-   */
-  void SetChecksum (uint16_t checksum);
-  /**
    * \param urgentPointer the urgent pointer for this TcpHeader
    */
   void SetUrgentPointer (uint16_t urgentPointer);
@@ -117,10 +113,6 @@
    */
   uint16_t GetWindowSize () const;
   /**
-   * \return the checksum for this TcpHeader
-   */
-  uint16_t GetChecksum () const;
-  /**
    * \return the urgent pointer for this TcpHeader
    */
   uint16_t GetUrgentPointer () const;
@@ -151,17 +143,13 @@
   virtual uint32_t Deserialize (Buffer::Iterator start);
 
   /**
-   * \param size The payload size in bytes
-   */
-  void SetPayloadSize (uint16_t size);
-
-  /**
    * \brief Is the TCP checksum correct ?
    * \returns true if the checksum is correct, false otherwise.
    */
   bool IsChecksumOk (void) const;
 
 private:
+  uint16_t CalculateHeaderChecksum (uint16_t size) const;
   uint16_t m_sourcePort;
   uint16_t m_destinationPort;
   uint32_t m_sequenceNumber;
@@ -170,13 +158,14 @@
   uint8_t m_flags;      // really a uint6_t
   uint16_t m_windowSize;
   uint16_t m_urgentPointer;
-  uint16_t m_payloadSize;
+
+  Ipv4Address m_source;
+  Ipv4Address m_destination;
+  uint8_t m_protocol;
+
   uint16_t m_initialChecksum;
-  uint16_t m_checksum;
-
   bool m_calcChecksum;
   bool m_goodChecksum;
-
 };
 
 }; // namespace ns3
--- a/src/internet-stack/tcp-l4-protocol.cc	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/internet-stack/tcp-l4-protocol.cc	Thu Jul 10 15:58:24 2008 -0700
@@ -448,6 +448,7 @@
   if(m_calcChecksum)
   {
     tcpHeader.EnableChecksums();
+    tcpHeader.InitializeChecksum (source, destination, PROT_NUMBER);
   }
 
   packet->PeekHeader (tcpHeader);
@@ -495,7 +496,6 @@
   TcpHeader tcpHeader;
   tcpHeader.SetDestinationPort (dport);
   tcpHeader.SetSourcePort (sport);
-  tcpHeader.SetPayloadSize(packet->GetSize());
   if(m_calcChecksum)
   {
     tcpHeader.EnableChecksums();
@@ -529,7 +529,6 @@
   // XXX outgoingHeader cannot be logged
 
   outgoingHeader.SetLength (5); //header length in units of 32bit words
-  outgoingHeader.SetPayloadSize(packet->GetSize());
   /* outgoingHeader.SetUrgentPointer (0); //XXX */
   if(m_calcChecksum)
   {
--- a/src/internet-stack/udp-header.cc	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/internet-stack/udp-header.cc	Thu Jul 10 15:58:24 2008 -0700
@@ -19,6 +19,7 @@
  */
 
 #include "udp-header.h"
+#include "ns3/address-utils.h"
 
 namespace ns3 {
 
@@ -32,8 +33,6 @@
   : m_sourcePort (0xfffd),
     m_destinationPort (0xfffd),
     m_payloadSize (0xfffd),
-    m_initialChecksum (0),
-    m_checksum(0),
     m_calcChecksum(false),
     m_goodChecksum(true)
 {}
@@ -71,35 +70,31 @@
   return m_destinationPort;
 }
 void 
-UdpHeader::SetPayloadSize (uint16_t size)
-{
-  m_payloadSize = size;
-}
-void 
 UdpHeader::InitializeChecksum (Ipv4Address source, 
                               Ipv4Address destination,
                               uint8_t protocol)
 {
-  Buffer buf = Buffer(12);
-  uint8_t tmp[4];
-  Buffer::Iterator it;
-  uint16_t udpLength = m_payloadSize + GetSerializedSize();
-
-  buf.AddAtStart(12);
-  it = buf.Begin();
+  m_source = source;
+  m_destination = destination;
+  m_protocol = protocol;
+}
+uint16_t
+UdpHeader::CalculateHeaderChecksum (uint16_t size) const
+{
+  Buffer buf = Buffer (12);
+  buf.AddAtStart (12);
+  Buffer::Iterator it = buf.Begin ();
 
-  source.Serialize(tmp);
-  it.Write(tmp, 4); /* source IP address */
-  destination.Serialize(tmp);
-  it.Write(tmp, 4); /* destination IP address */
-  it.WriteU8(0); /* protocol */
-  it.WriteU8(protocol); /* protocol */
-  it.WriteU8(udpLength >> 8); /* length */
-  it.WriteU8(udpLength & 0xff); /* length */
+  WriteTo (it, m_source);
+  WriteTo (it, m_destination);
+  it.WriteU8 (0); /* protocol */
+  it.WriteU8 (m_protocol); /* protocol */
+  it.WriteU8 (size >> 8); /* length */
+  it.WriteU8 (size & 0xff); /* length */
 
-  it = buf.Begin();
+  it = buf.Begin ();
   /* we don't CompleteChecksum ( ~ ) now */
-  m_initialChecksum = ~(it.CalculateIpChecksum(12));
+  return ~(it.CalculateIpChecksum (12));
 }
 
 bool
@@ -142,17 +137,17 @@
 UdpHeader::Serialize (Buffer::Iterator start) const
 {
   Buffer::Iterator i = start;
-  uint16_t udpLength = m_payloadSize + GetSerializedSize();
 
   i.WriteHtonU16 (m_sourcePort);
   i.WriteHtonU16 (m_destinationPort);
-  i.WriteHtonU16 (udpLength);
+  i.WriteHtonU16 (start.GetSize ());
   i.WriteU16 (0);
 
   if (m_calcChecksum)
     {
+      uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ());
       i = start;
-      uint16_t checksum = i.CalculateIpChecksum(udpLength, m_initialChecksum);
+      uint16_t checksum = i.CalculateIpChecksum (start.GetSize (), headerChecksum);
 
       i = start;
       i.Next(6);
@@ -166,12 +161,13 @@
   m_sourcePort = i.ReadNtohU16 ();
   m_destinationPort = i.ReadNtohU16 ();
   m_payloadSize = i.ReadNtohU16 () - GetSerializedSize ();
-  m_checksum = i.ReadU16();
+  i.Next (2);
 
   if(m_calcChecksum)
   {
+      uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ());
       i = start;
-      uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
+      uint16_t checksum = i.CalculateIpChecksum (start.GetSize (), headerChecksum);
 
       m_goodChecksum = (checksum == 0);
   }
--- a/src/internet-stack/udp-header.h	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/internet-stack/udp-header.h	Thu Jul 10 15:58:24 2008 -0700
@@ -66,10 +66,6 @@
    * \return the destination port for this UdpHeader
    */
   uint16_t GetDestinationPort (void) const;
-  /**
-   * \param size The payload size in bytes
-   */
-  void SetPayloadSize (uint16_t size);
 
   /**
    * \param source the ip source to use in the underlying
@@ -100,12 +96,14 @@
   bool IsChecksumOk (void) const;
 
 private:
+  uint16_t CalculateHeaderChecksum (uint16_t size) const;
   uint16_t m_sourcePort;
   uint16_t m_destinationPort;
   uint16_t m_payloadSize;
-  uint16_t m_initialChecksum;
-  uint16_t m_checksum;
 
+  Ipv4Address m_source;
+  Ipv4Address m_destination;
+  uint8_t m_protocol;
   bool m_calcChecksum;
   bool m_goodChecksum;
 };
--- a/src/internet-stack/udp-l4-protocol.cc	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/internet-stack/udp-l4-protocol.cc	Thu Jul 10 15:58:24 2008 -0700
@@ -163,7 +163,6 @@
     udpHeader.EnableChecksums();
   }
 
-  udpHeader.SetPayloadSize (packet->GetSize () - udpHeader.GetSerializedSize ());
   udpHeader.InitializeChecksum (source, destination, PROT_NUMBER);
 
   packet->RemoveHeader (udpHeader);
@@ -195,13 +194,12 @@
   if(m_calcChecksum)
   {
     udpHeader.EnableChecksums();
+    udpHeader.InitializeChecksum (saddr,
+                                  daddr,
+                                  PROT_NUMBER);
   }
   udpHeader.SetDestinationPort (dport);
   udpHeader.SetSourcePort (sport);
-  udpHeader.SetPayloadSize (packet->GetSize ());
-  udpHeader.InitializeChecksum (saddr,
-                                daddr,
-                                PROT_NUMBER);
 
   packet->AddHeader (udpHeader);