Commit 962c639b authored by Akshay Agrawal's avatar Akshay Agrawal Committed by TensorFlower Gardener
Browse files

Make functions defined with tfe.defun respect devices when executing.

Modifies GraphModeFunction to emit PartitionedCall ops instead of Call ops
so that the created functions can execute across devices. This should strictly
increase the set of functions that tfe.defun can faithfully execute.
Previous to this change, functions executed through tfe.defun would ignore
device annotations and only run on a single device. It is not yet possible to execute
a function across multiple processes.

Specifically, this CL:
(1) Adds a stateful version of PartitionedCall,
(2) Modifies `defun` to emit PartitionedCall or StatefulPartitionedCall by default,
(3) Makes `tf.gradients` aware of the existence of `(Stateful)PartitionedCall`,
(4) Fixes bugs in PartitionedCallOp related to the placement of
    resource-touching ops / which args and retvals are always on host memory, and
    also removes the requirement for args/retvals to be passed through the host.

PiperOrigin-RevId: 203164388
parent 69fe072c
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment