Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/circt/Dialect/Synth/SynthAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def MappingCostAttr : AttrDef<Synth_Dialect, "MappingCost"> {
let summary = "Simplified timing and area cost for tech mapping";
let parameters = (ins
"::mlir::FloatAttr":$area,
"::mlir::ArrayAttr":$arcs,
"::mlir::DictionaryAttr":$inputCaps
OptionalParameter<"::mlir::ArrayAttr">:$arcs,
OptionalParameter<"::mlir::DictionaryAttr">:$inputCaps
);
let genVerifyDecl = 1;
let assemblyFormat =
"`<` `area` `=` $area `,` `arcs` `=` $arcs `,` "
"`input_caps` `=` $inputCaps `>`";
"`<` `area` `=` $area "
"(`,` `arcs` `=` $arcs^)? "
"(`,` `input_caps` `=` $inputCaps^)? `>`";
}

#endif // CIRCT_DIALECT_SYNTH_SYNTHATTRIBUTES_TD
1 change: 1 addition & 0 deletions include/circt/Dialect/Synth/SynthOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef CIRCT_DIALECT_SYNTH_SYNTHOPS_H
#define CIRCT_DIALECT_SYNTH_SYNTHOPS_H

#include "circt/Dialect/Synth/SynthAttributes.h"
#include "circt/Dialect/Synth/SynthDialect.h"
#include "circt/Dialect/Synth/SynthOpInterfaces.h"
#include "circt/Support/LLVM.h"
Expand Down
23 changes: 23 additions & 0 deletions include/circt/Dialect/Synth/SynthOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,29 @@ def GambleOp : SymmetricThreeInputOp<"gamble", "evaluateGambleLogic"> {
}];
}

def CutRewritePatternOp : SynthOp<"cut_rewrite_pattern", [
IsolatedFromAbove,
SingleBlockImplicitTerminator<"YieldOp">
]> {
let summary = "Declarative cut rewrite pattern";

let arguments = (ins
TypeAttrOf<FunctionType>:$function_type,
MappingCostAttr:$cost
);

let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

def YieldOp : SynthOp<"yield",
[Pure, Terminator]> {
let summary = "Yield synth operations";

let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = "$operands attr-dict `:` type($operands)";
}


#endif // CIRCT_DIALECT_SYNTH_SYNTHOPS_TD
1 change: 1 addition & 0 deletions lib/Dialect/Synth/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
##===----------------------------------------------------------------------===//

add_circt_dialect_library(CIRCTSynth
SynthAttributes.cpp
SynthDialect.cpp
SynthOpInterfaces.cpp
SynthOps.cpp
Expand Down
39 changes: 39 additions & 0 deletions lib/Dialect/Synth/SynthAttributes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "circt/Dialect/Synth/SynthAttributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace circt;
using namespace circt::synth;
using namespace mlir;

//===----------------------------------------------------------------------===//
// MappingCostAttr
//===----------------------------------------------------------------------===//

LogicalResult
MappingCostAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
FloatAttr area, ArrayAttr arcs,
DictionaryAttr inputCaps) {
if (arcs)
for (auto attr : arcs)
if (!isa<LinearTimingArcAttr>(attr))
return emitError()
<< "expected arcs to contain synth.linear_timing_arc";

if (inputCaps)
for (auto entry : inputCaps)
if (!isa<FloatAttr>(entry.getValue()))
return emitError()
<< "expected input_caps values to be floating-point attributes";

return success();
}
104 changes: 104 additions & 0 deletions lib/Dialect/Synth/SynthOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWTypes.h"
#include "circt/Dialect/Synth/SynthAttributes.h"
#include "circt/Support/CustomDirectiveImpl.h"
#include "circt/Support/Naming.h"
#include "circt/Support/SATSolver.h"
Expand All @@ -19,7 +20,10 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
Expand Down Expand Up @@ -626,3 +630,103 @@ void GambleOp::emitCNFWithoutInversion(
// out = allSet | ~orSet
circt::addOrClauses(outVar, {allSet, -orSet}, addClause);
}

//===----------------------------------------------------------------------===//
// CutRewritePatternOp
//===----------------------------------------------------------------------===//

ParseResult CutRewritePatternOp::parse(OpAsmParser &parser,
OperationState &result) {

SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<Type> resultTypes;
SmallVector<DictionaryAttr> resultAttrs;
bool isVariadic = false;

if (function_interface_impl::parseFunctionSignatureWithArguments(
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
resultAttrs))
return failure();

auto inputTypes = llvm::map_to_vector(
entryArgs, [](auto &arg) -> Type { return arg.type; });
auto functionType =
parser.getBuilder().getFunctionType(inputTypes, resultTypes);

result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(functionType));
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();

return parser.parseRegion(*result.addRegion(), entryArgs,
/*enableNameShadowing=*/false);
}

void CutRewritePatternOp::print(OpAsmPrinter &p) {
auto functionType = getFunctionType();
call_interface_impl::printFunctionSignature(
p, functionType.getInputs(), /*argAttrs=*/{}, /*isVariadic=*/false,
functionType.getResults(), /*resultAttrs=*/{}, &getBody(),
/*printEmptyResult=*/false);

p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
{getFunctionTypeAttrName()});

p << ' ';
p.printRegion(getBody(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}

LogicalResult CutRewritePatternOp::verify() {
auto functionType = getFunctionType();

if (functionType.getNumResults() != 1)
return emitError() << "requires exactly one result";

for (auto type : functionType.getInputs())
if (!type.isInteger(1))
return emitError() << "argument type must be i1, but got " << type;

for (auto type : functionType.getResults())
if (!type.isInteger(1))
return emitError() << "result type must be i1, but got " << type;

// Check outputs.
auto *terminator = this->getBody().front().getTerminator();
if (terminator->getOperands().size() != functionType.getNumResults())
return emitError() << "result type doesn't match with the terminator";

for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(),
functionType.getResults()))
if (rhs != lhs)
return emitError() << rhs << " is expected but got " << lhs;

auto blockArgs = this->getBody().front().getArguments();
if (blockArgs.size() != functionType.getNumInputs())
return emitError() << "operand type doesn't match with the block arg";

for (auto [blockArg, inputType] :
llvm::zip(blockArgs, functionType.getInputs()))
if (blockArg.getType() != inputType)
return emitError() << inputType << " is expected but got "
<< blockArg.getType();

auto cost = getCost();
if (auto arcs = cost.getArcs())
for (auto attr : arcs) {
// TODO: LinearTimingArcAttr pin names are currently ignored for cut
// rewrite patterns. In the future MappingCostAttr should use a 2D
// ArrayAttr container for arcs. Names may be dropped from
// LinearTimingArcAttr in a future cleanup.
if (!isa<LinearTimingArcAttr>(attr))
return emitError()
<< "mapping cost arcs must use synth.linear_timing_arc";
}

if (auto inputCaps = cost.getInputCaps())
if (inputCaps.size() != functionType.getNumInputs())
return emitError()
<< "input_caps size must match the number of arguments";

return success();
}
10 changes: 8 additions & 2 deletions lib/Dialect/Synth/Transforms/TechMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,14 @@ struct TechMapperPass : public impl::TechMapperBase<TechMapperPass> {
}

llvm::DenseMap<StringAttr, DelayType> delayByInput;
for (auto attr : mappingCost.getArcs()) {
auto arc = cast<LinearTimingArcAttr>(attr);
auto arcs = mappingCost.getArcs();
if (!arcs) {
hwModule.emitError("expected synth.mapping_cost to contain arcs");
signalPassFailure();
return;
}
for (auto attr : arcs) {
auto arc = dyn_cast<LinearTimingArcAttr>(attr);
if (!arc) {
hwModule.emitError(
"expected synth.linear_timing_arc in synth.mapping_cost arcs");
Expand Down
38 changes: 38 additions & 0 deletions test/Dialect/Synth/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,41 @@ hw.module @test(out result : i1) {
%0 = synth.choice : i1
hw.output %0 : i1
}

// -----

// expected-error @below {{argument type must be i1, but got 'i2'}}
synth.cut_rewrite_pattern (%a: i2) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
%0 = comb.extract %a from 0 : (i2) -> i1
synth.yield %0 : i1
}

// -----

// expected-error @below {{result type must be i1, but got 'i2'}}
synth.cut_rewrite_pattern (%a: i1) -> i2 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
%0 = hw.constant 0 : i2
synth.yield %0 : i2
}

// -----

// expected-error @below {{requires exactly one result}}
synth.cut_rewrite_pattern (%a: i1) -> (i1, i1) attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
synth.yield %a, %a : i1, i1
}

// -----

// expected-error @below {{result type doesn't match with the terminator}}
synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
"synth.yield"() : () -> ()
}

// -----

// expected-error @below {{'i1' is expected but got 'i2'}}
synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
%0 = hw.constant 0 : i2
synth.yield %0 : i2
}
22 changes: 22 additions & 0 deletions test/Dialect/Synth/round-trip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,25 @@ hw.module @mux_inv(in %c: i4, in %a: i4, in %b: i4) {
hw.module @gamble(in %x: i1, in %y: i1, in %z: i1) {
%0 = synth.gamble %x, not %y, %z : i1
}

// CHECK-LABEL: synth.cut_rewrite_pattern
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1, %{{.*}}: i1) -> i1
// CHECK-SAME: attributes {cost = #synth.mapping_cost<area =
// CHECK-SAME: #synth.linear_timing_arc<"result", "b", 1, 0, <positive>>
synth.cut_rewrite_pattern (%a: i1, %b: i1, %c: i1) -> i1 attributes {
cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [
#synth.linear_timing_arc<"result", "b", 1, 0, #synth.polarity<positive>>
]>
} {
%0 = synth.aig.and_inv %a, not %b, %c : i1
synth.yield %0 : i1
}

// CHECK-LABEL: synth.cut_rewrite_pattern
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1) -> i1 attributes {cost = #synth.mapping_cost<area =
synth.cut_rewrite_pattern (%a: i1, %b: i1) -> i1 attributes {
cost = #synth.mapping_cost<area = 1.0 : f64>
} {
%0 = synth.aig.and_inv %a, %b : i1
synth.yield %0 : i1
}
Loading