diff --git a/scapy/automaton.py b/scapy/automaton.py index 10839ef79973df2549c9a54142626b5617346267..f35dc0d5153c8375372736e6b72e8d016dbbf71b 100644 --- a/scapy/automaton.py +++ b/scapy/automaton.py @@ -27,6 +27,15 @@ class ObjectPipe: return self.queue.popleft() +class Message: + def __init__(self, **args): + self.__dict__.update(args) + def __repr__(self): + return "<Message %s>" % " ".join("%s=%r"%(k,v) + for (k,v) in self.__dict__.iteritems() + if not k.startswith("_")) + + ############## ## Automata ## ############## @@ -57,6 +66,8 @@ class ATMT: return self def run(self): return self.func(self.automaton, *self.args, **self.kargs) + def __repr__(self): + return "NewStateRequested(%s)" % self.state @staticmethod def state(initial=0,final=0,error=0): @@ -125,6 +136,16 @@ class ATMT: return f return deco +class _ATMT_Command: + RUN = "RUN" + NEXT = "NEXT" + STOP = "STOP" + END = "END" + SINGLESTEP = "SINGLESTEP" + INTERCEPT = "INTERCEPT" + ACCEPT = "ACCEPT" + REPLACE = "REPLACE" + REJECT = "REJECT" class Automaton_metaclass(type): def __new__(cls, name, bases, dct): @@ -239,9 +260,6 @@ class Automaton_metaclass(type): class Automaton: __metaclass__ = Automaton_metaclass - class _IO: - pass - class _IO_wrapper: def __init__(self,rd,wr): self.rd = rd @@ -258,18 +276,20 @@ class Automaton: return self.wr.send(msg) def write(self, msg): return self.wr.send(msg) - - def __init__(self, *args, **kargs): self.running = False + self.threadid = None self.breakpointed = None self.breakpoints = set() + self.interception_points = set() self.debug_level=0 self.init_args=args self.init_kargs=kargs - self.io = self._IO() - self.oi = self._IO() + self.io = type.__new__(type, "IOnamespace",(),{}) + self.oi = type.__new__(type, "IOnamespace",(),{}) + self.cmdin = ObjectPipe() + self.cmdout = ObjectPipe() self.ioin = {} self.ioout = {} for n in self.ionames: @@ -301,6 +321,9 @@ class Automaton: Exception.__init__(self, msg) self.breakpoint = breakpoint + class CommandMessage(Exception): + pass + def parse_args(self, debug=0, store=1, **kargs): self.debug_level=debug self.socket_kargs = kargs @@ -325,6 +348,19 @@ class Automaton: self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) + def add_interception_points(self, *ipts): + for ipt in ipts: + if hasattr(ipt,"atmt_state"): + ipt = ipt.atmt_state + self.interception_points.add(ipt) + + def remove_interception_points(self, *ipts): + for ipt in ipts: + if hasattr(ipt,"atmt_state"): + ipt = ipt.atmt_state + if ipt in self.interception_points: + self.interception_points.remove(ipt) + def add_breakpoints(self, *bps): for bp in bps: if hasattr(bp,"atmt_state"): @@ -353,12 +389,48 @@ class Automaton: self.listen_sock = conf.L2listen(**self.socket_kargs) self.packets = PacketList(name="session[%s]"%self.__class__.__name__) - def next(self): - if not self.running: - self.start() + self.threadid = thread.start_new_thread(self.do_control, ()) + + def do_control(self): + singlestep = True + self.debug(3, "Starting control thread [tid=%i]" % self.threadid) + stop = False + while not stop: + c = self.cmdin.recv() + self.debug(5, "Received command %s" % c.type) + if c.type == _ATMT_Command.RUN: + singlestep = False + elif c.type == _ATMT_Command.NEXT: + singlestep = True + elif c.type == _ATMT_Command.STOP: + break + while True: + try: + state = self.do_next() + except KeyboardInterrupt: + self.debug(1,"Interrupted by user") + stop=True + break + except self.CommandMessage: + break + except StopIteration,e: + c = Message(type=_ATMT_Command.END, result=e.args[0]) + self.cmdout.send(c) + stop=True + break + if singlestep: + c = Message(type=_ATMT_Command.SINGLESTEP,state=state) + self.cmdout.send(c) + break + self.debug(3, "Stopping control thread (tid=%i)"%self.threadid) + self.threadid = None + + + def do_next(self): try: self.debug(1, "## state=[%s]" % self.state.state) + # Entering a new state. First, call new state function if self.state.state in self.breakpoints and self.state.state != self.breakpointed: self.breakpointed = self.state.state @@ -393,7 +465,7 @@ class Automaton: next_timeout,timeout_func = expirations.next() t0 = time.time() - fds = [] + fds = [self.cmdin] if len(self.recv_conditions[self.state.state]) > 0: fds.append(self.listen_sock) for ioev in self.ioevents[self.state.state]: @@ -409,8 +481,11 @@ class Automaton: else: remain = next_timeout-t + self.debug(5, "Select on %r" % fds) r,_,_ = select(fds,[],[],remain) for fd in r: + if fd == self.cmdin: + raise self.CommandMessage() if fd == self.listen_sock: pkt = self.listen_sock.recv(MTU) if pkt is not None: @@ -431,26 +506,43 @@ class Automaton: self.state = state_req return state_req - - - def run(self, *args, **kargs): - if not self.running: - self.start(*args, **kargs) - - while 1: + def run(self, resume=None, wait=True): + if resume is None: + resume = Message(type = _ATMT_Command.RUN) + self.cmdin.send(resume) + if wait: try: - self.next() + c = self.cmdout.recv() except KeyboardInterrupt: - self.debug(1,"Interrupted by user") - break - except StopIteration,e: - return e.args[0] + return + if c.type == _ATMT_Command.END: + return c.result + elif c.type == _ATMT_Command.INTERCEPT: + print "Packet intercepted" + return c.pkt + elif c.type == _ATMT_Command.SINGLESTEP: + return c.state - cont = run + def next(self): + return self.run(resume = Message(type=_ATMT_Command.NEXT)) - def run_bg(self, *args, **kargs): - self.threadid = thread.start_new_thread(self.run, args, kargs) + def stop(self): + self.cmdin.send(Message(type=_ATMT_Command.STOP)) + + def accept_packet(self, pkt=None): + rsm = Message() + if pkt is None: + rsm.type = _ATMT_Command.ACCEPT + else: + rsm.type = _ATMT_Command.REPLACE + rsm.pkt = pkt + return self.run(resume=rsm) + def reject_packet(self): + rsm = Message(type = _ATMT_Command.REJECT) + return self.run(resume=rsm) + + def __iter__(self): if not self.running: self.start() @@ -460,6 +552,21 @@ class Automaton: self.send_sock.send(pkt) def send(self, pkt): + if self.state.state in self.interception_points: + self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary()) + cmd = Message(type = _ATMT_Command.INTERCEPT, pkt = pkt) + self.cmdout.send(cmd) + cmd = self.cmdin.recv() + if cmd.type == _ATMT_Command.REJECT: + self.debug(3,"INTERCEPT: packet rejected") + return + elif cmd.type == _ATMT_Command.REPLACE: + self.debug(3,"INTERCEPT: packet replaced") + pkt = cmd.pkt + elif cmd.type == _ATMT_Command.ACCEPT: + self.debug(3,"INTERCEPT: packet accepted") + else: + self.debug(1,"INTERCEPT: unkown verdict: %r" % cmd.type) self.my_send(pkt) self.debug(3,"SENT : %s" % pkt.summary()) self.packets.append(pkt.copy())