diff --git a/scapy.py b/scapy.py index 11eb16d5d2e879087a9d519082a5e92be0dae27c..49014585eeca74b1c2ff226c9a727c260acffa68 100755 --- a/scapy.py +++ b/scapy.py @@ -11756,67 +11756,83 @@ class TFTP_read(Automaton): class TFTP_write(Automaton): - def parse_args(self, filename, data, server, dport=69,**kargs): + def parse_args(self, filename, data, server, sport=None, port=69,**kargs): Automaton.parse_args(self, **kargs) self.filename = filename self.server = server - self.dport = dport + self.port = port + self.sport = sport self.blocksize = 512 - self.data = [ data[i*self.blocksize:(i+1)*self.blocksize] - for i in range( (len(data)+self.blocksize-1)/self.blocksize ) ] #XXX: fails if len(data)%bsize=0 + self.origdata = data + def master_filter(self, pkt): + return ( IP in pkt and pkt[IP].src == self.server and UDP in pkt + and pkt[UDP].dport == self.my_tid + and (self.server_tid is None or pkt[UDP].sport == self.server_tid) ) + # BEGIN @ATMT.state(initial=1) - def state_BEGIN(self): - self.my_tid = RandShort()._fix() + def BEGIN(self): + self.data = [ self.origdata[i*self.blocksize:(i+1)*self.blocksize] + for i in range( len(self.origdata)/self.blocksize+1) ] + self.my_tid = self.sport or RandShort()._fix() bind_bottom_up(UDP, TFTP, dport=self.my_tid) self.server_tid = None - self.current_ack = None + + self.l3 = IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.port)/TFTP() + self.last_packet = self.l3/TFTP_WRQ(filename=self.filename, mode="octet") + self.send(self.last_packet) self.res = "" - @ATMT.condition(state_BEGIN) - def on_begin(self): - raise self.state_WAIT_ACK() - @ATMT.action(on_begin) - def send_wrq(self): - self.send(IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.dport)/TFTP()/TFTP_WRQ(filename=self.filename)) - self.awaiting = 0 - - # WAIT_ACK + self.awaiting=0 + + raise self.WAITING_ACK() + + # WAITING_ACK @ATMT.state() - def state_WAIT_ACK(self): + def WAITING_ACK(self): pass - @ATMT.condition(state_WAIT_ACK) - def no_more_data(self): - if not self.data: - raise self.state_END() - @ATMT.receive_condition(state_WAIT_ACK) - def wait_ack(self,pkt): - if IP in pkt and pkt[IP].src == self.server and UDP in pkt and pkt[UDP].dport == self.my_tid: - if self.awaiting == 0: + + @ATMT.receive_condition(WAITING_ACK) + def received_ack(self,pkt): + if TFTP_ACK in pkt and pkt[TFTP_ACK].block == self.awaiting: + if self.server_tid is None: self.server_tid = pkt[UDP].sport - if pkt[UDP].sport == self.server_tid: - if TFTP_ERROR in pkt: - self.errormsg = pkt[TFTP_ERROR].sprintf("TFTP ERROR %ir,errorcode%: %errormsg%") - raise self.state_ERROR() - if TFTP_ACK in pkt: - if pkt[TFTP_ACK].block == self.awaiting: - raise self.state_WAIT_ACK() - @ATMT.action(wait_ack) - def got_ack(self): + self.l3[UDP].dport = self.server_tid + raise self.SEND_DATA() + + @ATMT.receive_condition(WAITING_ACK) + def received_error(self, pkt): + if TFTP_ERROR in pkt: + raise self.ERROR(pkt) + + @ATMT.timeout(WAITING_ACK, 3) + def timeout_waiting(self): + raise self.WAITING_ACK() + @ATMT.action(timeout_waiting) + def retransmit_last_packet(self): + self.send(self.last_packet) + + # SEND_DATA + @ATMT.state() + def SEND_DATA(self): self.awaiting += 1 - self.send( IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.server_tid) - /TFTP()/TFTP_DATA(block=self.awaiting)/self.data.pop(0) ) + self.last_packet = self.l3/TFTP_DATA(block=self.awaiting)/self.data.pop(0) + self.send(self.last_packet) + if self.data: + raise self.WAITING_ACK() + raise self.END() + # ERROR @ATMT.state(error=1) - def state_ERROR(self): + def ERROR(self,pkt): split_bottom_up(UDP, TFTP, dport=self.my_tid) - return self.errormsg + return pkt[TFTP_ERROR].summary() # END @ATMT.state(final=1) - def state_END(self): + def END(self): split_bottom_up(UDP, TFTP, dport=self.my_tid)