Support compiling Switch/Merge by constant folding Switch predicates
This CL adds limited support for clustering Switch and Merge nodes. We compiled Switch nodes by versioning the computation on the value of the Switch condition. This lets us statically track the deadness through the cluster and allows us to resolve deadness entirely in the TF->XLA graph compiler without touching XLA. This CL is organized as follows: * jit/mark_for_compilation_pass is taught to not auto-cluster Switch and Merge even though these operations have XlaOpKernels. We also make a minor fix without which we would always reject Merge nodes, even when it was assigned to an XLA_* device. * jit/partially_decluster_pass gets a bugfix without which we would decluster nodes that were assigned to XLA_* devices to avoid Device->Host copies. * tf2xla/graph_compiler and tf2xla/kernels/control_flow_ops contains the main part of the change where we introduce XlaOpKernels for Merge and Switch, and teach the TF->XLA graph compiler to track deadness. PiperOrigin-RevId: 232336750
Loading
Please sign in to comment