From a49638ead6cbe6ef7829fa3030ed083df9b3136a Mon Sep 17 00:00:00 2001 From: Phil <phil@secdev.org> Date: Thu, 30 Apr 2009 16:13:09 +0200 Subject: [PATCH] Small redesign of automaton control loop and logic --- scapy/automaton.py | 104 +++++++++++++++++++++++---------------------- 1 file changed, 54 insertions(+), 50 deletions(-) diff --git a/scapy/automaton.py b/scapy/automaton.py index b9a41c1b..b5b1cb01 100644 --- a/scapy/automaton.py +++ b/scapy/automaton.py @@ -3,6 +3,7 @@ ## Copyright (C) Philippe Biondi <phil@secdev.org> ## This program is published under a GPLv2 license +from __future__ import with_statement import types,itertools,time,os from select import select from collections import deque @@ -316,7 +317,7 @@ class Automaton: ## Internals def __init__(self, *args, **kargs): - self.running = False + self.running = thread.allocate_lock() self.threadid = None self.breakpointed = None self.breakpoints = set() @@ -340,6 +341,8 @@ class Automaton: self.parse_args(*args, **kargs) + self.start() + def run_condition(self, cond, *args, **kargs): try: cond(self,*args, **kargs) @@ -355,41 +358,54 @@ class Automaton: self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) def __iter__(self): - if not self.running: - self.start() return self - def do_control(self): - singlestep = True - self.debug(3, "Starting control thread [tid=%i]" % self.threadid) - try: - while True: - c = self.cmdin.recv() - self.debug(5, "Received command %s" % c.type) - if c.type == _ATMT_Command.RUN: - singlestep = False - elif c.type == _ATMT_Command.NEXT: - singlestep = True - elif c.type == _ATMT_Command.STOP: - break + def do_control(self, *args, **kargs): + with self.running: + self.threadid = thread.get_ident() + + # Update default parameters + a = args+self.init_args[len(args):] + k = self.init_kargs + k.update(kargs) + self.parse_args(*a,**k) + + # Start the automaton + self.state=self.initial_states[0](self) + self.send_sock = conf.L3socket() + self.listen_sock = conf.L2listen(**self.socket_kargs) + self.packets = PacketList(name="session[%s]"%self.__class__.__name__) + + singlestep = True + self.debug(3, "Starting control thread [tid=%i]" % self.threadid) + try: while True: - try: - state = self.do_next() - except self.CommandMessage: - break - if singlestep: - c = Message(type=_ATMT_Command.SINGLESTEP,state=state) - self.cmdout.send(c) + c = self.cmdin.recv() + self.debug(5, "Received command %s" % c.type) + if c.type == _ATMT_Command.RUN: + singlestep = False + elif c.type == _ATMT_Command.NEXT: + singlestep = True + elif c.type == _ATMT_Command.STOP: break - except StopIteration,e: - c = Message(type=_ATMT_Command.END, result=e.args[0]) - self.cmdout.send(c) - except Exception,e: - self.debug(3, "Transfering exception [%s] from tid=%i"% (e,self.threadid)) - m = Message(type = _ATMT_Command.EXCEPTION, exception=e) - self.cmdout.send(m) - self.debug(3, "Stopping control thread (tid=%i)"%self.threadid) - self.threadid = None + while True: + try: + state = self.do_next() + except self.CommandMessage: + break + if singlestep: + c = Message(type=_ATMT_Command.SINGLESTEP,state=state) + self.cmdout.send(c) + break + except StopIteration,e: + c = Message(type=_ATMT_Command.END, result=e.args[0]) + self.cmdout.send(c) + except Exception,e: + self.debug(3, "Transfering exception [%s] from tid=%i"% (e,self.threadid)) + m = Message(type = _ATMT_Command.EXCEPTION, exception=e) + self.cmdout.send(m) + self.debug(3, "Stopping control thread (tid=%i)"%self.threadid) + self.threadid = None def do_next(self): @@ -404,10 +420,8 @@ class Automaton: self.breakpointed = None state_output = self.state.run() if self.state.error: - self.running = False raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output) if self.state.final: - self.running = False raise StopIteration(state_output) if state_output is None: @@ -470,6 +484,7 @@ class Automaton: self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state)) self.state = state_req return state_req + ## Public API @@ -500,22 +515,7 @@ class Automaton: self.breakpoints.remove(pb) def start(self, *args, **kargs): - self.running = True - - # Update default parameters - a = args+self.init_args[len(args):] - k = self.init_kargs - k.update(kargs) - self.parse_args(*a,**k) - - # Start the automaton - self.state=self.initial_states[0](self) - self.send_sock = conf.L3socket() - self.listen_sock = conf.L2listen(**self.socket_kargs) - self.packets = PacketList(name="session[%s]"%self.__class__.__name__) - - self.threadid = thread.start_new_thread(self.do_control, ()) - + thread.start_new_thread(self.do_control, args, kargs) def run(self, resume=None, wait=True): if resume is None: @@ -542,6 +542,10 @@ class Automaton: def stop(self): self.cmdin.send(Message(type=_ATMT_Command.STOP)) + def restart(self): + self.stop() + self.start() + def accept_packet(self, pkt=None): rsm = Message() if pkt is None: -- GitLab