Basic while loop gradient functionality in C++
This change introduces the basic framework to create the gradient graph of a while loop using the C++ API. This supports building the gradient graph as long as the body function of the while loop contains no ops whose gradient function requires a stack. In other words, it doesn't support gradient functions that use the input values to the op (e.g. add will work, but multiply will not). It also doesn't support nested while loops, and doesn't detect all error cases. PiperOrigin-RevId: 170243281
Loading
Please sign in to comment