diff --git a/scapy.py b/scapy.py index 3b243ddb7c7e3046b6669585d660d14809765373..11eb16d5d2e879087a9d519082a5e92be0dae27c 100755 --- a/scapy.py +++ b/scapy.py @@ -11667,78 +11667,88 @@ class Automaton: class TFTP_read(Automaton): - def parse_args(self, filename, server, port=69, **kargs): + def parse_args(self, filename, server, sport = None, port=69, **kargs): Automaton.parse_args(self, **kargs) self.filename = filename self.server = server self.port = port + self.sport = sport + + def master_filter(self, pkt): + return ( IP in pkt and pkt[IP].src == self.server and UDP in pkt + and pkt[UDP].dport == self.my_tid + and (self.server_tid is None or pkt[UDP].sport == self.server_tid) ) + # BEGIN @ATMT.state(initial=1) - def state_BEGIN(self): + def BEGIN(self): self.blocksize=512 - self.my_tid = RandShort()._fix() + self.my_tid = self.sport or RandShort()._fix() bind_bottom_up(UDP, TFTP, dport=self.my_tid) self.server_tid = None - self.send(IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.port)/TFTP()/TFTP_RRQ(filename=self.filename, mode="octet")) - self.awaiting = 1 - self.current_ack = None self.res = "" - @ATMT.condition(state_BEGIN) - def on_begin(self): - raise self.state_RECEIVING() - # RECEIVING + self.l3 = IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.port)/TFTP() + self.last_packet = self.l3/TFTP_RRQ(filename=self.filename, mode="octet") + self.send(self.last_packet) + self.awaiting=1 + + raise self.WAITING() + + # WAITING @ATMT.state() - def state_RECEIVING(self): + def WAITING(self): pass - @ATMT.receive_condition(state_RECEIVING) - def receiving(self, pkt): - if IP in pkt and pkt[IP].src == self.server and UDP in pkt and pkt[UDP].dport == self.my_tid: - if self.awaiting == 1: + + + @ATMT.receive_condition(WAITING) + def receive_data(self, pkt): + if TFTP_DATA in pkt and pkt[TFTP_DATA].block == self.awaiting: + if self.server_tid is None: self.server_tid = pkt[UDP].sport - if pkt[UDP].sport == self.server_tid: - self.pkt = pkt - raise self.state_RECEIVED() - @ATMT.timeout(state_RECEIVING, 3) - def recv_timeout_ack(self): - raise self.state_RECEIVING() - @ATMT.action(recv_timeout_ack) - def action_timeout(self): - if self.current_ack is not None: - self.send(self.current_ack) + self.l3[UDP].dport = self.server_tid + raise self.RECEIVING(pkt) + + @ATMT.receive_condition(WAITING, prio=1) + def receive_error(self, pkt): + if TFTP_ERROR in pkt: + raise self.ERROR(pkt) + + + @ATMT.timeout(WAITING, 3) + def timeout_waiting(self): + raise self.WAITING() + @ATMT.action(timeout_waiting) + def retransmit_last_packet(self): + self.send(self.last_packet) + + @ATMT.action(receive_data) +# @ATMT.action(receive_error) + def send_ack(self): + self.last_packet = self.l3 / TFTP_ACK(block = self.awaiting) + self.send(self.last_packet) + # RECEIVED @ATMT.state() - def state_RECEIVED(self): - pass - @ATMT.condition(state_RECEIVED) - def received_error(self): - if TFTP_ERROR in self.pkt: - raise self.state_ERROR() - @ATMT.condition(state_RECEIVED) - def received_ok(self): - if TFTP_DATA in self.pkt and self.pkt[TFTP_DATA].block == self.awaiting: - self.current_ack=IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.server_tid)/TFTP()/TFTP_ACK(block=self.awaiting) - self.awaiting += 1 - received = self.pkt[Raw].load - self.res += received - if len(received) == self.blocksize: - raise self.state_RECEIVING() - raise self.state_END() - - @ATMT.action(received_ok) - def received_data(self): - self.send(self.current_ack) + def RECEIVING(self, pkt): + recvd = pkt[Raw].load + self.res += recvd + self.awaiting += 1 + if len(recvd) == self.blocksize: + raise self.WAITING() + raise self.END() # ERROR @ATMT.state(error=1) - def state_ERROR(self): - pass + def ERROR(self,pkt): + split_bottom_up(UDP, TFTP, dport=self.my_tid) + return pkt[TFTP_ERROR].summary() #END @ATMT.state(final=1) - def state_END(self): + def END(self): split_bottom_up(UDP, TFTP, dport=self.my_tid) return self.res