src/internet-stack/udp-header.cc
changeset 3404 b5d4a04c7b68
parent 3363 33d1ca2e4ba4
--- a/src/internet-stack/udp-header.cc	Wed Jul 09 20:12:05 2008 -0700
+++ b/src/internet-stack/udp-header.cc	Thu Jul 10 15:58:24 2008 -0700
@@ -19,6 +19,7 @@
  */
 
 #include "udp-header.h"
+#include "ns3/address-utils.h"
 
 namespace ns3 {
 
@@ -32,8 +33,6 @@
   : m_sourcePort (0xfffd),
     m_destinationPort (0xfffd),
     m_payloadSize (0xfffd),
-    m_initialChecksum (0),
-    m_checksum(0),
     m_calcChecksum(false),
     m_goodChecksum(true)
 {}
@@ -71,35 +70,31 @@
   return m_destinationPort;
 }
 void 
-UdpHeader::SetPayloadSize (uint16_t size)
-{
-  m_payloadSize = size;
-}
-void 
 UdpHeader::InitializeChecksum (Ipv4Address source, 
                               Ipv4Address destination,
                               uint8_t protocol)
 {
-  Buffer buf = Buffer(12);
-  uint8_t tmp[4];
-  Buffer::Iterator it;
-  uint16_t udpLength = m_payloadSize + GetSerializedSize();
-
-  buf.AddAtStart(12);
-  it = buf.Begin();
+  m_source = source;
+  m_destination = destination;
+  m_protocol = protocol;
+}
+uint16_t
+UdpHeader::CalculateHeaderChecksum (uint16_t size) const
+{
+  Buffer buf = Buffer (12);
+  buf.AddAtStart (12);
+  Buffer::Iterator it = buf.Begin ();
 
-  source.Serialize(tmp);
-  it.Write(tmp, 4); /* source IP address */
-  destination.Serialize(tmp);
-  it.Write(tmp, 4); /* destination IP address */
-  it.WriteU8(0); /* protocol */
-  it.WriteU8(protocol); /* protocol */
-  it.WriteU8(udpLength >> 8); /* length */
-  it.WriteU8(udpLength & 0xff); /* length */
+  WriteTo (it, m_source);
+  WriteTo (it, m_destination);
+  it.WriteU8 (0); /* protocol */
+  it.WriteU8 (m_protocol); /* protocol */
+  it.WriteU8 (size >> 8); /* length */
+  it.WriteU8 (size & 0xff); /* length */
 
-  it = buf.Begin();
+  it = buf.Begin ();
   /* we don't CompleteChecksum ( ~ ) now */
-  m_initialChecksum = ~(it.CalculateIpChecksum(12));
+  return ~(it.CalculateIpChecksum (12));
 }
 
 bool
@@ -142,17 +137,17 @@
 UdpHeader::Serialize (Buffer::Iterator start) const
 {
   Buffer::Iterator i = start;
-  uint16_t udpLength = m_payloadSize + GetSerializedSize();
 
   i.WriteHtonU16 (m_sourcePort);
   i.WriteHtonU16 (m_destinationPort);
-  i.WriteHtonU16 (udpLength);
+  i.WriteHtonU16 (start.GetSize ());
   i.WriteU16 (0);
 
   if (m_calcChecksum)
     {
+      uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ());
       i = start;
-      uint16_t checksum = i.CalculateIpChecksum(udpLength, m_initialChecksum);
+      uint16_t checksum = i.CalculateIpChecksum (start.GetSize (), headerChecksum);
 
       i = start;
       i.Next(6);
@@ -166,12 +161,13 @@
   m_sourcePort = i.ReadNtohU16 ();
   m_destinationPort = i.ReadNtohU16 ();
   m_payloadSize = i.ReadNtohU16 () - GetSerializedSize ();
-  m_checksum = i.ReadU16();
+  i.Next (2);
 
   if(m_calcChecksum)
   {
+      uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ());
       i = start;
-      uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
+      uint16_t checksum = i.CalculateIpChecksum (start.GetSize (), headerChecksum);
 
       m_goodChecksum = (checksum == 0);
   }