Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions integration_test/circt-synth/divmod.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,36 @@ hw.module @divmod_mix_constant(in %in: i1, in %lhs: i1, in %rhs: i1, out out_div
hw.output %0, %1, %2, %3 : i4, i4, i4, i4
}

// RUN: circt-lec %t.mlir %s -c1=divmodu_constants -c2=divmodu_constants --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVMODU_CONSTANTS
// COMB_DIVMODU_CONSTANTS: c1 == c2
hw.module @divmodu_constants(in %lhs: i8, out out_divu_7: i8, out out_divu_10: i8, out out_modu_3: i8) {
%c7_i8 = hw.constant 7 : i8
%c10_i8 = hw.constant 10 : i8
%c3_i8 = hw.constant 3 : i8

%0 = comb.divu %lhs, %c7_i8 : i8
%1 = comb.divu %lhs, %c10_i8 : i8
%2 = comb.modu %lhs, %c3_i8 : i8
hw.output %0, %1, %2 : i8, i8, i8
}

// RUN: circt-lec %t.mlir %s -c1=divmods_constants -c2=divmods_constants --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVMODS_CONSTANTS
// COMB_DIVMODS_CONSTANTS: c1 == c2
hw.module @divmods_constants(in %lhs: i8, out out_divs_3: i8, out out_divs_neg3: i8, out out_mods_3: i8, out out_mods_neg3: i8) {
%c3_i8 = hw.constant 3 : i8
%c-3_i8 = hw.constant -3 : i8

%0 = comb.divs %lhs, %c3_i8 : i8
%1 = comb.divs %lhs, %c-3_i8 : i8
%2 = comb.mods %lhs, %c3_i8 : i8
%3 = comb.mods %lhs, %c-3_i8 : i8
hw.output %0, %1, %2, %3 : i8, i8, i8, i8
}

// RUN: circt-lec %t.mlir %s -c1=const_divmod_mods_neg1_i3 -c2=const_divmod_mods_neg1_i3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS_NEG1_I3
// COMB_MODS_NEG1_I3: c1 == c2
hw.module @const_divmod_mods_neg1_i3(in %lhs: i3, out out: i3) {
%c_neg1_i3 = hw.constant -1 : i3
%0 = comb.mods %lhs, %c_neg1_i3 : i3
hw.output %0 : i3
}
149 changes: 149 additions & 0 deletions lib/Conversion/CombToSynth/CombToSynth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DivisionByConstantInfo.h"
#include <array>

#define DEBUG_TYPE "comb-to-synth"
Expand Down Expand Up @@ -269,6 +270,70 @@ static LogicalResult emulateBinaryOpForUnknownBits(
return success();
}

static Value createLShrByConstant(OpBuilder &builder, Location loc, Value value,
unsigned amount) {
if (amount == 0)
return value;
return builder.createOrFold<comb::ShrUOp>(
loc, value,
hw::ConstantOp::create(
builder, loc,
APInt(value.getType().getIntOrFloatBitWidth(), amount)));
}

static Value createAShrByConstant(OpBuilder &builder, Location loc, Value value,
unsigned amount) {
if (amount == 0)
return value;
return builder.createOrFold<comb::ShrSOp>(
loc, value,
hw::ConstantOp::create(
builder, loc,
APInt(value.getType().getIntOrFloatBitWidth(), amount)));
}

template <bool isSigned>
static Value createMulHigh(OpBuilder &builder, Location loc, Value lhs,
const APInt &rhs) {
unsigned width = lhs.getType().getIntOrFloatBitWidth();
auto destTy = builder.getIntegerType(width << 1);
Value wideLhs = isSigned ? comb::createOrFoldSExt(builder, loc, lhs, destTy)
: comb::createZExt(builder, loc, lhs, width << 1);
Value wideRhs = hw::ConstantOp::create(
builder, loc, isSigned ? rhs.sext(width << 1) : rhs.zext(width << 1));
Value product = builder.createOrFold<comb::MulOp>(
loc, ValueRange{wideLhs, wideRhs}, true);
return builder.createOrFold<comb::ExtractOp>(loc, product, width, width);
}

static Value lowerUnsignedDivByConstant(OpBuilder &builder, Location loc,
Value lhs, const APInt &divisor) {
auto info = llvm::UnsignedDivisionByConstantInfo::get(divisor);
Value q = createLShrByConstant(builder, loc, lhs, info.PreShift);
q = createMulHigh<false>(builder, loc, q, info.Magic);
if (info.IsAdd) {
Value diff = builder.createOrFold<comb::SubOp>(loc, lhs, q);
diff = createLShrByConstant(builder, loc, diff, 1);
q = builder.createOrFold<comb::AddOp>(loc, q, diff);
}
return createLShrByConstant(builder, loc, q, info.PostShift);
}

