From d8b8156b1e56d8b7538f04acf521a104ae62eb50 Mon Sep 17 00:00:00 2001
From: Phil <phil@secdev.org>
Date: Thu, 30 Apr 2009 16:08:52 +0200
Subject: [PATCH] Added automata IOevent transition conditions

---
 scapy/automaton.py | 119 ++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 102 insertions(+), 17 deletions(-)

diff --git a/scapy/automaton.py b/scapy/automaton.py
index 338a1276..10839ef7 100644
--- a/scapy/automaton.py
+++ b/scapy/automaton.py
@@ -3,14 +3,30 @@
 ## Copyright (C) Philippe Biondi <phil@secdev.org>
 ## This program is published under a GPLv2 license
 
-import types,itertools,time
+import types,itertools,time,os
 from select import select
+from collections import deque
+import thread
 from config import conf
 from utils import do_graph
 from error import log_interactive
 from plist import PacketList
 from data import MTU
 
+class ObjectPipe:
+    def __init__(self):
+        self.rd,self.wr = os.pipe()
+        self.queue = deque()
+    def fileno(self):
+        return self.rd
+    def send(self, obj):
+        self.queue.append(obj)
+        os.write(self.wr,"X")
+    def recv(self, n=0):
+        os.read(self.rd,1)
+        return self.queue.popleft()
+
+
 ##############
 ## Automata ##
 ##############
@@ -21,6 +37,7 @@ class ATMT:
     CONDITION = "Condition"
     RECV = "Receive condition"
     TIMEOUT = "Timeout condition"
+    IOEVENT = "I/O event"
 
     class NewStateRequested(Exception):
         def __init__(self, state_func, automaton, *args, **kargs):
@@ -89,6 +106,16 @@ class ATMT:
             return f
         return deco
     @staticmethod
+    def ioevent(state, name, prio=0):
+        def deco(f, state=state):
+            f.atmt_type = ATMT.IOEVENT
+            f.atmt_state = state.atmt_state
+            f.atmt_condname = f.func_name
+            f.atmt_ioname = name
+            f.atmt_prio = prio
+            return f
+        return deco
+    @staticmethod
     def timeout(state, timeout):
         def deco(f, state=state, timeout=timeout):
             f.atmt_type = ATMT.TIMEOUT
@@ -106,9 +133,11 @@ class Automaton_metaclass(type):
         cls.state = None
         cls.recv_conditions={}
         cls.conditions={}
+        cls.ioevents={}
         cls.timeout={}
         cls.actions={}
         cls.initial_states=[]
+        cls.ionames = []
 
         members = {}
         classes = [cls]
@@ -127,11 +156,12 @@ class Automaton_metaclass(type):
                 s = m.atmt_state
                 cls.states[s] = m
                 cls.recv_conditions[s]=[]
+                cls.ioevents[s]=[]
                 cls.conditions[s]=[]
                 cls.timeout[s]=[]
                 if m.atmt_initial:
                     cls.initial_states.append(m)
-            elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT]:
+            elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]:
                 cls.actions[m.atmt_condname] = []
     
         for m in decorated:
@@ -139,6 +169,9 @@ class Automaton_metaclass(type):
                 cls.conditions[m.atmt_state].append(m)
             elif m.atmt_type == ATMT.RECV:
                 cls.recv_conditions[m.atmt_state].append(m)
+            elif m.atmt_type == ATMT.IOEVENT:
+                cls.ioevents[m.atmt_state].append(m)
+                cls.ionames.append(m.atmt_ioname)
             elif m.atmt_type == ATMT.TIMEOUT:
                 cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
             elif m.atmt_type == ATMT.ACTION:
@@ -150,7 +183,8 @@ class Automaton_metaclass(type):
             v.sort(lambda (t1,f1),(t2,f2): cmp(t1,t2))
             v.append((None, None))
         for v in itertools.chain(cls.conditions.itervalues(),
-                                 cls.recv_conditions.itervalues()):
+                                 cls.recv_conditions.itervalues(),
+                                 cls.ioevents.itervalues()):
             v.sort(lambda c1,c2: cmp(c1.atmt_prio,c2.atmt_prio))
         for condname,actlst in cls.actions.iteritems():
             actlst.sort(lambda c1,c2: cmp(c1.atmt_cond[condname], c2.atmt_cond[condname]))
@@ -177,7 +211,9 @@ class Automaton_metaclass(type):
                     s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n)
             
 
-        for c,k,v in [("purple",k,v) for k,v in self.conditions.items()]+[("red",k,v) for k,v in self.recv_conditions.items()]:
+        for c,k,v in ([("purple",k,v) for k,v in self.conditions.items()]+
+                      [("red",k,v) for k,v in self.recv_conditions.items()]+
+                      [("orange",k,v) for k,v in self.ioevents.items()]):
             for f in v:
                 for n in f.func_code.co_names+f.func_code.co_consts:
                     if n in self.states:
@@ -203,6 +239,28 @@ class Automaton_metaclass(type):
 class Automaton:
     __metaclass__ = Automaton_metaclass
 
+    class _IO:
+        pass
+
+    class _IO_wrapper:
+        def __init__(self,rd,wr):
+            self.rd = rd
+            self.wr = wr
+        def fileno(self):
+            if type(self.rd) is int:
+                return self.rd
+            return self.rd.fileno()
+        def recv(self, n=None):
+            return self.rd.recv(n)
+        def read(self, n=None):
+            return self.rd.recv(n)        
+        def send(self, msg):
+            return self.wr.send(msg)
+        def write(self, msg):
+            return self.wr.send(msg)
+
+            
+    
     def __init__(self, *args, **kargs):
         self.running = False
         self.breakpointed = None
@@ -210,6 +268,18 @@ class Automaton:
         self.debug_level=0
         self.init_args=args
         self.init_kargs=kargs
+        self.io = self._IO()
+        self.oi = self._IO()
+        self.ioin = {}
+        self.ioout = {}
+        for n in self.ionames:
+            self.ioin[n] = ioin = ObjectPipe()
+            self.ioout[n] = ioout = ObjectPipe()
+            ioin.ioname = n
+            ioout.ioname = n
+            setattr(self.io, n, self._IO_wrapper(ioout,ioin))
+            setattr(self.oi, n, self._IO_wrapper(ioin,ioout))
+        
         self.parse_args(*args, **kargs)
 
     def debug(self, lvl, msg):
@@ -313,8 +383,9 @@ class Automaton:
                 self.run_condition(cond, *state_output)
 
             # If still there and no conditions left, we are stuck!
-            if ( len(self.recv_conditions[self.state.state]) == 0
-                 and len(self.timeout[self.state.state]) == 1 ):
+            if ( len(self.recv_conditions[self.state.state]) == 0 and
+                 len(self.ioevents[self.state.state]) == 0 and
+                 len(self.timeout[self.state.state]) == 1 ):
                 raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output)
 
             # Finally listen and pay attention to timeouts
@@ -322,6 +393,11 @@ class Automaton:
             next_timeout,timeout_func = expirations.next()
             t0 = time.time()
             
+            fds = []
+            if len(self.recv_conditions[self.state.state]) > 0:
+                fds.append(self.listen_sock)
+            for ioev in self.ioevents[self.state.state]:
+                fds.append(self.ioin[ioev.atmt_ioname])
             while 1:
                 t = time.time()-t0
                 if next_timeout is not None:
@@ -333,16 +409,22 @@ class Automaton:
                 else:
                     remain = next_timeout-t
 
-                r,_,_ = select([self.listen_sock],[],[],remain)
-                if self.listen_sock in r:
-                    pkt = self.listen_sock.recv(MTU)
-                    if pkt is not None:
-                        if self.master_filter(pkt):
-                            self.debug(3, "RECVD: %s" % pkt.summary())
-                            for rcvcond in self.recv_conditions[self.state.state]:
-                                self.run_condition(rcvcond, pkt, *state_output)
-                        else:
-                            self.debug(4, "FILTR: %s" % pkt.summary())
+                r,_,_ = select(fds,[],[],remain)
+                for fd in r:
+                    if fd == self.listen_sock:
+                        pkt = self.listen_sock.recv(MTU)
+                        if pkt is not None:
+                            if self.master_filter(pkt):
+                                self.debug(3, "RECVD: %s" % pkt.summary())
+                                for rcvcond in self.recv_conditions[self.state.state]:
+                                    self.run_condition(rcvcond, pkt, *state_output)
+                            else:
+                                self.debug(4, "FILTR: %s" % pkt.summary())
+                    else:
+                        self.debug(3, "IOEVENT on %s" % fd.ioname)
+                        for ioevt in self.ioevents[self.state.state]:
+                            if ioevt.atmt_ioname == fd.ioname:
+                                self.run_condition(ioevt, fd, *state_output)
 
         except ATMT.NewStateRequested,state_req:
             self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
@@ -365,7 +447,10 @@ class Automaton:
                 return e.args[0]
 
     cont = run
-    
+
+    def run_bg(self, *args, **kargs):
+        self.threadid = thread.start_new_thread(self.run, args, kargs)
+        
     def __iter__(self):
         if not self.running:
             self.start()
-- 
GitLab