From 4d7fd6a2f131ba7adc972275f2d240c9bc08aacf Mon Sep 17 00:00:00 2001
From: Phil <phil@secdev.org>
Date: Sun, 14 Jun 2009 17:32:56 +0200
Subject: [PATCH] Have IP comparisons with Net() objects work. Also added "xxx
 in Net()" operation

---
 scapy/base_classes.py | 56 ++++++++++++++++++++++++++++++-------------
 1 file changed, 40 insertions(+), 16 deletions(-)

diff --git a/scapy/base_classes.py b/scapy/base_classes.py
index 6d36053c..c2fa4b73 100644
--- a/scapy/base_classes.py
+++ b/scapy/base_classes.py
@@ -46,28 +46,34 @@ class Net(Gen):
     """Generate a list of IPs from a network address or a name"""
     name = "ip"
     ipaddress = re.compile(r"^(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)(/[0-3]?[0-9])?$")
-    def __init__(self, net):
-        self.repr=net
 
+    @staticmethod
+    def _parse_digit(a,netmask):
+        netmask = min(8,max(netmask,0))
+        if a == "*":
+            a = (0,256)
+        elif a.find("-") >= 0:
+            x,y = map(int,a.split("-"))
+            if x > y:
+                y = x
+            a = (x &  (0xffL<<netmask) , max(y, (x | (0xffL>>(8-netmask))))+1)
+        else:
+            a = (int(a) & (0xffL<<netmask),(int(a) | (0xffL>>(8-netmask)))+1)
+        return a
+
+    @classmethod
+    def _parse_net(cls, net):
         tmp=net.split('/')+["32"]
-        if not self.ipaddress.match(net):
+        if not cls.ipaddress.match(net):
             tmp[0]=socket.gethostbyname(tmp[0])
         netmask = int(tmp[1])
+        return map(lambda x,y: cls._parse_digit(x,y), tmp[0].split("."), map(lambda x,nm=netmask: x-nm, (8,16,24,32))),netmask
+
+    def __init__(self, net):
+        self.repr=net
+        self.parsed,self.netmask = self._parse_net(net)
 
-        def parse_digit(a,netmask):
-            netmask = min(8,max(netmask,0))
-            if a == "*":
-                a = (0,256)
-            elif a.find("-") >= 0:
-                x,y = map(int,a.split("-"))
-                if x > y:
-                    y = x
-                a = (x &  (0xffL<<netmask) , max(y, (x | (0xffL>>(8-netmask))))+1)
-            else:
-                a = (int(a) & (0xffL<<netmask),(int(a) | (0xffL>>(8-netmask)))+1)
-            return a
 
-        self.parsed = map(lambda x,y: parse_digit(x,y), tmp[0].split("."), map(lambda x,nm=netmask: x-nm, (8,16,24,32)))
                                                                                                
     def __iter__(self):
         for d in xrange(*self.parsed[3]):
@@ -83,6 +89,24 @@ class Net(Gen):
                           
     def __repr__(self):
         return "Net(%r)" % self.repr
+    def __eq__(self, other):
+        if hasattr(other, "parsed"):
+            p2 = other.parsed
+        else:
+            p2,nm2 = self._parse_net(other)
+        return self.parsed == p2
+    def __contains__(self, other):
+        if hasattr(other, "parsed"):
+            p2 = other.parsed
+        else:
+            p2,nm2 = self._parse_net(other)
+        for (a1,b1),(a2,b2) in zip(self.parsed,p2):
+            if a1 > a2 or b1 < b2:
+                return False
+        return True
+    def __rcontains__(self, other):        
+        return self in self.__class__(other)
+        
 
 class OID(Gen):
     name = "OID"
-- 
GitLab