bug 236: add optional support for tcp and udp checksum.
authorSebastien Vincent <vincent@clarinet.u-strasbg.fr>
Tue Jul 01 10:52:11 2008 -0700 (19 months ago)
changeset 336333d1ca2e4ba4
parent 3362 9a6f1b3c6e0b
child 3364 8e6ac6061680
bug 236: add optional support for tcp and udp checksum.
src/common/buffer.cc
src/common/buffer.h
src/internet-stack/tcp-header.cc
src/internet-stack/tcp-header.h
src/internet-stack/tcp-l4-protocol.cc
src/internet-stack/tcp-l4-protocol.h
src/internet-stack/udp-header.cc
src/internet-stack/udp-header.h
src/internet-stack/udp-l4-protocol.cc
src/internet-stack/udp-l4-protocol.h
src/node/ipv4-header.cc
src/node/ipv4-header.h
     1.1 --- a/src/common/buffer.cc	Mon Jun 30 22:41:22 2008 -0700
     1.2 +++ b/src/common/buffer.cc	Tue Jul 01 10:52:11 2008 -0700
     1.3 @@ -1105,6 +1105,29 @@
     1.4  
     1.5  #endif /* BUFFER_USE_INLINE */
     1.6  
     1.7 +uint16_t
     1.8 +Buffer::Iterator::CalculateIpChecksum(uint16_t size)
     1.9 +{
    1.10 +  return CalculateIpChecksum(size, 0);
    1.11 +}
    1.12 +
    1.13 +uint16_t
    1.14 +Buffer::Iterator::CalculateIpChecksum(uint16_t size, uint32_t initialChecksum)
    1.15 +{
    1.16 +  /* see RFC 1071 to understand this code. */
    1.17 +  uint32_t sum = initialChecksum;
    1.18 +
    1.19 +  for (int j = 0; j < size/2; j++)
    1.20 +    sum += ReadU16 ();
    1.21 +
    1.22 +  if (size & 1)
    1.23 +     sum += ReadU8 ();
    1.24 +
    1.25 +  while (sum >> 16)
    1.26 +    sum = (sum & 0xffff) + (sum >> 16);
    1.27 +  return ~sum;
    1.28 +}
    1.29 +
    1.30  } // namespace ns3
    1.31  
    1.32  
     2.1 --- a/src/common/buffer.h	Mon Jun 30 22:41:22 2008 -0700
     2.2 +++ b/src/common/buffer.h	Tue Jul 01 10:52:11 2008 -0700
     2.3 @@ -342,6 +342,22 @@
     2.4         * bytes read.
     2.5         */
     2.6        void Read (uint8_t *buffer, uint32_t size);
     2.7 +
     2.8 +      /**
     2.9 +       * \brief Calculate the checksum.
    2.10 +       * \param size size of the buffer.
    2.11 +       * \return checksum
    2.12 +       */
    2.13 +      uint16_t CalculateIpChecksum(uint16_t size);
    2.14 +
    2.15 +      /**
    2.16 +       * \brief Calculate the checksum.
    2.17 +       * \param size size of the buffer.
    2.18 +       * \param initialChecksum initial value
    2.19 +       * \return checksum
    2.20 +       */
    2.21 +      uint16_t CalculateIpChecksum(uint16_t size, uint32_t initialChecksum);
    2.22 +
    2.23    private:
    2.24        friend class Buffer;
    2.25        Iterator (Buffer const*buffer);
     3.1 --- a/src/internet-stack/tcp-header.cc	Mon Jun 30 22:41:22 2008 -0700
     3.2 +++ b/src/internet-stack/tcp-header.cc	Tue Jul 01 10:52:11 2008 -0700
     3.3 @@ -28,8 +28,6 @@
     3.4  
     3.5  NS_OBJECT_ENSURE_REGISTERED (TcpHeader);
     3.6  
     3.7 -bool TcpHeader::m_calcChecksum = false;
     3.8 -
     3.9  TcpHeader::TcpHeader ()
    3.10    : m_sourcePort (0),
    3.11      m_destinationPort (0),
    3.12 @@ -38,8 +36,11 @@
    3.13      m_length (5),
    3.14      m_flags (0),
    3.15      m_windowSize (0xffff),
    3.16 +    m_urgentPointer (0),
    3.17 +    m_initialChecksum(0),
    3.18      m_checksum (0),
    3.19 -    m_urgentPointer (0)
    3.20 +    m_calcChecksum(false),
    3.21 +    m_goodChecksum(true)
    3.22  {}
    3.23  
    3.24  TcpHeader::~TcpHeader ()
    3.25 @@ -51,6 +52,11 @@
    3.26    m_calcChecksum = true;
    3.27  }
    3.28  
    3.29 +void TcpHeader::SetPayloadSize(uint16_t payloadSize)
    3.30 +{
    3.31 +  m_payloadSize = payloadSize;
    3.32 +}
    3.33 +
    3.34  void TcpHeader::SetSourcePort (uint16_t port)
    3.35  {
    3.36    m_sourcePort = port;
    3.37 @@ -130,8 +136,32 @@
    3.38                                     Ipv4Address destination,
    3.39                                     uint8_t protocol)
    3.40  {
    3.41 -  m_checksum = 0;
    3.42 -//XXX requires peeking into IP to get length of the TCP segment
    3.43 +  Buffer buf = Buffer(12);
    3.44 +  uint8_t tmp[4];
    3.45 +  Buffer::Iterator it;
    3.46 +  uint16_t tcpLength = m_payloadSize + GetSerializedSize();
    3.47 +
    3.48 +  buf.AddAtStart(12);
    3.49 +  it = buf.Begin();
    3.50 +
    3.51 +  source.Serialize(tmp);
    3.52 +  it.Write(tmp, 4); /* source IP address */
    3.53 +  destination.Serialize(tmp);
    3.54 +  it.Write(tmp, 4); /* destination IP address */
    3.55 +  it.WriteU8(0); /* protocol */
    3.56 +  it.WriteU8(protocol); /* protocol */
    3.57 +  it.WriteU8(tcpLength >> 8); /* length */
    3.58 +  it.WriteU8(tcpLength & 0xff); /* length */
    3.59 +
    3.60 +  it = buf.Begin();
    3.61 +  /* we don't CompleteChecksum ( ~ ) now */
    3.62 +  m_initialChecksum = ~(it.CalculateIpChecksum(12));
    3.63 +}
    3.64 +
    3.65 +bool
    3.66 +TcpHeader::IsChecksumOk (void) const
    3.67 +{
    3.68 +  return m_goodChecksum;
    3.69  }
    3.70  
    3.71  TypeId 
    3.72 @@ -188,28 +218,49 @@
    3.73  }
    3.74  void TcpHeader::Serialize (Buffer::Iterator start)  const
    3.75  {
    3.76 -  start.WriteHtonU16 (m_sourcePort);
    3.77 -  start.WriteHtonU16 (m_destinationPort);
    3.78 -  start.WriteHtonU32 (m_sequenceNumber);
    3.79 -  start.WriteHtonU32 (m_ackNumber);
    3.80 -  start.WriteHtonU16 (m_length << 12 | m_flags); //reserved bits are all zero
    3.81 -  start.WriteHtonU16 (m_windowSize);
    3.82 -  //XXX calculate checksum here
    3.83 -  start.WriteHtonU16 (m_checksum);
    3.84 -  start.WriteHtonU16 (m_urgentPointer);
    3.85 +  Buffer::Iterator i = start;
    3.86 +  uint16_t tcpLength = m_payloadSize + GetSerializedSize();
    3.87 +  i.WriteHtonU16 (m_sourcePort);
    3.88 +  i.WriteHtonU16 (m_destinationPort);
    3.89 +  i.WriteHtonU32 (m_sequenceNumber);
    3.90 +  i.WriteHtonU32 (m_ackNumber);
    3.91 +  i.WriteHtonU16 (m_length << 12 | m_flags); //reserved bits are all zero
    3.92 +  i.WriteHtonU16 (m_windowSize);
    3.93 +  i.WriteHtonU16 (0);
    3.94 +  i.WriteHtonU16 (m_urgentPointer);
    3.95 +
    3.96 +  if(m_calcChecksum)
    3.97 +  {
    3.98 +    i = start;
    3.99 +    uint16_t checksum = i.CalculateIpChecksum(tcpLength, m_initialChecksum);
   3.100 +    
   3.101 +    i = start;
   3.102 +    i.Next(16);
   3.103 +    i.WriteU16(checksum);
   3.104 +  }
   3.105  }
   3.106  uint32_t TcpHeader::Deserialize (Buffer::Iterator start)
   3.107  {
   3.108 -  m_sourcePort = start.ReadNtohU16 ();
   3.109 -  m_destinationPort = start.ReadNtohU16 ();
   3.110 -  m_sequenceNumber = start.ReadNtohU32 ();
   3.111 -  m_ackNumber = start.ReadNtohU32 ();
   3.112 -  uint16_t field = start.ReadNtohU16 ();
   3.113 +  Buffer::Iterator i = start;
   3.114 +  m_sourcePort = i.ReadNtohU16 ();
   3.115 +  m_destinationPort = i.ReadNtohU16 ();
   3.116 +  m_sequenceNumber = i.ReadNtohU32 ();
   3.117 +  m_ackNumber = i.ReadNtohU32 ();
   3.118 +  uint16_t field = i.ReadNtohU16 ();
   3.119    m_flags = field & 0x3F;
   3.120    m_length = field>>12;
   3.121 -  m_windowSize = start.ReadNtohU16 ();
   3.122 -  m_checksum = start.ReadNtohU16 ();
   3.123 -  m_urgentPointer = start.ReadNtohU16 ();
   3.124 +  m_windowSize = i.ReadNtohU16 ();
   3.125 +  m_checksum = i.ReadU16 ();
   3.126 +  m_urgentPointer = i.ReadNtohU16 ();
   3.127 +
   3.128 +  if(m_calcChecksum)
   3.129 +  {
   3.130 +      i = start;
   3.131 +      uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
   3.132 +
   3.133 +      m_goodChecksum = (checksum == 0);
   3.134 +  }
   3.135 +
   3.136    return GetSerializedSize ();
   3.137  }
   3.138  
     4.1 --- a/src/internet-stack/tcp-header.h	Mon Jun 30 22:41:22 2008 -0700
     4.2 +++ b/src/internet-stack/tcp-header.h	Tue Jul 01 10:52:11 2008 -0700
     4.3 @@ -47,7 +47,7 @@
     4.4    /**
     4.5     * \brief Enable checksum calculation for TCP (XXX currently has no effect)
     4.6     */
     4.7 -  static void EnableChecksums (void);
     4.8 +  void EnableChecksums (void);
     4.9  //Setters
    4.10    /**
    4.11     * \param port The source port for this TcpHeader
    4.12 @@ -150,6 +150,17 @@
    4.13    virtual void Serialize (Buffer::Iterator start) const;
    4.14    virtual uint32_t Deserialize (Buffer::Iterator start);
    4.15  
    4.16 +  /**
    4.17 +   * \param size The payload size in bytes
    4.18 +   */
    4.19 +  void SetPayloadSize (uint16_t size);
    4.20 +
    4.21 +  /**
    4.22 +   * \brief Is the TCP checksum correct ?
    4.23 +   * \returns true if the checksum is correct, false otherwise.
    4.24 +   */
    4.25 +  bool IsChecksumOk (void) const;
    4.26 +
    4.27  private:
    4.28    uint16_t m_sourcePort;
    4.29    uint16_t m_destinationPort;
    4.30 @@ -158,10 +169,14 @@
    4.31    uint8_t m_length; // really a uint4_t
    4.32    uint8_t m_flags;      // really a uint6_t
    4.33    uint16_t m_windowSize;
    4.34 +  uint16_t m_urgentPointer;
    4.35 +  uint16_t m_payloadSize;
    4.36 +  uint16_t m_initialChecksum;
    4.37    uint16_t m_checksum;
    4.38 -  uint16_t m_urgentPointer;
    4.39  
    4.40 -  static bool m_calcChecksum;
    4.41 +  bool m_calcChecksum;
    4.42 +  bool m_goodChecksum;
    4.43 +
    4.44  };
    4.45  
    4.46  }; // namespace ns3
     5.1 --- a/src/internet-stack/tcp-l4-protocol.cc	Mon Jun 30 22:41:22 2008 -0700
     5.2 +++ b/src/internet-stack/tcp-l4-protocol.cc	Tue Jul 01 10:52:11 2008 -0700
     5.3 @@ -21,6 +21,7 @@
     5.4  #include "ns3/assert.h"
     5.5  #include "ns3/log.h"
     5.6  #include "ns3/nstime.h"
     5.7 +#include "ns3/boolean.h"
     5.8  
     5.9  #include "ns3/packet.h"
    5.10  #include "ns3/node.h"
    5.11 @@ -328,6 +329,11 @@
    5.12                     ObjectFactoryValue (GetDefaultRttEstimatorFactory ()),
    5.13                     MakeObjectFactoryAccessor (&TcpL4Protocol::m_rttFactory),
    5.14                     MakeObjectFactoryChecker ())
    5.15 +    .AddAttribute ("CalcChecksum", "If true, we calculate the checksum of outgoing packets"
    5.16 +                   " and verify the checksum of incoming packets.",
    5.17 +                   BooleanValue (false),
    5.18 +                   MakeBooleanAccessor (&TcpL4Protocol::m_calcChecksum),
    5.19 +                   MakeBooleanChecker ())
    5.20      ;
    5.21    return tid;
    5.22  }
    5.23 @@ -439,14 +445,31 @@
    5.24    NS_LOG_FUNCTION (this << packet << source << destination << incomingInterface);
    5.25  
    5.26    TcpHeader tcpHeader;
    5.27 +  if(m_calcChecksum)
    5.28 +  {
    5.29 +    tcpHeader.EnableChecksums();
    5.30 +  }
    5.31 +  /* XXX very dirty but needs this to AddHeader again because of checksum */
    5.32 +  tcpHeader.SetLength(5); /* XXX TCP without options */
    5.33 +  tcpHeader.SetPayloadSize(packet->GetSize() - tcpHeader.GetSerializedSize());
    5.34 +  tcpHeader.InitializeChecksum(source, destination, PROT_NUMBER);
    5.35 +
    5.36    //these two do a peek, so that the packet can be forwarded up
    5.37    packet->RemoveHeader (tcpHeader);
    5.38 +
    5.39    NS_LOG_LOGIC("TcpL4Protocol " << this
    5.40                 << " receiving seq " << tcpHeader.GetSequenceNumber()
    5.41                 << " ack " << tcpHeader.GetAckNumber()
    5.42                 << " flags "<< std::hex << (int)tcpHeader.GetFlags() << std::dec
    5.43                 << " data size " << packet->GetSize());
    5.44 -  packet->AddHeader (tcpHeader); 
    5.45 +
    5.46 +  if(!tcpHeader.IsChecksumOk ())
    5.47 +  {
    5.48 +    NS_LOG_INFO("Bad checksum, dropping packet!");
    5.49 +    return;
    5.50 +  }
    5.51 +
    5.52 +  packet->AddHeader (tcpHeader);
    5.53    NS_LOG_LOGIC ("TcpL4Protocol "<<this<<" received a packet");
    5.54    Ipv4EndPointDemux::EndPoints endPoints =
    5.55      m_endPoints->Lookup (destination, tcpHeader.GetDestinationPort (),
    5.56 @@ -478,6 +501,11 @@
    5.57    TcpHeader tcpHeader;
    5.58    tcpHeader.SetDestinationPort (dport);
    5.59    tcpHeader.SetSourcePort (sport);
    5.60 +  tcpHeader.SetPayloadSize(packet->GetSize());
    5.61 +  if(m_calcChecksum)
    5.62 +  {
    5.63 +    tcpHeader.EnableChecksums();
    5.64 +  }
    5.65    tcpHeader.InitializeChecksum (saddr,
    5.66                                 daddr,
    5.67                                 PROT_NUMBER);
    5.68 @@ -507,8 +535,13 @@
    5.69    // XXX outgoingHeader cannot be logged
    5.70  
    5.71    outgoingHeader.SetLength (5); //header length in units of 32bit words
    5.72 -  outgoingHeader.SetChecksum (0);  //XXX
    5.73 -  outgoingHeader.SetUrgentPointer (0); //XXX
    5.74 +  outgoingHeader.SetPayloadSize(packet->GetSize());
    5.75 +  /* outgoingHeader.SetUrgentPointer (0); //XXX */
    5.76 +  if(m_calcChecksum)
    5.77 +  {
    5.78 +    outgoingHeader.EnableChecksums();
    5.79 +  }
    5.80 +  outgoingHeader.InitializeChecksum(saddr, daddr, PROT_NUMBER);
    5.81  
    5.82    packet->AddHeader (outgoingHeader);
    5.83  
     6.1 --- a/src/internet-stack/tcp-l4-protocol.h	Mon Jun 30 22:41:22 2008 -0700
     6.2 +++ b/src/internet-stack/tcp-l4-protocol.h	Tue Jul 01 10:52:11 2008 -0700
     6.3 @@ -117,6 +117,9 @@
     6.4    void SendPacket (Ptr<Packet>, TcpHeader,
     6.5                    Ipv4Address, Ipv4Address);
     6.6    static ObjectFactory GetDefaultRttEstimatorFactory (void);
     6.7 +
     6.8 +  bool m_goodChecksum;
     6.9 +  bool m_calcChecksum;
    6.10  };
    6.11  
    6.12  }; // namespace ns3
     7.1 --- a/src/internet-stack/udp-header.cc	Mon Jun 30 22:41:22 2008 -0700
     7.2 +++ b/src/internet-stack/udp-header.cc	Tue Jul 01 10:52:11 2008 -0700
     7.3 @@ -19,14 +19,11 @@
     7.4   */
     7.5  
     7.6  #include "udp-header.h"
     7.7 -#include "ipv4-checksum.h"
     7.8  
     7.9  namespace ns3 {
    7.10  
    7.11  NS_OBJECT_ENSURE_REGISTERED (UdpHeader);
    7.12  
    7.13 -bool UdpHeader::m_calcChecksum = false;
    7.14 -
    7.15  /* The magic values below are used only for debugging.
    7.16   * They can be used to easily detect memory corruption
    7.17   * problems so you can see the patterns in memory.
    7.18 @@ -35,7 +32,10 @@
    7.19    : m_sourcePort (0xfffd),
    7.20      m_destinationPort (0xfffd),
    7.21      m_payloadSize (0xfffd),
    7.22 -    m_initialChecksum (0)
    7.23 +    m_initialChecksum (0),
    7.24 +    m_checksum(0),
    7.25 +    m_calcChecksum(false),
    7.26 +    m_goodChecksum(true)
    7.27  {}
    7.28  UdpHeader::~UdpHeader ()
    7.29  {
    7.30 @@ -80,18 +80,35 @@
    7.31                                Ipv4Address destination,
    7.32                                uint8_t protocol)
    7.33  {
    7.34 -  uint8_t buf[12];
    7.35 -  source.Serialize (buf);
    7.36 -  destination.Serialize (buf+4);
    7.37 -  buf[8] = 0;
    7.38 -  buf[9] = protocol;
    7.39 -  uint16_t udpLength = m_payloadSize + GetSerializedSize ();
    7.40 -  buf[10] = udpLength >> 8;
    7.41 -  buf[11] = udpLength & 0xff;
    7.42 +  Buffer buf = Buffer(12);
    7.43 +  uint8_t tmp[4];
    7.44 +  Buffer::Iterator it;
    7.45 +  uint16_t udpLength = m_payloadSize + GetSerializedSize();
    7.46  
    7.47 -  m_initialChecksum = Ipv4ChecksumCalculate (0, buf, 12);
    7.48 +  buf.AddAtStart(12);
    7.49 +  it = buf.Begin();
    7.50 +
    7.51 +  source.Serialize(tmp);
    7.52 +  it.Write(tmp, 4); /* source IP address */
    7.53 +  destination.Serialize(tmp);
    7.54 +  it.Write(tmp, 4); /* destination IP address */
    7.55 +  it.WriteU8(0); /* protocol */
    7.56 +  it.WriteU8(protocol); /* protocol */
    7.57 +  it.WriteU8(udpLength >> 8); /* length */
    7.58 +  it.WriteU8(udpLength & 0xff); /* length */
    7.59 +
    7.60 +  it = buf.Begin();
    7.61 +  /* we don't CompleteChecksum ( ~ ) now */
    7.62 +  m_initialChecksum = ~(it.CalculateIpChecksum(12));
    7.63  }
    7.64  
    7.65 +bool
    7.66 +UdpHeader::IsChecksumOk (void) const
    7.67 +{
    7.68 +  return m_goodChecksum; 
    7.69 +}
    7.70 +
    7.71 +
    7.72  TypeId 
    7.73  UdpHeader::GetTypeId (void)
    7.74  {
    7.75 @@ -125,23 +142,21 @@
    7.76  UdpHeader::Serialize (Buffer::Iterator start) const
    7.77  {
    7.78    Buffer::Iterator i = start;
    7.79 +  uint16_t udpLength = m_payloadSize + GetSerializedSize();
    7.80 +
    7.81    i.WriteHtonU16 (m_sourcePort);
    7.82    i.WriteHtonU16 (m_destinationPort);
    7.83 -  i.WriteHtonU16 (m_payloadSize + GetSerializedSize ());
    7.84 +  i.WriteHtonU16 (udpLength);
    7.85    i.WriteU16 (0);
    7.86  
    7.87 -  if (m_calcChecksum) 
    7.88 +  if (m_calcChecksum)
    7.89      {
    7.90 -#if 0
    7.91 -      //XXXX
    7.92 -      uint16_t checksum = Ipv4ChecksumCalculate (m_initialChecksum, 
    7.93 -                                                  buffer->PeekData (), 
    7.94 -                                                  GetSerializedSize () + m_payloadSize);
    7.95 -      checksum = Ipv4ChecksumComplete (checksum);
    7.96 -      i = buffer->Begin ();
    7.97 -      i.Next (6);
    7.98 -      i.WriteU16 (checksum);
    7.99 -#endif
   7.100 +      i = start;
   7.101 +      uint16_t checksum = i.CalculateIpChecksum(udpLength, m_initialChecksum);
   7.102 +
   7.103 +      i = start;
   7.104 +      i.Next(6);
   7.105 +      i.WriteU16(checksum);
   7.106      }
   7.107  }
   7.108  uint32_t
   7.109 @@ -151,10 +166,16 @@
   7.110    m_sourcePort = i.ReadNtohU16 ();
   7.111    m_destinationPort = i.ReadNtohU16 ();
   7.112    m_payloadSize = i.ReadNtohU16 () - GetSerializedSize ();
   7.113 -  if (m_calcChecksum) 
   7.114 -    {
   7.115 -      // XXX verify checksum.
   7.116 -    }
   7.117 +  m_checksum = i.ReadU16();
   7.118 +
   7.119 +  if(m_calcChecksum)
   7.120 +  {
   7.121 +      i = start;
   7.122 +      uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
   7.123 +
   7.124 +      m_goodChecksum = (checksum == 0);
   7.125 +  }
   7.126 +
   7.127    return GetSerializedSize ();
   7.128  }
   7.129  
     8.1 --- a/src/internet-stack/udp-header.h	Mon Jun 30 22:41:22 2008 -0700
     8.2 +++ b/src/internet-stack/udp-header.h	Tue Jul 01 10:52:11 2008 -0700
     8.3 @@ -49,7 +49,7 @@
     8.4    /**
     8.5     * \brief Enable checksum calculation for UDP (XXX currently has no effect)
     8.6     */
     8.7 -  static void EnableChecksums (void);
     8.8 +  void EnableChecksums (void);
     8.9    /**
    8.10     * \param port the destination port for this UdpHeader
    8.11     */
    8.12 @@ -93,13 +93,21 @@
    8.13    virtual void Serialize (Buffer::Iterator start) const;
    8.14    virtual uint32_t Deserialize (Buffer::Iterator start);
    8.15  
    8.16 +  /**
    8.17 +   * \brief Is the UDP checksum correct ?
    8.18 +   * \returns true if the checksum is correct, false otherwise.
    8.19 +   */
    8.20 +  bool IsChecksumOk (void) const;
    8.21 +
    8.22  private:
    8.23    uint16_t m_sourcePort;
    8.24    uint16_t m_destinationPort;
    8.25    uint16_t m_payloadSize;
    8.26    uint16_t m_initialChecksum;
    8.27 +  uint16_t m_checksum;
    8.28  
    8.29 -  static bool m_calcChecksum;
    8.30 +  bool m_calcChecksum;
    8.31 +  bool m_goodChecksum;
    8.32  };
    8.33  
    8.34  } // namespace ns3
     9.1 --- a/src/internet-stack/udp-l4-protocol.cc	Mon Jun 30 22:41:22 2008 -0700
     9.2 +++ b/src/internet-stack/udp-l4-protocol.cc	Tue Jul 01 10:52:11 2008 -0700
     9.3 @@ -22,6 +22,7 @@
     9.4  #include "ns3/assert.h"
     9.5  #include "ns3/packet.h"
     9.6  #include "ns3/node.h"
     9.7 +#include "ns3/boolean.h"
     9.8  
     9.9  #include "udp-l4-protocol.h"
    9.10  #include "udp-header.h"
    9.11 @@ -45,6 +46,11 @@
    9.12    static TypeId tid = TypeId ("ns3::UdpL4Protocol")
    9.13      .SetParent<Ipv4L4Protocol> ()
    9.14      .AddConstructor<UdpL4Protocol> ()
    9.15 +    .AddAttribute ("CalcChecksum", "If true, we calculate the checksum of outgoing packets"
    9.16 +                   " and verify the checksum of incoming packets.",
    9.17 +                   BooleanValue (false),
    9.18 +                   MakeBooleanAccessor (&UdpL4Protocol::m_calcChecksum),
    9.19 +                   MakeBooleanChecker ())
    9.20      ;
    9.21    return tid;
    9.22  }
    9.23 @@ -151,9 +157,23 @@
    9.24                         Ptr<Ipv4Interface> interface)
    9.25  {
    9.26    NS_LOG_FUNCTION (this << packet << source << destination);
    9.27 +  UdpHeader udpHeader;
    9.28 +  if(m_calcChecksum)
    9.29 +  {
    9.30 +    udpHeader.EnableChecksums();
    9.31 +  }
    9.32  
    9.33 -  UdpHeader udpHeader;
    9.34 +  udpHeader.SetPayloadSize (packet->GetSize () - udpHeader.GetSerializedSize ());
    9.35 +  udpHeader.InitializeChecksum (source, destination, PROT_NUMBER);
    9.36 +
    9.37    packet->RemoveHeader (udpHeader);
    9.38 +
    9.39 +  if(!udpHeader.IsChecksumOk ())
    9.40 +  {
    9.41 +    NS_LOG_INFO("Bad checksum : dropping packet!");
    9.42 +    return;
    9.43 +  }
    9.44 +
    9.45    Ipv4EndPointDemux::EndPoints endPoints =
    9.46      m_endPoints->Lookup (destination, udpHeader.GetDestinationPort (),
    9.47                           source, udpHeader.GetSourcePort (), interface);
    9.48 @@ -172,6 +192,10 @@
    9.49    NS_LOG_FUNCTION (this << packet << saddr << daddr << sport << dport);
    9.50  
    9.51    UdpHeader udpHeader;
    9.52 +  if(m_calcChecksum)
    9.53 +  {
    9.54 +    udpHeader.EnableChecksums();
    9.55 +  }
    9.56    udpHeader.SetDestinationPort (dport);
    9.57    udpHeader.SetSourcePort (sport);
    9.58    udpHeader.SetPayloadSize (packet->GetSize ());
    10.1 --- a/src/internet-stack/udp-l4-protocol.h	Mon Jun 30 22:41:22 2008 -0700
    10.2 +++ b/src/internet-stack/udp-l4-protocol.h	Tue Jul 01 10:52:11 2008 -0700
    10.3 @@ -93,6 +93,7 @@
    10.4  private:
    10.5    Ptr<Node> m_node;
    10.6    Ipv4EndPointDemux *m_endPoints;
    10.7 +  bool m_calcChecksum;
    10.8  };
    10.9  
   10.10  }; // namespace ns3
    11.1 --- a/src/node/ipv4-header.cc	Mon Jun 30 22:41:22 2008 -0700
    11.2 +++ b/src/node/ipv4-header.cc	Tue Jul 01 10:52:11 2008 -0700
    11.3 @@ -38,6 +38,7 @@
    11.4      m_protocol (0),
    11.5      m_flags (0),
    11.6      m_fragmentOffset (0),
    11.7 +    m_checksum(0),
    11.8      m_goodChecksum (true)
    11.9  {}
   11.10  
   11.11 @@ -177,23 +178,6 @@
   11.12    return m_goodChecksum;
   11.13  }
   11.14  
   11.15 -uint16_t
   11.16 -Ipv4Header::ChecksumCalculate(Buffer::Iterator &i, uint16_t size)
   11.17 -{
   11.18 -  /* see RFC 1071 to understand this code. */
   11.19 -  uint32_t sum = 0;
   11.20 -
   11.21 -  for (int j = 0; j < size/2; j++)
   11.22 -    sum += i.ReadU16 ();
   11.23 -
   11.24 -  if (size & 1)
   11.25 -     sum += i.ReadU8 ();
   11.26 -
   11.27 -  while (sum >> 16)
   11.28 -    sum = (sum & 0xffff) + (sum >> 16);
   11.29 -  return ~sum;
   11.30 -}
   11.31 -
   11.32  TypeId 
   11.33  Ipv4Header::GetTypeId (void)
   11.34  {
   11.35 @@ -282,7 +266,7 @@
   11.36    if (m_calcChecksum) 
   11.37      {
   11.38        i = start;
   11.39 -      uint16_t checksum = ChecksumCalculate(i, 20);
   11.40 +      uint16_t checksum = i.CalculateIpChecksum(20);
   11.41        NS_LOG_LOGIC ("checksum=" <<checksum);
   11.42        i = start;
   11.43        i.Next (10);
   11.44 @@ -318,14 +302,15 @@
   11.45    m_fragmentOffset <<= 3;
   11.46    m_ttl = i.ReadU8 ();
   11.47    m_protocol = i.ReadU8 ();
   11.48 -  i.Next (2); // checksum
   11.49 +  m_checksum = i.ReadU16();
   11.50 +  /* i.Next (2); // checksum */
   11.51    m_source.Set (i.ReadNtohU32 ());
   11.52    m_destination.Set (i.ReadNtohU32 ());
   11.53  
   11.54    if (m_calcChecksum) 
   11.55      {
   11.56        i = start;
   11.57 -      uint16_t checksum = ChecksumCalculate(i, headerSize);
   11.58 +      uint16_t checksum = i.CalculateIpChecksum(headerSize);
   11.59        NS_LOG_LOGIC ("checksum=" <<checksum);
   11.60  
   11.61        m_goodChecksum = (checksum == 0);
    12.1 --- a/src/node/ipv4-header.h	Mon Jun 30 22:41:22 2008 -0700
    12.2 +++ b/src/node/ipv4-header.h	Tue Jul 01 10:52:11 2008 -0700
    12.3 @@ -146,7 +146,6 @@
    12.4    virtual uint32_t Deserialize (Buffer::Iterator start);
    12.5  private:
    12.6  
    12.7 -  static uint16_t ChecksumCalculate(Buffer::Iterator &i, uint16_t len);
    12.8    enum FlagsE {
    12.9      DONT_FRAGMENT = (1<<0),
   12.10      MORE_FRAGMENTS = (1<<1)
   12.11 @@ -163,6 +162,7 @@
   12.12    uint16_t m_fragmentOffset : 13;
   12.13    Ipv4Address m_source;
   12.14    Ipv4Address m_destination;
   12.15 +  uint16_t m_checksum;
   12.16    bool m_goodChecksum;
   12.17  };
   12.18