diff --git a/include/circt/Dialect/Synth/SynthAttributes.td b/include/circt/Dialect/Synth/SynthAttributes.td index 9844a3e554f2..e6d1093a71de 100644 --- a/include/circt/Dialect/Synth/SynthAttributes.td +++ b/include/circt/Dialect/Synth/SynthAttributes.td @@ -76,12 +76,14 @@ def MappingCostAttr : AttrDef { let summary = "Simplified timing and area cost for tech mapping"; let parameters = (ins "::mlir::FloatAttr":$area, - "::mlir::ArrayAttr":$arcs, - "::mlir::DictionaryAttr":$inputCaps + OptionalParameter<"::mlir::ArrayAttr">:$arcs, + OptionalParameter<"::mlir::DictionaryAttr">:$inputCaps ); + let genVerifyDecl = 1; let assemblyFormat = - "`<` `area` `=` $area `,` `arcs` `=` $arcs `,` " - "`input_caps` `=` $inputCaps `>`"; + "`<` `area` `=` $area " + "(`,` `arcs` `=` $arcs^)? " + "(`,` `input_caps` `=` $inputCaps^)? `>`"; } #endif // CIRCT_DIALECT_SYNTH_SYNTHATTRIBUTES_TD diff --git a/include/circt/Dialect/Synth/SynthOps.h b/include/circt/Dialect/Synth/SynthOps.h index aae1b0bee7f1..51f7b24c8134 100644 --- a/include/circt/Dialect/Synth/SynthOps.h +++ b/include/circt/Dialect/Synth/SynthOps.h @@ -13,6 +13,7 @@ #ifndef CIRCT_DIALECT_SYNTH_SYNTHOPS_H #define CIRCT_DIALECT_SYNTH_SYNTHOPS_H +#include "circt/Dialect/Synth/SynthAttributes.h" #include "circt/Dialect/Synth/SynthDialect.h" #include "circt/Dialect/Synth/SynthOpInterfaces.h" #include "circt/Support/LLVM.h" diff --git a/include/circt/Dialect/Synth/SynthOps.td b/include/circt/Dialect/Synth/SynthOps.td index b6b610770af9..0f79b562821a 100644 --- a/include/circt/Dialect/Synth/SynthOps.td +++ b/include/circt/Dialect/Synth/SynthOps.td @@ -305,6 +305,29 @@ def GambleOp : SymmetricThreeInputOp<"gamble", "evaluateGambleLogic"> { }]; } +def CutRewritePatternOp : SynthOp<"cut_rewrite_pattern", [ + IsolatedFromAbove, + SingleBlockImplicitTerminator<"YieldOp"> +]> { + let summary = "Declarative cut rewrite pattern"; + + let arguments = (ins + TypeAttrOf:$function_type, + MappingCostAttr:$cost + ); + + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def YieldOp : SynthOp<"yield", + [Pure, Terminator]> { + let summary = "Yield synth operations"; + + let arguments = (ins Variadic:$operands); + let assemblyFormat = "$operands attr-dict `:` type($operands)"; +} #endif // CIRCT_DIALECT_SYNTH_SYNTHOPS_TD diff --git a/lib/Dialect/Synth/CMakeLists.txt b/lib/Dialect/Synth/CMakeLists.txt index a4f378823fb4..da3ee219c7fc 100644 --- a/lib/Dialect/Synth/CMakeLists.txt +++ b/lib/Dialect/Synth/CMakeLists.txt @@ -5,6 +5,7 @@ ##===----------------------------------------------------------------------===// add_circt_dialect_library(CIRCTSynth + SynthAttributes.cpp SynthDialect.cpp SynthOpInterfaces.cpp SynthOps.cpp diff --git a/lib/Dialect/Synth/SynthAttributes.cpp b/lib/Dialect/Synth/SynthAttributes.cpp new file mode 100644 index 000000000000..294804b829b2 --- /dev/null +++ b/lib/Dialect/Synth/SynthAttributes.cpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Synth/SynthAttributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace circt; +using namespace circt::synth; +using namespace mlir; + +//===----------------------------------------------------------------------===// +// MappingCostAttr +//===----------------------------------------------------------------------===// + +LogicalResult +MappingCostAttr::verify(llvm::function_ref emitError, + FloatAttr area, ArrayAttr arcs, + DictionaryAttr inputCaps) { + if (arcs) + for (auto attr : arcs) + if (!isa(attr)) + return emitError() + << "expected arcs to contain synth.linear_timing_arc"; + + if (inputCaps) + for (auto entry : inputCaps) + if (!isa(entry.getValue())) + return emitError() + << "expected input_caps values to be floating-point attributes"; + + return success(); +} diff --git a/lib/Dialect/Synth/SynthOps.cpp b/lib/Dialect/Synth/SynthOps.cpp index 166018b6342d..e6ee980a079d 100644 --- a/lib/Dialect/Synth/SynthOps.cpp +++ b/lib/Dialect/Synth/SynthOps.cpp @@ -10,6 +10,7 @@ #include "circt/Dialect/Comb/CombOps.h" #include "circt/Dialect/HW/HWOps.h" #include "circt/Dialect/HW/HWTypes.h" +#include "circt/Dialect/Synth/SynthAttributes.h" #include "circt/Support/CustomDirectiveImpl.h" #include "circt/Support/Naming.h" #include "circt/Support/SATSolver.h" @@ -19,7 +20,10 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" @@ -626,3 +630,103 @@ void GambleOp::emitCNFWithoutInversion( // out = allSet | ~orSet circt::addOrClauses(outVar, {allSet, -orSet}, addClause); } + +//===----------------------------------------------------------------------===// +// CutRewritePatternOp +//===----------------------------------------------------------------------===// + +ParseResult CutRewritePatternOp::parse(OpAsmParser &parser, + OperationState &result) { + + SmallVector entryArgs; + SmallVector resultTypes; + SmallVector resultAttrs; + bool isVariadic = false; + + if (function_interface_impl::parseFunctionSignatureWithArguments( + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + auto inputTypes = llvm::map_to_vector( + entryArgs, [](auto &arg) -> Type { return arg.type; }); + auto functionType = + parser.getBuilder().getFunctionType(inputTypes, resultTypes); + + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(functionType)); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + return parser.parseRegion(*result.addRegion(), entryArgs, + /*enableNameShadowing=*/false); +} + +void CutRewritePatternOp::print(OpAsmPrinter &p) { + auto functionType = getFunctionType(); + call_interface_impl::printFunctionSignature( + p, functionType.getInputs(), /*argAttrs=*/{}, /*isVariadic=*/false, + functionType.getResults(), /*resultAttrs=*/{}, &getBody(), + /*printEmptyResult=*/false); + + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), + {getFunctionTypeAttrName()}); + + p << ' '; + p.printRegion(getBody(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} + +LogicalResult CutRewritePatternOp::verify() { + auto functionType = getFunctionType(); + + if (functionType.getNumResults() != 1) + return emitError() << "requires exactly one result"; + + for (auto type : functionType.getInputs()) + if (!type.isInteger(1)) + return emitError() << "argument type must be i1, but got " << type; + + for (auto type : functionType.getResults()) + if (!type.isInteger(1)) + return emitError() << "result type must be i1, but got " << type; + + // Check outputs. + auto *terminator = this->getBody().front().getTerminator(); + if (terminator->getOperands().size() != functionType.getNumResults()) + return emitError() << "result type doesn't match with the terminator"; + + for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(), + functionType.getResults())) + if (rhs != lhs) + return emitError() << rhs << " is expected but got " << lhs; + + auto blockArgs = this->getBody().front().getArguments(); + if (blockArgs.size() != functionType.getNumInputs()) + return emitError() << "operand type doesn't match with the block arg"; + + for (auto [blockArg, inputType] : + llvm::zip(blockArgs, functionType.getInputs())) + if (blockArg.getType() != inputType) + return emitError() << inputType << " is expected but got " + << blockArg.getType(); + + auto cost = getCost(); + if (auto arcs = cost.getArcs()) + for (auto attr : arcs) { + // TODO: LinearTimingArcAttr pin names are currently ignored for cut + // rewrite patterns. In the future MappingCostAttr should use a 2D + // ArrayAttr container for arcs. Names may be dropped from + // LinearTimingArcAttr in a future cleanup. + if (!isa(attr)) + return emitError() + << "mapping cost arcs must use synth.linear_timing_arc"; + } + + if (auto inputCaps = cost.getInputCaps()) + if (inputCaps.size() != functionType.getNumInputs()) + return emitError() + << "input_caps size must match the number of arguments"; + + return success(); +} diff --git a/lib/Dialect/Synth/Transforms/TechMapper.cpp b/lib/Dialect/Synth/Transforms/TechMapper.cpp index 405c3a4987d6..a096658510f7 100644 --- a/lib/Dialect/Synth/Transforms/TechMapper.cpp +++ b/lib/Dialect/Synth/Transforms/TechMapper.cpp @@ -222,8 +222,14 @@ struct TechMapperPass : public impl::TechMapperBase { } llvm::DenseMap delayByInput; - for (auto attr : mappingCost.getArcs()) { - auto arc = cast(attr); + auto arcs = mappingCost.getArcs(); + if (!arcs) { + hwModule.emitError("expected synth.mapping_cost to contain arcs"); + signalPassFailure(); + return; + } + for (auto attr : arcs) { + auto arc = dyn_cast(attr); if (!arc) { hwModule.emitError( "expected synth.linear_timing_arc in synth.mapping_cost arcs"); diff --git a/test/Dialect/Synth/errors.mlir b/test/Dialect/Synth/errors.mlir index 91c6e764d973..ea170b7046cb 100644 --- a/test/Dialect/Synth/errors.mlir +++ b/test/Dialect/Synth/errors.mlir @@ -5,3 +5,41 @@ hw.module @test(out result : i1) { %0 = synth.choice : i1 hw.output %0 : i1 } + +// ----- + +// expected-error @below {{argument type must be i1, but got 'i2'}} +synth.cut_rewrite_pattern (%a: i2) -> i1 attributes {cost = #synth.mapping_cost} { + %0 = comb.extract %a from 0 : (i2) -> i1 + synth.yield %0 : i1 +} + +// ----- + +// expected-error @below {{result type must be i1, but got 'i2'}} +synth.cut_rewrite_pattern (%a: i1) -> i2 attributes {cost = #synth.mapping_cost} { + %0 = hw.constant 0 : i2 + synth.yield %0 : i2 +} + +// ----- + +// expected-error @below {{requires exactly one result}} +synth.cut_rewrite_pattern (%a: i1) -> (i1, i1) attributes {cost = #synth.mapping_cost} { + synth.yield %a, %a : i1, i1 +} + +// ----- + +// expected-error @below {{result type doesn't match with the terminator}} +synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost} { + "synth.yield"() : () -> () +} + +// ----- + +// expected-error @below {{'i1' is expected but got 'i2'}} +synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost} { + %0 = hw.constant 0 : i2 + synth.yield %0 : i2 +} diff --git a/test/Dialect/Synth/round-trip.mlir b/test/Dialect/Synth/round-trip.mlir index 03d5f72d78cf..73c6f1763bab 100644 --- a/test/Dialect/Synth/round-trip.mlir +++ b/test/Dialect/Synth/round-trip.mlir @@ -54,3 +54,25 @@ hw.module @mux_inv(in %c: i4, in %a: i4, in %b: i4) { hw.module @gamble(in %x: i1, in %y: i1, in %z: i1) { %0 = synth.gamble %x, not %y, %z : i1 } + +// CHECK-LABEL: synth.cut_rewrite_pattern +// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1, %{{.*}}: i1) -> i1 +// CHECK-SAME: attributes {cost = #synth.mapping_cost> +synth.cut_rewrite_pattern (%a: i1, %b: i1, %c: i1) -> i1 attributes { + cost = #synth.mapping_cost> + ]> +} { + %0 = synth.aig.and_inv %a, not %b, %c : i1 + synth.yield %0 : i1 +} + +// CHECK-LABEL: synth.cut_rewrite_pattern +// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1) -> i1 attributes {cost = #synth.mapping_cost i1 attributes { + cost = #synth.mapping_cost +} { + %0 = synth.aig.and_inv %a, %b : i1 + synth.yield %0 : i1 +}