From 12c75d4bb983175fd66c389dad096141978a1366 Mon Sep 17 00:00:00 2001
From: Pierre LALET <pierre.lalet@cea.fr>
Date: Fri, 29 Sep 2017 14:27:51 +0200
Subject: [PATCH] Clean-up EAP, RadiusAttribute & BGPCapability
 .{get,has}layer()

---
 scapy/contrib/bgp.py   | 35 ++++++++++++-----------------------
 scapy/contrib/bgp.uts  |  6 ++++--
 scapy/layers/eap.py    | 30 +++++++++++-------------------
 scapy/layers/radius.py | 30 +++++++++++-------------------
 scapy/packet.py        | 14 ++++++++++----
 5 files changed, 48 insertions(+), 67 deletions(-)

diff --git a/scapy/contrib/bgp.py b/scapy/contrib/bgp.py
index 1f8a9f21..87ea38f7 100644
--- a/scapy/contrib/bgp.py
+++ b/scapy/contrib/bgp.py
@@ -607,29 +607,18 @@ class BGPCapability(six.with_metaclass(_BGPCapability_metaclass, Packet)):
 
     # Every BGP capability object inherits from BGPCapability.
     def haslayer(self, cls):
-        ret = 0
-        if cls == BGPCapability:
-            # If cls is BGPCap (the parent class), check that the object is an
-            # instance of an existing BGP capability class.
-            for cap_class in _capabilities_registry:
-                if isinstance(self, _capabilities_registry[cap_class]):
-                    ret = 1
-                    break
-        elif cls in _capabilities_registry and isinstance(self, cls):
-            ret = 1
-        return ret
-
-    def getlayer(self, cls, nb=1, _track=None, **flt):
-        if cls == BGPCapability:
-            for cap_class in _capabilities_registry:
-                if isinstance(self, _capabilities_registry[cap_class]) and \
-                   all(self.getfieldval(fldname) == fldvalue
-                       for fldname, fldvalue in flt.iteritems()):
-                    return self
-        else:
-            return super(BGPCapability, self).getlayer(
-                cls, nb=nb, _track=_track, **flt
-            )
+        if cls == "BGPCapability":
+            if isinstance(self, BGPCapability):
+                return True
+        if issubclass(cls, BGPCapability):
+            if isinstance(self, cls):
+                return True
+        return super(BGPCapability, self).haslayer(cls)
+
+    def getlayer(self, cls, nb=1, _track=None, _subclass=True, **flt):
+        return super(BGPCapability, self).getlayer(
+            cls, nb=nb, _track=_track, _subclass=True, **flt
+        )
 
     def post_build(self, p, pay):
         length = 0
diff --git a/scapy/contrib/bgp.uts b/scapy/contrib/bgp.uts
index 521abe07..55eb506b 100644
--- a/scapy/contrib/bgp.uts
+++ b/scapy/contrib/bgp.uts
@@ -119,10 +119,12 @@ except _BGPInvalidDataException:
   True
 
 = BGPCapability - Test haslayer()
-BGPCapFourBytesASN().haslayer(BGPCapability) == True
+assert BGPCapFourBytesASN().haslayer(BGPCapability)
+assert BGPCapability in BGPCapFourBytesASN()
 
 = BGPCapability - Test getlayer()
-isinstance(BGPCapFourBytesASN().getlayer(BGPCapability), BGPCapFourBytesASN)
+assert isinstance(BGPCapFourBytesASN().getlayer(BGPCapability), BGPCapFourBytesASN)
+assert isinstance(BGPCapFourBytesASN()[BGPCapability], BGPCapFourBytesASN)
 
 
 ############################ BGPCapMultiprotocol ##############################
diff --git a/scapy/layers/eap.py b/scapy/layers/eap.py
index 0213d2ff..f5b49764 100644
--- a/scapy/layers/eap.py
+++ b/scapy/layers/eap.py
@@ -238,25 +238,17 @@ class EAP(Packet):
         return cls
 
     def haslayer(self, cls):
