src/internet-stack/ipv4-l3-protocol.cc
changeset 3820 c04ecfdce1ef
parent 3744 bb6876ea0851
child 3877 5091e3a14b26
--- a/src/internet-stack/ipv4-l3-protocol.cc	Mon Oct 27 15:12:00 2008 -0400
+++ b/src/internet-stack/ipv4-l3-protocol.cc	Wed Oct 29 11:18:39 2008 -0700
@@ -35,9 +35,11 @@
 
 #include "ipv4-l3-protocol.h"
 #include "ipv4-l4-protocol.h"
+#include "icmpv4-l4-protocol.h"
 #include "ipv4-interface.h"
 #include "ipv4-loopback-interface.h"
 #include "arp-ipv4-interface.h"
+#include "ipv4-raw-socket-impl.h"
 
 NS_LOG_COMPONENT_DEFINE ("Ipv4L3Protocol");
 
@@ -120,6 +122,30 @@
   SetupLoopback ();
 }
 
+Ptr<Socket> 
+Ipv4L3Protocol::CreateRawSocket (void)
+{
+  NS_LOG_FUNCTION (this);
+  Ptr<Ipv4RawSocketImpl> socket = CreateObject<Ipv4RawSocketImpl> ();
+  socket->SetNode (m_node);
+  m_sockets.push_back (socket);
+  return socket;
+}
+void 
+Ipv4L3Protocol::DeleteRawSocket (Ptr<Socket> socket)
+{
+  NS_LOG_FUNCTION (this << socket);
+  for (SocketList::iterator i = m_sockets.begin (); i != m_sockets.end (); ++i)
+    {
+      if ((*i) == socket)
+        {
+          m_sockets.erase (i);
+          return;
+        }
+    }
+  return;
+}
+
 void 
 Ipv4L3Protocol::DoDispose (void)
 {
@@ -514,6 +540,12 @@
       return;
     }
 
+  for (SocketList::iterator i = m_sockets.begin (); i != m_sockets.end (); ++i)
+    {
+      Ptr<Ipv4RawSocketImpl> socket = *i;
+      socket->ForwardUp (packet, ipHeader, device);
+    }
+
   if (Forwarding (index, packet, ipHeader, device)) 
     {
       return;
@@ -522,6 +554,25 @@
   ForwardUp (packet, ipHeader, ipv4Interface);
 }
 
