bug 236: add optional support for tcp and udp checksum.
--- a/src/common/buffer.cc Mon Jun 30 22:41:22 2008 -0700
+++ b/src/common/buffer.cc Tue Jul 01 10:52:11 2008 -0700
@@ -1105,6 +1105,29 @@
#endif /* BUFFER_USE_INLINE */
+uint16_t
+Buffer::Iterator::CalculateIpChecksum(uint16_t size)
+{
+ return CalculateIpChecksum(size, 0);
+}
+
+uint16_t
+Buffer::Iterator::CalculateIpChecksum(uint16_t size, uint32_t initialChecksum)
+{
+ /* see RFC 1071 to understand this code. */
+ uint32_t sum = initialChecksum;
+
+ for (int j = 0; j < size/2; j++)
+ sum += ReadU16 ();
+
+ if (size & 1)
+ sum += ReadU8 ();
+
+ while (sum >> 16)
+ sum = (sum & 0xffff) + (sum >> 16);
+ return ~sum;
+}
+
} // namespace ns3
--- a/src/common/buffer.h Mon Jun 30 22:41:22 2008 -0700
+++ b/src/common/buffer.h Tue Jul 01 10:52:11 2008 -0700
@@ -342,6 +342,22 @@
* bytes read.
*/
void Read (uint8_t *buffer, uint32_t size);
+
+ /**
+ * \brief Calculate the checksum.
+ * \param size size of the buffer.
+ * \return checksum
+ */
+ uint16_t CalculateIpChecksum(uint16_t size);
+
+ /**
+ * \brief Calculate the checksum.
+ * \param size size of the buffer.
+ * \param initialChecksum initial value
+ * \return checksum
+ */
+ uint16_t CalculateIpChecksum(uint16_t size, uint32_t initialChecksum);
+
private:
friend class Buffer;
Iterator (Buffer const*buffer);
--- a/src/internet-stack/tcp-header.cc Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-header.cc Tue Jul 01 10:52:11 2008 -0700
@@ -28,8 +28,6 @@
NS_OBJECT_ENSURE_REGISTERED (TcpHeader);
-bool TcpHeader::m_calcChecksum = false;
-
TcpHeader::TcpHeader ()
: m_sourcePort (0),
m_destinationPort (0),
@@ -38,8 +36,11 @@
m_length (5),
m_flags (0),
m_windowSize (0xffff),
+ m_urgentPointer (0),
+ m_initialChecksum(0),
m_checksum (0),
- m_urgentPointer (0)
+ m_calcChecksum(false),
+ m_goodChecksum(true)
{}
TcpHeader::~TcpHeader ()
@@ -51,6 +52,11 @@
m_calcChecksum = true;
}
+void TcpHeader::SetPayloadSize(uint16_t payloadSize)
+{
+ m_payloadSize = payloadSize;
+}
+
void TcpHeader::SetSourcePort (uint16_t port)
{
m_sourcePort = port;
@@ -130,8 +136,32 @@
Ipv4Address destination,
uint8_t protocol)
{
- m_checksum = 0;
-//XXX requires peeking into IP to get length of the TCP segment
+ Buffer buf = Buffer(12);
+ uint8_t tmp[4];
+ Buffer::Iterator it;
+ uint16_t tcpLength = m_payloadSize + GetSerializedSize();
+
+ buf.AddAtStart(12);
+ it = buf.Begin();
+
+ source.Serialize(tmp);
+ it.Write(tmp, 4); /* source IP address */
+ destination.Serialize(tmp);
+ it.Write(tmp, 4); /* destination IP address */
+ it.WriteU8(0); /* protocol */
+ it.WriteU8(protocol); /* protocol */
+ it.WriteU8(tcpLength >> 8); /* length */
+ it.WriteU8(tcpLength & 0xff); /* length */
+
+ it = buf.Begin();
+ /* we don't CompleteChecksum ( ~ ) now */
+ m_initialChecksum = ~(it.CalculateIpChecksum(12));
+}
+
+bool
+TcpHeader::IsChecksumOk (void) const
+{
+ return m_goodChecksum;
}
TypeId
@@ -188,28 +218,49 @@
}
void TcpHeader::Serialize (Buffer::Iterator start) const
{
- start.WriteHtonU16 (m_sourcePort);
- start.WriteHtonU16 (m_destinationPort);
- start.WriteHtonU32 (m_sequenceNumber);
- start.WriteHtonU32 (m_ackNumber);
- start.WriteHtonU16 (m_length << 12 | m_flags); //reserved bits are all zero
- start.WriteHtonU16 (m_windowSize);
- //XXX calculate checksum here
- start.WriteHtonU16 (m_checksum);
- start.WriteHtonU16 (m_urgentPointer);
+ Buffer::Iterator i = start;
+ uint16_t tcpLength = m_payloadSize + GetSerializedSize();
+ i.WriteHtonU16 (m_sourcePort);
+ i.WriteHtonU16 (m_destinationPort);
+ i.WriteHtonU32 (m_sequenceNumber);
+ i.WriteHtonU32 (m_ackNumber);
+ i.WriteHtonU16 (m_length << 12 | m_flags); //reserved bits are all zero
+ i.WriteHtonU16 (m_windowSize);
+ i.WriteHtonU16 (0);
+ i.WriteHtonU16 (m_urgentPointer);
+
+ if(m_calcChecksum)
+ {
+ i = start;
+ uint16_t checksum = i.CalculateIpChecksum(tcpLength, m_initialChecksum);
+
+ i = start;
+ i.Next(16);
+ i.WriteU16(checksum);
+ }
}
uint32_t TcpHeader::Deserialize (Buffer::Iterator start)
{
- m_sourcePort = start.ReadNtohU16 ();
- m_destinationPort = start.ReadNtohU16 ();
- m_sequenceNumber = start.ReadNtohU32 ();
- m_ackNumber = start.ReadNtohU32 ();
- uint16_t field = start.ReadNtohU16 ();
+ Buffer::Iterator i = start;
+ m_sourcePort = i.ReadNtohU16 ();
+ m_destinationPort = i.ReadNtohU16 ();
+ m_sequenceNumber = i.ReadNtohU32 ();
+ m_ackNumber = i.ReadNtohU32 ();
+ uint16_t field = i.ReadNtohU16 ();
m_flags = field & 0x3F;
m_length = field>>12;
- m_windowSize = start.ReadNtohU16 ();
- m_checksum = start.ReadNtohU16 ();
- m_urgentPointer = start.ReadNtohU16 ();
+ m_windowSize = i.ReadNtohU16 ();
+ m_checksum = i.ReadU16 ();
+ m_urgentPointer = i.ReadNtohU16 ();
+
+ if(m_calcChecksum)
+ {
+ i = start;
+ uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
+
+ m_goodChecksum = (checksum == 0);
+ }
+
return GetSerializedSize ();
}
--- a/src/internet-stack/tcp-header.h Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-header.h Tue Jul 01 10:52:11 2008 -0700
@@ -47,7 +47,7 @@
/**
* \brief Enable checksum calculation for TCP (XXX currently has no effect)
*/
- static void EnableChecksums (void);
+ void EnableChecksums (void);
//Setters
/**
* \param port The source port for this TcpHeader
@@ -150,6 +150,17 @@
virtual void Serialize (Buffer::Iterator start) const;
virtual uint32_t Deserialize (Buffer::Iterator start);
+ /**
+ * \param size The payload size in bytes
+ */
+ void SetPayloadSize (uint16_t size);
+
+ /**
+ * \brief Is the TCP checksum correct ?
+ * \returns true if the checksum is correct, false otherwise.
+ */
+ bool IsChecksumOk (void) const;
+
private:
uint16_t m_sourcePort;
uint16_t m_destinationPort;
@@ -158,10 +169,14 @@
uint8_t m_length; // really a uint4_t
uint8_t m_flags; // really a uint6_t
uint16_t m_windowSize;
+ uint16_t m_urgentPointer;
+ uint16_t m_payloadSize;
+ uint16_t m_initialChecksum;
uint16_t m_checksum;
- uint16_t m_urgentPointer;
- static bool m_calcChecksum;
+ bool m_calcChecksum;
+ bool m_goodChecksum;
+
};
}; // namespace ns3
--- a/src/internet-stack/tcp-l4-protocol.cc Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-l4-protocol.cc Tue Jul 01 10:52:11 2008 -0700
@@ -21,6 +21,7 @@
#include "ns3/assert.h"
#include "ns3/log.h"
#include "ns3/nstime.h"
+#include "ns3/boolean.h"
#include "ns3/packet.h"
#include "ns3/node.h"
@@ -328,6 +329,11 @@
ObjectFactoryValue (GetDefaultRttEstimatorFactory ()),
MakeObjectFactoryAccessor (&TcpL4Protocol::m_rttFactory),
MakeObjectFactoryChecker ())
+ .AddAttribute ("CalcChecksum", "If true, we calculate the checksum of outgoing packets"
+ " and verify the checksum of incoming packets.",
+ BooleanValue (false),
+ MakeBooleanAccessor (&TcpL4Protocol::m_calcChecksum),
+ MakeBooleanChecker ())
;
return tid;
}
@@ -439,14 +445,31 @@
NS_LOG_FUNCTION (this << packet << source << destination << incomingInterface);
TcpHeader tcpHeader;
+ if(m_calcChecksum)
+ {
+ tcpHeader.EnableChecksums();
+ }
+ /* XXX very dirty but needs this to AddHeader again because of checksum */
+ tcpHeader.SetLength(5); /* XXX TCP without options */
+ tcpHeader.SetPayloadSize(packet->GetSize() - tcpHeader.GetSerializedSize());
+ tcpHeader.InitializeChecksum(source, destination, PROT_NUMBER);
+
//these two do a peek, so that the packet can be forwarded up
packet->RemoveHeader (tcpHeader);
+
NS_LOG_LOGIC("TcpL4Protocol " << this
<< " receiving seq " << tcpHeader.GetSequenceNumber()
<< " ack " << tcpHeader.GetAckNumber()
<< " flags "<< std::hex << (int)tcpHeader.GetFlags() << std::dec
<< " data size " << packet->GetSize());
- packet->AddHeader (tcpHeader);
+
+ if(!tcpHeader.IsChecksumOk ())
+ {
+ NS_LOG_INFO("Bad checksum, dropping packet!");
+ return;
+ }
+
+ packet->AddHeader (tcpHeader);
NS_LOG_LOGIC ("TcpL4Protocol "<<this<<" received a packet");
Ipv4EndPointDemux::EndPoints endPoints =
m_endPoints->Lookup (destination, tcpHeader.GetDestinationPort (),
@@ -478,6 +501,11 @@
TcpHeader tcpHeader;
tcpHeader.SetDestinationPort (dport);
tcpHeader.SetSourcePort (sport);
+ tcpHeader.SetPayloadSize(packet->GetSize());
+ if(m_calcChecksum)
+ {
+ tcpHeader.EnableChecksums();
+ }
tcpHeader.InitializeChecksum (saddr,
daddr,
PROT_NUMBER);
@@ -507,8 +535,13 @@
// XXX outgoingHeader cannot be logged
outgoingHeader.SetLength (5); //header length in units of 32bit words
- outgoingHeader.SetChecksum (0); //XXX
- outgoingHeader.SetUrgentPointer (0); //XXX
+ outgoingHeader.SetPayloadSize(packet->GetSize());
+ /* outgoingHeader.SetUrgentPointer (0); //XXX */
+ if(m_calcChecksum)
+ {
+ outgoingHeader.EnableChecksums();
+ }
+ outgoingHeader.InitializeChecksum(saddr, daddr, PROT_NUMBER);
packet->AddHeader (outgoingHeader);
--- a/src/internet-stack/tcp-l4-protocol.h Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-l4-protocol.h Tue Jul 01 10:52:11 2008 -0700
@@ -117,6 +117,9 @@
void SendPacket (Ptr<Packet>, TcpHeader,
Ipv4Address, Ipv4Address);
static ObjectFactory GetDefaultRttEstimatorFactory (void);
+
+ bool m_goodChecksum;
+ bool m_calcChecksum;
};
}; // namespace ns3
--- a/src/internet-stack/udp-header.cc Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/udp-header.cc Tue Jul 01 10:52:11 2008 -0700
@@ -19,14 +19,11 @@
*/
#include "udp-header.h"
-#include "ipv4-checksum.h"
namespace ns3 {
NS_OBJECT_ENSURE_REGISTERED (UdpHeader);
-bool UdpHeader::m_calcChecksum = false;
-
/* The magic values below are used only for debugging.
* They can be used to easily detect memory corruption
* problems so you can see the patterns in memory.
@@ -35,7 +32,10 @@
: m_sourcePort (0xfffd),
m_destinationPort (0xfffd),
m_payloadSize (0xfffd),
- m_initialChecksum (0)
+ m_initialChecksum (0),
+ m_checksum(0),
+ m_calcChecksum(false),
+ m_goodChecksum(true)
{}
UdpHeader::~UdpHeader ()
{
@@ -80,18 +80,35 @@
Ipv4Address destination,
uint8_t protocol)
{
- uint8_t buf[12];
- source.Serialize (buf);
- destination.Serialize (buf+4);
- buf[8] = 0;
- buf[9] = protocol;
- uint16_t udpLength = m_payloadSize + GetSerializedSize ();
- buf[10] = udpLength >> 8;
- buf[11] = udpLength & 0xff;
+ Buffer buf = Buffer(12);
+ uint8_t tmp[4];
+ Buffer::Iterator it;
+ uint16_t udpLength = m_payloadSize + GetSerializedSize();
+
+ buf.AddAtStart(12);
+ it = buf.Begin();
- m_initialChecksum = Ipv4ChecksumCalculate (0, buf, 12);
+ source.Serialize(tmp);
+ it.Write(tmp, 4); /* source IP address */
+ destination.Serialize(tmp);
+ it.Write(tmp, 4); /* destination IP address */
+ it.WriteU8(0); /* protocol */
+ it.WriteU8(protocol); /* protocol */
+ it.WriteU8(udpLength >> 8); /* length */
+ it.WriteU8(udpLength & 0xff); /* length */
+
+ it = buf.Begin();
+ /* we don't CompleteChecksum ( ~ ) now */
+ m_initialChecksum = ~(it.CalculateIpChecksum(12));
}
+bool
+UdpHeader::IsChecksumOk (void) const
+{
+ return m_goodChecksum;
+}
+
+
TypeId
UdpHeader::GetTypeId (void)
{
@@ -125,23 +142,21 @@
UdpHeader::Serialize (Buffer::Iterator start) const
{
Buffer::Iterator i = start;
+ uint16_t udpLength = m_payloadSize + GetSerializedSize();
+
i.WriteHtonU16 (m_sourcePort);
i.WriteHtonU16 (m_destinationPort);
- i.WriteHtonU16 (m_payloadSize + GetSerializedSize ());
+ i.WriteHtonU16 (udpLength);
i.WriteU16 (0);
- if (m_calcChecksum)
+ if (m_calcChecksum)
{
-#if 0
- //XXXX
- uint16_t checksum = Ipv4ChecksumCalculate (m_initialChecksum,
- buffer->PeekData (),
- GetSerializedSize () + m_payloadSize);
- checksum = Ipv4ChecksumComplete (checksum);
- i = buffer->Begin ();
- i.Next (6);
- i.WriteU16 (checksum);
-#endif
+ i = start;
+ uint16_t checksum = i.CalculateIpChecksum(udpLength, m_initialChecksum);
+
+ i = start;
+ i.Next(6);
+ i.WriteU16(checksum);
}
}
uint32_t
@@ -151,10 +166,16 @@
m_sourcePort = i.ReadNtohU16 ();
m_destinationPort = i.ReadNtohU16 ();
m_payloadSize = i.ReadNtohU16 () - GetSerializedSize ();
- if (m_calcChecksum)
- {
- // XXX verify checksum.
- }
+ m_checksum = i.ReadU16();
+
+ if(m_calcChecksum)
+ {
+ i = start;
+ uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
+
+ m_goodChecksum = (checksum == 0);
+ }
+
return GetSerializedSize ();
}
--- a/src/internet-stack/udp-header.h Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/udp-header.h Tue Jul 01 10:52:11 2008 -0700
@@ -49,7 +49,7 @@
/**
* \brief Enable checksum calculation for UDP (XXX currently has no effect)
*/
- static void EnableChecksums (void);
+ void EnableChecksums (void);
/**
* \param port the destination port for this UdpHeader
*/
@@ -93,13 +93,21 @@
virtual void Serialize (Buffer::Iterator start) const;
virtual uint32_t Deserialize (Buffer::Iterator start);
+ /**
+ * \brief Is the UDP checksum correct ?
+ * \returns true if the checksum is correct, false otherwise.
+ */
+ bool IsChecksumOk (void) const;
+
private:
uint16_t m_sourcePort;
uint16_t m_destinationPort;
uint16_t m_payloadSize;
uint16_t m_initialChecksum;
+ uint16_t m_checksum;
- static bool m_calcChecksum;
+ bool m_calcChecksum;
+ bool m_goodChecksum;
};
} // namespace ns3
--- a/src/internet-stack/udp-l4-protocol.cc Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/udp-l4-protocol.cc Tue Jul 01 10:52:11 2008 -0700
@@ -22,6 +22,7 @@
#include "ns3/assert.h"
#include "ns3/packet.h"
#include "ns3/node.h"
+#include "ns3/boolean.h"
#include "udp-l4-protocol.h"
#include "udp-header.h"
@@ -45,6 +46,11 @@
static TypeId tid = TypeId ("ns3::UdpL4Protocol")
.SetParent<Ipv4L4Protocol> ()
.AddConstructor<UdpL4Protocol> ()
+ .AddAttribute ("CalcChecksum", "If true, we calculate the checksum of outgoing packets"
+ " and verify the checksum of incoming packets.",
+ BooleanValue (false),
+ MakeBooleanAccessor (&UdpL4Protocol::m_calcChecksum),
+ MakeBooleanChecker ())
;
return tid;
}
@@ -151,9 +157,23 @@
Ptr<Ipv4Interface> interface)
{
NS_LOG_FUNCTION (this << packet << source << destination);
+ UdpHeader udpHeader;
+ if(m_calcChecksum)
+ {
+ udpHeader.EnableChecksums();
+ }
- UdpHeader udpHeader;
+ udpHeader.SetPayloadSize (packet->GetSize () - udpHeader.GetSerializedSize ());
+ udpHeader.InitializeChecksum (source, destination, PROT_NUMBER);
+
packet->RemoveHeader (udpHeader);
+
+ if(!udpHeader.IsChecksumOk ())
+ {
+ NS_LOG_INFO("Bad checksum : dropping packet!");
+ return;
+ }
+
Ipv4EndPointDemux::EndPoints endPoints =
m_endPoints->Lookup (destination, udpHeader.GetDestinationPort (),
source, udpHeader.GetSourcePort (), interface);
@@ -172,6 +192,10 @@
NS_LOG_FUNCTION (this << packet << saddr << daddr << sport << dport);
UdpHeader udpHeader;
+ if(m_calcChecksum)
+ {
+ udpHeader.EnableChecksums();
+ }
udpHeader.SetDestinationPort (dport);
udpHeader.SetSourcePort (sport);
udpHeader.SetPayloadSize (packet->GetSize ());
--- a/src/internet-stack/udp-l4-protocol.h Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/udp-l4-protocol.h Tue Jul 01 10:52:11 2008 -0700
@@ -93,6 +93,7 @@
private:
Ptr<Node> m_node;
Ipv4EndPointDemux *m_endPoints;
+ bool m_calcChecksum;
};
}; // namespace ns3
--- a/src/node/ipv4-header.cc Mon Jun 30 22:41:22 2008 -0700
+++ b/src/node/ipv4-header.cc Tue Jul 01 10:52:11 2008 -0700
@@ -38,6 +38,7 @@
m_protocol (0),
m_flags (0),
m_fragmentOffset (0),
+ m_checksum(0),
m_goodChecksum (true)
{}
@@ -177,23 +178,6 @@
return m_goodChecksum;
}
-uint16_t
-Ipv4Header::ChecksumCalculate(Buffer::Iterator &i, uint16_t size)
-{
- /* see RFC 1071 to understand this code. */
- uint32_t sum = 0;
-
- for (int j = 0; j < size/2; j++)
- sum += i.ReadU16 ();
-
- if (size & 1)
- sum += i.ReadU8 ();
-
- while (sum >> 16)
- sum = (sum & 0xffff) + (sum >> 16);
- return ~sum;
-}
-
TypeId
Ipv4Header::GetTypeId (void)
{
@@ -282,7 +266,7 @@
if (m_calcChecksum)
{
i = start;
- uint16_t checksum = ChecksumCalculate(i, 20);
+ uint16_t checksum = i.CalculateIpChecksum(20);
NS_LOG_LOGIC ("checksum=" <<checksum);
i = start;
i.Next (10);
@@ -318,14 +302,15 @@
m_fragmentOffset <<= 3;
m_ttl = i.ReadU8 ();
m_protocol = i.ReadU8 ();
- i.Next (2); // checksum
+ m_checksum = i.ReadU16();
+ /* i.Next (2); // checksum */
m_source.Set (i.ReadNtohU32 ());
m_destination.Set (i.ReadNtohU32 ());
if (m_calcChecksum)
{
i = start;
- uint16_t checksum = ChecksumCalculate(i, headerSize);
+ uint16_t checksum = i.CalculateIpChecksum(headerSize);
NS_LOG_LOGIC ("checksum=" <<checksum);
m_goodChecksum = (checksum == 0);
--- a/src/node/ipv4-header.h Mon Jun 30 22:41:22 2008 -0700
+++ b/src/node/ipv4-header.h Tue Jul 01 10:52:11 2008 -0700
@@ -146,7 +146,6 @@
virtual uint32_t Deserialize (Buffer::Iterator start);
private:
- static uint16_t ChecksumCalculate(Buffer::Iterator &i, uint16_t len);
enum FlagsE {
DONT_FRAGMENT = (1<<0),
MORE_FRAGMENTS = (1<<1)
@@ -163,6 +162,7 @@
uint16_t m_fragmentOffset : 13;
Ipv4Address m_source;
Ipv4Address m_destination;
+ uint16_t m_checksum;
bool m_goodChecksum;
};