From 8ac38c611c2a6d2eedef2e568ee5eaee9b4bf1c0 Mon Sep 17 00:00:00 2001
From: Pierre LALET <pierre.lalet@cea.fr>
Date: Mon, 25 Sep 2017 11:07:38 +0200
Subject: [PATCH] Allow filters in the Packet slice / .getlayer() API

With a packet `pkt = IP() / IP(ttl=3) / IP()`, this will allow:

  - pkt.getlayer(IP, ttl=3)
  - pkt[IP, {"ttl": 3}]

This is particularly useful with Dot11Elt() layers, when you want a
specific value (the SSID or supported rates for example).
---
 scapy/contrib/bgp.py   | 13 ++++++-------
 scapy/layers/eap.py    | 13 ++++++-------
 scapy/layers/ntp.py    | 13 ++++++-------
 scapy/layers/radius.py | 14 ++++++--------
 scapy/packet.py        | 34 +++++++++++++++++++---------------
 test/regression.uts    |  5 +++++
 6 files changed, 48 insertions(+), 44 deletions(-)

diff --git a/scapy/contrib/bgp.py b/scapy/contrib/bgp.py
index a0a5fdad..48fcace7 100644
--- a/scapy/contrib/bgp.py
+++ b/scapy/contrib/bgp.py
@@ -619,16 +619,15 @@ class BGPCapability(six.with_metaclass(_BGPCapability_metaclass, Packet)):
             ret = 1
         return ret
 
-    def getlayer(self, cls, nb=1, _track=None):
-        layer = None
+    def getlayer(self, cls, nb=1, _track=None, **flt):
         if cls == BGPCapability:
             for cap_class in _capabilities_registry:
-                if isinstance(self, _capabilities_registry[cap_class]):
-                    layer = self
-                    break
+                if isinstance(self, _capabilities_registry[cap_class]) and \
+                   all(self.getfieldval(fldname) == fldvalue
+                       for fldname, fldvalue in flt.iteritems()):
+                    return self
         else:
-            layer = Packet.getlayer(self, cls, nb, _track)
-        return layer
+            return Packet.getlayer(self, cls, nb, _track, **flt)
 
     def post_build(self, p, pay):
         length = 0
diff --git a/scapy/layers/eap.py b/scapy/layers/eap.py
index b441ff44..d86b78a1 100644
--- a/scapy/layers/eap.py
+++ b/scapy/layers/eap.py
@@ -248,16 +248,15 @@ class EAP(Packet):
             ret = 1
         return ret
 
-    def getlayer(self, cls, nb=1, _track=None):
-        layer = None
+    def getlayer(self, cls, nb=1, _track=None, **flt):
         if cls == EAP:
             for eap_class in EAP.registered_methods.values():
-                if isinstance(self, eap_class):
-                    layer = self
-                    break
+                if isinstance(self, eap_class) and \
+                   all(self.getfieldval(fldname) == fldvalue
+                       for fldname, fldvalue in flt.iteritems()):
+                    return self
         else:
-            layer = Packet.getlayer(self, cls, nb, _track)
-        return layer
+            return Packet.getlayer(self, cls, nb, _track, **flt)
 
     def answers(self, other):
         if isinstance(other, EAP):
diff --git a/scapy/layers/ntp.py b/scapy/layers/ntp.py
index 2e75067e..a9512d5f 100644
--- a/scapy/layers/ntp.py
+++ b/scapy/layers/ntp.py
@@ -255,21 +255,20 @@ class NTP(Packet):
             ret = 1
         return ret
 
-    def getlayer(self, cls, nb=1, _track=None):
+    def getlayer(self, cls, nb=1, _track=None, **flt):
         ntp_classes = [
             get_cls("NTPHeader"),
             get_cls("NTPControl"),
             get_cls("NTPPrivate")
         ]
-        layer = None
         if cls == NTP:
             for ntp_class in ntp_classes:
-                if isinstance(self, ntp_class):
-                    layer = self
-                    break
+                if isinstance(self, ntp_class) and \
+                   all(self.getfieldval(fldname) == fldvalue
+                       for fldname, fldvalue in flt.iteritems()):
+                    return self
         else:
-            layer = Packet.getlayer(self, cls, nb, _track)
-        return layer
+            return Packet.getlayer(self, cls, nb, _track, **flt)
 
     def mysummary(self):
         return self.sprintf("NTP v%ir,NTP.version%, %NTP.mode%")
diff --git a/scapy/layers/radius.py b/scapy/layers/radius.py
index 9b7507fb..118233c3 100644
--- a/scapy/layers/radius.py
+++ b/scapy/layers/radius.py
@@ -278,17 +278,15 @@ class RadiusAttribute(Packet):
             return True
         return False
 
