diff --git a/integration_test/circt-synth/divmod.mlir b/integration_test/circt-synth/divmod.mlir index e951f7aebf0f..d71953328d4a 100644 --- a/integration_test/circt-synth/divmod.mlir +++ b/integration_test/circt-synth/divmod.mlir @@ -51,3 +51,36 @@ hw.module @divmod_mix_constant(in %in: i1, in %lhs: i1, in %rhs: i1, out out_div hw.output %0, %1, %2, %3 : i4, i4, i4, i4 } +// RUN: circt-lec %t.mlir %s -c1=divmodu_constants -c2=divmodu_constants --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVMODU_CONSTANTS +// COMB_DIVMODU_CONSTANTS: c1 == c2 +hw.module @divmodu_constants(in %lhs: i8, out out_divu_7: i8, out out_divu_10: i8, out out_modu_3: i8) { + %c7_i8 = hw.constant 7 : i8 + %c10_i8 = hw.constant 10 : i8 + %c3_i8 = hw.constant 3 : i8 + + %0 = comb.divu %lhs, %c7_i8 : i8 + %1 = comb.divu %lhs, %c10_i8 : i8 + %2 = comb.modu %lhs, %c3_i8 : i8 + hw.output %0, %1, %2 : i8, i8, i8 +} + +// RUN: circt-lec %t.mlir %s -c1=divmods_constants -c2=divmods_constants --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVMODS_CONSTANTS +// COMB_DIVMODS_CONSTANTS: c1 == c2 +hw.module @divmods_constants(in %lhs: i8, out out_divs_3: i8, out out_divs_neg3: i8, out out_mods_3: i8, out out_mods_neg3: i8) { + %c3_i8 = hw.constant 3 : i8 + %c-3_i8 = hw.constant -3 : i8 + + %0 = comb.divs %lhs, %c3_i8 : i8 + %1 = comb.divs %lhs, %c-3_i8 : i8 + %2 = comb.mods %lhs, %c3_i8 : i8 + %3 = comb.mods %lhs, %c-3_i8 : i8 + hw.output %0, %1, %2, %3 : i8, i8, i8, i8 +} + +// RUN: circt-lec %t.mlir %s -c1=const_divmod_mods_neg1_i3 -c2=const_divmod_mods_neg1_i3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS_NEG1_I3 +// COMB_MODS_NEG1_I3: c1 == c2 +hw.module @const_divmod_mods_neg1_i3(in %lhs: i3, out out: i3) { + %c_neg1_i3 = hw.constant -1 : i3 + %0 = comb.mods %lhs, %c_neg1_i3 : i3 + hw.output %0 : i3 +} diff --git a/lib/Conversion/CombToSynth/CombToSynth.cpp b/lib/Conversion/CombToSynth/CombToSynth.cpp index 55c5ccb1e4cb..be4f29b8ad8a 100644 --- a/lib/Conversion/CombToSynth/CombToSynth.cpp +++ b/lib/Conversion/CombToSynth/CombToSynth.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DivisionByConstantInfo.h" #include #define DEBUG_TYPE "comb-to-synth" @@ -269,6 +270,70 @@ static LogicalResult emulateBinaryOpForUnknownBits( return success(); } +static Value createLShrByConstant(OpBuilder &builder, Location loc, Value value, + unsigned amount) { + if (amount == 0) + return value; + return builder.createOrFold( + loc, value, + hw::ConstantOp::create( + builder, loc, + APInt(value.getType().getIntOrFloatBitWidth(), amount))); +} + +static Value createAShrByConstant(OpBuilder &builder, Location loc, Value value, + unsigned amount) { + if (amount == 0) + return value; + return builder.createOrFold( + loc, value, + hw::ConstantOp::create( + builder, loc, + APInt(value.getType().getIntOrFloatBitWidth(), amount))); +} + +template +static Value createMulHigh(OpBuilder &builder, Location loc, Value lhs, + const APInt &rhs) { + unsigned width = lhs.getType().getIntOrFloatBitWidth(); + auto destTy = builder.getIntegerType(width << 1); + Value wideLhs = isSigned ? comb::createOrFoldSExt(builder, loc, lhs, destTy) + : comb::createZExt(builder, loc, lhs, width << 1); + Value wideRhs = hw::ConstantOp::create( + builder, loc, isSigned ? rhs.sext(width << 1) : rhs.zext(width << 1)); + Value product = builder.createOrFold( + loc, ValueRange{wideLhs, wideRhs}, true); + return builder.createOrFold(loc, product, width, width); +} + +static Value lowerUnsignedDivByConstant(OpBuilder &builder, Location loc, + Value lhs, const APInt &divisor) { + auto info = llvm::UnsignedDivisionByConstantInfo::get(divisor); + Value q = createLShrByConstant(builder, loc, lhs, info.PreShift); + q = createMulHigh(builder, loc, q, info.Magic); + if (info.IsAdd) { + Value diff = builder.createOrFold(loc, lhs, q); + diff = createLShrByConstant(builder, loc, diff, 1); + q = builder.createOrFold(loc, q, diff); + } + return createLShrByConstant(builder, loc, q, info.PostShift); +} + +static Value lowerSignedDivByConstant(OpBuilder &builder, Location loc, + Value lhs, const APInt &divisor) { + unsigned width = lhs.getType().getIntOrFloatBitWidth(); + auto info = llvm::SignedDivisionByConstantInfo::get(divisor); + Value q = createMulHigh(builder, loc, lhs, info.Magic); + if (divisor.isStrictlyPositive() && info.Magic.isNegative()) + q = builder.createOrFold(loc, q, lhs); + else if (divisor.isNegative() && info.Magic.isStrictlyPositive()) + q = builder.createOrFold(loc, q, lhs); + q = createAShrByConstant(builder, loc, q, info.ShiftAmount); + Value signBit = builder.createOrFold(loc, q, width - 1, 1); + Value signPadded = comb::createZExt(builder, loc, signBit, width); + return builder.createOrFold(loc, q, signPadded); +} + //===----------------------------------------------------------------------===// // Conversion patterns //===----------------------------------------------------------------------===// @@ -969,6 +1034,21 @@ struct CombDivUOpConversion : DivModOpConversionBase { if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter))) return success(); + // Lower div if rhs is a constant value + if (auto rhsConst = adaptor.getRhs().getDefiningOp()) { + APInt divisor = rhsConst.getValue(); + if (divisor.isZero()) { + replaceOpWithNewOpAndCopyNamehint(rewriter, op, + op.getType(), 0); + return success(); + } + replaceOpAndCopyNamehint(rewriter, op, + lowerUnsignedDivByConstant(rewriter, op.getLoc(), + adaptor.getLhs(), + divisor)); + return success(); + } + // When rhs is not power of two and the number of unknown bits are small, // create a mux tree that emulates all possible cases. return emulateBinaryOpForUnknownBits( @@ -991,6 +1071,26 @@ struct CombModUOpConversion : DivModOpConversionBase { if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter))) return success(); + // Lower mod where rhs is constant value + if (auto rhsConst = adaptor.getRhs().getDefiningOp()) { + APInt divisor = rhsConst.getValue(); + if (divisor.isZero()) { + replaceOpWithNewOpAndCopyNamehint(rewriter, op, + op.getType(), 0); + return success(); + } + auto loc = op.getLoc(); + Value q = + lowerUnsignedDivByConstant(rewriter, loc, adaptor.getLhs(), divisor); + Value product = + rewriter.createOrFold(loc, q, adaptor.getRhs()); + Value remainder = + rewriter.createOrFold(loc, adaptor.getLhs(), product); + replaceOpAndCopyNamehint(rewriter, op, remainder); + + return success(); + } + // When rhs is not power of two and the number of unknown bits are small, // create a mux tree that emulates all possible cases. return emulateBinaryOpForUnknownBits( @@ -1012,6 +1112,36 @@ struct CombDivSOpConversion : DivModOpConversionBase { ConversionPatternRewriter &rewriter) const override { // Currently only lower with emulation. // TODO: Implement a signed division lowering at least for power of two. + + if (auto rhsConst = adaptor.getRhs().getDefiningOp()) { + APInt divisor = rhsConst.getValue(); + unsigned width = op.getType().getIntOrFloatBitWidth(); + if (divisor.isZero()) { + replaceOpWithNewOpAndCopyNamehint(rewriter, op, + op.getType(), 0); + return success(); + } + if (divisor.isOne()) { + replaceOpAndCopyNamehint(rewriter, op, adaptor.getLhs()); + return success(); + } + if (divisor.isAllOnes()) { + replaceOpAndCopyNamehint( + rewriter, op, + rewriter.createOrFold( + op.getLoc(), + hw::ConstantOp::create(rewriter, op.getLoc(), + APInt::getZero(width)), + adaptor.getLhs())); + return success(); + } + replaceOpAndCopyNamehint(rewriter, op, + lowerSignedDivByConstant(rewriter, op.getLoc(), + adaptor.getLhs(), + divisor)); + return success(); + } + return emulateBinaryOpForUnknownBits( rewriter, maxEmulationUnknownBits, op, [](const APInt &lhs, const APInt &rhs) { @@ -1030,6 +1160,25 @@ struct CombModSOpConversion : DivModOpConversionBase { ConversionPatternRewriter &rewriter) const override { // Currently only lower with emulation. // TODO: Implement a signed modulus lowering at least for power of two. + + if (auto rhsConst = adaptor.getRhs().getDefiningOp()) { + APInt divisor = rhsConst.getValue(); + if (divisor.isZero() || divisor.isOne() || divisor.isAllOnes()) { + replaceOpWithNewOpAndCopyNamehint(rewriter, op, + op.getType(), 0); + return success(); + } + auto loc = op.getLoc(); + Value q = + lowerSignedDivByConstant(rewriter, loc, adaptor.getLhs(), divisor); + Value product = + rewriter.createOrFold(loc, q, adaptor.getRhs()); + Value remainder = + rewriter.createOrFold(loc, adaptor.getLhs(), product); + replaceOpAndCopyNamehint(rewriter, op, remainder); + return success(); + } + return emulateBinaryOpForUnknownBits( rewriter, maxEmulationUnknownBits, op, [](const APInt &lhs, const APInt &rhs) { diff --git a/test/Conversion/CombToSynth/comb-to-aig-arith.mlir b/test/Conversion/CombToSynth/comb-to-aig-arith.mlir index 34163521d4a7..412054c62ab6 100644 --- a/test/Conversion/CombToSynth/comb-to-aig-arith.mlir +++ b/test/Conversion/CombToSynth/comb-to-aig-arith.mlir @@ -1,6 +1,7 @@ // RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-synth{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux},cse))" | FileCheck %s // RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-synth{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add},cse))" | FileCheck %s --check-prefix=ALLOW_ADD // RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-synth{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.icmp}))" | FileCheck %s --check-prefix=ALLOW_ICMP +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-synth{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add,comb.mul,comb.sub},cse))" | FileCheck %s --check-prefix=ALLOW_ARITH // CHECK-LABEL: @parity hw.module @parity(in %arg0: i4, out out: i1) { @@ -447,7 +448,6 @@ hw.module @divmod(in %in: i1, in %rhs: i2, out out_divu: i2, out out_modu: i2, o hw.output %0, %1, %2, %3 : i2, i2, i2, i2 } - // CHECK-LABEL: @divmodu_power_of_two // ALLOW_ICMP-LABEL: @divmodu_power_of_two hw.module @divmodu_power_of_two(in %lhs: i8, out out_divu: i8, out out_modu: i8) { @@ -482,3 +482,176 @@ hw.module @mul_two_bit(in %lhs: i2, in %rhs: i2, out out: i2) { %0 = comb.mul %lhs, %rhs : i2 hw.output %0 : i2 } + +// CHECK-LABEL: @divu_const_3 +// CHECK-NOT: comb.divu +// ALLOW_ARITH-LABEL: @divu_const_3 +// ALLOW_ARITH-NOT: comb.divu +// ALLOW_ARITH: hw.constant 171 : i16 +// ALLOW_ARITH: comb.mul +// ALLOW_ARITH: comb.extract +hw.module @divu_const_3(in %lhs: i8, out out: i8) { + %rhs = hw.constant 3 : i8 + %div = comb.divu %lhs, %rhs : i8 + hw.output %div : i8 +} + +// CHECK-LABEL: @divu_const_7 +// CHECK-NOT: comb.divu +// ALLOW_ARITH-LABEL: @divu_const_7 +// ALLOW_ARITH-NOT: comb.divu +// ALLOW_ARITH: hw.constant 37 : i16 +// ALLOW_ARITH: comb.mul +// ALLOW_ARITH: comb.sub +// ALLOW_ARITH: comb.add +// ALLOW_ARITH: comb.extract +hw.module @divu_const_7(in %lhs: i8, out out: i8) { + %rhs = hw.constant 7 : i8 + %div = comb.divu %lhs, %rhs : i8 + hw.output %div : i8 +} + +// CHECK-LABEL: @divu_const_10 +// CHECK-NOT: comb.divu +// ALLOW_ARITH-LABEL: @divu_const_10 +// ALLOW_ARITH-NOT: comb.divu +// ALLOW_ARITH: hw.constant 205 : i16 +// ALLOW_ARITH: comb.mul +// ALLOW_ARITH: comb.extract +hw.module @divu_const_10(in %lhs: i8, out out: i8) { + %rhs = hw.constant 10 : i8 + %div = comb.divu %lhs, %rhs : i8 + hw.output %div : i8 +} + +// CHECK-LABEL: @modu_const_3 +// CHECK-NOT: comb.modu +// ALLOW_ARITH-LABEL: @modu_const_3 +// ALLOW_ARITH-NOT: comb.modu +// ALLOW_ARITH: hw.constant 3 : i8 +// ALLOW_ARITH: hw.constant 171 : i16 +// ALLOW_ARITH: comb.mul +// ALLOW_ARITH: comb.sub +hw.module @modu_const_3(in %lhs: i8, out out: i8) { + %rhs = hw.constant 3 : i8 + %mod = comb.modu %lhs, %rhs : i8 + hw.output %mod : i8 +} + +// CHECK-LABEL: @divs_const_3 +// CHECK-NOT: comb.divs +// ALLOW_ARITH-LABEL: @divs_const_3 +// ALLOW_ARITH-NOT: comb.divs +// ALLOW_ARITH: hw.constant 86 : i16 +// ALLOW_ARITH: comb.mul +// ALLOW_ARITH: comb.add +hw.module @divs_const_3(in %lhs: i8, out out: i8) { + %rhs = hw.constant 3 : i8 + %div = comb.divs %lhs, %rhs : i8 + hw.output %div : i8 +} + +// CHECK-LABEL: @divs_const_neg3 +// CHECK-NOT: comb.divs +// ALLOW_ARITH-LABEL: @divs_const_neg3 +// ALLOW_ARITH-NOT: comb.divs +// ALLOW_ARITH: hw.constant 85 : i16 +// ALLOW_ARITH: comb.mul +// ALLOW_ARITH: comb.sub +// ALLOW_ARITH: comb.add +hw.module @divs_const_neg3(in %lhs: i8, out out: i8) { + %rhs = hw.constant -3 : i8 + %div = comb.divs %lhs, %rhs : i8 + hw.output %div : i8 +} + +// CHECK-LABEL: @mods_const_3 +// CHECK-NOT: comb.mods +// ALLOW_ARITH-LABEL: @mods_const_3 +// ALLOW_ARITH-NOT: comb.mods +// ALLOW_ARITH: hw.constant 3 : i8 +// ALLOW_ARITH: hw.constant 86 : i16 +// ALLOW_ARITH: comb.mul +// ALLOW_ARITH: comb.add +// ALLOW_ARITH: comb.sub +hw.module @mods_const_3(in %lhs: i8, out out: i8) { + %rhs = hw.constant 3 : i8 + %mod = comb.mods %lhs, %rhs : i8 + hw.output %mod : i8 +} + +// CHECK-LABEL: @mods_const_neg3 +// CHECK-NOT: comb.mods +// ALLOW_ARITH-LABEL: @mods_const_neg3 +// ALLOW_ARITH-NOT: comb.mods +// ALLOW_ARITH: hw.constant -3 : i8 +// ALLOW_ARITH: hw.constant 85 : i16 +// ALLOW_ARITH: comb.mul +// ALLOW_ARITH: comb.sub +// ALLOW_ARITH: comb.add +hw.module @mods_const_neg3(in %lhs: i8, out out: i8) { + %rhs = hw.constant -3 : i8 + %mod = comb.mods %lhs, %rhs : i8 + hw.output %mod : i8 +} + +// CHECK-LABEL: @divs_const_1 +// CHECK-NOT: comb.divs +// CHECK-NOT: comb.mul +// CHECK-NOT: comb.replicate +// CHECK: hw.output %lhs +// ALLOW_ARITH-LABEL: @divs_const_1 +// ALLOW_ARITH-NOT: comb.divs +// ALLOW_ARITH-NOT: comb.mul +// ALLOW_ARITH-NOT: comb.replicate +// ALLOW_ARITH: hw.output %lhs +hw.module @divs_const_1(in %lhs: i8, out out: i8) { + %rhs = hw.constant 1 : i8 + %div = comb.divs %lhs, %rhs : i8 + hw.output %div : i8 +} + +// CHECK-LABEL: @divs_const_neg1 +// CHECK-NOT: comb.divs +// CHECK: hw.constant -1 : i8 +// ALLOW_ARITH-LABEL: @divs_const_neg1 +// ALLOW_ARITH-NOT: comb.divs +// ALLOW_ARITH: hw.constant 0 : i8 +// ALLOW_ARITH: comb.sub +hw.module @divs_const_neg1(in %lhs: i8, out out: i8) { + %rhs = hw.constant -1 : i8 + %div = comb.divs %lhs, %rhs : i8 + hw.output %div : i8 +} + +// CHECK-LABEL: @divu_const_0 +// CHECK: hw.constant 0 : i8 +// ALLOW_ARITH-LABEL: @divu_const_0 +// ALLOW_ARITH: hw.constant 0 : i8 +hw.module @divu_const_0(in %lhs: i8, out out: i8) { + %rhs = hw.constant 0 : i8 + %div = comb.divu %lhs, %rhs : i8 + hw.output %div : i8 +} + +// CHECK-LABEL: @modu_const_0 +// CHECK: hw.constant 0 : i8 +// ALLOW_ARITH-LABEL: @modu_const_0 +// ALLOW_ARITH: hw.constant 0 : i8 +hw.module @modu_const_0(in %lhs: i8, out out: i8) { + %rhs = hw.constant 0 : i8 + %mod = comb.modu %lhs, %rhs : i8 + hw.output %mod : i8 +} + +// CHECK-LABEL: @const_divmod_mods_neg1_i3 +// CHECK-NOT: comb.mods +// CHECK: hw.constant 0 : i3 +// ALLOW_ARITH-LABEL: @const_divmod_mods_neg1_i3 +// ALLOW_ARITH-NOT: comb.mods +// ALLOW_ARITH: hw.constant 0 : i3 +hw.module @const_divmod_mods_neg1_i3(in %lhs: i3, out out: i3) { + %c_neg1_i3 = hw.constant -1 : i3 + %0 = comb.mods %lhs, %c_neg1_i3 : i3 + hw.output %0 : i3 +}