Commit a659294d authored by Sanjoy Das's avatar Sanjoy Das Committed by TensorFlower Gardener
Browse files

[TF:XLA] Create XLA clusters with heterogeneous device assignment

This CL teaches the TF/XLA bridge to create, compile and run clusters that have
nodes placed on multiple devices.  XLA does not support multi-device compilation
so we only use a single XLA backend to compile the whole cluster, and we choose
which XLA backend to use based on some heuristics.

The main motivation is to avoid breaking up XLA:GPU clusters because of a few
stray TF nodes placed on the CPU.

This CL is organized as follows:

 * jit/xla_cluster_util implements the heuristic we use to choose the XLA
   compiler backend to use for a cluster containing nodes placed on multiple
   devices.

 * jit/build_xla_ops_pass is taught to lower an XLA cluster into a
   _XlaCompile/_XlaRun/PartitionedCall triplet instead of the
   XlaCompile/_XlaRun/"TF call" triplet that it used to lower to before.  We
   need a PartitionedCall op because normal TF calls do not support callees with
   heterogeneous device assignment.

 * jit/mark_for_compilation_pass is taught to create these heterogeneous
   clusters in the first place.

 * jit/compilation_passes_test_main is added to make the unit tests run with
   --tf_xla_cpu_global_jit=true as a command line argument.  This is necessary
   because we've changed how we figure out that clusters placed on CPU should
   not be compiled unless tf_xla_cpu_global_jit is true.  I also noticed that
   the `is_compilable` lambda is making the structure of MarkForCompilationPass
   more convoluted than it needs to me, and fixing it is a TODO for me.

 * kernels/xla_ops needs a bailout for a situation that should not happen for
   any "real" models, but does happen for some unit tests.

 * common_runtime/function now pipes `override_device` through
   ExpandInlineFunctions and teaches InlineFunctionBody to not drop the assigned
   device on the floor when override_device is set.

 * common_runtime/graph_optimizer tells ExpandInlineFunctions to override the
   device when optimizing the graph.  This is a tricky change, but to me it
   seemed justified because `override_device=true` will preserve behavior of the
   function call (all nodes run on the caller's device).  I needed to make this
   change because some models have function calls where the function call nodes
   are placed on the CPU but some nodes in the callee are placed on the GPU.
   Since EncapsulateSubgraphsPass runs the inliner on the extracted clusters
   this introduces nodes placed on the GPU into the cluster, and this causes us
   to (incorrectly) re-infer that the cluster should run on the GPU in
   build_xla_ops_pass.

 * common_runtime/optimization_registry has a minor fix to help debugging.

 * There changes to two tests because now more clusters are getting compiled by
   XLA than before.

PiperOrigin-RevId: 235946559
parent 50eae363
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment