From 975dcc5704de5cd26ce8c8c1ab5cddccecf72b24 Mon Sep 17 00:00:00 2001
From: Phil <phil@secdev.org>
Date: Thu, 30 Apr 2009 16:12:41 +0200
Subject: [PATCH] Reorganized Automaton methods

---
 scapy/automaton.py | 173 +++++++++++++++++++++++----------------------
 1 file changed, 88 insertions(+), 85 deletions(-)

diff --git a/scapy/automaton.py b/scapy/automaton.py
index f268f809..b9a41c1b 100644
--- a/scapy/automaton.py
+++ b/scapy/automaton.py
@@ -261,6 +261,20 @@ class Automaton_metaclass(type):
 class Automaton:
     __metaclass__ = Automaton_metaclass
 
+    ## Methods to overload
+    def parse_args(self, debug=0, store=1, **kargs):
+        self.debug_level=debug
+        self.socket_kargs = kargs
+        self.store_packets = store        
+
+    def master_filter(self, pkt):
+        return True
+
+    def my_send(self, pkt):
+        self.send_sock.send(pkt)
+
+
+    ## Utility classes and exceptions
     class _IO_wrapper:
         def __init__(self,rd,wr):
             self.rd = rd
@@ -277,7 +291,30 @@ class Automaton:
             return self.wr.send(msg)
         def write(self, msg):
             return self.wr.send(msg)
-    
+
+    class ErrorState(Exception):
+        def __init__(self, msg, result=None):
+            Exception.__init__(self, msg)
+            self.result = result
+    class Stuck(ErrorState):
+        pass
+
+    class Breakpoint(Exception):
+        def __init__(self, msg, breakpoint):
+            Exception.__init__(self, msg)
+            self.breakpoint = breakpoint
+
+    class CommandMessage(Exception):
+        pass
+
+
+    ## Services
+    def debug(self, lvl, msg):
+        if self.debug_level >= lvl:
+            log_interactive.debug(msg)
+            
+
+    ## Internals
     def __init__(self, *args, **kargs):
         self.running = False
         self.threadid = None
@@ -303,37 +340,6 @@ class Automaton:
         
         self.parse_args(*args, **kargs)
 
-    def debug(self, lvl, msg):
-        if self.debug_level >= lvl:
-            log_interactive.debug(msg)
-            
-
-
-
-    class ErrorState(Exception):
-        def __init__(self, msg, result=None):
-            Exception.__init__(self, msg)
-            self.result = result
-    class Stuck(ErrorState):
-        pass
-
-    class Breakpoint(Exception):
-        def __init__(self, msg, breakpoint):
-            Exception.__init__(self, msg)
-            self.breakpoint = breakpoint
-
-    class CommandMessage(Exception):
-        pass
-
-    def parse_args(self, debug=0, store=1, **kargs):
-        self.debug_level=debug
-        self.socket_kargs = kargs
-        self.store_packets = store
-        
-
-    def master_filter(self, pkt):
-        return True
-
     def run_condition(self, cond, *args, **kargs):
         try:
             cond(self,*args, **kargs)
@@ -347,50 +353,11 @@ class Automaton:
             raise
         else:
             self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))
-            
-
-    def add_interception_points(self, *ipts):
-        for ipt in ipts:
-            if hasattr(ipt,"atmt_state"):
-                ipt = ipt.atmt_state
-            self.interception_points.add(ipt)
-        
-    def remove_interception_points(self, *ipts):
-        for ipt in ipts:
-            if hasattr(ipt,"atmt_state"):
-                ipt = ipt.atmt_state
-            if ipt in self.interception_points:
-                self.interception_points.remove(ipt)
-
-    def add_breakpoints(self, *bps):
-        for bp in bps:
-            if hasattr(bp,"atmt_state"):
-                bp = bp.atmt_state
-            self.breakpoints.add(bp)
 
-    def remove_breakpoints(self, *bps):
-        for bp in bps:
-            if hasattr(bp,"atmt_state"):
-                bp = bp.atmt_state
-            if bp in self.breakpoints:
-                self.breakpoints.remove(pb)
-
-    def start(self, *args, **kargs):
-        self.running = True
-
-        # Update default parameters
-        a = args+self.init_args[len(args):]
-        k = self.init_kargs
-        k.update(kargs)
-        self.parse_args(*a,**k)
-
-        # Start the automaton
-        self.state=self.initial_states[0](self)
-        self.send_sock = conf.L3socket()
-        self.listen_sock = conf.L2listen(**self.socket_kargs)
-        self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
-
-        self.threadid = thread.start_new_thread(self.do_control, ())
+    def __iter__(self):
+        if not self.running:
+            self.start()
+        return self        
 
     def do_control(self):
         singlestep = True
@@ -429,7 +396,6 @@ class Automaton:
         try:
             self.debug(1, "## state=[%s]" % self.state.state)
 
-
             # Entering a new state. First, call new state function
             if self.state.state in self.breakpoints and self.state.state != self.breakpointed: 
                 self.breakpointed = self.state.state
@@ -504,6 +470,52 @@ class Automaton:
             self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
             self.state = state_req
             return state_req
+
+
+    ## Public API
+    def add_interception_points(self, *ipts):
+        for ipt in ipts:
+            if hasattr(ipt,"atmt_state"):
+                ipt = ipt.atmt_state
+            self.interception_points.add(ipt)
+        
+    def remove_interception_points(self, *ipts):
+        for ipt in ipts:
+            if hasattr(ipt,"atmt_state"):
+                ipt = ipt.atmt_state
+            if ipt in self.interception_points:
+                self.interception_points.remove(ipt)
+
+    def add_breakpoints(self, *bps):
+        for bp in bps:
+            if hasattr(bp,"atmt_state"):
+                bp = bp.atmt_state
+            self.breakpoints.add(bp)
+
+    def remove_breakpoints(self, *bps):
+        for bp in bps:
+            if hasattr(bp,"atmt_state"):
+                bp = bp.atmt_state
+            if bp in self.breakpoints:
+                self.breakpoints.remove(pb)
+
+    def start(self, *args, **kargs):
+        self.running = True
+
+        # Update default parameters
+        a = args+self.init_args[len(args):]
+        k = self.init_kargs
+        k.update(kargs)
+        self.parse_args(*a,**k)
+
+        # Start the automaton
+        self.state=self.initial_states[0](self)
+        self.send_sock = conf.L3socket()
+        self.listen_sock = conf.L2listen(**self.socket_kargs)
+        self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
+
+        self.threadid = thread.start_new_thread(self.do_control, ())
+
         
     def run(self, resume=None, wait=True):
         if resume is None:
@@ -529,7 +541,6 @@ class Automaton:
 
     def stop(self):
         self.cmdin.send(Message(type=_ATMT_Command.STOP))
-        
 
     def accept_packet(self, pkt=None):
         rsm = Message()
@@ -544,14 +555,6 @@ class Automaton:
         return self.run(resume=rsm)
 
     
-    def __iter__(self):
-        if not self.running:
-            self.start()
-        return self
-
-    def my_send(self, pkt):
-        self.send_sock.send(pkt)
-
     def send(self, pkt):
         if self.state.state in self.interception_points:
             self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary())
-- 
GitLab