diff --git a/scapy/automaton.py b/scapy/automaton.py index f7a644151f55aa8eed7bf9a2afdefe987785c3f6..4b4b1fd8b4086f337bb7d38296c62007ac8417c5 100644 --- a/scapy/automaton.py +++ b/scapy/automaton.py @@ -155,6 +155,7 @@ class ATMT: class _ATMT_Command: RUN = "RUN" NEXT = "NEXT" + FREEZE = "FREEZE" STOP = "STOP" END = "END" EXCEPTION = "EXCEPTION" @@ -327,19 +328,30 @@ class Automaton: def write(self, msg): return self.wr.send(msg) - class ErrorState(Exception): + + class AutomatonException(Exception): + pass + + class ErrorState(AutomatonException): def __init__(self, msg, result=None): Exception.__init__(self, msg) self.result = result class Stuck(ErrorState): pass - class Breakpoint(Exception): - def __init__(self, msg, breakpoint): + class AutomatonStopped(AutomatonException): + def __init__(self, msg, state=None): Exception.__init__(self, msg) - self.breakpoint = breakpoint + self.state = state + + class Breakpoint(AutomatonStopped): + pass + class InterceptionPoint(AutomatonStopped): + def __init__(self, msg, state, packet): + Automaton.AutomatonStopped.__init__(self, msg, state) + self.packet = packet - class CommandMessage(Exception): + class CommandMessage(AutomatonException): pass @@ -351,7 +363,7 @@ class Automaton: def send(self, pkt): if self.state.state in self.interception_points: self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary()) - cmd = Message(type = _ATMT_Command.INTERCEPT, pkt = pkt) + cmd = Message(type = _ATMT_Command.INTERCEPT, state=self.state, pkt=pkt) self.cmdout.send(cmd) cmd = self.cmdin.recv() if cmd.type == _ATMT_Command.REJECT: @@ -431,13 +443,18 @@ class Automaton: else: self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) + def _do_start(self, *args, **kargs): + + thread.start_new_thread(self._do_control, args, kargs) + + def _do_control(self, *args, **kargs): with self.running: self.threadid = thread.get_ident() # Update default parameters a = args+self.init_args[len(args):] - k = self.init_kargs + k = self.init_kargs.copy() k.update(kargs) self.parse_args(*a,**k) @@ -457,6 +474,8 @@ class Automaton: singlestep = False elif c.type == _ATMT_Command.NEXT: singlestep = True + elif c.type == _ATMT_Command.FREEZE: + continue elif c.type == _ATMT_Command.STOP: break while True: @@ -486,7 +505,7 @@ class Automaton: if self.state.state in self.breakpoints and self.state.state != self.breakpointed: self.breakpointed = self.state.state raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state, - breakpoint = self.state.state) + state = self.state.state) self.breakpointed = None state_output = self.state.run() if self.state.error: @@ -567,8 +586,7 @@ class Automaton: for ipt in ipts: if hasattr(ipt,"atmt_state"): ipt = ipt.atmt_state - if ipt in self.interception_points: - self.interception_points.remove(ipt) + self.interception_points.discard(ipt) def add_breakpoints(self, *bps): for bp in bps: @@ -580,11 +598,11 @@ class Automaton: for bp in bps: if hasattr(bp,"atmt_state"): bp = bp.atmt_state - if bp in self.breakpoints: - self.breakpoints.remove(pb) + self.breakpoints.discard(pb) def start(self, *args, **kargs): - thread.start_new_thread(self._do_control, args, kargs) + if not self.running.locked(): + self._do_start(*args, **kargs) def run(self, resume=None, wait=True): if resume is None: @@ -594,14 +612,14 @@ class Automaton: try: c = self.cmdout.recv() except KeyboardInterrupt: + self.cmdin.send(Message(type = _ATMT_Command.FREEZE)) return if c.type == _ATMT_Command.END: return c.result elif c.type == _ATMT_Command.INTERCEPT: - print "Packet intercepted" - return c.pkt + raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt) elif c.type == _ATMT_Command.SINGLESTEP: - return c.state + raise self.Breakpoint("singlestep", state=c.state.state) elif c.type == _ATMT_Command.EXCEPTION: raise c.exception @@ -612,24 +630,25 @@ class Automaton: return self.run(resume = Message(type=_ATMT_Command.NEXT)) def stop(self): - self.cmdin.send(Message(type=_ATMT_Command.STOP)) + if self.running.locked(): + self.cmdin.send(Message(type=_ATMT_Command.STOP)) - def restart(self): + def restart(self, *args, **kargs): self.stop() - self.start() + self.start(*args, **kargs) - def accept_packet(self, pkt=None): + def accept_packet(self, pkt=None, wait=True): rsm = Message() if pkt is None: rsm.type = _ATMT_Command.ACCEPT else: rsm.type = _ATMT_Command.REPLACE rsm.pkt = pkt - return self.run(resume=rsm) + return self.run(resume=rsm, wait=wait) - def reject_packet(self): + def reject_packet(self, wait=True): rsm = Message(type = _ATMT_Command.REJECT) - return self.run(resume=rsm) + return self.run(resume=rsm, wait=wait)