Reland improve fusion logic of (a dot b) * alpha
The previous fusion approach didn't work because a multiplication by a scalar value
will be changed into an explicit broadcast.
Another issue that is fixed in this CL is retrieving the constant value from
the literal. This depends on the PrimitiveType, before we always assumed it to be double.
Also when checking ImplementedAsGemm() we should not call it recursively, but instead just the check related to kDot.
Finally add an execution test and adjust the fusion logic test.
The fix for the issue that caused the revert is that we check earlier that consumer->operand_count() is 2.
Also, we fix the call to Get() to pass {} instead of {0}.
And we handle an output fusion node in GemmThunk to extract the dimension numbers from the dot operation.
PiperOrigin-RevId: 196631031
Loading
Please sign in to comment