Fix the threading model of gradient tapes.
The set of tapes needs to be global to enable multithreaded programming (when it's natural for tensors to cross threads during reduction operations) but each thread still needs to be able to locally pause recording while it does gradient-related bookkeeping (like custom gradients or initialization). Also removes a mutex from the thread-local structure since it's unnecessary as we're always holding the GIL while calling across the python-c boundary unless we explicitly release it. PiperOrigin-RevId: 181246570
Loading
Please sign in to comment