Commit da5513d0 authored by Sanjoy Das's avatar Sanjoy Das Committed by TensorFlower Gardener
Browse files

Support compiling Switch/Merge by constant folding Switch predicates

This CL adds limited support for clustering Switch and Merge nodes.  We compiled
Switch nodes by versioning the computation on the value of the Switch condition.
This lets us statically track the deadness through the cluster and allows us to
resolve deadness entirely in the TF->XLA graph compiler without touching XLA.

This CL is organized as follows:

 * jit/mark_for_compilation_pass is taught to not auto-cluster Switch and Merge
   even though these operations have XlaOpKernels.  We also make a minor fix
   without which we would always reject Merge nodes, even when it was assigned
   to an XLA_* device.

 * jit/partially_decluster_pass gets a bugfix without which we would decluster
   nodes that were assigned to XLA_* devices to avoid Device->Host copies.

 * tf2xla/graph_compiler and tf2xla/kernels/control_flow_ops contains the main
   part of the change where we introduce XlaOpKernels for Merge and Switch, and
   teach the TF->XLA graph compiler to track deadness.

PiperOrigin-RevId: 232336750
parent 187f8762
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment