Commit f2a6408b authored by Yong Tang's avatar Yong Tang Committed by drpngx
Browse files

Fix incorrect error message in `conv2d_transpose` (#13089)



This fix fixes the incorrect error message in `conv2d_transpose`
where `value.get_shape()[3]` should be changed to `value.get_shape()[axis]`
for `input_depth` in `NCHW` format.

The incorrectness could be seen from the following where `4 != 4`
is quite confusing:

```python
ubuntu@ubuntu:~$ cat v.py
import tensorflow as tf

strides = [1, 1, 1, 1]
input_shape = [2, 3, 6, 4]
output_shape = [2, 2, 6, 4]
filter_shape = [3, 3, 2, 4]

input = tf.constant(1.0, shape=input_shape, name="input", dtype=tf.float32)
filter = tf.constant(1.0, shape=filter_shape, name="filter", dtype=tf.float32)

output = tf.nn.conv2d_transpose(input, filter, output_shape, strides=strides, padding="SAME", data_format="NCHW")
ubuntu@ubuntu:~$ python v.py
Traceback (most recent call last):
  File "v.py", line 16, in <module>
    output = tf.nn.conv2d_transpose(input, filter, output_shape, strides=strides, padding="SAME", data_format="NCHW")
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/nn_ops.py", line 1026, in conv2d_transpose
    )[3]))
ValueError: input channels does not match filter's input channels, 4 != 4
```

Signed-off-by: default avatarYong Tang <yong.tang.github@outlook.com>
parent 942193ae
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment