[TF:XLA] Create XLA clusters with heterogeneous device assignment
This CL teaches the TF/XLA bridge to create, compile and run clusters that have nodes placed on multiple devices. XLA does not support multi-device compilation so we only use a single XLA backend to compile the whole cluster, and we choose which XLA backend to use based on some heuristics. The main motivation is to avoid breaking up XLA:GPU clusters because of a few stray TF nodes placed on the CPU. This CL is organized as follows: * jit/xla_cluster_util implements the heuristic we use to choose the XLA compiler backend to use for a cluster containing nodes placed on multiple devices. * jit/build_xla_ops_pass is taught to lower an XLA cluster into a _XlaCompile/_XlaRun/PartitionedCall triplet instead of the XlaCompile/_XlaRun/"TF call" triplet that it used to lower to before. We need a PartitionedCall op because normal TF calls do not support callees with heterogeneous device assignment. * jit/mark_for_compilation_pass is taught to create these heterogeneous clusters in the first place. * jit/compilation_passes_test_main is added to make the unit tests run with --tf_xla_cpu_global_jit=true as a command line argument. This is necessary because we've changed how we figure out that clusters placed on CPU should not be compiled unless tf_xla_cpu_global_jit is true. I also noticed that the `is_compilable` lambda is making the structure of MarkForCompilationPass more convoluted than it needs to me, and fixing it is a TODO for me. * kernels/xla_ops needs a bailout for a situation that should not happen for any "real" models, but does happen for some unit tests. * common_runtime/function now pipes `override_device` through ExpandInlineFunctions and teaches InlineFunctionBody to not drop the assigned device on the floor when override_device is set. * common_runtime/graph_optimizer tells ExpandInlineFunctions to override the device when optimizing the graph. This is a tricky change, but to me it seemed justified because `override_device=true` will preserve behavior of the function call (all nodes run on the caller's device). I needed to make this change because some models have function calls where the function call nodes are placed on the CPU but some nodes in the callee are placed on the GPU. Since EncapsulateSubgraphsPass runs the inliner on the extracted clusters this introduces nodes placed on the GPU into the cluster, and this causes us to (incorrectly) re-infer that the cluster should run on the GPU in build_xla_ops_pass. * common_runtime/optimization_registry has a minor fix to help debugging. * There changes to two tests because now more clusters are getting compiled by XLA than before. PiperOrigin-RevId: 235946559
Loading
Please sign in to comment