Commit a17a64d0 authored by Saurabh Saxena's avatar Saurabh Saxena Committed by TensorFlower Gardener
Browse files

Loosens constraints on maximum_iterations in while_v2 for XLA.

maximum_iterations no longer needs to be a TF-graph-compile-time constant but should be a compile-time constant when building the XLA graph (when values of placeholders are also available).

There are broadly 2 parts of this change.
1. while_v2: This removes the _maximum_iterations attr from the functional While op generated by while_v2 and makes it an input instead. This frees maximum_iterations from being a graph-building time constant. Its index in the list of op inputs is preserved in both the forward and backwards graph so it should be ok for n'th order derivatives.

2. The XLA changes are needed to propagate constants inside the loop body e.g. when building the gradient of a While inside a while_loop. They are 2-fold:
a. This updates the const_analysis pass to inspect the cond and body functions of While for compile time constant requirements.
b. When compiling the XLA While op in while_op.cc if there are inputs that are compile time constants and the corresponding loop variables are loop invariants, the constants get propagated into the loop body.

Makes XlaCompiler::FindFunctionBody public to allow accessing the body of the While in while_op.cc. If the body has been specialized using the PropagateConstIntoFunctionalNodes pass the rewritten function only exists in XlaCompiler::local_flib_def_.

PiperOrigin-RevId: 235122516
parent 796ce572
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment