From 33c3b1a112faf0f223cf0d983aac85b95ff2c248 Mon Sep 17 00:00:00 2001
From: Phil <phil@secdev.org>
Date: Mon, 28 Jul 2008 16:13:00 +0200
Subject: [PATCH] Moved base classes into base_classes.py to avoid some cycly
 dependency problems

---
 scapy/all.py          |   1 +
 scapy/asn1fields.py   |  18 +++--
 scapy/base_classes.py | 168 ++++++++++++++++++++++++++++++++++++++++++
 scapy/config.py       |   5 +-
 scapy/fields.py       |  13 ++--
 scapy/layers/dhcp.py  |   1 +
 scapy/packet.py       | 155 +-------------------------------------
 scapy/plist.py        |   5 +-
 8 files changed, 196 insertions(+), 170 deletions(-)
 create mode 100644 scapy/base_classes.py

diff --git a/scapy/all.py b/scapy/all.py
index e78256ae..f9268d01 100644
--- a/scapy/all.py
+++ b/scapy/all.py
@@ -1,4 +1,5 @@
 
+from base_classes import *
 from config import *
 from dadict import *
 from data import *
diff --git a/scapy/asn1fields.py b/scapy/asn1fields.py
index 01d16879..67ce760b 100644
--- a/scapy/asn1fields.py
+++ b/scapy/asn1fields.py
@@ -237,12 +237,12 @@ class ASN1F_SEQUENCE_OF(ASN1F_SEQUENCE):
             try:
                 p = self.asn1pkt(s1)
             except ASN1F_badsequence,e:
-                lst.append(Raw(s1))
+                lst.append(packet.Raw(s1))
                 break
             lst.append(p)
-            if Raw in p:
-                s1 = p[Raw].load
-                del(p[Raw].underlayer.payload)
+            if packet.Raw in p:
+                s1 = p[packet.Raw].load
+                del(p[packet.Raw].underlayer.payload)
             else:
                 break
         self.set_val(pkt, lst)
@@ -265,8 +265,8 @@ class ASN1F_PACKET(ASN1F_field):
         try:
             c = cls(x)
         except ASN1F_badsequence:
-            c = Raw(x)
-        cpad = c[Padding]
+            c = packet.Raw(x)
+        cpad = c[packet.Padding]
         x = ""
         if cpad is not None:
             x = cpad.load
@@ -287,10 +287,10 @@ class ASN1F_CHOICE(ASN1F_PACKET):
         self.default=default
     def m2i(self, pkt, x):
         if len(x) == 0:
-            return Raw(),""
+            return packet.Raw(),""
             raise ASN1_Error("ASN1F_CHOICE: got empty string")
         if ord(x[0]) not in self.choice:
-            return Raw(x),"" # XXX return RawASN1 packet ? Raise error 
+            return packet.Raw(x),"" # XXX return RawASN1 packet ? Raise error 
             raise ASN1_Error("Decoding Error: choice [%i] not found in %r" % (ord(x[0]), self.choice.keys()))
 
         z = ASN1F_PACKET.extract_packet(self, self.choice[ord(x[0])], x)
@@ -299,3 +299,5 @@ class ASN1F_CHOICE(ASN1F_PACKET):
         return RandChoice(*map(lambda x:fuzz(x()), self.choice.values()))
             
     
+# This import must come in last to avoid problems with cyclic dependencies
+import packet
diff --git a/scapy/base_classes.py b/scapy/base_classes.py
new file mode 100644
index 00000000..bf151bd5
--- /dev/null
+++ b/scapy/base_classes.py
@@ -0,0 +1,168 @@
+###############
+## Generators ##
+################
+
+import re,random,socket
+import config
+import error
+
+class Gen(object):
+    def __iter__(self):
+        return iter([])
+    
+class SetGen(Gen):
+    def __init__(self, set, _iterpacket=1):
+        self._iterpacket=_iterpacket
+        if type(set) is list:
+            self.set = set
+        elif isinstance(set, BasePacketList):
+            self.set = list(set)
+        else:
+            self.set = [set]
+    def transf(self, element):
+        return element
+    def __iter__(self):
+        for i in self.set:
+            if (type(i) is tuple) and (len(i) == 2) and type(i[0]) is int and type(i[1]) is int:
+                if  (i[0] <= i[1]):
+                    j=i[0]
+                    while j <= i[1]:
+                        yield j
+                        j += 1
+            elif isinstance(i, Gen) and (self._iterpacket or not isinstance(i,BasePacket)):
+                for j in i:
+                    yield j
+            else:
+                yield i
+    def __repr__(self):
+        return "<SetGen %s>" % self.set.__repr__()
+
+class Net(Gen):
+    """Generate a list of IPs from a network address or a name"""
+    name = "ip"
+    ipaddress = re.compile(r"^(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)(/[0-3]?[0-9])?$")
+    def __init__(self, net):
+        self.repr=net
+
+        tmp=net.split('/')+["32"]
+        if not self.ipaddress.match(net):
+            tmp[0]=socket.gethostbyname(tmp[0])
+        netmask = int(tmp[1])
+
+        def parse_digit(a,netmask):
+            netmask = min(8,max(netmask,0))
+            if a == "*":
+                a = (0,256)
+            elif a.find("-") >= 0:
+                x,y = map(int,a.split("-"))
+                if x > y:
+                    y = x
+                a = (x &  (0xffL<<netmask) , max(y, (x | (0xffL>>(8-netmask))))+1)
+            else:
+                a = (int(a) & (0xffL<<netmask),(int(a) | (0xffL>>(8-netmask)))+1)
+            return a
+
+        self.parsed = map(lambda x,y: parse_digit(x,y), tmp[0].split("."), map(lambda x,nm=netmask: x-nm, (8,16,24,32)))
+                                                                                               
+    def __iter__(self):
+        for d in xrange(*self.parsed[3]):
+            for c in xrange(*self.parsed[2]):
+                for b in xrange(*self.parsed[1]):
+                    for a in xrange(*self.parsed[0]):
+                        yield "%i.%i.%i.%i" % (a,b,c,d)
+    def choice(self):
+        ip = []
+        for v in self.parsed:
+            ip.append(str(random.randint(v[0],v[1]-1)))
+        return ".".join(ip) 
+                          
+    def __repr__(self):
+        return "Net(%r)" % self.repr
+
+class OID(Gen):
+    name = "OID"
+    def __init__(self, oid):
+        self.oid = oid        
+        self.cmpt = []
+        fmt = []        
+        for i in oid.split("."):
+            if "-" in i:
+                fmt.append("%i")
+                self.cmpt.append(tuple(map(int, i.split("-"))))
+            else:
+                fmt.append(i)
+        self.fmt = ".".join(fmt)
+    def __repr__(self):
+        return "OID(%r)" % self.oid
+    def __iter__(self):        
+        ii = [k[0] for k in self.cmpt]
+        while 1:
+            yield self.fmt % tuple(ii)
+            i = 0
+            while 1:
+                if i >= len(ii):
+                    raise StopIteration
+                if ii[i] < self.cmpt[i][1]:
+                    ii[i]+=1
+                    break
+                else:
+                    ii[i] = self.cmpt[i][0]
+                i += 1
+
+
+ 
+######################################
+## Packet abstract and base classes ##
+######################################
+
+class Packet_metaclass(type):
+    def __new__(cls, name, bases, dct):
+        newcls = super(Packet_metaclass, cls).__new__(cls, name, bases, dct)
+        for f in newcls.fields_desc:
+            f.register_owner(newcls)
+        config.conf.layers.register(newcls)
+        return newcls
+    def __getattr__(self, attr):
+        for k in self.fields_desc:
+            if k.name == attr:
+                return k
+        raise AttributeError(attr)
+
+class NewDefaultValues(Packet_metaclass):
+    """NewDefaultValues metaclass. Example usage:
+    class MyPacket(Packet):
+        fields_desc = [ StrField("my_field", "my default value"),  ]
+        
+    class MyPacket_variant(MyPacket):
+        __metaclass__ = NewDefaultValues
+        my_field = "my new default value"
+    """    
+    def __new__(cls, name, bases, dct):
+        fields = None
+        for b in bases:
+            if hasattr(b,"fields_desc"):
+                fields = b.fields_desc
+                break
+        if fields is None:
+            raise error.Scapy_Exception("No fields_desc in superclasses")
+
+        new_fields = []
+        for f in fields:
+            if f.name in dct:
+                f = f.copy()
+                f.default = dct[f.name]
+                del(dct[f.name])
+            new_fields.append(f)
+        dct["fields_desc"] = new_fields
+        return super(NewDefaultValues, cls).__new__(cls, name, bases, dct)
+
+class BasePacket(Gen):
+    pass
+
+
+#############################
+## Packet list base classe ##
+#############################
+
+class BasePacketList:
+    pass
diff --git a/scapy/config.py b/scapy/config.py
index 7e8d0df8..fdf12aa2 100644
--- a/scapy/config.py
+++ b/scapy/config.py
@@ -2,6 +2,7 @@ import os
 from arch import *
 from data import *
 from themes import *
+import base_classes
 
 ############
 ## Config ##
@@ -69,11 +70,11 @@ class Num2Layer:
         self.layer2num[layer] = num
 
     def __getitem__(self, item):
-        if isinstance(item, Packet_metaclass):
+        if isinstance(item, base_classes.Packet_metaclass):
             return self.layer2num[item]
         return self.num2layer[item]
     def __contains__(self, item):
-        if isinstance(item, Packet_metaclass):
+        if isinstance(item, base_classes.Packet_metaclass):
             return item in self.layer2num
         return item in self.num2layer
     def get(self, item, default=None):
diff --git a/scapy/fields.py b/scapy/fields.py
index ba7d6074..637e706a 100644
--- a/scapy/fields.py
+++ b/scapy/fields.py
@@ -1,6 +1,7 @@
 import struct,copy,socket
 from data import *
 from utils import *
+from base_classes import BasePacket
 
 ############
 ## Fields ##
@@ -66,7 +67,7 @@ class Field:
         if type(x) is list:
             x = x[:]
             for i in xrange(len(x)):
-                if isinstance(x[i], Packet):
+                if isinstance(x[i], BasePacket):
                     x[i] = x[i].copy()
         return x
     def __repr__(self):
@@ -468,8 +469,8 @@ class PacketField(StrField):
     def getfield(self, pkt, s):
         i = self.m2i(pkt, s)
         remain = ""
-        if i.haslayer(Padding):
-            r = i.getlayer(Padding)
+        if i.haslayer(packet.Padding):
+            r = i.getlayer(packet.Padding)
             del(r.underlayer.payload)
             remain = r.load
         return remain,i
@@ -534,8 +535,8 @@ class PacketListField(PacketField):
                     break
                 c -= 1
             p = self.m2i(pkt,remain)
-            if Padding in p:
-                pad = p[Padding]
+            if packet.Padding in p:
+                pad = p[packet.Padding]
                 remain = pad.load
                 del(pad.underlayer.payload)
             else:
@@ -937,3 +938,5 @@ class FloatField(BitField):
         b = sec+frac
         return s,b    
 
+# This import must come in last to avoid problems with cyclic dependencies
+import packet
diff --git a/scapy/layers/dhcp.py b/scapy/layers/dhcp.py
index 84c12da2..d414e149 100644
--- a/scapy/layers/dhcp.py
+++ b/scapy/layers/dhcp.py
@@ -4,6 +4,7 @@ from scapy.packet import *
 from scapy.fields import *
 from scapy.ansmachine import *
 from scapy.layers.inet import UDP
+from scapy.base_classes import Net
 
 dhcpmagic="c\x82Sc"
 
diff --git a/scapy/packet.py b/scapy/packet.py
index de4407db..821b1b57 100644
--- a/scapy/packet.py
+++ b/scapy/packet.py
@@ -1,161 +1,10 @@
 import re,time,itertools,os,random,socket
 from fields import *
 from config import conf
+from base_classes import BasePacket,Gen,SetGen,Packet_metaclass
 
 
-################
-## Generators ##
-################
-
-class Gen(object):
-    def __iter__(self):
-        return iter([])
-    
-class SetGen(Gen):
-    def __init__(self, set, _iterpacket=1):
-        self._iterpacket=_iterpacket
-        if type(set) is list:
-            self.set = set
-        elif isinstance(set, PacketList):
-            self.set = list(set)
-        else:
-            self.set = [set]
-    def transf(self, element):
-        return element
-    def __iter__(self):
-        for i in self.set:
-            if (type(i) is tuple) and (len(i) == 2) and type(i[0]) is int and type(i[1]) is int:
-                if  (i[0] <= i[1]):
-                    j=i[0]
-                    while j <= i[1]:
-                        yield j
-                        j += 1
-            elif isinstance(i, Gen) and (self._iterpacket or not isinstance(i,Packet)):
-                for j in i:
-                    yield j
-            else:
-                yield i
-    def __repr__(self):
-        return "<SetGen %s>" % self.set.__repr__()
-
-class Net(Gen):
-    """Generate a list of IPs from a network address or a name"""
-    name = "ip"
-    ipaddress = re.compile(r"^(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)(/[0-3]?[0-9])?$")
-    def __init__(self, net):
-        self.repr=net
-
-        tmp=net.split('/')+["32"]
-        if not self.ipaddress.match(net):
-            tmp[0]=socket.gethostbyname(tmp[0])
-        netmask = int(tmp[1])
-
-        def parse_digit(a,netmask):
-            netmask = min(8,max(netmask,0))
-            if a == "*":
-                a = (0,256)
-            elif a.find("-") >= 0:
-                x,y = map(int,a.split("-"))
-                if x > y:
-                    y = x
-                a = (x &  (0xffL<<netmask) , max(y, (x | (0xffL>>(8-netmask))))+1)
-            else:
-                a = (int(a) & (0xffL<<netmask),(int(a) | (0xffL>>(8-netmask)))+1)
-            return a
-
-        self.parsed = map(lambda x,y: parse_digit(x,y), tmp[0].split("."), map(lambda x,nm=netmask: x-nm, (8,16,24,32)))
-                                                                                               
-    def __iter__(self):
-        for d in xrange(*self.parsed[3]):
-            for c in xrange(*self.parsed[2]):
-                for b in xrange(*self.parsed[1]):
-                    for a in xrange(*self.parsed[0]):
-                        yield "%i.%i.%i.%i" % (a,b,c,d)
-    def choice(self):
-        ip = []
-        for v in self.parsed:
-            ip.append(str(random.randint(v[0],v[1]-1)))
-        return ".".join(ip) 
-                          
-    def __repr__(self):
-        return "Net(%r)" % self.repr
-
-class OID(Gen):
-    name = "OID"
-    def __init__(self, oid):
-        self.oid = oid        
-        self.cmpt = []
-        fmt = []        
-        for i in oid.split("."):
-            if "-" in i:
-                fmt.append("%i")
-                self.cmpt.append(tuple(map(int, i.split("-"))))
-            else:
-                fmt.append(i)
-        self.fmt = ".".join(fmt)
-    def __repr__(self):
-        return "OID(%r)" % self.oid
-    def __iter__(self):        
-        ii = [k[0] for k in self.cmpt]
-        while 1:
-            yield self.fmt % tuple(ii)
-            i = 0
-            while 1:
-                if i >= len(ii):
-                    raise StopIteration
-                if ii[i] < self.cmpt[i][1]:
-                    ii[i]+=1
-                    break
-                else:
-                    ii[i] = self.cmpt[i][0]
-                i += 1
- 
-###########################
-## Packet abstract class ##
-###########################
-
-class Packet_metaclass(type):
-    def __new__(cls, name, bases, dct):
-        newcls = super(Packet_metaclass, cls).__new__(cls, name, bases, dct)
-        for f in newcls.fields_desc:
-            f.register_owner(newcls)
-        conf.layers.register(newcls)
-        return newcls
-    def __getattr__(self, attr):
-        for k in self.fields_desc:
-            if k.name == attr:
-                return k
-        raise AttributeError(attr)
-
-class NewDefaultValues(Packet_metaclass):
-    """NewDefaultValues metaclass. Example usage:
-    class MyPacket(Packet):
-        fields_desc = [ StrField("my_field", "my default value"),  ]
-        
-    class MyPacket_variant(MyPacket):
-        __metaclass__ = NewDefaultValues
-        my_field = "my new default value"
-    """    
-    def __new__(cls, name, bases, dct):
-        fields = None
-        for b in bases:
-            if hasattr(b,"fields_desc"):
-                fields = b.fields_desc
-                break
-        if fields is None:
-            raise Scapy_Exception("No fields_desc in superclasses")
-
-        new_fields = []
-        for f in fields:
-            if f.name in dct:
-                f = f.copy()
-                f.default = dct[f.name]
-                del(dct[f.name])
-            new_fields.append(f)
-        dct["fields_desc"] = new_fields
-        return super(NewDefaultValues, cls).__new__(cls, name, bases, dct)
-
-class Packet(Gen):
+class Packet(BasePacket):
     __metaclass__ = Packet_metaclass
     name=None
 
diff --git a/scapy/plist.py b/scapy/plist.py
index 7560b095..6cd7b976 100644
--- a/scapy/plist.py
+++ b/scapy/plist.py
@@ -1,11 +1,12 @@
 import os,socket
 from config import conf
+from base_classes import BasePacket,BasePacketList
 
 #############
 ## Results ##
 #############
 
-class PacketList:
+class PacketList(BasePacketList):
     res = []
     def __init__(self, res=None, name="PacketList", stats=None):
         """create a packet list from a list of packets
@@ -56,7 +57,7 @@ class PacketList:
     def __getattr__(self, attr):
         return getattr(self.res, attr)
     def __getitem__(self, item):
-        if isinstance(item,type) and issubclass(item,Packet):
+        if isinstance(item,type) and issubclass(item,BasePacket):
             return self.__class__(filter(lambda x: item in self._elt2pkt(x),self.res),
                                   name="%s from %s"%(item.__name__,self.listname))
         if type(item) is slice:
-- 
GitLab