BREAKING_CHANGE: Remove event_ndims in Bijector, and require...
BREAKING_CHANGE: Remove event_ndims in Bijector, and require `log_det_jacobian` methods to take event_ndims. The class level event_ndims parameter is being deprecated in favor of passing it in to the `log_det_jacobian` methods. Specific changes: - `log_det_jacobian` signatures are now `log_det_jacobian(input, event_ndims)` - Constructors no long have event_ndims passed in (e.g. Affine() vs. Affine(event_ndims=0)). - All bijectors must specify a subset of [forward_min_event_ndims, inverse_min_event_ndims]. This is the minimal dimensionality the bijector operates on, with it being "broadcasted" to any passed in event_ndims (e.g. Exp has forward_min_event_ndims = 0. That means it operates on scalars. However, we can use the bijector on any event_ndims > 0 (i.e. we've broadcasted the transformation to work on any amount of event_ndims > 0), and jacobian reduction will work in those cases. As a result of this change, all bijectors should "broadcast" (e.g. Sigmoid now works on any number of event_ndims). Other changes (internal and documentation): - Added clarifications on Jacobian Determinant vs. Jacobian Matrix. - Added clarifications on min_event_ndims, and what the jacobian reduction is over. - Changed caching of ildj to be keyed on event_ndims. - Several bug fixes to bugs unearthed while writing this code (e.g. transformed distribution shape computation being incorrect) PiperOrigin-RevId: 192504919
Loading
Please sign in to comment