[XLA] Optimize ShapeTree<T>
This optimizes ShapeTree quite significantly. In particular this optimizes for the common case of querying/iterating, copying and moving ShapeTrees. * Allocate all ShapeTreeNodes inside a single, owned, vector. This reduces the number of memory allocations and improves cache performance. * Instead of storing children nodes as unique_ptrs, store them as indices into the owning container's vector. This allows cheap copy-construction (a std::vector POD copy) and doesn't change the fast path (dereferencing a pointer is just as fast as dereferencing a base + offset). * Instead of a unique_ptr<Shape>, use a shared_ptr<Shape>. This removes a load of copy-construction overhead at the cost of a shared_ptr over a unique_ptr (one extra allocation). * Instead of computing ShapeIndexes on-demand in the iterators/ForEach*, precompute them during construction time. This adds a few more bytes per ShapeTree, but now we can... * ... store a std::pair<ShapeIndex, T> as the ShapeTreeNode's data element. This allows us to provide a std::pair<K,V>&, STL-like interface from iterators without going through any of the previous unique_ptr hacks around storage lifetimes. * Because we no longer need to iterate from the beginning to build up the ShapeIndex, we can now offer a ::find() function to return an iterator for a ShapeIndex in O(K) time. As the iteration order is guaranteed to be pre-order, this can be used (and will be, later) to speed up the fast-path of mutating a subtree of a ShapeTree from tf2xla::ExtractSubBuffers. * Similarly because we now have a very standard, cheap STL interface with no performance cliffs, we can hopefully improve ShapedBuffer's copy and move constructors to be cheaper. PiperOrigin-RevId: 197548717
Loading
Please sign in to comment