static Value lowerSignedDivByConstant(OpBuilder &builder, Location loc,
Value lhs, const APInt &divisor) {
unsigned width = lhs.getType().getIntOrFloatBitWidth();
auto info = llvm::SignedDivisionByConstantInfo::get(divisor);
Value q = createMulHigh<true>(builder, loc, lhs, info.Magic);
if (divisor.isStrictlyPositive() && info.Magic.isNegative())
q = builder.createOrFold<comb::AddOp>(loc, q, lhs);
else if (divisor.isNegative() && info.Magic.isStrictlyPositive())
q = builder.createOrFold<comb::SubOp>(loc, q, lhs);
q = createAShrByConstant(builder, loc, q, info.ShiftAmount);
Value signBit = builder.createOrFold<comb::ExtractOp>(loc, q, width - 1, 1);
Value signPadded = comb::createZExt(builder, loc, signBit, width);
return builder.createOrFold<comb::AddOp>(loc, q, signPadded);
}

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -969,6 +1034,21 @@ struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
return success();

// Lower div if rhs is a constant value
if (auto rhsConst = adaptor.getRhs().getDefiningOp<hw::ConstantOp>()) {
APInt divisor = rhsConst.getValue();
if (divisor.isZero()) {
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
op.getType(), 0);
return success();
}
replaceOpAndCopyNamehint(rewriter, op,
lowerUnsignedDivByConstant(rewriter, op.getLoc(),
adaptor.getLhs(),
divisor));
return success();
}

// When rhs is not power of two and the number of unknown bits are small,
// create a mux tree that emulates all possible cases.
return emulateBinaryOpForUnknownBits(
Expand All @@ -991,6 +1071,26 @@ struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
return success();

// Lower mod where rhs is constant value
if (auto rhsConst = adaptor.getRhs().getDefiningOp<hw::ConstantOp>()) {
APInt divisor = rhsConst.getValue();
if (divisor.isZero()) {
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
op.getType(), 0);
return success();
}
auto loc = op.getLoc();
Value q =
lowerUnsignedDivByConstant(rewriter, loc, adaptor.getLhs(), divisor);
Value product =
rewriter.createOrFold<comb::MulOp>(loc, q, adaptor.getRhs());
Value remainder =
rewriter.createOrFold<comb::SubOp>(loc, adaptor.getLhs(), product);
replaceOpAndCopyNamehint(rewriter, op, remainder);

return success();
}

// When rhs is not power of two and the number of unknown bits are small,
// create a mux tree that emulates all possible cases.
return emulateBinaryOpForUnknownBits(
Expand All @@ -1012,6 +1112,36 @@ struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
ConversionPatternRewriter &rewriter) const override {
// Currently only lower with emulation.
// TODO: Implement a signed division lowering at least for power of two.

if (auto rhsConst = adaptor.getRhs().getDefiningOp<hw::ConstantOp>()) {
APInt divisor = rhsConst.getValue();
unsigned width = op.getType().getIntOrFloatBitWidth();
if (divisor.isZero()) {
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
op.getType(), 0);
return success();
}
if (divisor.isOne()) {
replaceOpAndCopyNamehint(rewriter, op, adaptor.getLhs());
return success();
}
if (divisor.isAllOnes()) {
replaceOpAndCopyNamehint(
rewriter, op,
rewriter.createOrFold<comb::SubOp>(
op.getLoc(),
hw::ConstantOp::create(rewriter, op.getLoc(),
APInt::getZero(width)),
adaptor.getLhs()));
return success();
}
replaceOpAndCopyNamehint(rewriter, op,
lowerSignedDivByConstant(rewriter, op.getLoc(),
adaptor.getLhs(),
divisor));
return success();
}

return emulateBinaryOpForUnknownBits(
rewriter, maxEmulationUnknownBits, op,
[](const APInt &lhs, const APInt &rhs) {
Expand All @@ -1030,6 +1160,25 @@ struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
ConversionPatternRewriter &rewriter) const override {
// Currently only lower with emulation.
// TODO: Implement a signed modulus lowering at least for power of two.

if (auto rhsConst = adaptor.getRhs().getDefiningOp<hw::ConstantOp>()) {
APInt divisor = rhsConst.getValue();
if (divisor.isZero() || divisor.isOne() || divisor.isAllOnes()) {
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
op.getType(), 0);
return success();
}
auto loc = op.getLoc();
Value q =
lowerSignedDivByConstant(rewriter, loc, adaptor.getLhs(), divisor);
Value product =
rewriter.createOrFold<comb::MulOp>(loc, q, adaptor.getRhs());
Value remainder =
rewriter.createOrFold<comb::SubOp>(loc, adaptor.getLhs(), product);
replaceOpAndCopyNamehint(rewriter, op, remainder);
return success();
}

return emulateBinaryOpForUnknownBits(
rewriter, maxEmulationUnknownBits, op,
[](const APInt &lhs, const APInt &rhs) {
Expand Down
Loading
Loading