Fix c++ gradients issue where multiple dependent outputs result in incorrect answer.
The issue is that we incorrectly calculate the pending num_expected_backprops for outputs nodes when one output transitively depends on another. this is because we use output nodes as an indicator of when we need to end our traversal. Instead we should only use output nodes that don't transitively get consumed by other output nodes as end indicators for our traversal. This change implements that fix. Fixes #13190 PiperOrigin-RevId: 170971937
Loading
Please sign in to comment