From f30d24fab028f98d1ba89e8bc5d6b5b86b632de3 Mon Sep 17 00:00:00 2001
From: Florian Maury <florian.maury@ssi.gouv.fr>
Date: Mon, 2 Jan 2017 16:26:56 +0100
Subject: [PATCH] Add next_cls_cb attribute to PacketListField

  - this feature adds the ability to have a PacketListField
    of heterogeneous Packet types with dynamic discovery
    of the next type. This discovery can be based on any elements
    including previously parsed packets, underlayers, remaining
    bytes (look ahead), and last parsed packet.
  - this feature also adds the ability to parse PacketListFields
    where neither the length nor the number of elements can be
    predicted before parsing. This could be done previously using
    a length_from callback that did significant peeks into the
    string to parse, but it felt clumsy.
---
 scapy/fields.py |  20 +++--
 test/fields.uts | 190 ++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 205 insertions(+), 5 deletions(-)

diff --git a/scapy/fields.py b/scapy/fields.py
index fc41b760..a7b511a8 100644
--- a/scapy/fields.py
+++ b/scapy/fields.py
@@ -435,15 +435,15 @@ class PacketLenField(PacketField):
 
 
 class PacketListField(PacketField):
-    __slots__ = ["count_from", "length_from"]
+    __slots__ = ["count_from", "length_from", "next_cls_cb"]
     islist = 1
-    def __init__(self, name, default, cls, count_from=None, length_from=None):
+    def __init__(self, name, default, cls=None, count_from=None, length_from=None, next_cls_cb=None):
         if default is None:
             default = []  # Create a new list for each instance
         PacketField.__init__(self, name, default, cls)
         self.count_from = count_from
         self.length_from = length_from
-
+        self.next_cls_cb = next_cls_cb
 
     def any2i(self, pkt, x):
         if not isinstance(x, list):
@@ -462,11 +462,14 @@ class PacketListField(PacketField):
         else:
             return [p if isinstance(p, bytes) else p.copy() for p in x]
     def getfield(self, pkt, s):
-        c = l = None
+        c = l = cls = None
         if self.length_from is not None:
             l = self.length_from(pkt)
         elif self.count_from is not None:
             c = self.count_from(pkt)
+        if self.next_cls_cb is not None:
+            cls = self.next_cls_cb(pkt, [], None, s)
+            c = 1
 
         lst = []
         ret = b""
@@ -479,7 +482,10 @@ class PacketListField(PacketField):
                     break
                 c -= 1
             try:
-                p = self.m2i(pkt,remain)
+                if cls is not None:
+                    p = cls(remain)
+                else:
+                    p = self.m2i(pkt, remain)
             except Exception:
                 if conf.debug_dissector:
                     raise
@@ -490,6 +496,10 @@ class PacketListField(PacketField):
                     pad = p[conf.padding_layer]
                     remain = pad.load
                     del(pad.underlayer.payload)
+                    if self.next_cls_cb is not None:
+                        cls = self.next_cls_cb(pkt, lst, p, remain)
+                        if cls is not None:
+                            c += 1
                 else:
                     remain = b""
             lst.append(p)
diff --git a/test/fields.uts b/test/fields.uts
index 815d3676..80bc4595 100644
--- a/test/fields.uts
+++ b/test/fields.uts
@@ -328,6 +328,196 @@ assert( str(a) == str(b) )
 assert TCPOptionsField("test", "").getfield(TCP(dataofs=0), "") == ('', [])
 
 
