Commit 7ffbedee authored by Yong Tang's avatar Yong Tang
Browse files

Add uint16 support for py_func



In tf most of the numeric data types are supported though uint16 support
is not:
```
$ python
>>> import tensorflow as tf
>>> def sum_func(x, y):
...   return x + y
...
>>> x = tf.constant(1, dtype=tf.uint16)
>>> y = tf.constant(2, dtype=tf.uint16)
>>> z = tf.py_func(sum_func, [x, y], tf.uint16)
>>> tf.Session().run(z)
...
...
tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported numpy type 4
	 [[Node: PyFunc = PyFunc[Tin=[DT_UINT16, DT_UINT16], Tout=[DT_UINT16], token="pyfunc_0", _device="/job:localhost/replica:0/task:0/device:CPU:0"](Const, Const_1)]]
...
```

The reason is that there is no conversion between numpy uint16 and tf.uint16.

This fix adds the support so that py_func could process tf.uint16 data types.

This fix also adds test cases for different data types with py_func to
increase the test coverage.

Signed-off-by: default avatarYong Tang <yong.tang.github@outlook.com>
parent e1fa4112
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment