Fix taking higher-order derivatives of cond_v2.
The problem: When we build the N-th derivative of an op or set of ops, we will likely end up reconstructing the previous (N-1)-th derivatives (we could theoretically avoid this by cleverly finding and reusing previously-constructed gradients as we traverse the forward pass). In the case of the If op, this means that we end up constructing the same gradient functions multiple times when taking higher-order derivatives. Prior to this change, we would always generate the same function name for the same grad function. This usually worked because the two functions would be identical, and we already silently dedup identical functions (this is to ease importing graphs with functions). However, it occasionally didn't work because we ended up generating two different FunctionDefs with the same name (I'm not sure why the FunctionDefs were different, but I'm guessing it's the unordered_map in the TF_GraphToFunction implementation). The solution: Rather than depend on the subtle deduping behavior, I made the cond_v2 implementation find unique names for all grad functions. This will result in more functions being generated, but I think it makes the behavior more obvious. In addition, this change properly adds the If branch functions to the graph. PiperOrigin-RevId: 199560887
Loading
Please sign in to comment