[transform_ext] Transform op to convert const dense_resource ops to function args#170
[transform_ext] Transform op to convert const dense_resource ops to function args#170tkarna wants to merge 1 commit into
Conversation
|
High-level question, what happens to the original data from the dense resource? |
The const dense_resource op is removed so the data is lost. We can access the weight data from the torch model and then pass that buffer to the mlir kernel. |
|
To motivate this transformation: The const dense_resource ops are problematic for at least two reasons:
The present PR is just a workaround though. The payload analysis and arg ordering uses linalg.matmul ops as "anchor ops" that is used to infer the roles of the const weight arrays. This approach works for gemms and MLPs but does not generalize well: for example convolution ops can be linag.generic ops in which case we'd need to analyze whether the op is a convolution or, say, elementwise post op. Such analysis can get arbitrarily complex. In long term, we should probably handle the const model data at higher level. At torch model level we can identify the model layers and their weights so we can easily identify their roles. We could for example modify torch-mlir to directly generate the weights as func args with proper annotation/ordering or use torch.compile instead (?). |
Adds
convert_const_resources_to_argstransform op that finds all arith.constant dense_resource ops and moves them to function arguments.For example, KernelBench level2-9 case has matmul weights and bias encoded as dense_resource ops:
The weight and bias tensors are appended to the function arguments:
In general the payload function can have arbitrarily many dense_resource ops in any arbitrary order. They need to be identified to be able to pass in the right buffers. For now, the arguments are ordered by matmul ops and where they appear in the matmul producer/consumer chain: matmul_0_A, matmul_0_B, matmul_0_epilogue, matmul_1_A, etc. If dense_resource is not associated with a matmul op an error is raised.