--- a/src/internet-stack/tcp-header.cc Mon Jun 30 22:41:22 2008 -0700
+++ b/src/internet-stack/tcp-header.cc Tue Jul 01 10:52:11 2008 -0700
@@ -28,8 +28,6 @@
NS_OBJECT_ENSURE_REGISTERED (TcpHeader);
-bool TcpHeader::m_calcChecksum = false;
-
TcpHeader::TcpHeader ()
: m_sourcePort (0),
m_destinationPort (0),
@@ -38,8 +36,11 @@
m_length (5),
m_flags (0),
m_windowSize (0xffff),
+ m_urgentPointer (0),
+ m_initialChecksum(0),
m_checksum (0),
- m_urgentPointer (0)
+ m_calcChecksum(false),
+ m_goodChecksum(true)
{}
TcpHeader::~TcpHeader ()
@@ -51,6 +52,11 @@
m_calcChecksum = true;
}
+void TcpHeader::SetPayloadSize(uint16_t payloadSize)
+{
+ m_payloadSize = payloadSize;
+}
+
void TcpHeader::SetSourcePort (uint16_t port)
{
m_sourcePort = port;
@@ -130,8 +136,32 @@
Ipv4Address destination,
uint8_t protocol)
{
- m_checksum = 0;
-//XXX requires peeking into IP to get length of the TCP segment
+ Buffer buf = Buffer(12);
+ uint8_t tmp[4];
+ Buffer::Iterator it;
+ uint16_t tcpLength = m_payloadSize + GetSerializedSize();
+
+ buf.AddAtStart(12);
+ it = buf.Begin();
+
+ source.Serialize(tmp);
+ it.Write(tmp, 4); /* source IP address */
+ destination.Serialize(tmp);
+ it.Write(tmp, 4); /* destination IP address */
+ it.WriteU8(0); /* protocol */
+ it.WriteU8(protocol); /* protocol */
+ it.WriteU8(tcpLength >> 8); /* length */
+ it.WriteU8(tcpLength & 0xff); /* length */
+
+ it = buf.Begin();
+ /* we don't CompleteChecksum ( ~ ) now */
+ m_initialChecksum = ~(it.CalculateIpChecksum(12));
+}
+
+bool
+TcpHeader::IsChecksumOk (void) const
+{
+ return m_goodChecksum;
}
TypeId
@@ -188,28 +218,49 @@
}
void TcpHeader::Serialize (Buffer::Iterator start) const
{
- start.WriteHtonU16 (m_sourcePort);
- start.WriteHtonU16 (m_destinationPort);
- start.WriteHtonU32 (m_sequenceNumber);
- start.WriteHtonU32 (m_ackNumber);
- start.WriteHtonU16 (m_length << 12 | m_flags); //reserved bits are all zero
- start.WriteHtonU16 (m_windowSize);
- //XXX calculate checksum here
- start.WriteHtonU16 (m_checksum);
- start.WriteHtonU16 (m_urgentPointer);
+ Buffer::Iterator i = start;
+ uint16_t tcpLength = m_payloadSize + GetSerializedSize();
+ i.WriteHtonU16 (m_sourcePort);
+ i.WriteHtonU16 (m_destinationPort);
+ i.WriteHtonU32 (m_sequenceNumber);
+ i.WriteHtonU32 (m_ackNumber);
+ i.WriteHtonU16 (m_length << 12 | m_flags); //reserved bits are all zero
+ i.WriteHtonU16 (m_windowSize);
+ i.WriteHtonU16 (0);
+ i.WriteHtonU16 (m_urgentPointer);
+
+ if(m_calcChecksum)
+ {
+ i = start;
+ uint16_t checksum = i.CalculateIpChecksum(tcpLength, m_initialChecksum);
+
+ i = start;
+ i.Next(16);
+ i.WriteU16(checksum);
+ }
}
uint32_t TcpHeader::Deserialize (Buffer::Iterator start)
{
- m_sourcePort = start.ReadNtohU16 ();
- m_destinationPort = start.ReadNtohU16 ();
- m_sequenceNumber = start.ReadNtohU32 ();
- m_ackNumber = start.ReadNtohU32 ();
- uint16_t field = start.ReadNtohU16 ();
+ Buffer::Iterator i = start;
+ m_sourcePort = i.ReadNtohU16 ();
+ m_destinationPort = i.ReadNtohU16 ();
+ m_sequenceNumber = i.ReadNtohU32 ();
+ m_ackNumber = i.ReadNtohU32 ();
+ uint16_t field = i.ReadNtohU16 ();
m_flags = field & 0x3F;
m_length = field>>12;
- m_windowSize = start.ReadNtohU16 ();
- m_checksum = start.ReadNtohU16 ();
- m_urgentPointer = start.ReadNtohU16 ();
+ m_windowSize = i.ReadNtohU16 ();
+ m_checksum = i.ReadU16 ();
+ m_urgentPointer = i.ReadNtohU16 ();
+
+ if(m_calcChecksum)
+ {
+ i = start;
+ uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum);
+
+ m_goodChecksum = (checksum == 0);
+ }
+
return GetSerializedSize ();
}