+############
+############
++ PacketListField tests
+
+= Create a layer
+~ field lengthfield
+class TestPLF(Packet):
+    name="test"
+    fields_desc=[ FieldLenField("len", None, count_of="plist"),
+                  PacketListField("plist", None, IP, count_from=lambda pkt:pkt.len) ]
+
+= Test the PacketListField assembly
+~ field lengthfield
+x=TestPLF()
+str(x)
+_ == "\x00\x00"
+
+= Test the PacketListField assembly 2
+~ field lengthfield
+x=TestPLF()
+x.plist=[IP()/TCP(), IP()/UDP()]
+str(x)
+_.startswith('\x00\x02E')
+
+= Test disassembly
+~ field lengthfield
+x=TestPLF(plist=[IP()/TCP(seq=1234567), IP()/UDP()])
+TestPLF(str(x))
+_.show()
+IP in _ and TCP in _ and UDP in _ and _[TCP].seq == 1234567
+
+= Nested PacketListField
+~ field lengthfield
+y=IP()/TCP(seq=111111)/TestPLF(plist=[IP()/TCP(seq=222222),IP()/UDP()])
+TestPLF(plist=[y,IP()/TCP(seq=333333)])
+_.show()
+IP in _ and TCP in _ and UDP in _ and _[TCP].seq == 111111 and _[TCP:2].seq==222222 and _[TCP:3].seq == 333333
+
+= Complex packet
+~ field lengthfield ccc
+class TestPkt(Packet):
+    fields_desc = [ ByteField("f1",65),
+                    ShortField("f2",0x4244) ]
+    def extract_padding(self, p):
+        return "", p
+
+class TestPLF2(Packet):
+    fields_desc = [ FieldLenField("len1", None, count_of="plist",fmt="H", adjust=lambda pkt,x:x+2),
+                    FieldLenField("len2", None, length_of="plist",fmt="I", adjust=lambda pkt,x:(x+1)/2),
+                    PacketListField("plist", None, TestPkt, length_from=lambda x:(x.len2*2)/3*3) ]
+
+a=TestPLF2()
+str(a)
+assert( _ == "\x00\x02\x00\x00\x00\x00" )
+
+a.plist=[TestPkt(),TestPkt(f1=100)] 
+str(a)
+assert(_ == '\x00\x04\x00\x00\x00\x03ABDdBD')
+
+a /= "123456"
+b = TestPLF2(str(a))
+b.show()
+assert(b.len1 == 4 and b.len2 == 3)
+assert(b[TestPkt].f1 == 65 and b[TestPkt].f2 == 0x4244)
+assert(b[TestPkt:2].f1 == 100)
+assert(Raw in b and b[Raw].load == "123456")
+
+a.plist.append(TestPkt(f1=200))
+b = TestPLF2(str(a))
+b.show()
+assert(b.len1 == 5 and b.len2 == 5)
+assert(b[TestPkt].f1 == 65 and b[TestPkt].f2 == 0x4244)
+assert(b[TestPkt:2].f1 == 100)
+assert(b[TestPkt:3].f1 == 200)
+assert(b.getlayer(TestPkt,4) is None)
+assert(Raw in b and b[Raw].load == "123456")
+hexdiff(a,b)
+assert( str(a) == str(b) )
+
+= Create layers for heterogeneous PacketListField
+~ field lengthfield
+TestPLFH1 = type('TestPLFH1', (Packet,), {
+    'name': 'test1',
+    'fields_desc': [ByteField('data', 0)],
+    'guess_payload_class': lambda self, p: conf.padding_layer,
+    }
+)
+TestPLFH2 = type('TestPLFH2', (Packet,), {
+    'name': 'test2',
+    'fields_desc': [ShortField('data', 0)],
+    'guess_payload_class': lambda self, p: conf.padding_layer,
+    }
+)
+class TestPLFH3(Packet):
+    name = 'test3'
+    fields_desc = [
+        PacketListField(
+            'data', [],
+            next_cls_cb=lambda pkt, lst, p, remain: pkt.detect_next_packet(lst, p, remain)
+        )
+    ]
+    def detect_next_packet(self, lst, p, remain):
+        if len(remain) < 3:
+            return None
+        if isinstance(p, type(None)):
+            return TestPLFH1
+        if p.data & 3 == 1:
+            return TestPLFH1
+        if p.data & 3 == 2:
+            return TestPLFH2
+        return None
+
+= Test heterogeneous PacketListField
+~ field lengthfield
+
+p = TestPLFH3('\x02\x01\x01\xc1\x02\x80\x04toto')
+assert(isinstance(p.data[0], TestPLFH1))
+assert(p.data[0].data == 0x2)
+assert(isinstance(p.data[1], TestPLFH2))
+assert(p.data[1].data == 0x101)
+assert(isinstance(p.data[2], TestPLFH1))
+assert(p.data[2].data == 0xc1)
+assert(isinstance(p.data[3], TestPLFH1))
+assert(p.data[3].data == 0x2)
+assert(isinstance(p.data[4], TestPLFH2))
+assert(p.data[4].data == 0x8004)
+assert(isinstance(p.payload, conf.raw_layer))
+assert(p.payload.load == 'toto')
+
+p = TestPLFH3('\x02\x01\x01\xc1\x02\x80\x02to')
+assert(isinstance(p.data[0], TestPLFH1))
+assert(p.data[0].data == 0x2)
+assert(isinstance(p.data[1], TestPLFH2))
+assert(p.data[1].data == 0x101)
+assert(isinstance(p.data[2], TestPLFH1))
+assert(p.data[2].data == 0xc1)
+assert(isinstance(p.data[3], TestPLFH1))
+assert(p.data[3].data == 0x2)
+assert(isinstance(p.data[4], TestPLFH2))
+assert(p.data[4].data == 0x8002)
+assert(isinstance(p.payload, conf.raw_layer))
+assert(p.payload.load == 'to')
+
+= Create layers for heterogeneous PacketListField with memory
+~ field lengthfield
+TestPLFH4 = type('TestPLFH4', (Packet,), {
+    'name': 'test4',
+    'fields_desc': [ByteField('data', 0)],
+    'guess_payload_class': lambda self, p: conf.padding_layer,
+    }
+)
+TestPLFH5 = type('TestPLFH5', (Packet,), {
+    'name': 'test5',
+    'fields_desc': [ShortField('data', 0)],
+    'guess_payload_class': lambda self, p: conf.padding_layer,
+    }
+)
+class TestPLFH6(Packet):
+    __slots__ = ['_memory']
+    name = 'test6'
+    fields_desc = [
+        PacketListField(
+            'data', [],
+            next_cls_cb=lambda pkt, lst, p, remain: pkt.detect_next_packet(lst, p, remain)
+        )
+    ]
+    def detect_next_packet(self, lst, p, remain):
+        if isinstance(p, type(None)):
+            self._memory = [TestPLFH4] * 3 + [TestPLFH5]
+        try:
+            return self._memory.pop(0)
+        except IndexError:
+            return None
+
+= Test heterogeneous PacketListField with memory
+~ field lengthfield
+
+p = TestPLFH6('\x01\x02\x03\xc1\x02toto')
+assert(isinstance(p.data[0], TestPLFH4))
+assert(p.data[0].data == 0x1)
+assert(isinstance(p.data[1], TestPLFH4))
+assert(p.data[1].data == 0x2)
+assert(isinstance(p.data[2], TestPLFH4))
+assert(p.data[2].data == 0x3)
+assert(isinstance(p.data[3], TestPLFH5))
+assert(p.data[3].data == 0xc102)
+assert(isinstance(p.payload, conf.raw_layer))
+assert(p.payload.load == 'toto')
+
+
 ############
 ############
 + Tests on MultiFlagsField
-- 
GitLab