src/common/packet.cc
changeset 3035 644bfc099992
parent 2992 ba52f937610c
child 3039 722cf749a9e3
--- a/src/common/packet.cc	Wed Apr 23 15:01:27 2008 -0700
+++ b/src/common/packet.cc	Thu Apr 24 14:52:59 2008 -0700
@@ -24,6 +24,59 @@
 
 uint32_t Packet::m_globalUid = 0;
 
+TypeId 
+TagIterator::Item::GetTypeId (void) const
+{
+  return m_tid;
+}
+uint32_t 
+TagIterator::Item::GetStart (void) const
+{
+  return m_start;
+}
+uint32_t 
+TagIterator::Item::GetEnd (void) const
+{
+  return m_end;
+}
+void 
+TagIterator::Item::GetTag (Mtag &tag) const
+{
+  if (tag.GetInstanceTypeId () != GetTypeId ())
+    {
+      NS_FATAL_ERROR ("The tag you provided is not of the right type.");
+    }
+  tag.Deserialize (m_buffer);
+}
+TagIterator::Item::Item (TypeId tid, uint32_t start, uint32_t end, MtagBuffer buffer)
+  : m_tid (tid),
+    m_start (start),
+    m_end (end),
+    m_buffer (buffer)
+{}
+bool 
+TagIterator::HasNext (void) const
+{
+  return m_current.HasNext ();
+}
+TagIterator::Item 
+TagIterator::Next (void)
+{
+  MtagList::Iterator::Item i = m_current.Next ();
+  
+  TagIterator::Item item = TagIterator::Item (i.tid, 
+                                              i.start-m_current.GetOffsetStart (), 
+                                              i.end-m_current.GetOffsetStart (), 
+                                              i.buf);
+  
+  
+  return item;
+}
+TagIterator::TagIterator (MtagList::Iterator i)
+  : m_current (i)
+{}
+
+
 void 
 Packet::Ref (void) const
 {
@@ -51,6 +104,7 @@
 Packet::Packet ()
   : m_buffer (),
     m_tags (),
+    m_tagList (),
     m_metadata (m_globalUid, 0),
     m_refCount (1)
 {
@@ -60,6 +114,7 @@
 Packet::Packet (const Packet &o)
   : m_buffer (o.m_buffer),
     m_tags (o.m_tags),
+    m_tagList (o.m_tagList),
     m_metadata (o.m_metadata),
     m_refCount (1)
 {}
@@ -73,6 +128,7 @@
     }
   m_buffer = o.m_buffer;
   m_tags = o.m_tags;
+  m_tagList = o.m_tagList;
   m_metadata = o.m_metadata;
   return *this;
 }
@@ -80,6 +136,7 @@
 Packet::Packet (uint32_t size)
   : m_buffer (size),
     m_tags (),
+    m_tagList (),
     m_metadata (m_globalUid, size),
     m_refCount (1)
 {
@@ -88,6 +145,7 @@
 Packet::Packet (uint8_t const*buffer, uint32_t size)
   : m_buffer (),
     m_tags (),
+    m_tagList (),
     m_metadata (m_globalUid, size),
     m_refCount (1)
 {
@@ -97,9 +155,10 @@
   i.Write (buffer, size);
 }
 
-Packet::Packet (Buffer buffer, Tags tags, PacketMetadata metadata)
+Packet::Packet (const Buffer &buffer, const Tags &tags, const MtagList &tagList, const PacketMetadata &metadata)
   : m_buffer (buffer),
     m_tags (tags),
+    m_tagList (tagList),
     m_metadata (metadata),
     m_refCount (1)
 {}
@@ -113,7 +172,7 @@
   PacketMetadata metadata = m_metadata.CreateFragment (start, end);
   // again, call the constructor directly rather than
   // through Create because it is private.
-  return Ptr<Packet> (new Packet (buffer, m_tags, metadata), false);
+  return Ptr<Packet> (new Packet (buffer, m_tags, m_tagList, metadata), false);
 }
 
 uint32_t 
