Fix shape inference for bitwise ops with broadcasting (#14678)
* Fix shape inference for bitwise ops with broadcasting This fix tries to address the issue raised in 14646 where shape inference for bitwise ops is incorrect with broadcasting. As was specified in 14646, in the following ``` >>> import tensorflow as tf >>> tf.bitwise.bitwise_and(tf.zeros([3,1], dtype=tf.int32), tf.zeros([1,3], dtype=tf.int32)) <tf.Tensor 'BitwiseAnd:0' shape=(3, 1) dtype=int32> ``` the result shape should be (3, 3), not (3, 1). This fix fixes the issue by changing `.SetShapeFn(shape_inference::UnchangedShape)` to `.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)` This fix fixes 14646. Signed-off-by:Yong Tang <yong.tang.github@outlook.com> * Add test cases for shape inference for bitwise ops with broadcasting This commit adds test cases for shape inference of bitwise ops with broadcasting Signed-off-by:
Yong Tang <yong.tang.github@outlook.com> * Update to add addtional assert for shape function for comment feedback. Signed-off-by:
Yong Tang <yong.tang.github@outlook.com>
Loading
Please sign in to comment