+Ptr<Icmpv4L4Protocol> 
+Ipv4L3Protocol::GetIcmp (void) const
+{
+  Ptr<Ipv4L4Protocol> prot = GetProtocol (Icmpv4L4Protocol::GetStaticProtocolNumber ());
+  if (prot != 0)
+    {
+      return prot->GetObject<Icmpv4L4Protocol> ();
+    }
+  else
+    {
+      return 0;
+    }
+}
+
+bool
+Ipv4L3Protocol::IsUnicast (Ipv4Address ad, Ipv4Mask interfaceMask) const
+{
+  return !ad.IsMulticast () && !ad.IsSubnetDirectedBroadcast (interfaceMask);
+}
 
 void 
 Ipv4L3Protocol::Send (Ptr<Packet> packet, 
@@ -542,16 +593,29 @@
   ipHeader.SetDestination (destination);
   ipHeader.SetProtocol (protocol);
   ipHeader.SetPayloadSize (packet->GetSize ());
-  ipHeader.SetTtl (m_defaultTtl);
-  ipHeader.SetMayFragment ();
   ipHeader.SetIdentification (m_identification);
 
   m_identification ++;
 
+  SocketSetDontFragmentTag dfTag;
+  bool found = packet->FindFirstMatchingTag (dfTag);
+  if (found)
+    {
+      if (dfTag.IsEnabled ())
+        {
+          ipHeader.SetDontFragment ();
+        }
+      else
+        {
+          ipHeader.SetMayFragment ();
+        }
+    }
+  
+
   // Set TTL to 1 if it is a broadcast packet of any type.  Otherwise,
   // possibly override the default TTL if the packet is tagged
   SocketIpTtlTag tag;
-  bool found = packet->FindFirstMatchingTag (tag);
+  found = packet->FindFirstMatchingTag (tag);
 
   if (destination.IsBroadcast ()) 
     {
@@ -564,6 +628,7 @@
     }
   else
     {
+      ipHeader.SetTtl (m_defaultTtl);
       uint32_t ifaceIndex = 0;
       for (Ipv4InterfaceList::iterator ifaceIter = m_interfaces.begin ();
            ifaceIter != m_interfaces.end (); ifaceIter++, ifaceIndex++)
@@ -585,10 +650,28 @@
           Ptr<Ipv4Interface> outInterface = *ifaceIter;
           Ptr<Packet> packetCopy = packet->Copy ();
 
-          NS_ASSERT (packetCopy->GetSize () <= outInterface->GetMtu ());
           packetCopy->AddHeader (ipHeader);
-          m_txTrace (packetCopy, ifaceIndex);
-          outInterface->Send (packetCopy, destination);
+          if (packetCopy->GetSize () > outInterface->GetMtu () &&
+              ipHeader.IsDontFragment () &&
+              IsUnicast (ipHeader.GetDestination (), outInterface->GetNetworkMask ()))
+            {
+              Ptr<Icmpv4L4Protocol> icmp = GetIcmp ();
+              NS_ASSERT (icmp != 0);
+              icmp->SendDestUnreachFragNeeded (ipHeader, packet, outInterface->GetMtu ());
+              m_dropTrace (packetCopy);
+            }
+          else if (packet->GetSize () > outInterface->GetMtu () &&
+                   !ipHeader.IsDontFragment ())
+            {
+              NS_LOG_LOGIC ("Too big: need fragmentation but no frag support.");
+              m_dropTrace (packet);
+            }
+          else
+            {
+              NS_ASSERT (packetCopy->GetSize () <= outInterface->GetMtu ());
+              m_txTrace (packetCopy, ifaceIndex);
+              outInterface->Send (packetCopy, destination);
+            }
         }
     }
   else
@@ -625,17 +708,38 @@
   NS_LOG_LOGIC ("Send via interface " << route.GetInterface ());
 
   Ptr<Ipv4Interface> outInterface = GetInterface (route.GetInterface ());
-  NS_ASSERT (packet->GetSize () <= outInterface->GetMtu ());
-  m_txTrace (packet, route.GetInterface ());
-  if (route.IsGateway ()) 
+  if (packet->GetSize () > outInterface->GetMtu () &&
+      ipHeader.IsDontFragment () &&
+      IsUnicast (ipHeader.GetDestination (), outInterface->GetNetworkMask ()))
+    {
+      NS_LOG_LOGIC ("Too big: need fragmentation but not allowed");
+      Ptr<Icmpv4L4Protocol> icmp = GetIcmp ();
+      NS_ASSERT (icmp != 0);
+      Ptr<Packet> copyNoHeader = packet->Copy ();
+      Ipv4Header tmp;
+      copyNoHeader->RemoveHeader (tmp);
+      icmp->SendDestUnreachFragNeeded (ipHeader, copyNoHeader, outInterface->GetMtu ());
+      m_dropTrace (packet);
+    }
+  else if (packet->GetSize () > outInterface->GetMtu () &&
+           !ipHeader.IsDontFragment ())
     {
-      NS_LOG_LOGIC ("Send to gateway " << route.GetGateway ());
-      outInterface->Send (packet, route.GetGateway ());
-    } 
-  else 
+      NS_LOG_LOGIC ("Too big: need fragmentation but no frag support.");
+      m_dropTrace (packet);
+    }
+  else
     {
-      NS_LOG_LOGIC ("Send to destination " << ipHeader.GetDestination ());
-      outInterface->Send (packet, ipHeader.GetDestination ());
+      m_txTrace (packet, route.GetInterface ());
+      if (route.IsGateway ()) 
+        {
+          NS_LOG_LOGIC ("Send to gateway " << route.GetGateway ());
+          outInterface->Send (packet, route.GetGateway ());
+        } 
+      else 
+        {
+          NS_LOG_LOGIC ("Send to destination " << ipHeader.GetDestination ());
+          outInterface->Send (packet, ipHeader.GetDestination ());
+        }
     }
 }
 
@@ -685,17 +789,6 @@
       NS_LOG_LOGIC ("For me (Ipv4Addr any address)");
       return false;
     }
-
-  if (ipHeader.GetTtl () == 1) 
-    {
-      // Should send ttl expired here
-      // XXX
-      NS_LOG_LOGIC ("Not for me (TTL expired).  Drop");
-      m_dropTrace (packet);
-      return true;
-    }
-  ipHeader.SetTtl (ipHeader.GetTtl () - 1);
-
 //  
 // If this is a to a multicast address and this node is a member of the 
 // indicated group we need to return false so the multicast is forwarded up.
@@ -710,18 +803,39 @@
           // We forward with a packet copy, since forwarding may change
           // the packet, affecting our local delivery
           NS_LOG_LOGIC ("Forwarding (multicast).");
-          Lookup (ifIndex, ipHeader, packet->Copy (),
-          MakeCallback (&Ipv4L3Protocol::SendRealOut, this));
+          DoForward (ifIndex, packet->Copy (), ipHeader);
           return false;
         }   
-    }     
+    }
+
+  DoForward (ifIndex, packet, ipHeader);
+  return true;
+}
+
+void
+Ipv4L3Protocol::DoForward (uint32_t ifIndex, 
+                           Ptr<Packet> packet, 
+                           Ipv4Header ipHeader)
+{
+  NS_LOG_FUNCTION (this << ifIndex << packet << ipHeader);
+
+  ipHeader.SetTtl (ipHeader.GetTtl () - 1);
+  if (ipHeader.GetTtl () == 0)
+    {
+      if (IsUnicast (ipHeader.GetDestination (), GetInterface (ifIndex)->GetNetworkMask ()))
+        {
+          Ptr<Icmpv4L4Protocol> icmp = GetIcmp ();
+          icmp->SendTimeExceededTtl (ipHeader, packet);
+        }
+      m_dropTrace (packet);
+      return;
+    }  
   NS_LOG_LOGIC ("Not for me, forwarding.");
   Lookup (ifIndex, ipHeader, packet,
-  MakeCallback (&Ipv4L3Protocol::SendRealOut, this));
-  
-  return true;
+          MakeCallback (&Ipv4L3Protocol::SendRealOut, this));
 }
 
+
 void
 Ipv4L3Protocol::ForwardUp (Ptr<Packet> p, Ipv4Header const&ip,
                            Ptr<Ipv4Interface> incomingInterface)
@@ -729,7 +843,26 @@
   NS_LOG_FUNCTION (this << p << &ip);
 
   Ptr<Ipv4L4Protocol> protocol = GetProtocol (ip.GetProtocol ());
-  protocol->Receive (p, ip.GetSource (), ip.GetDestination (), incomingInterface);
+  if (protocol != 0)
+    {
+      // we need to make a copy in the unlikely event we hit the
+      // RX_ENDPOINT_UNREACH codepath
+      Ptr<Packet> copy = p->Copy ();
+      enum Ipv4L4Protocol::RxStatus status = 
+        protocol->Receive (p, ip.GetSource (), ip.GetDestination (), incomingInterface);
+      switch (status) {
+      case Ipv4L4Protocol::RX_OK:
+        // fall through
+      case Ipv4L4Protocol::RX_CSUM_FAILED:
+        break;
+      case Ipv4L4Protocol::RX_ENDPOINT_UNREACH:
+        if (IsUnicast (ip.GetDestination (), incomingInterface->GetNetworkMask ()))
+          {
+            GetIcmp ()->SendDestUnreachPort (ip, copy);
+          }
+        break;
+      }
+    }
 }
 
 void