Removed redundant str() call in _maybe_define_function.
It was causing a significant slowdown when either of the input tensors had <1000 elements (which is a NumPy heuristic for collapsing the str with ...) Reproducer: def square(X): return tf.matmul(X, X, transpose_a=True) square_f = tf.function(square) >>> X = tf.zeros([42, 24]) # 1008 elements. >>> _ = square_f(X) >>> %timeit square_f(X) 10000 loops, best of 3: 167 ?s per loop >>> X = tf.zeros([42, 1]) >>> _ = square_f(X) >>> %timeit square_f(X) 100 loops, best of 3: 1.54 ms per loop PiperOrigin-RevId: 237459548
Loading
Please sign in to comment