Convert NumPy arrays to Tensors when they're arguments to a defun
Previously they were counted in the cache key as if they were Tensors, but were not fed as placeholders, leading to stale values when the trace was reused.
There is an 8%ish performance impact from the tuple comprehension on the defun no-signature-call microbenchmarks. I don't see a much faster way to do this without rewriting it in C, but I'm open to ideas. I've avoided re-packing the input tuple unless there's actually a numpy array, so this CL will slow down NumPy defun calls more (in addition to the convert_to_tensor overhead).
After:
entry {
name: "MicroBenchmarks.benchmark_defun_with_signature"
iters: 30000
wall_time: 134.219272931
extras {
key: "examples_per_sec"
value {
double_value: 7450.49483699
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_with_signature_and_kwargs"
iters: 30000
wall_time: 142.88717111
extras {
key: "examples_per_sec"
value {
double_value: 6998.52892485
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature"
iters: 30000
wall_time: 76.2096961339
extras {
key: "examples_per_sec"
value {
double_value: 13121.6898994
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature_and_with_kwargs"
iters: 30000
wall_time: 81.8309704463
extras {
key: "examples_per_sec"
value {
double_value: 12220.3121208
}
}
}
Before:
entry {
name: "MicroBenchmarks.benchmark_defun_with_signature"
iters: 30000
wall_time: 129.392266273
extras {
key: "examples_per_sec"
value {
double_value: 7728.43716862
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_with_signature_and_kwargs"
iters: 30000
wall_time: 141.65956974
extras {
key: "examples_per_sec"
value {
double_value: 7059.1771656
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature"
iters: 30000
wall_time: 70.6333637238
extras {
key: "examples_per_sec"
value {
double_value: 14157.6154282
}
}
}
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature_and_with_kwargs"
iters: 30000
wall_time: 78.4090677897
extras {
key: "examples_per_sec"
value {
double_value: 12753.6269489
}
}
}
PiperOrigin-RevId: 212491803
Loading
Please sign in to comment