diff --git a/scapy.py b/scapy.py
index 11eb16d5d2e879087a9d519082a5e92be0dae27c..49014585eeca74b1c2ff226c9a727c260acffa68 100755
--- a/scapy.py
+++ b/scapy.py
@@ -11756,67 +11756,83 @@ class TFTP_read(Automaton):
 
 
 class TFTP_write(Automaton):
-    def parse_args(self, filename, data, server, dport=69,**kargs):
+    def parse_args(self, filename, data, server, sport=None, port=69,**kargs):
         Automaton.parse_args(self, **kargs)
         self.filename = filename
         self.server = server
-        self.dport = dport
+        self.port = port
+        self.sport = sport
         self.blocksize = 512
-        self.data = [ data[i*self.blocksize:(i+1)*self.blocksize]
-                      for i in range( (len(data)+self.blocksize-1)/self.blocksize ) ] #XXX: fails if len(data)%bsize=0
+        self.origdata = data
 
+    def master_filter(self, pkt):
+        return ( IP in pkt and pkt[IP].src == self.server and UDP in pkt
+                 and pkt[UDP].dport == self.my_tid
+                 and (self.server_tid is None or pkt[UDP].sport == self.server_tid) )
+        
 
     # BEGIN
     @ATMT.state(initial=1)
-    def state_BEGIN(self):
-        self.my_tid = RandShort()._fix()
+    def BEGIN(self):
+        self.data = [ self.origdata[i*self.blocksize:(i+1)*self.blocksize]
+                      for i in range( len(self.origdata)/self.blocksize+1) ] 
+        self.my_tid = self.sport or RandShort()._fix()
         bind_bottom_up(UDP, TFTP, dport=self.my_tid)
         self.server_tid = None
-        self.current_ack = None
+        
+        self.l3 = IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.port)/TFTP()
+        self.last_packet = self.l3/TFTP_WRQ(filename=self.filename, mode="octet")
+        self.send(self.last_packet)
         self.res = ""
-    @ATMT.condition(state_BEGIN)
-    def on_begin(self):
-        raise self.state_WAIT_ACK()
-    @ATMT.action(on_begin)
-    def send_wrq(self):
-        self.send(IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.dport)/TFTP()/TFTP_WRQ(filename=self.filename))
-        self.awaiting = 0
-
-    # WAIT_ACK
+        self.awaiting=0
+
+        raise self.WAITING_ACK()
+        
+    # WAITING_ACK
     @ATMT.state()
-    def state_WAIT_ACK(self):
+    def WAITING_ACK(self):
         pass
-    @ATMT.condition(state_WAIT_ACK)
-    def no_more_data(self):
-        if not self.data:
-            raise self.state_END()
-    @ATMT.receive_condition(state_WAIT_ACK)
-    def wait_ack(self,pkt):
-        if IP in pkt and pkt[IP].src == self.server and UDP in pkt and pkt[UDP].dport == self.my_tid:
-            if self.awaiting == 0:
+
+    @ATMT.receive_condition(WAITING_ACK)    
+    def received_ack(self,pkt):
+        if TFTP_ACK in pkt and pkt[TFTP_ACK].block == self.awaiting:
+            if self.server_tid is None:
                 self.server_tid = pkt[UDP].sport
-            if pkt[UDP].sport == self.server_tid:
-                if TFTP_ERROR in pkt:
-                    self.errormsg = pkt[TFTP_ERROR].sprintf("TFTP ERROR %ir,errorcode%: %errormsg%")
-                    raise self.state_ERROR()
-                if TFTP_ACK in pkt:
-                    if pkt[TFTP_ACK].block == self.awaiting:
-                        raise self.state_WAIT_ACK()
-    @ATMT.action(wait_ack)
-    def got_ack(self):
+                self.l3[UDP].dport = self.server_tid
+            raise self.SEND_DATA()
+
+    @ATMT.receive_condition(WAITING_ACK)
+    def received_error(self, pkt):
+        if TFTP_ERROR in pkt:
+            raise self.ERROR(pkt)
+
+    @ATMT.timeout(WAITING_ACK, 3)
+    def timeout_waiting(self):
+        raise self.WAITING_ACK()
+    @ATMT.action(timeout_waiting)
+    def retransmit_last_packet(self):
+        self.send(self.last_packet)
+    
+    # SEND_DATA
+    @ATMT.state()
+    def SEND_DATA(self):
         self.awaiting += 1
-        self.send( IP(dst=self.server)/UDP(sport=self.my_tid, dport=self.server_tid)
-                   /TFTP()/TFTP_DATA(block=self.awaiting)/self.data.pop(0) )
+        self.last_packet = self.l3/TFTP_DATA(block=self.awaiting)/self.data.pop(0)
+        self.send(self.last_packet)
+        if self.data:
+            raise self.WAITING_ACK()
+        raise self.END()
+    
 
     # ERROR
     @ATMT.state(error=1)
-    def state_ERROR(self):
+    def ERROR(self,pkt):
         split_bottom_up(UDP, TFTP, dport=self.my_tid)
-        return self.errormsg
+        return pkt[TFTP_ERROR].summary()
 
     # END
     @ATMT.state(final=1)
-    def state_END(self):
+    def END(self):
         split_bottom_up(UDP, TFTP, dport=self.my_tid)