This CL implements the following behavior for stop_gradient in while loop...
This CL implements the following behavior for stop_gradient in while loop gradient computation. Consider a while loop with inputs xs = (x1, x2, ..., xn) and outputs ys = (y1, y2, ..., yn). If all the ys are either stopped or not used, the entire while loop will be skipped in the gradient graph and the gradient for the xs will be None for downstream backprop. If there is a real gradient for any of the ys, we will backprop 0 for all the ys that are stopped or not used. Note that this CL doesn't address any issues of using stop_gradient inside the loop body. To properly support stop_gradient inside a while loop, we probably need to perform a fix-point static graph analysis to compute the dependency between ys and xs. We will address that in a separate CL. Change: 131790233
Loading
Please sign in to comment