@@ -126,7 +185,13 @@
 Packet::AddHeader (const Header &header)
 {
   uint32_t size = header.GetSerializedSize ();
-  m_buffer.AddAtStart (size);
+  uint32_t orgStart = m_buffer.GetCurrentStartOffset ();
+  bool resized = m_buffer.AddAtStart (size);
+  if (resized)
+    {
+      m_tagList.AddAtStart (m_buffer.GetCurrentStartOffset () - orgStart,
+                            m_buffer.GetCurrentStartOffset () + size);
+    }
   header.Serialize (m_buffer.Begin ());
   m_metadata.AddHeader (header, size);
 }
@@ -142,7 +207,13 @@
 Packet::AddTrailer (const Trailer &trailer)
 {
   uint32_t size = trailer.GetSerializedSize ();
-  m_buffer.AddAtEnd (size);
+  uint32_t orgEnd = m_buffer.GetCurrentEndOffset ();
+  bool resized = m_buffer.AddAtEnd (size);
+  if (resized)
+    {
+      m_tagList.AddAtEnd (m_buffer.GetCurrentEndOffset () - orgEnd,
+                          m_buffer.GetCurrentEndOffset () - size);
+    }
   Buffer::Iterator end = m_buffer.End ();
   trailer.Serialize (end);
   m_metadata.AddTrailer (trailer, size);
@@ -159,17 +230,28 @@
 void 
 Packet::AddAtEnd (Ptr<const Packet> packet)
 {
+  uint32_t aStart = m_buffer.GetCurrentStartOffset ();
+  uint32_t bEnd = packet->m_buffer.GetCurrentEndOffset ();
   m_buffer.AddAtEnd (packet->m_buffer);
-  /**
-   * XXX: we might need to merge the tag list of the
-   * other packet into the current packet.
-   */
+  uint32_t appendPrependOffset = m_buffer.GetCurrentEndOffset () - packet->m_buffer.GetSize ();
+  m_tagList.AddAtEnd (m_buffer.GetCurrentStartOffset () - aStart, 
+                      appendPrependOffset);
+  MtagList copy = packet->m_tagList;
+  copy.AddAtStart (m_buffer.GetCurrentEndOffset () - bEnd,
+                   appendPrependOffset);
+  m_tagList.Add (copy);
   m_metadata.AddAtEnd (packet->m_metadata);
 }
 void
 Packet::AddPaddingAtEnd (uint32_t size)
 {
-  m_buffer.AddAtEnd (size);
+  uint32_t orgEnd = m_buffer.GetCurrentEndOffset ();
+  bool resized = m_buffer.AddAtEnd (size);
+  if (resized)
+    {
+      m_tagList.AddAtEnd (m_buffer.GetCurrentEndOffset () - orgEnd,
+                          m_buffer.GetCurrentEndOffset () - size);
+    }
   m_metadata.AddPaddingAtEnd (size);
 }
 void 
@@ -189,6 +271,7 @@
 Packet::RemoveAllTags (void)
 {
   m_tags.RemoveAll ();
+  m_tagList.RemoveAll ();
 }
 
 uint8_t const *
@@ -207,6 +290,7 @@
 Packet::PrintTags (std::ostream &os) const
 {
   m_tags.Print (os, " ");
+  // XXX: tagList.
 }
 
 void 
@@ -388,6 +472,38 @@
   buffer.RemoveAtStart (metadataDeserialized);
 }
 
+void 
+Packet::AddMtag (const Mtag &tag) const
+{
+  MtagList *list = const_cast<MtagList *> (&m_tagList);
+  MtagBuffer buffer = list->Add (tag.GetInstanceTypeId (), tag.GetSerializedSize (), 
+                                 m_buffer.GetCurrentStartOffset (),
+                                 m_buffer.GetCurrentEndOffset ());
+  tag.Serialize (buffer);
+}
+TagIterator 
+Packet::GetTagIterator (void) const
+{
+  return TagIterator (m_tagList.Begin (m_buffer.GetCurrentStartOffset (), m_buffer.GetCurrentEndOffset ()));
+}
+
+bool 
+Packet::FindFirstMatchingTag (Mtag &tag) const
+{
+  TypeId tid = tag.GetInstanceTypeId ();
+  TagIterator i = GetTagIterator ();
+  while (i.HasNext ())
+    {
+      TagIterator::Item item = i.Next ();
+      if (tid == item.GetTypeId ())
+        {
+          item.GetTag (tag);
+          return true;
+        }
+    }
+  return false;
+}
+
 std::ostream& operator<< (std::ostream& os, const Packet &packet)
 {
   packet.Print (os);
@@ -403,24 +519,136 @@
 
 #include "ns3/test.h"
 #include <string>
+#include <stdarg.h>
+
+using namespace ns3;
+
+namespace {
+
+class ATestTagBase : public Mtag
+{
+public:
+  ATestTagBase () : m_error (false) {}
+  bool m_error;
+};
+
+template <int N>
+class ATestTag : public ATestTagBase
+{
+public:
+  static TypeId GetTypeId (void) {
+    std::ostringstream oss;
+    oss << "anon::ATestTag<" << N << ">";
+    static TypeId tid = TypeId (oss.str ().c_str ())
+      .SetParent<Mtag> ()
+      .AddConstructor<ATestTag<N> > ()
+      .HideFromDocumentation ()
+      ;
+    return tid;
+  }
+  virtual TypeId GetInstanceTypeId (void) const {
+    return GetTypeId ();
+  }
+  virtual uint32_t GetSerializedSize (void) const {
+    return N;
+  }
+  virtual void Serialize (MtagBuffer buf) const {
+    for (uint32_t i = 0; i < N; ++i)
+      {
+        buf.WriteU8 (N);
+      }
+  }
+  virtual void Deserialize (MtagBuffer buf) {
+    for (uint32_t i = 0; i < N; ++i)
+      {
+        uint8_t v = buf.ReadU8 ();
+        if (v != N)
+          {
+            m_error = true;
+          }
+      }
+  }
+  ATestTag ()
+    : ATestTagBase () {}
+};
+
+struct Expected
+{
+  Expected (uint32_t n_, uint32_t start_, uint32_t end_)
+    : n (n_), start (start_), end (end_) {}
+  
+  uint32_t n;
+  uint32_t start;
+  uint32_t end;
+};
+
+}
+
+#define E(a,b,c) a,b,c
+
+#define CHECK(p, n, ...)                                \
+  if (!DoCheck (p, __FILE__, __LINE__, n, __VA_ARGS__)) \
+    {                                                   \
+      result = false;                                   \
+    }
 
 namespace ns3 {
 
-class PacketTest: public Test {
+
+class PacketTest: public Test 
+{
 public:
+  PacketTest ();
   virtual bool RunTests (void);
-  PacketTest ();
+private:
+  bool DoCheck (Ptr<const Packet> p, const char *file, int line, uint32_t n, ...);
 };
 
 
 PacketTest::PacketTest ()
   : Test ("Packet") {}
 
+bool
+PacketTest::DoCheck (Ptr<const Packet> p, const char *file, int line, uint32_t n, ...)
+{
+  bool result = true;
+  std::vector<struct Expected> expected;
+  va_list ap;
+  va_start (ap, n);
+  for (uint32_t k = 0; k < n; ++k)
+    {
+      uint32_t N = va_arg (ap, uint32_t);
+      uint32_t start = va_arg (ap, uint32_t);
+      uint32_t end = va_arg (ap, uint32_t);
+      expected.push_back (Expected (N, start, end));
+    }
+  va_end (ap);
+
+  TagIterator i = p->GetTagIterator ();
+  uint32_t j = 0;
+  while (i.HasNext () && j < expected.size ())
+    {
+      TagIterator::Item item = i.Next ();
+      struct Expected e = expected[j];
+      std::ostringstream oss;
+      oss << "anon::ATestTag<" << e.n << ">";
+      NS_TEST_ASSERT_EQUAL_FILELINE (item.GetTypeId ().GetName (), oss.str (), file, line);
+      NS_TEST_ASSERT_EQUAL_FILELINE (item.GetStart (), e.start, file, line);
+      NS_TEST_ASSERT_EQUAL_FILELINE (item.GetEnd (), e.end, file, line);
+      ATestTagBase *tag = dynamic_cast<ATestTagBase *> (item.GetTypeId ().GetConstructor () ());
+      NS_TEST_ASSERT (tag != 0);
+      item.GetTag (*tag);
+      NS_TEST_ASSERT (!tag->m_error);
+      delete tag;
+      j++;
+    }
+  return result;
+}
 
 bool
 PacketTest::RunTests (void)
 {
-  bool ok = true;
+  bool result = true;
 
   Ptr<Packet> pkt1 = Create<Packet> (reinterpret_cast<const uint8_t*> ("hello"), 5);
   Ptr<Packet> pkt2 = Create<Packet> (reinterpret_cast<const uint8_t*> (" world"), 6);
@@ -428,25 +656,49 @@
   packet->AddAtEnd (pkt1);
   packet->AddAtEnd (pkt2);
   
-  if (packet->GetSize () != 11)
-    {
-      Failure () << "expected size 11, got " << packet->GetSize () << std::endl;
-      ok = false;
-    }
+  NS_TEST_ASSERT_EQUAL (packet->GetSize (), 11);
 
   std::string msg = std::string (reinterpret_cast<const char *>(packet->PeekData ()),
                                  packet->GetSize ());
-  if (msg != "hello world")
-    {
-      Failure () << "expected 'hello world', got '" << msg << "'" << std::endl;
-      ok = false;
-    }
+  NS_TEST_ASSERT_EQUAL (msg, "hello world");
+
+
+  Ptr<const Packet> p = Create<Packet> (1000);
+
+  p->AddMtag (ATestTag<1> ());
+  CHECK (p, 1, E (1, 0, 1000));
+  Ptr<const Packet> copy = p->Copy ();
+  CHECK (copy, 1, E (1, 0, 1000));
+
+  p->AddMtag (ATestTag<2> ());
+  CHECK (p, 2, E (1, 0, 1000), E(2, 0, 1000));
+  CHECK (copy, 1, E (1, 0, 1000));
 
-  return ok;
+  Ptr<Packet> frag0 = p->CreateFragment (0, 10);
+  Ptr<Packet> frag1 = p->CreateFragment (10, 90);
+  Ptr<const Packet> frag2 = p->CreateFragment (100, 900);
+  frag0->AddMtag (ATestTag<3> ());
+  CHECK (frag0, 3, E (1, 0, 10), E(2, 0, 10), E (3, 0, 10));
+  frag1->AddMtag (ATestTag<4> ());
+  CHECK (frag1, 3, E (1, 0, 90), E(2, 0, 90), E (4, 0, 90));
+  frag2->AddMtag (ATestTag<5> ());
+  CHECK (frag2, 3, E (1, 0, 900), E(2, 0, 900), E (5, 0, 900));
+
+  frag1->AddAtEnd (frag2);
+  CHECK (frag1, 6, E (1, 0, 90), E(2, 0, 90), E (4, 0, 90), E (1, 90, 990), E(2, 90, 990), E (5, 90, 990));
+
+  CHECK (frag0, 3, E (1, 0, 10), E(2, 0, 10), E (3, 0, 10));
+  frag0->AddAtEnd (frag1);
+  CHECK (frag0, 9, 
+         E (1, 0, 10), E(2, 0, 10), E (3, 0, 10),
+         E (1, 10, 100), E(2, 10, 100), E (4, 10, 100), 
+         E (1, 100, 1000), E(2, 100, 1000), E (5, 100, 1000));
+
+  return result;
 }
 
 
-static PacketTest gPacketTest;
+static PacketTest g_packetTest;
 
 }; // namespace ns3