Add support for non-Tensor args in recompute_grad
Previously, the function decorated by recompute_grad had to have a signature that contained only positional arguments, and all those arguments had to be Tensors. Most "layers" users define however have non-Tensor arguments (for example, various hyperparameters) and often have keyword arguments as well. This change allows a user to use whatever function signature they wish while being explicit about which arguments are Tensors. PiperOrigin-RevId: 193600682
Loading
Please sign in to comment