Rely on call op placement for the default device inside a graph function
Device placements inside a function will follow the placement of the call operation unless a device scope is opened inside the function body. This means we no longer need device stacks in defun cache keys, and when functions get serialized in SavedModels the device won't be hard-coded.
Requires adding the current distribution strategy stack to the function cache key, since distribution strategies rely on functions being retraced for each new device (e.g. to access different variables on different devices), and before this CL retracing happened because the function was called with different devices set. This cache key addition does slow things down a bit, but (on my machine at least) the slowdown is more than offset by the gains from not specializing on the device stack.
Baseline before this CL:
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature"
iters: 30000
wall_time: 88.3192300797
extras {
key: "examples_per_sec"
value {
double_value: 11322.5624714
}
}
}
After this CL (includes distribution strategies in cache key):
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature"
iters: 30000
wall_time: 84.1960986455
extras {
key: "examples_per_sec"
value {
double_value: 11877.0348756
}
}
}
Hypothetical world where we didn't have to add distribution strategies to the cache key and also didn't need to add devices (i.e. max speedup to be had by optimizing the distribution strategies cache key addition):
entry {
name: "MicroBenchmarks.benchmark_defun_without_signature"
iters: 30000
wall_time: 72.5416739782
extras {
key: "examples_per_sec"
value {
double_value: 13785.1795411
}
}
}
PiperOrigin-RevId: 216777533
Loading
Please sign in to comment