-        ret = 0
-        if cls == EAP:
-            for eap_class in EAP.registered_methods.values():
-                if isinstance(self, eap_class):
-                    ret = 1
-                    break
-        elif cls in list(EAP.registered_methods.values()) and isinstance(self, cls):
-            ret = 1
-        return ret
-
-    def getlayer(self, cls, nb=1, _track=None, **flt):
-        if cls == EAP:
-            for eap_class in EAP.registered_methods.values():
-                if isinstance(self, eap_class) and \
-                   all(self.getfieldval(fldname) == fldvalue
-                       for fldname, fldvalue in flt.iteritems()):
-                    return self
-        else:
-            return super(EAP, self).getlayer(cls, nb=nb, _track=_track, **flt)
+        if cls == "EAP":
+            if isinstance(self, EAP):
+                return True
+        elif issubclass(cls, EAP):
+            if isinstance(self, cls):
+                return True
+        return super(EAP, self).haslayer(cls)
+
+    def getlayer(self, cls, nb=1, _track=None, _subclass=True, **flt):
+        return super(EAP, self).getlayer(cls, nb=nb, _track=_track,
+                                         _subclass=True, **flt)
 
     def answers(self, other):
         if isinstance(other, EAP):
diff --git a/scapy/layers/radius.py b/scapy/layers/radius.py
index 7fa569d5..ada1cb42 100644
--- a/scapy/layers/radius.py
+++ b/scapy/layers/radius.py
@@ -270,25 +270,17 @@ class RadiusAttribute(Packet):
         return cls
 
     def haslayer(self, cls):
-        if cls == RadiusAttribute:
-            for attr_class in RadiusAttribute.registered_attributes.values():
-                if isinstance(self, attr_class):
-                    return True
-        elif cls in RadiusAttribute.registered_attributes.values() and isinstance(self, cls):
-            return True
-        return False
-
-    def getlayer(self, cls, nb=1, _track=None, **flt):
-        if cls == RadiusAttribute:
-            for attr_class in RadiusAttribute.registered_attributes.values():
-                if isinstance(self, attr_class) and \
-                   all(self.getfieldval(fldname) == fldvalue
-                       for fldname, fldvalue in flt.iteritems()):
-                    return self
-        else:
-            return super(RadiusAttribute, self).getlayer(
-                cls, nb=nb, _track=_track, **flt
-            )
+        if cls == "RadiusAttribute":
+            if isinstance(self, RadiusAttribute):
+                return True
+        elif issubclass(cls, RadiusAttribute):
+            if isinstance(self, cls):
+                return True
+        return super(RadiusAttribute, self).haslayer(cls)
+
+    def getlayer(self, cls, nb=1, _track=None, _subclass=True, **flt):
+        return super(RadiusAttribute, self).getlayer(cls, nb=nb, _track=_track,
+                                                     _subclass=True, **flt)
 
     def post_build(self, p, pay):
         length = self.len
diff --git a/scapy/packet.py b/scapy/packet.py
index 50fe67a4..89184d17 100644
--- a/scapy/packet.py
+++ b/scapy/packet.py
@@ -889,11 +889,15 @@ class Packet(six.with_metaclass(Packet_metaclass, BasePacket)):
                         return ret
         return self.payload.haslayer(cls)
 
-    def getlayer(self, cls, nb=1, _track=None, **flt):
+    def getlayer(self, cls, nb=1, _track=None, _subclass=False, **flt):
         """Return the nb^th layer that is an instance of cls, matching flt
 values.
 
         """
+        if _subclass:
+            match = lambda cls1, cls2: issubclass(cls1, cls2)
+        else:
+            match = lambda cls1, cls2: cls1 == cls2
         if isinstance(cls, int):
             nb = cls+1
             cls = None
@@ -901,7 +905,7 @@ values.
             ccls,fld = cls.split(".",1)
         else:
             ccls,fld = cls,None
-        if cls is None or self.__class__ == cls or self.__class__.__name__ == ccls:
+        if cls is None or match(self.__class__, cls) or self.__class__.__name__ == ccls:
             if all(self.getfieldval(fldname) == fldvalue
                    for fldname, fldvalue in flt.iteritems()):
                 if nb == 1:
@@ -920,11 +924,13 @@ values.
             for fvalue in fvalue_gen:
                 if isinstance(fvalue, Packet):
                     track=[]
-                    ret = fvalue.getlayer(cls, nb, _track=track)
+                    ret = fvalue.getlayer(cls, nb=nb, _track=track,
+                                          _subclass=_subclass)
                     if ret is not None:
                         return ret
                     nb = track[0]
-        return self.payload.getlayer(cls, nb=nb, _track=_track, **flt)
+        return self.payload.getlayer(cls, nb=nb, _track=_track,
+                                     _subclass=_subclass, **flt)
 
     def firstlayer(self):
         q = self
-- 
GitLab