From 3f3ef9c902cfcd6eacf82f81939ea37136394085 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 15 Nov 2024 11:29:58 +0800 Subject: [PATCH] simple dot transpose pattern --- .../transforms/disc_algebraic_simplifier.cc | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc b/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc index 9770b32e4a5..da61ccae002 100644 --- a/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc +++ b/tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc @@ -501,6 +501,38 @@ struct TrunciSimplifierPattern : public OpRewritePattern { } }; +// Simplifier dot and transpose op pattern, an example as following: +// from: Dot(A^T, B)^T +// to: Dot(B^T, A) +struct SimplifierDotTransposePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::TransposeOp op, + PatternRewriter& rewriter) const override { + auto loc = op->getLoc(); + Value input = op->getOperand(0); + auto dotOp = input.getDefiningOp(); + if (!dotOp) { + return failure(); + } + auto lhs = dotOp.getLhs(); + auto rhs = dotOp.getRhs(); + auto lhsTranspose = lhs.getDefiningOp(); + if (!lhsTranspose) { + llvm::dbgs() << "hls should be transposed tensor\n"; + return failure(); + } + auto rhsTranspose = rewriter.create( + loc, rhs, lhsTranspose.getPermutation()); + auto newDotOp = rewriter.create( + loc, op.getType(), rhsTranspose, lhsTranspose.getOperand(), + dotOp.getPrecisionConfigAttr()); + rewriter.replaceOp(op, newDotOp.getResult()); + newDotOp->getParentOp()->dump(); + return success(); + } +}; + // Simplify index_cast pattern. An examples as following: // %0 = arith.index_cast %arg0 : index to i32 // %1 = arith.index_cast %0 : i32 to index @@ -718,7 +750,8 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) { SimplifierFromElementsPattern, TrunciSimplifierPattern, IndexCastSimplifierPattern, - SimplifierGetDimensionSizePattern + SimplifierGetDimensionSizePattern, + SimplifierDotTransposePattern >(patterns.getContext()); if (isMemIntensiveOptExperimentalEnabled()) { // Will be enabled by default after a set of robustness testing.