Fix incorrect error message in `conv2d_transpose` (#13089)
This fix fixes the incorrect error message in `conv2d_transpose`
where `value.get_shape()[3]` should be changed to `value.get_shape()[axis]`
for `input_depth` in `NCHW` format.
The incorrectness could be seen from the following where `4 != 4`
is quite confusing:
```python
ubuntu@ubuntu:~$ cat v.py
import tensorflow as tf
strides = [1, 1, 1, 1]
input_shape = [2, 3, 6, 4]
output_shape = [2, 2, 6, 4]
filter_shape = [3, 3, 2, 4]
input = tf.constant(1.0, shape=input_shape, name="input", dtype=tf.float32)
filter = tf.constant(1.0, shape=filter_shape, name="filter", dtype=tf.float32)
output = tf.nn.conv2d_transpose(input, filter, output_shape, strides=strides, padding="SAME", data_format="NCHW")
ubuntu@ubuntu:~$ python v.py
Traceback (most recent call last):
File "v.py", line 16, in <module>
output = tf.nn.conv2d_transpose(input, filter, output_shape, strides=strides, padding="SAME", data_format="NCHW")
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/nn_ops.py", line 1026, in conv2d_transpose
)[3]))
ValueError: input channels does not match filter's input channels, 4 != 4
```
Signed-off-by:
Yong Tang <yong.tang.github@outlook.com>
Loading
Please sign in to comment