-    def getlayer(self, cls, nb=1, _track=None):
-        layer = None
+    def getlayer(self, cls, nb=1, _track=None, **flt):
         if cls == RadiusAttribute:
             for attr_class in RadiusAttribute.registered_attributes.values():
-                if isinstance(self, attr_class):
-                    layer = self
-                    break
+                if isinstance(self, attr_class) and \
+                   all(self.getfieldval(fldname) == fldvalue
+                       for fldname, fldvalue in flt.iteritems()):
+                    return self
         else:
-            layer = Packet.getlayer(self, cls, nb, _track)
-        return layer
-
+            return Packet.getlayer(self, cls, nb, _track, **flt)
 
     def post_build(self, p, pay):
         length = self.len
diff --git a/scapy/packet.py b/scapy/packet.py
index 326f15b1..50fe67a4 100644
--- a/scapy/packet.py
+++ b/scapy/packet.py
@@ -888,8 +888,12 @@ class Packet(six.with_metaclass(Packet_metaclass, BasePacket)):
                     if ret:
                         return ret
         return self.payload.haslayer(cls)
-    def getlayer(self, cls, nb=1, _track=None):
-        """Return the nb^th layer that is an instance of cls."""
+
+    def getlayer(self, cls, nb=1, _track=None, **flt):
+        """Return the nb^th layer that is an instance of cls, matching flt
+values.
+
+        """
         if isinstance(cls, int):
             nb = cls+1
             cls = None
@@ -898,13 +902,15 @@ class Packet(six.with_metaclass(Packet_metaclass, BasePacket)):
         else:
             ccls,fld = cls,None
         if cls is None or self.__class__ == cls or self.__class__.__name__ == ccls:
-            if nb == 1:
-                if fld is None:
-                    return self
+            if all(self.getfieldval(fldname) == fldvalue
+                   for fldname, fldvalue in flt.iteritems()):
+                if nb == 1:
+                    if fld is None:
+                        return self
+                    else:
+                        return self.getfieldval(fld)
                 else:
-                    return self.getfieldval(fld)
-            else:
-                nb -=1
+                    nb -=1
         for f in self.packetfields:
             fvalue_gen = self.getfieldval(f.name)
             if fvalue_gen is None:
@@ -918,7 +924,7 @@ class Packet(six.with_metaclass(Packet_metaclass, BasePacket)):
                     if ret is not None:
                         return ret
                     nb = track[0]
-        return self.payload.getlayer(cls,nb,_track=_track)
+        return self.payload.getlayer(cls, nb=nb, _track=_track, **flt)
 
     def firstlayer(self):
         q = self
@@ -930,13 +936,11 @@ class Packet(six.with_metaclass(Packet_metaclass, BasePacket)):
         if isinstance(cls, slice):
             lname = cls.start
             if cls.stop:
-                ret = self.getlayer(cls.start, cls.stop)
+                ret = self.getlayer(cls.start, nb=cls.stop, **(cls.step or {}))
             else:
-                ret = self.getlayer(cls.start)
-            if ret is None and cls.step is not None:
-                ret = cls.step
+                ret = self.getlayer(cls.start, **(cls.step or {}))
         else:
-            lname=cls
+            lname = cls
             ret = self.getlayer(cls)
         if ret is None:
             if isinstance(lname, Packet_metaclass):
@@ -1287,7 +1291,7 @@ class NoPayload(Packet):
         return isinstance(other, NoPayload) or isinstance(other, conf.padding_layer)
     def haslayer(self, cls):
         return 0
-    def getlayer(self, cls, nb=1, _track=None):
+    def getlayer(self, cls, nb=1, _track=None, **flt):
         if _track is not None:
             _track.append(nb)
         return None
diff --git a/test/regression.uts b/test/regression.uts
index 43e5543a..dff21aa8 100644
--- a/test/regression.uts
+++ b/test/regression.uts
@@ -8645,6 +8645,11 @@ Dot11(type=0, subtype=1).answers(query) == True
 assert Dot11Elt(info="scapy").summary() == "SSID='scapy'"
 assert Dot11Elt(ID=1).mysummary() == ""
 
+= Multiple Dot11Elt layers
+pkt = Dot11() / Dot11Beacon() / Dot11Elt(ID="Rates") / Dot11Elt(ID="SSID", info="Scapy")
+assert pkt[Dot11Elt::{"ID": 0}].info == "Scapy"
+assert pkt.getlayer(Dot11Elt, ID=0).info == "Scapy"
+
 = Dot11WEP - build
 ~ crypto
 conf.wepkey = ""
-- 
GitLab