Introduce an "indexed array" analysis
Context: we want to optimize computations hanging off of a embedding lookup from a constant array. For instance, consider: embedding = gather from a constant array using non-constant indices embedding_reshaped = reshape embedding embedding_reshaped_transposed = transpose embedding_reshaped result = dot(embedding_reshaped_transposed, constant) In the graph above, depending on how the details work out, we may be able to fold `result` into a gather from a precomputed constant array. However, it is inconvenient to get there by incremental rewrites -- it is probably not profitable to rewrite embedding_reshaped or embedding_reshaped_transposed [0] as embedding lookups but we get to "see" that the dot can be rewritten only after rewriting the reshape and the transpose. This analysis aims to make the optimization above more straightforward by allowing a transformation pass (that uses this analysis) to query the analysis to see if if `result` _can_ be represented as an embedding lookup. If yes it can then apply some profitability heuristics to decide if it is worth it to rewrite it as one. This suggested workflow gives us separation of concerns (the legality of the rewrite is computed separately from its profitability) and, more importantly, lets us "look ahead" and analyze the dot without rewriting its operands. The implementation is far from complete (most of the interesting bits are TODO) but I wanted to get an early design review before I spent too much time on this. [0] Under the assumption that transposing or reshaping are not expensive enough to pay the price of keeping around a new potentially large constant (in particular, some of these may have been equivalent to free bitcasts). PiperOrigin-RevId: 197064648
Loading
Please sign in to comment