Resolve distributed variables captured by defun at call time
Before this change, when was function is called in a distribution strategy context, it would capture the component variables from some device and always use these variables, even when the function is executed on a different device. This CL "reevaluates" distributed variables to get the correct variable at call time. These correct variables are then passed to the function. We don't handle distributed tensors. First, because the mechanics for handling distributed tensors are different from handling distributed variables, their support added significant complexity to already complex defuns. Second, there is no easy way for users have a function capture a distributed tensor or feed a distributed tensor explicitly. If this changes, we can support them (the code exists in this CL's history). We also don't handle distributed variables explicitly passed into the function for similar reasons. PiperOrigin-RevId: 207640908
Loading
Please sign in to comment