From db0b0d4aea4018b77543a8c749673bece2fb7ea2 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Mar 2026 18:15:35 -0500 Subject: [PATCH 01/18] Poseidon --- .github/workflows/poseidon.yml | 46 + enzyme/Enzyme/AdjointGenerator.h | 10 +- enzyme/Enzyme/CApi.cpp | 3 +- enzyme/Enzyme/CMakeLists.txt | 63 + enzyme/Enzyme/DiffeGradientUtils.cpp | 12 +- enzyme/Enzyme/DiffeGradientUtils.h | 3 +- enzyme/Enzyme/DifferentialUseAnalysis.h | 9 +- enzyme/Enzyme/Enzyme.cpp | 204 ++- enzyme/Enzyme/EnzymeLogic.cpp | 28 +- enzyme/Enzyme/EnzymeLogic.h | 9 +- enzyme/Enzyme/FunctionUtils.cpp | 37 +- enzyme/Enzyme/FunctionUtils.h | 6 +- enzyme/Enzyme/GradientUtils.cpp | 7 +- enzyme/Enzyme/GradientUtils.h | 1 + enzyme/Enzyme/Poseidon/Poseidon.cpp | 1268 +++++++++++++++++ enzyme/Enzyme/Poseidon/Poseidon.h | 61 + enzyme/Enzyme/Poseidon/PoseidonEvaluators.cpp | 974 +++++++++++++ enzyme/Enzyme/Poseidon/PoseidonEvaluators.h | 86 ++ .../Enzyme/Poseidon/PoseidonHerbieUtils.cpp | 895 ++++++++++++ enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.h | 77 + enzyme/Enzyme/Poseidon/PoseidonPrecUtils.cpp | 711 +++++++++ enzyme/Enzyme/Poseidon/PoseidonPrecUtils.h | 85 ++ enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp | 158 ++ enzyme/Enzyme/Poseidon/PoseidonProfUtils.h | 51 + enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp | 889 ++++++++++++ enzyme/Enzyme/Poseidon/PoseidonSolvers.h | 54 + enzyme/Enzyme/Poseidon/PoseidonTypes.cpp | 959 +++++++++++++ enzyme/Enzyme/Poseidon/PoseidonTypes.h | 252 ++++ enzyme/Enzyme/Poseidon/PoseidonUtils.cpp | 1022 +++++++++++++ enzyme/Enzyme/Poseidon/PoseidonUtils.h | 83 ++ .../Enzyme/Runtimes/FPProfiler/FPProfiler.cpp | 195 +++ enzyme/test/Enzyme/Poseidon/CMakeLists.txt | 12 + enzyme/test/Enzyme/Poseidon/add_noopt.ll | 24 + enzyme/test/Enzyme/Poseidon/add_prof.ll | 38 + .../test/Enzyme/Poseidon/fcmp_select_prof.ll | 187 +++ enzyme/test/Enzyme/Poseidon/fma_prof.ll | 44 + enzyme/test/Enzyme/Poseidon/structret_prof.ll | 44 + enzyme/test/lit.site.cfg.py.in | 1 + 38 files changed, 8574 insertions(+), 34 deletions(-) create mode 100644 .github/workflows/poseidon.yml create mode 100644 enzyme/Enzyme/Poseidon/Poseidon.cpp create mode 100644 enzyme/Enzyme/Poseidon/Poseidon.h create mode 100644 enzyme/Enzyme/Poseidon/PoseidonEvaluators.cpp create mode 100644 enzyme/Enzyme/Poseidon/PoseidonEvaluators.h create mode 100644 enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.cpp create mode 100644 enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.h create mode 100644 enzyme/Enzyme/Poseidon/PoseidonPrecUtils.cpp create mode 100644 enzyme/Enzyme/Poseidon/PoseidonPrecUtils.h create mode 100644 enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp create mode 100644 enzyme/Enzyme/Poseidon/PoseidonProfUtils.h create mode 100644 enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp create mode 100644 enzyme/Enzyme/Poseidon/PoseidonSolvers.h create mode 100644 enzyme/Enzyme/Poseidon/PoseidonTypes.cpp create mode 100644 enzyme/Enzyme/Poseidon/PoseidonTypes.h create mode 100644 enzyme/Enzyme/Poseidon/PoseidonUtils.cpp create mode 100644 enzyme/Enzyme/Poseidon/PoseidonUtils.h create mode 100644 enzyme/Enzyme/Runtimes/FPProfiler/FPProfiler.cpp create mode 100644 enzyme/test/Enzyme/Poseidon/CMakeLists.txt create mode 100644 enzyme/test/Enzyme/Poseidon/add_noopt.ll create mode 100644 enzyme/test/Enzyme/Poseidon/add_prof.ll create mode 100644 enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll create mode 100644 enzyme/test/Enzyme/Poseidon/fma_prof.ll create mode 100644 enzyme/test/Enzyme/Poseidon/structret_prof.ll diff --git a/.github/workflows/poseidon.yml b/.github/workflows/poseidon.yml new file mode 100644 index 000000000000..d54434dbb0eb --- /dev/null +++ b/.github/workflows/poseidon.yml @@ -0,0 +1,46 @@ +name: Poseidon CI + +on: + push: + branches: + - main + pull_request: + merge_group: + +jobs: + build-linux: + name: Poseidon CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + llvm: ["20"] + build: ["Release", "Debug"] + os: [ubuntu-22.04] + + timeout-minutes: 30 + + steps: + - name: Install dependencies + run: | + wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true + sudo apt-get update + sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libzstd-dev libmpfr-dev racket + sudo python3 -m pip install --upgrade pip lit + - uses: actions/checkout@v4 + - name: mkdir + run: rm -rf build && mkdir build + - name: cmake with Poseidon + working-directory: build + run: cmake ../enzyme -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DENABLE_POSEIDON=1 -DLLVM_EXTERNAL_LIT=`which lit` -DLLVM_DIR=/usr/lib/llvm-${{ matrix.llvm }}/lib/cmake/llvm + - name: make + working-directory: build + run: make -j `nproc` + - name: make check-poseidon + working-directory: build + run: make -j `nproc` check-poseidon + - name: make check-enzyme-integration + working-directory: build + run: make -j3 check-enzyme-integration \ No newline at end of file diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 86341a9fadae..450fb1f303f8 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -48,6 +48,10 @@ #include "TraceUtils.h" #include "TypeAnalysis/TBAA.h" +#ifdef ENABLE_POSEIDON +#include "Poseidon/Poseidon.h" +#endif + #define DEBUG_TYPE "enzyme" // Helper instruction visitor that generates adjoints @@ -4610,7 +4614,8 @@ class AdjointGenerator : public llvm::InstVisitor { .forceAnonymousTape = false, .typeInfo = nextTypeInfo, .runtimeActivity = gutils->runtimeActivity, - .strongZero = gutils->strongZero}, + .strongZero = gutils->strongZero, + .profiled = gutils->profiled}, TR.analyzer->interprocedural, subdata, /*omp*/ true); @@ -5989,7 +5994,8 @@ class AdjointGenerator : public llvm::InstVisitor { .forceAnonymousTape = false, .typeInfo = nextTypeInfo, .runtimeActivity = gutils->runtimeActivity, - .strongZero = gutils->strongZero}, + .strongZero = gutils->strongZero, + .profiled = gutils->profiled}, TR.analyzer->interprocedural, subdata); if (!newcalled) return; diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index bd18c0794a1a..8ba08fa7def5 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -717,7 +717,8 @@ LLVMValueRef EnzymeCreatePrimalAndGradient( .forceAnonymousTape = (bool)forceAnonymousTape, .typeInfo = eunwrap(typeInfo, cast(unwrap(todiff))), .runtimeActivity = (bool)runtimeActivity, - .strongZero = (bool)strongZero}, + .strongZero = (bool)strongZero, + .profiled = false}, eunwrap(TA), eunwrap(augmented))); } EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 93be6d6f0de2..f86b247a4ba8 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -50,6 +50,54 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp TypeAnalysis/TypeAnalysisPrinter.cpp TypeAnalysis/RustDebugInfo.cpp) +set(ENABLE_POSEIDON 0 CACHE BOOL "Enable Poseidon") + +if(ENABLE_POSEIDON) + include(ExternalProject) + ExternalProject_Add(herbie + GIT_REPOSITORY https://github.com/sbrantq/herbie + GIT_TAG 193e2a4e6e4902fe5ee7fa8bd5747c19f70ce19b + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND make clean + COMMAND raco pkg install --auto --update-deps fpbench || true + COMMAND raco pkg install --auto --update-deps rival || true + COMMAND cargo build --release --manifest-path=egg-herbie/Cargo.toml + COMMAND raco pkg install ./egg-herbie + COMMAND mkdir -p herbie-compiled/ + COMMAND raco exe -o herbie --orig-exe --embed-dlls --vv src/main.rkt + COMMAND raco distribute herbie-compiled herbie + BUILD_IN_SOURCE true + INSTALL_COMMAND COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/herbie/install + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_BINARY_DIR}/herbie-prefix/src/herbie/herbie-compiled ${CMAKE_CURRENT_BINARY_DIR}/herbie/install/herbie + ) + list(APPEND ENZYME_SRC + Poseidon/Poseidon.cpp + Poseidon/PoseidonEvaluators.cpp + Poseidon/PoseidonHerbieUtils.cpp + Poseidon/PoseidonProfUtils.cpp + Poseidon/PoseidonPrecUtils.cpp + Poseidon/PoseidonSolvers.cpp + Poseidon/PoseidonTypes.cpp + Poseidon/PoseidonUtils.cpp + ) + add_compile_definitions(ENABLE_POSEIDON=1) + set_source_files_properties(Poseidon/PoseidonHerbieUtils.cpp PROPERTIES COMPILE_DEFINITIONS HERBIE_BINARY="${CMAKE_CURRENT_BINARY_DIR}/herbie/install/herbie/bin/herbie") + + add_library(EnzymeFPProfile STATIC + Runtimes/FPProfiler/FPProfiler.cpp + ) + target_compile_features(EnzymeFPProfile PRIVATE cxx_std_17) + target_include_directories(EnzymeFPProfile PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/Runtimes/FPProfiler) + + install(TARGETS EnzymeFPProfile + EXPORT EnzymeTargets + ARCHIVE DESTINATION lib + COMPONENT dev + ) +endif() + + if (ENZYME_ENABLE_PLUGINS) # on windows `PLUGIN_TOOL` doesn't link against LLVM.dll if ((WIN32 OR CYGWIN) AND LLVM_LINK_LLVM_DYLIB) @@ -119,6 +167,14 @@ if (ENZYME_ENABLE_PLUGINS) ) target_compile_definitions(LLDEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) endif() + + if(ENABLE_POSEIDON) + target_link_libraries(LLVMEnzyme-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + if (${Clang_FOUND}) + target_link_libraries(ClangEnzyme-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + endif() + target_link_libraries(LLDEnzyme-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + endif() endif() if (${ENZYME_STATIC_LIB}) @@ -129,6 +185,9 @@ if (${ENZYME_STATIC_LIB}) DEPENDS intrinsics_gen ) + if(ENABLE_POSEIDON) + target_link_libraries(EnzymeStatic-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + endif() endif() if (${ENZYME_EXTERNAL_SHARED_LIB}) @@ -154,6 +213,10 @@ if (${ENZYME_EXTERNAL_SHARED_LIB}) target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} LLVM) endif() + if(ENABLE_POSEIDON) + target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + endif() + install(TARGETS Enzyme-${LLVM_VERSION_MAJOR} EXPORT EnzymeTargets LIBRARY DESTINATION lib COMPONENT shlib diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 19e3c947b04b..140d492bddf8 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -89,7 +89,8 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( bool strongZero, unsigned width, Function *todiff, TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturn, bool diffeReturnArg, ArrayRef constant_args, - bool returnTape, bool returnPrimal, Type *additionalArg, bool omp) { + bool returnTape, bool returnPrimal, Type *additionalArg, bool omp, + bool profiled) { Function *oldFunc = todiff; assert(mode == DerivativeMode::ReverseModeGradient || mode == DerivativeMode::ReverseModeCombined || @@ -108,14 +109,16 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( std::string prefix; switch (mode) { - case DerivativeMode::ForwardModeError: case DerivativeMode::ForwardMode: case DerivativeMode::ForwardModeSplit: prefix = "fwddiffe"; break; + case DerivativeMode::ForwardModeError: + prefix = "fwderr"; + break; case DerivativeMode::ReverseModeCombined: case DerivativeMode::ReverseModeGradient: - prefix = "diffe"; + prefix = profiled ? "instr" : "diffe"; break; case DerivativeMode::ReverseModePrimal: llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n"); @@ -129,7 +132,7 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( nonconstant_values, returnvals, returnTape, returnPrimal, (mode == DerivativeMode::ReverseModeGradient) ? false : shadowReturn, prefix + oldFunc->getName(), &originalToNew, - /*diffeReturnArg*/ diffeReturnArg, additionalArg); + /*diffeReturnArg*/ diffeReturnArg, additionalArg, profiled); // Convert overwritten args from the input function to the preprocessed // function @@ -165,6 +168,7 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values, nonconstant_values, retType, shadowReturn, constant_args, originalToNew, mode, runtimeActivity, strongZero, width, omp); + res->profiled = profiled; return res; } diff --git a/enzyme/Enzyme/DiffeGradientUtils.h b/enzyme/Enzyme/DiffeGradientUtils.h index d999b08fdbb6..fc53c2aae4d1 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.h +++ b/enzyme/Enzyme/DiffeGradientUtils.h @@ -84,7 +84,8 @@ class DiffeGradientUtils final : public GradientUtils { FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturnArg, bool diffeReturnArg, llvm::ArrayRef constant_args, bool returnTape, - bool returnPrimal, llvm::Type *additionalArg, bool omp); + bool returnPrimal, llvm::Type *additionalArg, bool omp, + bool profiled = false); llvm::AllocaInst *getDifferential(llvm::Value *val); diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index 841b67379f25..93b1be2c5e92 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -123,12 +123,13 @@ inline bool is_value_needed_in_reverse( } } } - if (gutils->mode == DerivativeMode::ForwardModeError && + if ((gutils->profiled || + gutils->mode == DerivativeMode::ForwardModeError) && !gutils->isConstantValue(const_cast(inst))) { if (EnzymePrintDiffUse) - llvm::errs() - << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as forward mode error always needs result\n"; + llvm::errs() << " Need: " << to_string(VT) << " of " << *inst + << " in reverse as profiling mode or forward mode " + "error always needs result\n"; return seen[idx] = true; } } diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index bc0e6343ee97..192a3401369c 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -68,8 +68,11 @@ #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/AbstractCallSite.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -77,6 +80,9 @@ #include "DiffeGradientUtils.h" #include "EnzymeLogic.h" #include "GradientUtils.h" +#ifdef ENABLE_POSEIDON +#include "Poseidon/Poseidon.h" +#endif #include "TraceInterface.h" #include "TraceUtils.h" #include "Utils.h" @@ -382,6 +388,7 @@ static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret, class EnzymeBase { public: EnzymeLogic Logic; + bool poseidonProfiling = false; EnzymeBase(bool PostOpt) : Logic(EnzymePostOpt.getNumOccurrences() ? EnzymePostOpt : PostOpt) { // initializeLowerAutodiffIntrinsicPass(*PassRegistry::getPassRegistry()); @@ -487,6 +494,7 @@ class EnzymeBase { bool runtimeActivity; bool strongZero; bool subsequent_calls_may_write; + double errTol; }; #if LLVM_VERSION_MAJOR > 16 @@ -527,6 +535,7 @@ class EnzymeBase { mode != DerivativeMode::ForwardModeError && mode != DerivativeMode::ReverseModeCombined; StringSet<> ActiveRandomVariables; + double errTol = 0.0; DIFFE_TYPE retType = whatType(fn->getReturnType(), mode); @@ -849,6 +858,21 @@ class EnzymeBase { } skipArg = true; break; + } else if (*metaString == "enzyme_err_tol") { + ++i; + Value *relErrorArg = CI->getArgOperand(i); + if (auto *CFP = dyn_cast(relErrorArg)) { + errTol = CFP->getValueAPF().convertToDouble(); + } else { + EmitFailure( + "InvalidErrorTolerance", CI->getDebugLoc(), CI, + "Relative error tolerance must be a constant floating-point " + "value, got ", + *relErrorArg); + return {}; + } + skipArg = true; + break; } else { EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI, "illegal enzyme metadata classification ", *CI, @@ -1129,7 +1153,8 @@ class EnzymeBase { overwritten_args, runtimeActivity, strongZero, - subsequent_calls_may_write}); + subsequent_calls_may_write, + errTol}); } static FnTypeInfo populate_type_args(TypeAnalysis &TA, llvm::Function *fn, @@ -1543,8 +1568,10 @@ class EnzymeBase { .forceAnonymousTape = false, .typeInfo = type_args, .runtimeActivity = options.runtimeActivity, - .strongZero = options.strongZero}, + .strongZero = options.strongZero, + .profiled = poseidonProfiling}, TA, /*augmented*/ nullptr); + poseidonProfiling = false; break; case DerivativeMode::ReverseModePrimal: case DerivativeMode::ReverseModeGradient: { @@ -1615,7 +1642,8 @@ class EnzymeBase { .forceAnonymousTape = forceAnonymousTape, .typeInfo = type_args, .runtimeActivity = options.runtimeActivity, - .strongZero = options.strongZero}, + .strongZero = options.strongZero, + .profiled = false}, TA, aug); } } @@ -1947,6 +1975,138 @@ class EnzymeBase { return status; } +#ifdef ENABLE_POSEIDON + bool HandlePoseidonProf(CallInst *CI, SmallVectorImpl &calls) { + assert(FPProfileGenerate); + + IRBuilder<> Builder(CI); + Function *F = parseFunctionParameter(CI); + if (!F) + return false; + + assert(F); + + std::map byVal; + std::vector constants; + SmallVector args; + + auto mode = DerivativeMode::ReverseModeCombined; + poseidonProfiling = true; + + auto options = + handleArguments(Builder, CI, F, mode, false, constants, args, byVal); + if (!options) { + poseidonProfiling = false; + return false; + } + + Value *ret = CI; + Type *retElemType = nullptr; + if (CI->hasStructRetAttr()) { + ret = CI->getArgOperand(0); + retElemType = + CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet) + .getValueAsType(); + } + + return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, + byVal, constants, F, mode, *options, false, calls); + } + + bool HandlePoseidonOpt(CallInst *CI, SmallVectorImpl &calls) { + IRBuilder<> Builder(CI); + Function *F = parseFunctionParameter(CI); + if (!F) + return false; + + assert(F); + + auto mode = DerivativeMode::ReverseModeCombined; + poseidonProfiling = true; + + std::map byVal; + std::vector constants; + SmallVector args; + + auto options = + handleArguments(Builder, CI, F, mode, false, constants, args, byVal); + if (!options) { + poseidonProfiling = false; + return false; + } + + SmallVector primalArgs; + size_t argIdx = 0; + for (size_t i = 0; i < constants.size(); ++i) { + if (argIdx >= args.size()) + break; + + if (i == 0 && F->hasParamAttribute(0, Attribute::StructRet)) { + argIdx++; + if (constants[i] == DIFFE_TYPE::DUP_ARG || + constants[i] == DIFFE_TYPE::DUP_NONEED) { + argIdx++; + } + continue; + } + + primalArgs.push_back(args[argIdx]); + argIdx++; + + if (constants[i] == DIFFE_TYPE::DUP_ARG || + constants[i] == DIFFE_TYPE::DUP_NONEED) { + argIdx++; + } + } + + if (!FPProfileUse.getNumOccurrences() || FPProfileUse.empty()) { + EmitWarning("MissingProfileMode", *CI, + "__enzyme_fp_optimize called without -fpprofile-generate or " + "-fpprofile-use=. " + "Emitting unoptimized function call. Use -fpprofile-generate " + "to create a profile or -fpprofile-use= to use an " + "existing profile."); + } else { + F = Logic.PPC.preprocessForClone(F, mode); + setPoseidonMetadata(*F); + + SmallString<128> profilePath(FPProfileUse); + llvm::sys::path::append(profilePath, F->getName() + ".fpprofile"); + if (!llvm::sys::fs::exists(profilePath.str())) { + EmitFailure("NoProfile", CI->getDebugLoc(), CI, "No profile found at ", + profilePath, " (FPProfileUse: ", FPProfileUse, ")"); + return false; + } + + auto &TTI = Logic.PPC.FAM.getResult(*CI->getFunction()); + + if (FPOptPrint) { + llvm::errs() << "FPOpt: Optimizing " << F->getName() + << " with relative error tolerance: " << options->errTol + << "\n"; + } + + bool optimized = fpOptimize(*F, TTI, options->errTol); + + if (!optimized) { + EmitWarning("NoChange", *CI, "Poseidon returned false (no change) for ", + F->getName()); + } + } + + CallInst *optCall = Builder.CreateCall(F->getFunctionType(), F, primalArgs); + optCall->setCallingConv(CI->getCallingConv()); + optCall->setDebugLoc(CI->getDebugLoc()); + + CI->replaceAllUsesWith(optCall); + CI->eraseFromParent(); + + calls.push_back(optCall); + + return true; + } +#endif + bool handleFullModuleTrunc(Function &F) { if (startsWith(F.getName(), EnzymeFPRTPrefix)) return false; @@ -2096,6 +2256,7 @@ class EnzymeBase { SmallVector toTruncateFuncOp; SmallVector toTruncateValue; SmallVector toExpandValue; + SmallVector toFPOpt; MapVector toProbProg; SetVector InactiveCalls; SetVector IterCalls; @@ -2413,6 +2574,7 @@ class EnzymeBase { bool truncateFuncMem = false; bool truncateValue = false; bool expandValue = false; + bool fpOpt = false; bool probProg = false; DerivativeMode derivativeMode; ProbProgMode probProgMode; @@ -2469,6 +2631,9 @@ class EnzymeBase { enableEnzyme = true; probProgMode = ProbProgMode::Condition; probProg = true; + } else if (Fn->getName().contains("__enzyme_fp_optimize")) { + enableEnzyme = true; + fpOpt = true; } if (enableEnzyme) { @@ -2522,9 +2687,11 @@ class EnzymeBase { toTruncateValue.push_back(CI); else if (expandValue) toExpandValue.push_back(CI); - else if (probProg) { + else if (probProg) toProbProg[CI] = probProgMode; - } else + else if (fpOpt) + toFPOpt.push_back(CI); + else toLower[CI] = derivativeMode; if (auto dc = dyn_cast(fn)) { @@ -2628,6 +2795,18 @@ class EnzymeBase { HandleProbProg(call, mode, calls); } +#ifdef ENABLE_POSEIDON + for (auto CI : toFPOpt) { + Changed |= FPProfileGenerate ? HandlePoseidonProf(CI, calls) + : HandlePoseidonOpt(CI, calls); + } +#else + if (!toFPOpt.empty()) { + llvm_unreachable("Poseidon is not enabled. Please specify " + "-DENABLE_POSEIDON=1 when building Enzyme."); + } +#endif + if (Logic.PostOpt) { auto Params = llvm::getInlineParams(); @@ -2977,8 +3156,13 @@ class EnzymeBase { call->eraseFromParent(); } - for (const auto &pair : Logic.PPC.cache) - pair.second->eraseFromParent(); + // Poseidon opt uses preprocessed functions + for (const auto &pair : Logic.PPC.cache) { + if (pair.second->use_empty()) { + pair.second->eraseFromParent(); + } + } + Logic.clear(); if (changed && Logic.PostOpt) { @@ -3492,6 +3676,12 @@ extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB, if (registerFixupJuliaPass(Name, MPM)) { return true; } +#ifdef ENABLE_POSEIDON + if (Name == "fp-opt") { + MPM.addPass(FPOptNewPM()); + return true; + } +#endif if (Name == "preserve-nvvm") { MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); return true; diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index ae7839d6dcd8..6efcdd5833db 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3845,7 +3845,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( .forceAnonymousTape = key.forceAnonymousTape, .typeInfo = key.typeInfo, .runtimeActivity = key.runtimeActivity, - .strongZero = key.strongZero}, + .strongZero = key.strongZero, + .profiled = key.profiled}, TA, &aug, omp); SmallVector revargs; @@ -3944,7 +3945,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( .forceAnonymousTape = key.forceAnonymousTape, .typeInfo = key.typeInfo, .runtimeActivity = key.runtimeActivity, - .strongZero = key.strongZero}, + .strongZero = key.strongZero, + .profiled = key.profiled}, TA, augmenteddata, omp); { @@ -4228,6 +4230,26 @@ Function *EnzymeLogic::CreatePrimalAndGradient( assert(augmenteddata->constant_args == key.constant_args); } + if (key.profiled) { + Module *M = key.todiff->getParent(); + LLVMContext &Ctx = M->getContext(); + + Type *DoubleTy = Type::getDoubleTy(Ctx); + Type *PtrTy = PointerType::getUnqual(Ctx); + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *SizeTy = Type::getInt64Ty(Ctx); + + M->getOrInsertGlobal("ENZYME_FPPROFILE_RUNTIME_VAR", Int32Ty); + + FunctionType *LogValueFT = FunctionType::get( + Type::getVoidTy(Ctx), {PtrTy, SizeTy, DoubleTy, Int32Ty, PtrTy}, false); + M->getOrInsertFunction("enzymeLogValue", LogValueFT); + + FunctionType *LogGradFT = FunctionType::get( + Type::getVoidTy(Ctx), {PtrTy, SizeTy, DoubleTy, DoubleTy}, false); + M->getOrInsertFunction("enzymeLogGrad", LogGradFT); + } + bool diffeReturnArg = key.retType == DIFFE_TYPE::OUT_DIFF; DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone( @@ -4235,7 +4257,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( key.todiff, TLI, TA, oldTypeInfo, key.retType, augmenteddata ? augmenteddata->shadowReturnUsed : key.shadowReturnUsed, diffeReturnArg, key.constant_args, /*returnTape*/ false, key.returnUsed, - key.additionalType, omp); + key.additionalType, omp, key.profiled); gutils->AtomicAdd = key.AtomicAdd; gutils->FreeMemory = key.freeMemory; diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index fbb54947110d..e32a5603807b 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -174,6 +174,7 @@ struct ReverseCacheKey { const FnTypeInfo typeInfo; bool runtimeActivity; bool strongZero; + bool profiled; ReverseCacheKey replaceTypeInfo(const FnTypeInfo &newTypeInfo) const { return {todiff, @@ -191,7 +192,8 @@ struct ReverseCacheKey { forceAnonymousTape, newTypeInfo, runtimeActivity, - strongZero}; + strongZero, + profiled}; } /* inline bool operator==(const ReverseCacheKey& rhs) const { @@ -298,6 +300,11 @@ struct ReverseCacheKey { if (rhs.strongZero < strongZero) return false; + if (profiled < rhs.profiled) + return true; + if (rhs.profiled < profiled) + return false; + // equal return false; } diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 7dc7586eaf9a..705fb25bd763 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -28,6 +28,7 @@ #include "EnzymeLogic.h" #include "GradientUtils.h" #include "LibraryFuncs.h" +#include "Utils.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -115,6 +116,10 @@ #include "CacheUtility.h" #include "Utils.h" +#ifdef ENABLE_POSEIDON +#include "Poseidon/Poseidon.h" +#endif + #define addAttribute addAttributeAtIndex #define removeAttribute removeAttributeAtIndex #define getAttribute getAttributeAtIndex @@ -2244,8 +2249,8 @@ bool DetectReadonlyOrThrow(Module &M) { return changed; } -Function *PreProcessCache::preprocessForClone(Function *F, - DerivativeMode mode) { +Function *PreProcessCache::preprocessForClone(Function *F, DerivativeMode mode, + bool profiled) { TimeTraceScope timeScope("preprocessForClone", F->getName()); @@ -2691,7 +2696,8 @@ Function *PreProcessCache::preprocessForClone(Function *F, FAM.invalidate(*NewF, PA); } - if (mode != DerivativeMode::ForwardMode) + if (mode != DerivativeMode::ForwardMode && + mode != DerivativeMode::ForwardModeError) ReplaceReallocs(NewF); { @@ -2727,7 +2733,8 @@ Function *PreProcessCache::preprocessForClone(Function *F, PA.preserve(); } - if (mode != DerivativeMode::ForwardMode) + if (mode != DerivativeMode::ForwardMode && + mode != DerivativeMode::ForwardModeError) ReplaceReallocs(NewF); if (mode == DerivativeMode::ReverseModePrimal || @@ -2739,6 +2746,19 @@ Function *PreProcessCache::preprocessForClone(Function *F, UpgradeAllocasToMallocs(NewF, mode, unreachable); } +#ifdef ENABLE_POSEIDON + if (profiled) { + preprocessForPoseidon(NewF); + + auto PA = InstCombinePass().run(*NewF, FAM); + FAM.invalidate(*NewF, PA); + PA = EarlyCSEPass().run(*NewF, FAM); + FAM.invalidate(*NewF, PA); + PA = GVNPass().run(*NewF, FAM); + FAM.invalidate(*NewF, PA); + } +#endif + CanonicalizeLoops(NewF, FAM); RemoveRedundantPHI(NewF, FAM); @@ -2999,9 +3019,14 @@ Function *PreProcessCache::CloneFunctionWithReturns( SmallPtrSetImpl &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const Twine &name, llvm::ValueMap *VMapO, - bool diffeReturnArg, llvm::Type *additionalArg) { + bool diffeReturnArg, llvm::Type *additionalArg, bool profiled) { if (!F->empty()) - F = preprocessForClone(F, mode); + F = preprocessForClone(F, mode, profiled); +#ifdef ENABLE_POSEIDON + if (profiled) { + setPoseidonMetadata(*F); + } +#endif llvm::ValueToValueMapTy VMap; llvm::FunctionType *FTy = getFunctionTypeForClone( F->getFunctionType(), mode, width, additionalArg, constant_args, diff --git a/enzyme/Enzyme/FunctionUtils.h b/enzyme/Enzyme/FunctionUtils.h index d3e85c1b93d3..6dd73a50badd 100644 --- a/enzyme/Enzyme/FunctionUtils.h +++ b/enzyme/Enzyme/FunctionUtils.h @@ -95,7 +95,8 @@ class PreProcessCache { std::map, llvm::Function *> cache; std::map CloneOrigin; - llvm::Function *preprocessForClone(llvm::Function *F, DerivativeMode mode); + llvm::Function *preprocessForClone(llvm::Function *F, DerivativeMode mode, + bool profiled = false); llvm::AAResults &getAAResultsFromFunction(llvm::Function *NewF); @@ -108,7 +109,8 @@ class PreProcessCache { llvm::SmallPtrSetImpl &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const llvm::Twine &name, llvm::ValueMap *VMapO, - bool diffeReturnArg, llvm::Type *additionalArg = nullptr); + bool diffeReturnArg, llvm::Type *additionalArg = nullptr, + bool profiled = false); void ReplaceReallocs(llvm::Function *NewF, bool mem2reg = false); void LowerAllocAddr(llvm::Function *NewF); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 1cdfa35f7c1d..a97bd026bacb 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -163,8 +163,8 @@ GradientUtils::GradientUtils( llvm::ValueMap &originalToNewFn_, DerivativeMode mode, bool runtimeActivity, bool strongZero, unsigned width, bool omp) - : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), - invertedPointers(), + : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), profiled(false), + oldFunc(oldFunc_), invertedPointers(), OrigDT(oldFunc_->empty() ? ((DominatorTree *)nullptr) : &Logic.PPC.FAM.getResult( @@ -4908,7 +4908,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction( .forceAnonymousTape = true, .typeInfo = type_args, .runtimeActivity = runtimeActivity, - .strongZero = strongZero}, + .strongZero = strongZero, + .profiled = false}, TA, /*map*/ &augdata); assert(newf); diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index cdeb6d739fa1..bb6effdf5933 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -127,6 +127,7 @@ class GradientUtils : public CacheUtility { EnzymeLogic &Logic; bool AtomicAdd; DerivativeMode mode; + bool profiled; llvm::Function *oldFunc; llvm::ValueMap invertedPointers; llvm::DominatorTree *OrigDT; diff --git a/enzyme/Enzyme/Poseidon/Poseidon.cpp b/enzyme/Enzyme/Poseidon/Poseidon.cpp new file mode 100644 index 000000000000..69be4a93c74a --- /dev/null +++ b/enzyme/Enzyme/Poseidon/Poseidon.cpp @@ -0,0 +1,1268 @@ +#include + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include "llvm/Analysis/TargetTransformInfo.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Verifier.h" + +#include "llvm/Passes/PassBuilder.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" + +#include "llvm/Pass.h" + +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "../Utils.h" +#include "Poseidon.h" +#include "PoseidonHerbieUtils.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonProfUtils.h" +#include "PoseidonSolvers.h" +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +using namespace llvm; +#ifdef DEBUG_TYPE +#undef DEBUG_TYPE +#endif +#define DEBUG_TYPE "fp-opt" + +extern "C" { +cl::opt FPProfileGenerate( + "fpprofile-generate", cl::init(false), cl::Hidden, + cl::desc("Generate instrumented program for FP profiling")); +cl::opt FPProfileUse( + "fpprofile-use", cl::Hidden, cl::value_desc("directory"), cl::ValueOptional, + cl::desc("FP profile directory to read from for FP optimization")); +cl::opt FPOptPrint("fpopt-print", cl::init(false), cl::Hidden, + cl::desc("Print FPOpt debug info")); +cl::opt FPOptEnableHerbie( + "fpopt-enable-herbie", cl::init(true), cl::Hidden, + cl::desc("Use Herbie to rewrite floating-point expressions")); +cl::opt FPOptEnablePT( + "fpopt-enable-pt", cl::init(true), cl::Hidden, + cl::desc("Consider precision changes of floating-point expressions")); +cl::opt FPOptCachePath("fpopt-cache-path", cl::init("cache"), + cl::Hidden, + cl::desc("Path to cache Herbie results")); +cl::opt FPOptEnableSolver( + "fpopt-enable-solver", cl::init(true), cl::Hidden, + cl::desc("Use the solver to select desirable rewrite candidates; when " + "disabled, apply all Herbie's first choices")); +cl::opt FPOptMaxExprDepth( + "fpopt-max-expr-depth", cl::init(100), cl::Hidden, + cl::desc( + "The maximum depth of expression construction; abort if exceeded")); +cl::opt FPOptMaxExprLength( + "fpopt-max-expr-length", cl::init(10000), cl::Hidden, + cl::desc("The maximum length of an expression; abort if exceeded")); +cl::opt FPOptReductionEval( + "fpopt-reduction-eval", cl::init("arithmean"), cl::Hidden, + cl::desc("Which reduction result to use in candidate evaluation. " + "Options are 'geomean', 'arithmean', and 'maxabs'")); +cl::opt FPOptMinUsesForSplit( + "fpopt-min-uses-split", cl::init(99), cl::Hidden, + cl::desc("Minimum number of uses of bottleneck node to trigger split")); +cl::opt + FPOptMinOpsForSplit("fpopt-min-ops-split", cl::init(99), cl::Hidden, + cl::desc("Minimum number of upstream operations of " + "bottleneck node to trigger split")); +cl::opt + FPOptAggressiveDCE("fpopt-aggressive-dce", cl::init(false), cl::Hidden, + cl::desc("Aggressively eliminate zero gradient outputs " + "as dead code (non-conditional only)")); +cl::opt FPOptMultiOutputPTOnly( + "fpopt-multi-output-pt-only", cl::init(false), cl::Hidden, + cl::desc("Skip Herbie expression generation for subgraphs with multiple " + "outputs (only apply precision changes)")); +} + +bool Poseidonable(const llvm::Value &V) { + const Instruction *I = dyn_cast(&V); + if (!I) + return false; + + switch (I->getOpcode()) { + case Instruction::FNeg: + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + return I->getType()->isFloatTy() || I->getType()->isDoubleTy(); + case Instruction::Call: { + const CallInst *CI = dyn_cast(I); + if (CI && CI->getCalledFunction() && + (CI->getType()->isFloatTy() || CI->getType()->isDoubleTy())) { + StringRef funcName = CI->getCalledFunction()->getName(); + return + // LLVM intrinsics + startsWith(funcName, "llvm.sin.") || + startsWith(funcName, "llvm.cos.") || + startsWith(funcName, "llvm.tan.") || + startsWith(funcName, "llvm.asin.") || + startsWith(funcName, "llvm.acos.") || + startsWith(funcName, "llvm.atan.") || + startsWith(funcName, "llvm.atan2.") || + startsWith(funcName, "llvm.sinh.") || + startsWith(funcName, "llvm.cosh.") || + startsWith(funcName, "llvm.tanh.") || + startsWith(funcName, "llvm.exp.") || + startsWith(funcName, "llvm.log.") || + startsWith(funcName, "llvm.sqrt.") || + startsWith(funcName, "llvm.pow.") || + startsWith(funcName, "llvm.powi.") || + startsWith(funcName, "llvm.fabs.") || + startsWith(funcName, "llvm.fma.") || + startsWith(funcName, "llvm.fmuladd.") || + startsWith(funcName, "llvm.maxnum.") || + startsWith(funcName, "llvm.minnum.") || + startsWith(funcName, "llvm.ceil.") || + startsWith(funcName, "llvm.floor.") || + startsWith(funcName, "llvm.exp2.") || + startsWith(funcName, "llvm.log10.") || + startsWith(funcName, "llvm.log2.") || + startsWith(funcName, "llvm.rint.") || + startsWith(funcName, "llvm.round.") || + startsWith(funcName, "llvm.trunc.") || + startsWith(funcName, "llvm.copysign.") || + startsWith(funcName, "llvm.fdim.") || + startsWith(funcName, "llvm.fmod.") || + + // libm functions + funcName == "sin" || funcName == "sinf" || funcName == "cos" || + funcName == "cosf" || funcName == "tan" || funcName == "tanf" || + funcName == "asin" || funcName == "asinf" || funcName == "acos" || + funcName == "acosf" || funcName == "atan" || funcName == "atanf" || + funcName == "atan2" || funcName == "atan2f" || funcName == "sinh" || + funcName == "sinhf" || funcName == "cosh" || funcName == "coshf" || + funcName == "tanh" || funcName == "tanhf" || funcName == "asinh" || + funcName == "asinhf" || funcName == "acosh" || funcName == "acoshf" || + funcName == "atanh" || funcName == "atanhf" || funcName == "sqrt" || + funcName == "sqrtf" || funcName == "cbrt" || funcName == "cbrtf" || + funcName == "pow" || funcName == "powf" || funcName == "exp" || + funcName == "expf" || funcName == "log" || funcName == "logf" || + funcName == "fabs" || funcName == "fabsf" || funcName == "fma" || + funcName == "fmaf" || funcName == "hypot" || funcName == "hypotf" || + funcName == "expm1" || funcName == "expm1f" || funcName == "log1p" || + funcName == "log1pf" || funcName == "ceil" || funcName == "ceilf" || + funcName == "floor" || funcName == "floorf" || funcName == "erf" || + funcName == "erff" || funcName == "exp2" || funcName == "exp2f" || + funcName == "lgamma" || funcName == "lgammaf" || + funcName == "log10" || funcName == "log10f" || funcName == "log2" || + funcName == "log2f" || funcName == "rint" || funcName == "rintf" || + funcName == "round" || funcName == "roundf" || funcName == "tgamma" || + funcName == "tgammaf" || funcName == "trunc" || + funcName == "truncf" || funcName == "copysign" || + funcName == "copysignf" || funcName == "fdim" || + funcName == "fdimf" || funcName == "fmod" || funcName == "fmodf" || + funcName == "remainder" || funcName == "remainderf"; + } + return false; + } + default: + return false; + } +} + +void setPoseidonMetadata(Function &F) { + for (auto [idx, I] : enumerate(instructions(F))) { + if (Poseidonable(I)) { + I.setMetadata("enzyme_active", MDNode::get(I.getContext(), {})); + I.setMetadata("enzyme_fpprofile_idx", + MDNode::get(I.getContext(), + {ConstantAsMetadata::get(ConstantInt::get( + Type::getInt64Ty(I.getContext()), idx))})); + } + } +} + +void preprocessForPoseidon(Function *F) { + using namespace llvm::PatternMatch; + + // fmul + fadd -> fmuladd + for (auto &BB : *F) { + for (auto &I : make_early_inc_range(BB)) { + Value *X, *Y, *Z; + + if (auto *FAdd = dyn_cast(&I)) { + if (!isa(FAdd) || !FAdd->hasAllowReassoc() || + !FAdd->hasAllowContract()) + continue; + + // fadd (fmul X, Y), Z + if (match(FAdd, m_FAdd(m_OneUse(m_FMul(m_Value(X), m_Value(Y))), + m_Value(Z)))) { + IRBuilder<> B(FAdd); + B.setFastMathFlags(FAdd->getFastMathFlags()); + + Value *FMulAdd = + B.CreateIntrinsic(Intrinsic::fmuladd, FAdd->getType(), {X, Y, Z}); + FAdd->replaceAllUsesWith(FMulAdd); + FAdd->eraseFromParent(); + } + // fadd Z, (fmul X, Y) + else if (match(FAdd, + m_FAdd(m_Value(Z), + m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))) { + IRBuilder<> B(FAdd); + B.setFastMathFlags(FAdd->getFastMathFlags()); + + Value *FMulAdd = + B.CreateIntrinsic(Intrinsic::fmuladd, FAdd->getType(), {X, Y, Z}); + FAdd->replaceAllUsesWith(FMulAdd); + + FAdd->eraseFromParent(); + } + } + } + } + + for (auto &BB : *F) { + for (auto &I : make_early_inc_range(BB)) { + Value *X, *Y, *Z; + + if (auto *FSub = dyn_cast(&I)) { + if (!isa(FSub) || !FSub->hasAllowReassoc() || + !FSub->hasAllowContract()) + continue; + + // Pattern: fsub (fmul X, Y), Z -> fmuladd(X, Y, -Z) + if (match(FSub, m_FSub(m_OneUse(m_FMul(m_Value(X), m_Value(Y))), + m_Value(Z)))) { + IRBuilder<> B(FSub); + B.setFastMathFlags(FSub->getFastMathFlags()); + + Value *NegZ = B.CreateFNeg(Z); + Value *FMulAdd = B.CreateIntrinsic(Intrinsic::fmuladd, + FSub->getType(), {X, Y, NegZ}); + FSub->replaceAllUsesWith(FMulAdd); + FSub->eraseFromParent(); + } + // Pattern: fsub Z, (fmul X, Y) -> fmuladd(-X, Y, Z) + else if (match(FSub, + m_FSub(m_Value(Z), + m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))) { + IRBuilder<> B(FSub); + B.setFastMathFlags(FSub->getFastMathFlags()); + + Value *NegX = B.CreateFNeg(X); + Value *FMulAdd = B.CreateIntrinsic(Intrinsic::fmuladd, + FSub->getType(), {NegX, Y, Z}); + FSub->replaceAllUsesWith(FMulAdd); + FSub->eraseFromParent(); + } + } + } + } + + // fcmp + select -> fmax/fmin + for (auto &BB : *F) { + for (auto &I : make_early_inc_range(BB)) { + if (auto *Select = dyn_cast(&I)) { + Value *Cond = Select->getCondition(); + Value *TrueVal = Select->getTrueValue(); + Value *FalseVal = Select->getFalseValue(); + + if (!Select->getType()->isFloatingPointTy()) + continue; + + CmpPredicate Pred; + Value *CmpLHS, *CmpRHS; + + if (match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { + IRBuilder<> B(Select); + Value *Result = nullptr; + + // select (fcmp ogt X, 0.0), X, 0.0 -> maxnum(X, 0.0) + if (Pred == FCmpInst::FCMP_OGT && match(CmpRHS, m_AnyZeroFP()) && + CmpLHS == TrueVal && match(FalseVal, m_AnyZeroFP())) { + Result = B.CreateIntrinsic( + Intrinsic::maxnum, CmpLHS->getType(), + {CmpLHS, ConstantFP::get(CmpLHS->getType(), 0.0)}); + } + // select (fcmp olt X, 0.0), 0.0, X -> maxnum(X, 0.0) + else if (Pred == FCmpInst::FCMP_OLT && match(CmpRHS, m_AnyZeroFP()) && + CmpLHS == FalseVal && match(TrueVal, m_AnyZeroFP())) { + Result = B.CreateIntrinsic( + Intrinsic::maxnum, CmpLHS->getType(), + {CmpLHS, ConstantFP::get(CmpLHS->getType(), 0.0)}); + } + // select (fcmp ogt X, Y), X, Y -> maxnum(X, Y) + else if (Pred == FCmpInst::FCMP_OGT && CmpLHS == TrueVal && + CmpRHS == FalseVal) { + Result = B.CreateIntrinsic(Intrinsic::maxnum, CmpLHS->getType(), + {CmpLHS, CmpRHS}); + } + // select (fcmp olt X, Y), X, Y -> minnum(X, Y) + else if (Pred == FCmpInst::FCMP_OLT && CmpLHS == TrueVal && + CmpRHS == FalseVal) { + Result = B.CreateIntrinsic(Intrinsic::minnum, CmpLHS->getType(), + {CmpLHS, CmpRHS}); + } + + if (Result) { + Select->replaceAllUsesWith(Result); + Select->eraseFromParent(); + + if (auto *FCmp = dyn_cast(Cond)) { + if (FCmp->use_empty()) { + FCmp->eraseFromParent(); + } + } + } + } + } + } + } +} + +// Run (our choice of) floating point optimizations on function `F`. +// Return whether or not we change the function. +bool fpOptimize(Function &F, const TargetTransformInfo &TTI, double errorTol) { + bool changed = false; + + // llvm::errs() << "FPOpt: Starting optimization for " << F.getName() << "\n"; + // F.print(llvm::errs()); + + const std::string functionName = F.getName().str(); + assert(!FPProfileUse.empty()); + SmallString<128> profilePathBuf(FPProfileUse); + llvm::sys::path::append(profilePathBuf, F.getName() + ".fpprofile"); + const std::string profilePath = profilePathBuf.str().str(); + + if (!FPOptCachePath.empty()) { + if (auto EC = llvm::sys::fs::create_directories(FPOptCachePath, true)) + llvm::errs() << "Warning: Could not create cache directory: " + << EC.message() << "\n"; + } + + std::unordered_map profileMap; + if (!profilePath.empty()) { + parseProfileFile(profilePath, profileMap); + if (profileMap.empty()) { + llvm::errs() << "Warning: No profile data found in " << profilePath + << "\n"; + } + } + + int symbolCounter = 0; + auto getNextSymbol = [&symbolCounter]() -> std::string { + return "v" + std::to_string(symbolCounter++); + }; + + // Extract change: + + // E1) create map for all instructions I, map[I] = FPLLValue(I) + // E2) for all instructions, if Poseidonable(I), map[I] = FPNode(operation(I), + // map[operands(I)]) + // E3) floodfill for all starting locations I to find all distinct graphs / + // outputs. + + /* + B1: + x = sin(arg) + + B2: + y = 1 - x * x + + + -> result y = cos(arg)^2 + +B1: + nothing + +B2: + costmp = cos(arg) + y = costmp * costmp + + */ + + std::unordered_map> valueToNodeMap; + std::unordered_map symbolToValueMap; + + llvm::errs() << "FPOpt: Starting Floodfill for " << F.getName() << "\n"; + + for (auto &BB : F) { + for (auto &I : BB) { + if (!Poseidonable(I)) { + valueToNodeMap[&I] = + std::make_shared(&I, "__nh", "__nh"); // Non-Poseidonable + if (FPOptPrint) + llvm::errs() + << "Registered FPLLValue for non-Poseidonable instruction: " << I + << "\n"; + continue; + } + + std::string dtype; + if (I.getType()->isFloatTy()) { + dtype = "f32"; + } else if (I.getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for instruction"); + } + auto node = std::make_shared(&I, getHerbieOperator(I), dtype); + + auto operands = + isa(I) ? cast(I).args() : I.operands(); + for (auto &operand : operands) { + if (!valueToNodeMap.count(operand)) { + if (auto Arg = dyn_cast(operand)) { + std::string dtype; + if (Arg->getType()->isFloatTy()) { + dtype = "f32"; + } else if (Arg->getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for argument"); + } + valueToNodeMap[operand] = + std::make_shared(Arg, "__arg", dtype); + if (FPOptPrint) + llvm::errs() << "Registered FPNode for argument: " << *Arg + << "\n"; + } else if (auto C = dyn_cast(operand)) { + SmallString<10> value; + C->getValueAPF().toString(value); + std::string dtype; + if (C->getType()->isFloatTy()) { + dtype = "f32"; + } else if (C->getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for constant"); + } + valueToNodeMap[operand] = + std::make_shared(value.c_str(), dtype); + if (FPOptPrint) + llvm::errs() << "Registered FPNode for " << dtype + << " constant: " << value << "\n"; + } else if (auto CI = dyn_cast(operand)) { + // e.g., powi intrinsic has a constant int as its exponent + double exponent = static_cast(CI->getSExtValue()); + std::string dtype = "f64"; + std::string doubleStr = std::to_string(exponent); + valueToNodeMap[operand] = + std::make_shared(doubleStr.c_str(), dtype); + if (FPOptPrint) + llvm::errs() << "Registered FPNode for " << dtype + << " constant (casted from integer): " << doubleStr + << "\n"; + } else if (auto GV = dyn_cast(operand)) { + Type *elemType = GV->getValueType(); + + assert(elemType->isFloatingPointTy() && + "Global variable is not floating point type"); + std::string dtype; + if (elemType->isFloatTy()) { + dtype = "f32"; + } else if (elemType->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable( + "Unexpected floating point type for global variable"); + } + valueToNodeMap[operand] = + std::make_shared(GV, "__gv", dtype); + if (FPOptPrint) + llvm::errs() << "Registered FPNode for global variable: " << *GV + << "\n"; + } else { + assert(0 && "Unknown operand"); + } + } + node->addOperand(valueToNodeMap[operand]); + } + valueToNodeMap[&I] = node; + } + } + + SmallSet processed; + SmallVector subgraphs; + for (auto &BB : F) { + for (auto &I : BB) { + // Not a Poseidonable instruction, doesn't make sense to create graph node + // out of. + if (!Poseidonable(I)) { + if (FPOptPrint) + llvm::errs() << "Skipping non-Poseidonable instruction: " << I + << "\n"; + continue; + } + + // Instruction is already in a set + if (processed.contains(&I)) { + if (FPOptPrint) + llvm::errs() << "Skipping already seen instruction: " << I << "\n"; + continue; + } + + if (FPOptPrint) + llvm::errs() << "Starting floodfill from: " << I << "\n"; + + SmallVector todo; + SetVector input_seen; + SetVector output_seen; + SetVector operation_seen; + todo.push_back(&I); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + assert(valueToNodeMap.count(cur) && "Node not found in valueToNodeMap"); + + // We now can assume that this is a Poseidonable expression + // Since we can only herbify instructions, let's assert that + assert(isa(cur)); + auto I2 = cast(cur); + + // Don't repeat any instructions we've already seen (to avoid loops + // for phi nodes) + if (operation_seen.contains(I2)) { + if (FPOptPrint) + llvm::errs() << "Skipping already seen instruction: " << *I2 + << "\n"; + continue; + } + + assert(!processed.contains(cur)); + + if (FPOptPrint) + llvm::errs() << "Insert to operation_seen and processed: " << *I2 + << "\n"; + operation_seen.insert(I2); + processed.insert(cur); + + auto operands = + isa(I2) ? cast(I2)->args() : I2->operands(); + + for (const auto &operand : operands) { + if (!Poseidonable(*operand)) { + if (FPOptPrint) + llvm::errs() << "Non-Poseidonable input found: " << *operand + << "\n"; + + // Don't mark constants as input `llvm::Value`s + if (!isa(operand)) + input_seen.insert(operand); + } else { + if (FPOptPrint) + llvm::errs() << "Adding operand to todo list: " << *operand + << "\n"; + todo.push_back(operand); + } + } + + for (auto U : I2->users()) { + if (auto I3 = dyn_cast(U)) { + if (!Poseidonable(*I3)) { + if (FPOptPrint) + llvm::errs() << "Output instruction found: " << *I2 << "\n"; + output_seen.insert(I2); + } else { + if (FPOptPrint) + llvm::errs() << "Adding user to todo list: " << *I3 << "\n"; + todo.push_back(I3); + } + } + } + } + + // Don't bother with graphs without any Poseidonable operations + if (!operation_seen.empty()) { + if (FPOptPrint) { + llvm::errs() << "Found a subgraph with " << operation_seen.size() + << " operations and " << input_seen.size() + << " inputs and " << output_seen.size() << " outputs\n"; + + llvm::errs() << "Inputs:\n"; + + for (auto &input : input_seen) { + llvm::errs() << *input << "\n"; + } + + llvm::errs() << "Outputs:\n"; + for (auto &output : output_seen) { + llvm::errs() << *output << "\n"; + } + + llvm::errs() << "Operations:\n"; + for (auto &operation : operation_seen) { + llvm::errs() << *operation << "\n"; + } + } + + if (operation_seen.size() == 1) { + if (FPOptPrint) + llvm::errs() << "Skipping trivial subgraph\n"; + continue; + } + + subgraphs.emplace_back(input_seen, output_seen, operation_seen); + } + } + } + + if (FPOptPrint) { + llvm::errs() << "FPOpt: Found " << subgraphs.size() + << " initial subgraphs in " << F.getName() << "\n"; + } + + // Profile read must happen before aggressive DCE as it requires gradients + for (auto &subgraph : subgraphs) { + for (auto op : subgraph.operations) { + if (auto MD = op->getMetadata("enzyme_fpprofile_idx")) { + if (auto C = dyn_cast(MD->getOperand(0))) { + size_t idx = cast(C->getValue())->getZExtValue(); + auto it = profileMap.find(idx); + + if (it != profileMap.end()) { + const auto &profileInfo = it->second; + + auto node = valueToNodeMap[op]; + node->sens = profileInfo.sumSens; + node->grad = profileInfo.sumGrad; + node->executions = profileInfo.exec; + node->updateBounds(profileInfo.minRes, profileInfo.maxRes); + + if (FPOptPrint) { + llvm::errs() << "Range of " << *op << " is [" + << node->getLowerBound() << ", " + << node->getUpperBound() << "]\n"; + llvm::errs() << "Sensitivity score of " << *op + << " is: " << node->sens << "\n" + << "Gradient sum of " << *op << " is: " << node->grad + << "\n" + << "Execution count of " << *op + << " is: " << node->executions << "\n"; + } + + auto operands = + isa(op) ? cast(op)->args() : op->operands(); + + for (const auto &operand_ : enumerate(operands)) { + auto &operand = operand_.value(); + auto i = operand_.index(); + + if (i < profileInfo.minOperands.size()) { + auto operandNode = valueToNodeMap[operand]; + operandNode->updateBounds(profileInfo.minOperands[i], + profileInfo.maxOperands[i]); + if (FPOptPrint) { + llvm::errs() << "Range of " << *operand << " is [" + << operandNode->getLowerBound() << ", " + << operandNode->getUpperBound() << "]\n"; + } + } + } + } else { + if (!FPOptLooseCoverage) { + llvm::errs() << "FP Instruction " << *op + << " has no execution logged (idx=" << idx << ")!\n"; + llvm_unreachable("Unexecuted instruction found; set " + "-fpopt-loose-coverage " + "to suppress this error\n"); + } + if (FPOptPrint) + llvm::errs() << "Sensitivity of " << *op + << " not found in the log; using 0 instead\n"; + } + } + } + } + } + + if (FPOptAggressiveDCE) { + SmallSet critical; + for (auto &BB : F) { + for (auto &I : BB) { + if (auto *fcmp = dyn_cast(&I)) { + critical.insert(fcmp->getOperand(0)); + critical.insert(fcmp->getOperand(1)); + } + } + } + + SmallVector worklist(critical.begin(), critical.end()); + while (!worklist.empty()) { + Value *V = worklist.pop_back_val(); + + if (auto *inst = dyn_cast(V)) { + auto operands = isa(inst) ? cast(inst)->args() + : inst->operands(); + for (auto &op : operands) { + if (op->getType()->isFloatingPointTy() && + critical.insert(op).second) { + worklist.push_back(op); + } + } + + if (auto *phi = dyn_cast(inst)) { + for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) { + Value *incoming = phi->getIncomingValue(i); + if (incoming->getType()->isFloatingPointTy() && + critical.insert(incoming).second) { + worklist.push_back(incoming); + } + } + } + + if (auto *load = dyn_cast(inst)) { + Value *ptr = load->getPointerOperand(); + for (auto *user : ptr->users()) { + if (auto *store = dyn_cast(user)) { + Value *storedVal = store->getValueOperand(); + if (storedVal->getType()->isFloatingPointTy() && + critical.insert(storedVal).second) { + worklist.push_back(storedVal); + } + } + } + } + } + } + + if (FPOptPrint) { + llvm::errs() << "Critical values:\n"; + for (auto *value : critical) { + llvm::errs() << "\t" << *value << "\n"; + } + } + + auto subgraphIt = subgraphs.begin(); + while (subgraphIt != subgraphs.end()) { + Subgraph &subgraph = *subgraphIt; + SmallVector toRemove; + + for (auto *op : subgraph.operations) { + auto node = valueToNodeMap[op]; + if (node->grad == 0. && !critical.count(op)) { + if (FPOptPrint) + llvm::errs() << "Aggressive DCE: eliminating zero-gradient " + << "non-critical instruction: " << *op << "\n"; + toRemove.push_back(op); + } + } + + for (auto *op : toRemove) { + if (!op->use_empty()) { + op->replaceAllUsesWith(UndefValue::get(op->getType())); + } + + valueToNodeMap.erase(op); + subgraph.operations.remove(op); + subgraph.outputs.remove(op); + + op->eraseFromParent(); + } + + if (subgraph.outputs.empty()) { + if (FPOptPrint) + llvm::errs() << "Removing empty subgraph\n"; + subgraphIt = subgraphs.erase(subgraphIt); + } else { + ++subgraphIt; + } + } + } + + if (FPOptPrint && FPOptAggressiveDCE) { + llvm::errs() << "FPOpt: After aggressive DCE, have " << subgraphs.size() + << " subgraphs in " << F.getName() << "\n"; + } + + splitSubgraphs(subgraphs); + + if (FPOptPrint) { + llvm::errs() << "FPOpt: After splitting, have " << subgraphs.size() + << " subgraphs in " << F.getName() << "\n"; + } + + if (FPOptPrint) { + llvm::errs() << "\n=== Function IR after Subgraph Splitting ===\n"; + + std::unordered_map instToSubgraphIdx; + for (size_t idx = 0; idx < subgraphs.size(); ++idx) { + for (auto *inst : subgraphs[idx].operations) { + instToSubgraphIdx[inst] = idx; + } + for (auto *inst : subgraphs[idx].outputs) { + if (instToSubgraphIdx.find(inst) == instToSubgraphIdx.end()) { + instToSubgraphIdx[inst] = idx; + } + } + } + + for (auto &BB : F) { + BB.printAsOperand(llvm::errs(), false); + llvm::errs() << ":\n"; + for (auto &I : BB) { + llvm::errs() << " "; + I.print(llvm::errs()); + + auto it = instToSubgraphIdx.find(&I); + if (it != instToSubgraphIdx.end()) { + llvm::errs() << " ; [SG" << it->second << "]"; + } + llvm::errs() << "\n"; + } + } + llvm::errs() << "=== End of Function IR ===\n\n"; + } + + // 1) Identify subgraphs of the computation which can be entirely represented + // in herbie-style arithmetic + // 2) Make the herbie FP-style expression by + // converting llvm instructions into herbie string (FPNode ....) + if (subgraphs.empty()) { + if (FPOptPrint) + llvm::errs() << "No subgraphs found\n"; + return false; + } + + SmallVector COs; + SmallVector CSs; + + int subgraphCounter = 0; + + for (auto &subgraph : subgraphs) { + assert(subgraph.inputs.size() > 0 && "No inputs found for subgraph"); + + bool skipHerbie = false; + if (FPOptMultiOutputPTOnly && subgraph.outputs.size() > 1) { + skipHerbie = true; + if (FPOptPrint) + llvm::errs() << "Skipping Herbie for subgraph with " + << subgraph.outputs.size() + << " outputs (fpopt-multi-output-pt-only is set)\n"; + } + + if (FPOptEnableHerbie && !skipHerbie) { + for (const auto &input : subgraph.inputs) { + auto node = valueToNodeMap[input]; + if (node->op == "__const") { + // Constants don't need a symbol + continue; + } + + if (!node->hasSymbol()) { + node->symbol = getNextSymbol(); + } + symbolToValueMap[node->symbol] = input; + if (FPOptPrint) + llvm::errs() << "assigning symbol: " << node->symbol << " to " + << *input << "\n"; + } + + std::vector herbieInputs; + std::vector newCOs; + + assert(subgraph.outputs.size() > 0 && "No outputs found for subgraph"); + for (auto &output : subgraph.outputs) { + // 3) run fancy opts + double grad = valueToNodeMap[output]->grad; + unsigned executions = valueToNodeMap[output]->executions; + + if (grad == 0.) { + llvm::errs() << "Skipping zero gradient instruction: " << *output + << "\n"; + continue; + } + + std::string expr = valueToNodeMap[output]->toFullExpression( + valueToNodeMap, subgraph.inputs); + + if (expr.length() > FPOptMaxExprLength) { + llvm::errs() << "WARNING: Skipping Herbie optimization for " + << *output << " since expression length " + << expr.length() << " exceeds limit of " + << FPOptMaxExprLength << "\n"; + continue; + } + + // Skip trivial expressions with only one operation + auto parenCount = std::count(expr.begin(), expr.end(), '('); + assert(parenCount > 0); + if (parenCount == 1) { + if (FPOptPrint) + llvm::errs() << "Skipping Herbie for simple expression: " << expr + << "\n"; + continue; + } + + SmallSet args; + getUniqueArgs(expr, args); + + std::string properties = ":herbie-conversions ([binary64 binary32])"; + if (valueToNodeMap[output]->dtype == "f32") { + properties += " :precision binary32"; + } else if (valueToNodeMap[output]->dtype == "f64") { + properties += " :precision binary64"; + } else { + llvm_unreachable("Unexpected dtype"); + } + + std::string precondition = + getPrecondition(args, valueToNodeMap, symbolToValueMap); + properties += " :pre " + precondition; + + CandidateOutput CO(subgraph, output, expr, grad, executions, TTI); + properties += " :name \"" + std::to_string(newCOs.size()) + "\""; + + std::string argStr; + for (const auto &arg : args) { + if (!argStr.empty()) + argStr += " "; + argStr += arg; + } + + std::string herbieInput = + "(FPCore (" + argStr + ") " + properties + " " + expr + ")"; + if (FPOptPrint) + llvm::errs() << "Herbie input:\n" << herbieInput << "\n"; + + herbieInputs.push_back(herbieInput); + newCOs.push_back(CO); + } + + if (!herbieInputs.empty()) { + if (!improveViaHerbie(herbieInputs, newCOs, F.getParent(), TTI, + valueToNodeMap, symbolToValueMap, + subgraphCounter)) { + if (FPOptPrint) + llvm::errs() << "Failed to optimize expressions using Herbie!\n"; + } + + COs.insert(COs.end(), newCOs.begin(), newCOs.end()); + } + } + + if (FPOptEnablePT) { + // Sort `cs.operations` by the gradient and construct + // `PrecisionChange`s. + CandidateSubgraph CS(subgraph, TTI); + auto *o0 = subgraph.outputs[0]; + CS.executions = valueToNodeMap[o0]->executions; + + const SmallVector precTypes{ + PrecisionChangeType::FP32, + PrecisionChangeType::FP64, + }; + + const auto &PTFuncs = getPTFuncs(); + + // Check if we have a cached DP table + std::string cacheFilePath = FPOptCachePath + "/table.json"; + bool skipEvaluation = FPOptSolverType == "dp" && + !FPOptCachePath.empty() && + llvm::sys::fs::exists(cacheFilePath); + + SetVector operations; + for (auto *I : subgraph.operations) { + assert(isa(valueToNodeMap[I].get()) && + "Corrupted FPNode for original instructions"); + auto node = cast(valueToNodeMap[I].get()); + if (PTFuncs.count(node->op) != 0) { + operations.insert(node); + llvm::errs() << "FPOpt: PT Function identified: " << *I << "\n"; + } + } + + // Prioritize operations with low sensitivity scores + SmallVector sortedOps(operations.begin(), operations.end()); + llvm::sort(sortedOps, [](const auto &a, const auto &b) { + return a->sens < b->sens; + }); + + // Create PrecisionChanges for 0-10%, 0-20%, ..., up to 0-100% + size_t lastNumChanged = 0; + for (int percent = 10; percent <= 100; percent += 10) { + size_t numToChange = sortedOps.size() * percent / 100; + if (numToChange == 0 || numToChange == lastNumChanged) { + continue; + } + + lastNumChanged = numToChange; + + if (FPOptPrint && numToChange > 0) { + llvm::errs() << "Created PrecisionChange for " << percent + << "% of Funcs (" << numToChange << ")\n"; + double minSens = sortedOps[0]->sens; + double maxSens = sortedOps[numToChange - 1]->sens; + llvm::errs() << "Sensitivity score range: [" << minSens << ", " + << maxSens << "]\n"; + } + + for (auto prec : precTypes) { + PrecisionChangeType currentPrec = + getPrecisionChangeType(subgraph.outputs[0]->getType()); + if (prec == currentPrec) { + continue; + } + + std::string precStr = getPrecisionChangeTypeString(prec).str(); + std::string desc = + "Funcs 0% -- " + std::to_string(percent) + "% -> " + precStr; + + SetVector nodesToChange(sortedOps.begin(), + sortedOps.begin() + numToChange); + PrecisionChange change(nodesToChange, currentPrec, prec); + + SmallVector changes{std::move(change)}; + PTCandidate candidate{std::move(changes), desc}; + + if (!skipEvaluation) { + candidate.CompCost = getCompCost(subgraph, TTI, candidate); + } + + CS.candidates.push_back(std::move(candidate)); + } + } + + SetVector allOperations; + for (auto *I : subgraph.operations) { + assert(isa(valueToNodeMap[I].get()) && + "Corrupted FPNode for original instructions"); + auto node = cast(valueToNodeMap[I].get()); + allOperations.insert(node); + } + + // Prioritize operations with low sensitivity scores + SmallVector sortedAllOps(allOperations.begin(), + allOperations.end()); + llvm::sort(sortedAllOps, [](const auto &a, const auto &b) { + return a->sens < b->sens; + }); + + // Create PrecisionChanges for 0-10%, 0-20%, ..., up to 0-100% + lastNumChanged = 0; + for (int percent = 10; percent <= 100; percent += 10) { + size_t numToChange = sortedAllOps.size() * percent / 100; + if (numToChange == 0 || numToChange == lastNumChanged) { + continue; + } + + lastNumChanged = numToChange; + + if (FPOptPrint && numToChange > 0) { + llvm::errs() << "Created PrecisionChange for " << percent + << "% of all operations (" << numToChange << ")\n"; + double minSens = sortedAllOps[0]->sens; + double maxSens = sortedAllOps[numToChange - 1]->sens; + llvm::errs() << "Sensitivity score range: [" << minSens << ", " + << maxSens << "]\n"; + } + + for (auto prec : precTypes) { + PrecisionChangeType currentPrec = + getPrecisionChangeType(subgraph.outputs[0]->getType()); + if (prec == currentPrec) { + continue; + } + + std::string precStr = getPrecisionChangeTypeString(prec).str(); + std::string desc = + "All 0% -- " + std::to_string(percent) + "% -> " + precStr; + + SetVector nodesToChange( + sortedAllOps.begin(), sortedAllOps.begin() + numToChange); + PrecisionChange change(nodesToChange, currentPrec, prec); + + SmallVector changes{std::move(change)}; + PTCandidate candidate{std::move(changes), desc}; + + if (!skipEvaluation) { + candidate.CompCost = getCompCost(subgraph, TTI, candidate); + } + + CS.candidates.push_back(std::move(candidate)); + } + } + + if (!skipEvaluation) { + setUnifiedAccuracyCost(CS, valueToNodeMap, symbolToValueMap); + } + + CSs.push_back(std::move(CS)); + } + llvm::errs() << "##### Finished synthesizing candidates for " + << ++subgraphCounter << " of " << subgraphs.size() + << " subgraphs! #####\n"; + } + + // Perform rewrites + if (FPOptPrint) { + if (FPOptEnableHerbie) { + for (auto &CO : COs) { + llvm::errs() << "\n################################\n"; + llvm::errs() << "Initial AccuracyCost: " << CO.initialAccCost << "\n"; + llvm::errs() << "Initial ComputationCost: " << CO.initialCompCost + << "\n"; + llvm::errs() << "Initial HerbieCost: " << CO.initialHerbieCost << "\n"; + llvm::errs() << "Initial HerbieAccuracy: " << CO.initialHerbieAccuracy + << "\n"; + llvm::errs() << "Initial Expression: " << CO.expr << "\n"; + llvm::errs() << "Grad: " << CO.grad << "\n\n"; + llvm::errs() << "Candidates:\n"; + llvm::errs() << "Δ AccCost\t\tΔ " + "CompCost\t\tHerbieCost\t\tAccuracy\t\tExpression\n"; + llvm::errs() << "--------------------------------\n"; + for (size_t i = 0; i < CO.candidates.size(); ++i) { + auto &candidate = CO.candidates[i]; + llvm::errs() << CO.getAccCostDelta(i) << "\t\t" + << CO.getCompCostDelta(i) << "\t\t" + << candidate.herbieCost << "\t\t" + << candidate.herbieAccuracy << "\t\t" << candidate.expr + << "\n"; + } + llvm::errs() << "################################\n\n"; + } + } + if (FPOptEnablePT) { + for (auto &CS : CSs) { + llvm::errs() << "\n################################\n"; + llvm::errs() << "Initial AccuracyCost: " << CS.initialAccCost << "\n"; + llvm::errs() << "Initial ComputationCost: " << CS.initialCompCost + << "\n"; + llvm::errs() << "Candidates:\n"; + llvm::errs() << "Δ AccCost\t\tΔ CompCost\t\tDescription\n" + << "---------------------------\n"; + for (size_t i = 0; i < CS.candidates.size(); ++i) { + auto &candidate = CS.candidates[i]; + llvm::errs() << CS.getAccCostDelta(i) << "\t\t" + << CS.getCompCostDelta(i) << "\t\t" << candidate.desc + << "\n"; + } + llvm::errs() << "################################\n\n"; + } + } + } + + if (!FPOptEnableSolver) { + if (FPOptEnableHerbie) { + for (auto &CO : COs) { + CO.apply(0, valueToNodeMap, symbolToValueMap); + changed = true; + } + } + } else { + if (FPOptSolverType == "greedy") { + changed = + accuracyGreedySolver(COs, CSs, valueToNodeMap, symbolToValueMap); + } else if (FPOptSolverType == "dp") { + changed = accuracyDPSolver(F, TTI, COs, CSs, valueToNodeMap, + symbolToValueMap, errorTol); + } else { + llvm::errs() << "FPOpt: Unknown solver type: " << FPOptSolverType << "\n"; + return false; + } + } + + llvm::errs() << "FPOpt: Finished optimizing " << F.getName() << "\n"; + + // Cleanup + if (changed) { + for (auto &subgraph : subgraphs) { + if (subgraph.outputs_rewritten != subgraph.outputs.size()) { + if (FPOptPrint) + llvm::errs() << "Skip erasing a subgraph: only rewrote " + << subgraph.outputs_rewritten << " of " + << subgraph.outputs.size() << " outputs\n"; + continue; // Intermediate operations cannot be erased safely + } + for (auto *I : subgraph.operations) { + if (FPOptPrint) + llvm::errs() << "Erasing: " << *I << "\n"; + if (!I->use_empty()) { + I->replaceAllUsesWith(UndefValue::get(I->getType())); + } + I->eraseFromParent(); + } + } + + llvm::errs() << "FPOpt: Finished cleaning up " << F.getName() << "\n"; + } + + runPoseidonFunctionSimplify(F, OptimizationLevel::O3); + + if (FPOptPrint) { + llvm::errs() << "FPOpt: Finished Optimization\n"; + F.print(llvm::errs()); + } + + return changed; +} + +namespace {} // namespace + +char FPOpt::ID = 0; + +FPOpt::FPOpt() : FunctionPass(ID) {} + +void FPOpt::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + FunctionPass::getAnalysisUsage(AU); +} + +bool FPOpt::runOnFunction(Function &F) { + auto &TTI = getAnalysis().getTTI(F); + return fpOptimize(F, TTI); +} + +static RegisterPass + X("fp-opt", "Run Enzyme/Poseidon Floating point optimizations"); + +FunctionPass *createFPOptPass() { return new FPOpt(); } + +#include +#include + +#include "llvm/IR/LegacyPassManager.h" + +extern "C" void AddFPOptPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createFPOptPass()); +} + +FPOptNewPM::Result FPOptNewPM::run(llvm::Module &M, + llvm::ModuleAnalysisManager &MAM) { + bool changed = false; + FunctionAnalysisManager &FAM = + MAM.getResult(M).getManager(); + for (auto &F : M) { + if (!F.isDeclaration()) { + const auto &TTI = FAM.getResult(F); + changed |= fpOptimize(F, TTI); + } + } + + return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} +llvm::AnalysisKey FPOptNewPM::Key; diff --git a/enzyme/Enzyme/Poseidon/Poseidon.h b/enzyme/Enzyme/Poseidon/Poseidon.h new file mode 100644 index 000000000000..f018b3d23797 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/Poseidon.h @@ -0,0 +1,61 @@ +#ifndef ENZYME_POSEIDON_H +#define ENZYME_POSEIDON_H + +#include + +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPProfileGenerate; +extern llvm::cl::opt FPProfileUse; +extern llvm::cl::opt FPOptPrint; +extern llvm::cl::opt FPOptEnableHerbie; +extern llvm::cl::opt FPOptEnablePT; +extern llvm::cl::opt FPOptEnableSolver; +extern llvm::cl::opt FPOptMaxExprDepth; +extern llvm::cl::opt FPOptMaxExprLength; +extern llvm::cl::opt FPOptReductionEval; +extern llvm::cl::opt FPOptCachePath; +extern llvm::cl::opt FPOptMultiOutputPTOnly; +} + +bool Poseidonable(const Value &V); +void setPoseidonMetadata(Function &F); +void preprocessForPoseidon(Function *F); +bool fpOptimize(Function &F, const TargetTransformInfo &TTI, + double errorTol = 0.0); + +class FPOpt final : public FunctionPass { +public: + static char ID; + FPOpt(); + + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; +}; + +llvm::FunctionPass *createFPOptPass(); + +class FPOptNewPM final : public llvm::AnalysisInfoMixin { + friend struct llvm::AnalysisInfoMixin; + +private: + static llvm::AnalysisKey Key; + +public: + using Result = llvm::PreservedAnalyses; + FPOptNewPM() {} + + Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM); + + static bool isRequired() { return true; } +}; + +#endif // ENZYME_POSEIDON_H diff --git a/enzyme/Enzyme/Poseidon/PoseidonEvaluators.cpp b/enzyme/Enzyme/Poseidon/PoseidonEvaluators.cpp new file mode 100644 index 000000000000..5655aecad196 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonEvaluators.cpp @@ -0,0 +1,974 @@ +//=- PoseidonEvaluators.cpp - Expression evaluators for Poseidon ----------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the evaluator classes for floating-point expressions +// in the Poseidon optimization pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/raw_ostream.h" + +#include + +#include "PoseidonEvaluators.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonTypes.h" + +using namespace llvm; + +extern "C" { +cl::opt FPOptStrictMode( + "fpopt-strict-mode", cl::init(false), cl::Hidden, + cl::desc( + "Discard all FPOpt candidates that produce NaN or inf outputs for any " + "input point that originally produced finite outputs")); +cl::opt FPOptGeoMeanEps( + "fpopt-geo-mean-eps", cl::init(0.0), cl::Hidden, + cl::desc("The offset used in the geometric mean " + "calculation; if = 0, zeros are replaced with ULPs")); +} + +FPEvaluator::FPEvaluator(PTCandidate *pt) { + if (pt) { + for (const auto &change : pt->changes) { + for (auto node : change.nodes) { + nodePrecisions[node] = change.newType; + } + } + } +} + +PrecisionChangeType FPEvaluator::getNodePrecision(const FPNode *node) const { + // If the node has a new precision from PT, use it + PrecisionChangeType precType; + + auto it = nodePrecisions.find(node); + if (it != nodePrecisions.end()) { + precType = it->second; + } else { + // Otherwise, use the node's original precision + if (node->dtype == "f32") { + precType = PrecisionChangeType::FP32; + } else if (node->dtype == "f64") { + precType = PrecisionChangeType::FP64; + } else { + llvm_unreachable( + ("Operator " + node->op + " has unexpected dtype: " + node->dtype) + .c_str()); + } + } + + if (precType != PrecisionChangeType::FP32 && + precType != PrecisionChangeType::FP64) { + llvm_unreachable("Unsupported FP precision"); + } + + return precType; +} + +void FPEvaluator::evaluateNode(const FPNode *node, + const MapVector &inputValues) { + if (cache.find(node) != cache.end()) + return; + + if (isa(node)) { + double constVal = node->getLowerBound(); + cache.emplace(node, constVal); + return; + } + + if (isa(node) && inputValues.count(cast(node)->value)) { + double inputValue = inputValues.lookup(cast(node)->value); + cache.emplace(node, inputValue); + return; + } + + if (node->op == "if") { + evaluateNode(node->operands[0].get(), inputValues); + double cond = getResult(node->operands[0].get()); + + if (cond == 1.0) { + evaluateNode(node->operands[1].get(), inputValues); + double then_val = getResult(node->operands[1].get()); + cache.emplace(node, then_val); + } else { + evaluateNode(node->operands[2].get(), inputValues); + double else_val = getResult(node->operands[2].get()); + cache.emplace(node, else_val); + } + return; + } else if (node->op == "and") { + evaluateNode(node->operands[0].get(), inputValues); + double op0 = getResult(node->operands[0].get()); + if (op0 != 1.0) { + cache.emplace(node, 0.0); + return; + } + evaluateNode(node->operands[1].get(), inputValues); + double op1 = getResult(node->operands[1].get()); + if (op1 != 1.0) { + cache.emplace(node, 0.0); + return; + } + cache.emplace(node, 1.0); + return; + } else if (node->op == "or") { + evaluateNode(node->operands[0].get(), inputValues); + double op0 = getResult(node->operands[0].get()); + if (op0 == 1.0) { + cache.emplace(node, 1.0); + return; + } + evaluateNode(node->operands[1].get(), inputValues); + double op1 = getResult(node->operands[1].get()); + if (op1 == 1.0) { + cache.emplace(node, 1.0); + return; + } + cache.emplace(node, 0.0); + return; + } else if (node->op == "not") { + evaluateNode(node->operands[0].get(), inputValues); + double op = getResult(node->operands[0].get()); + cache.emplace(node, (op == 1.0) ? 0.0 : 1.0); + return; + } else if (node->op == "TRUE") { + cache.emplace(node, 1.0); + return; + } else if (node->op == "FALSE") { + cache.emplace(node, 0.0); + return; + } + + PrecisionChangeType nodePrec = getNodePrecision(node); + + for (const auto &operand : node->operands) { + evaluateNode(operand.get(), inputValues); + } + + double res = 0.0; + + auto evalUnary = [&](auto doubleFunc, auto floatFunc) -> double { + double op = getResult(node->operands[0].get()); + if (nodePrec == PrecisionChangeType::FP32) + return floatFunc(static_cast(op)); + else + return doubleFunc(op); + }; + + auto evalBinary = [&](auto doubleFunc, auto floatFunc) -> double { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + if (nodePrec == PrecisionChangeType::FP32) + return floatFunc(static_cast(op0), static_cast(op1)); + else + return doubleFunc(op0, op1); + }; + + auto evalTernary = [&](auto doubleFunc, auto floatFunc) -> double { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + double op2 = getResult(node->operands[2].get()); + if (nodePrec == PrecisionChangeType::FP32) + return floatFunc(static_cast(op0), static_cast(op1), + static_cast(op2)); + else + return doubleFunc(op0, op1, op2); + }; + + if (node->op == "neg") { + double op = getResult(node->operands[0].get()); + res = + (nodePrec == PrecisionChangeType::FP32) ? -static_cast(op) : -op; + } else if (node->op == "+") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) + static_cast(op1) + : op0 + op1; + } else if (node->op == "-") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) - static_cast(op1) + : op0 - op1; + } else if (node->op == "*") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) * static_cast(op1) + : op0 * op1; + } else if (node->op == "/") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) / static_cast(op1) + : op0 / op1; + } else if (node->op == "sin") { + res = evalUnary(static_cast(std::sin), + static_cast(sinf)); + } else if (node->op == "cos") { + res = evalUnary(static_cast(std::cos), + static_cast(cosf)); + } else if (node->op == "tan") { + res = evalUnary(static_cast(std::tan), + static_cast(tanf)); + } else if (node->op == "exp") { + res = evalUnary(static_cast(std::exp), + static_cast(expf)); + } else if (node->op == "expm1") { + res = evalUnary(static_cast(std::expm1), + static_cast(expm1f)); + } else if (node->op == "log") { + res = evalUnary(static_cast(std::log), + static_cast(logf)); + } else if (node->op == "log1p") { + res = evalUnary(static_cast(std::log1p), + static_cast(log1pf)); + } else if (node->op == "sqrt") { + res = evalUnary(static_cast(std::sqrt), + static_cast(sqrtf)); + } else if (node->op == "cbrt") { + res = evalUnary(static_cast(std::cbrt), + static_cast(cbrtf)); + } else if (node->op == "asin") { + res = evalUnary(static_cast(std::asin), + static_cast(asinf)); + } else if (node->op == "acos") { + res = evalUnary(static_cast(std::acos), + static_cast(acosf)); + } else if (node->op == "atan") { + res = evalUnary(static_cast(std::atan), + static_cast(atanf)); + } else if (node->op == "sinh") { + res = evalUnary(static_cast(std::sinh), + static_cast(sinhf)); + } else if (node->op == "cosh") { + res = evalUnary(static_cast(std::cosh), + static_cast(coshf)); + } else if (node->op == "tanh") { + res = evalUnary(static_cast(std::tanh), + static_cast(tanhf)); + } else if (node->op == "asinh") { + res = evalUnary(static_cast(std::asinh), + static_cast(asinhf)); + } else if (node->op == "acosh") { + res = evalUnary(static_cast(std::acosh), + static_cast(acoshf)); + } else if (node->op == "atanh") { + res = evalUnary(static_cast(std::atanh), + static_cast(atanhf)); + } else if (node->op == "ceil") { + res = evalUnary(static_cast(std::ceil), + static_cast(ceilf)); + } else if (node->op == "floor") { + res = evalUnary(static_cast(std::floor), + static_cast(floorf)); + } else if (node->op == "exp2") { + res = evalUnary(static_cast(std::exp2), + static_cast(exp2f)); + } else if (node->op == "log10") { + res = evalUnary(static_cast(std::log10), + static_cast(log10f)); + } else if (node->op == "log2") { + res = evalUnary(static_cast(std::log2), + static_cast(log2f)); + } else if (node->op == "rint") { + res = evalUnary(static_cast(std::rint), + static_cast(rintf)); + } else if (node->op == "round") { + res = evalUnary(static_cast(std::round), + static_cast(roundf)); + } else if (node->op == "trunc") { + res = evalUnary(static_cast(std::trunc), + static_cast(truncf)); + } else if (node->op == "pow") { + res = evalBinary(static_cast(std::pow), + static_cast(powf)); + } else if (node->op == "fabs") { + res = evalUnary(static_cast(std::fabs), + static_cast(fabsf)); + } else if (node->op == "hypot") { + res = evalBinary(static_cast(std::hypot), + static_cast(hypotf)); + } else if (node->op == "atan2") { + res = evalBinary(static_cast(std::atan2), + static_cast(atan2f)); + } else if (node->op == "copysign") { + res = evalBinary(static_cast(std::copysign), + static_cast(copysignf)); + } else if (node->op == "fmax") { + res = evalBinary(static_cast(std::fmax), + static_cast(fmaxf)); + } else if (node->op == "fmin") { + res = evalBinary(static_cast(std::fmin), + static_cast(fminf)); + } else if (node->op == "fdim") { + res = evalBinary(static_cast(std::fdim), + static_cast(fdimf)); + } else if (node->op == "fmod") { + res = evalBinary(static_cast(std::fmod), + static_cast(fmodf)); + } else if (node->op == "remainder") { + res = evalBinary(static_cast(std::remainder), + static_cast(remainderf)); + } else if (node->op == "fma") { + res = evalTernary(static_cast(std::fma), + static_cast(fmaf)); + } else if (node->op == "lgamma") { + res = evalUnary(static_cast(std::lgamma), + static_cast(lgammaf)); + } else if (node->op == "tgamma") { + res = evalUnary(static_cast(std::tgamma), + static_cast(tgammaf)); + } else if (node->op == "==") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) == static_cast(op1) + : op0 == op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "!=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) != static_cast(op1) + : op0 != op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "<") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) < static_cast(op1) + : op0 < op1; + res = result ? 1.0 : 0.0; + } else if (node->op == ">") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) > static_cast(op1) + : op0 > op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "<=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) <= static_cast(op1) + : op0 <= op1; + res = result ? 1.0 : 0.0; + } else if (node->op == ">=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) >= static_cast(op1) + : op0 >= op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "PI") { + res = M_PI; + } else if (node->op == "E") { + res = M_E; + } else if (node->op == "INFINITY") { + res = INFINITY; + } else if (node->op == "NAN") { + res = NAN; + } else { + std::string msg = "FPEvaluator: Unexpected operator " + node->op; + llvm_unreachable(msg.c_str()); + } + + cache.emplace(node, res); +} + +double FPEvaluator::getResult(const FPNode *node) const { + auto it = cache.find(node); + assert(it != cache.end() && "Node not evaluated yet"); + return it->second; +} + +MPFREvaluator::CachedValue::CachedValue(unsigned prec) : prec(prec) { + mpfr_init2(value, prec); + mpfr_set_zero(value, 1); +} + +MPFREvaluator::CachedValue::CachedValue(CachedValue &&other) noexcept + : prec(other.prec) { + mpfr_init2(value, other.prec); + mpfr_swap(value, other.value); +} + +MPFREvaluator::CachedValue & +MPFREvaluator::CachedValue::operator=(CachedValue &&other) noexcept { + if (this != &other) { + mpfr_set_prec(value, other.prec); + prec = other.prec; + mpfr_swap(value, other.value); + } + return *this; +} + +MPFREvaluator::CachedValue::~CachedValue() { mpfr_clear(value); } + +MPFREvaluator::MPFREvaluator(unsigned prec, PTCandidate *pt) : prec(prec) { + if (pt) { + for (const auto &change : pt->changes) { + for (auto node : change.nodes) { + nodeToNewPrec[node] = getMPFRPrec(change.newType); + } + } + } +} + +unsigned MPFREvaluator::getNodePrecision(const FPNode *node, + bool groundTruth) const { + // If trying to evaluate the ground truth, use the current MPFR precision + if (groundTruth) + return prec; + + // If the node has a new precision for PT, use it + auto it = nodeToNewPrec.find(node); + if (it != nodeToNewPrec.end()) { + return it->second; + } + + // Otherwise, use the original precision + return node->getMPFRPrec(); +} + +// Compute the expression with MPFR at `prec` precision +// recursively. When operand is a FPConst, use its lower +// bound. When operand is a FPLLValue, get its inputs from +// `inputs`. +void MPFREvaluator::evaluateNode(const FPNode *node, + const MapVector &inputValues, + bool groundTruth) { + if (cache.find(node) != cache.end()) + return; + + if (isa(node)) { + double constVal = node->getLowerBound(); + CachedValue cv(53); + mpfr_set_d(cv.value, constVal, MPFR_RNDN); + cache.emplace(node, CachedValue(std::move(cv))); + return; + } + + if (isa(node) && inputValues.count(cast(node)->value)) { + double inputValue = inputValues.lookup(cast(node)->value); + CachedValue cv(53); + mpfr_set_d(cv.value, inputValue, MPFR_RNDN); + cache.emplace(node, std::move(cv)); + return; + } + + if (node->op == "if") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &cond = getResult(node->operands[0].get()); + if (0 == mpfr_cmp_ui(cond, 1)) { + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &then_val = getResult(node->operands[1].get()); + cache.emplace(node, CachedValue(cache.at(node->operands[1].get()).prec)); + mpfr_set(cache.at(node).value, then_val, MPFR_RNDN); + } else { + evaluateNode(node->operands[2].get(), inputValues, groundTruth); + mpfr_t &else_val = getResult(node->operands[2].get()); + cache.emplace(node, CachedValue(cache.at(node->operands[2].get()).prec)); + mpfr_set(cache.at(node).value, else_val, MPFR_RNDN); + } + return; + } + + unsigned nodePrec = getNodePrecision(node, groundTruth); + cache.emplace(node, CachedValue(nodePrec)); + mpfr_t &res = cache.at(node).value; + + if (node->op == "neg") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_neg(res, op, MPFR_RNDN); + } else if (node->op == "+") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_add(res, op0, op1, MPFR_RNDN); + } else if (node->op == "-") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_sub(res, op0, op1, MPFR_RNDN); + } else if (node->op == "*") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_mul(res, op0, op1, MPFR_RNDN); + } else if (node->op == "/") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_div(res, op0, op1, MPFR_RNDN); + } else if (node->op == "sin") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_sin(res, op, MPFR_RNDN); + } else if (node->op == "cos") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_cos(res, op, MPFR_RNDN); + } else if (node->op == "tan") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_tan(res, op, MPFR_RNDN); + } else if (node->op == "asin") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_asin(res, op, MPFR_RNDN); + } else if (node->op == "acos") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_acos(res, op, MPFR_RNDN); + } else if (node->op == "atan") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_atan(res, op, MPFR_RNDN); + } else if (node->op == "atan2") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_atan2(res, op0, op1, MPFR_RNDN); + } else if (node->op == "exp") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_exp(res, op, MPFR_RNDN); + } else if (node->op == "expm1") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_expm1(res, op, MPFR_RNDN); + } else if (node->op == "log") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log(res, op, MPFR_RNDN); + } else if (node->op == "log1p") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log1p(res, op, MPFR_RNDN); + } else if (node->op == "sqrt") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_sqrt(res, op, MPFR_RNDN); + } else if (node->op == "cbrt") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_cbrt(res, op, MPFR_RNDN); + } else if (node->op == "pow") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_pow(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fma") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + evaluateNode(node->operands[2].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_t &op2 = getResult(node->operands[2].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_prec_round(op2, nodePrec, MPFR_RNDN); + mpfr_fma(res, op0, op1, op2, MPFR_RNDN); + } else if (node->op == "fabs") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_abs(res, op, MPFR_RNDN); + } else if (node->op == "hypot") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_hypot(res, op0, op1, MPFR_RNDN); + } else if (node->op == "asinh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_asinh(res, op, MPFR_RNDN); + } else if (node->op == "acosh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_acosh(res, op, MPFR_RNDN); + } else if (node->op == "atanh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_atanh(res, op, MPFR_RNDN); + } else if (node->op == "sinh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_sinh(res, op, MPFR_RNDN); + } else if (node->op == "cosh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_cosh(res, op, MPFR_RNDN); + } else if (node->op == "tanh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_tanh(res, op, MPFR_RNDN); + } else if (node->op == "ceil") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_ceil(res, op); + } else if (node->op == "floor") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_floor(res, op); + } else if (node->op == "erf") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_erf(res, op, MPFR_RNDN); + } else if (node->op == "exp2") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_exp2(res, op, MPFR_RNDN); + } else if (node->op == "log10") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log10(res, op, MPFR_RNDN); + } else if (node->op == "log2") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log2(res, op, MPFR_RNDN); + } else if (node->op == "rint") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_rint(res, op, MPFR_RNDN); + } else if (node->op == "round") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_round(res, op); + } else if (node->op == "trunc") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_trunc(res, op); + } else if (node->op == "copysign") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_copysign(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fdim") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_dim(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fmod") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_fmod(res, op0, op1, MPFR_RNDN); + } else if (node->op == "remainder") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_remainder(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fmax") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_max(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fmin") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_min(res, op0, op1, MPFR_RNDN); + } else if (node->op == "==") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 == mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "!=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 != mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "<") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 > mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == ">") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 < mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "<=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 >= mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == ">=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 <= mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "and") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 == mpfr_cmp_ui(op0, 1) && 0 == mpfr_cmp_ui(op1, 1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "or") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 == mpfr_cmp_ui(op0, 1) || 0 == mpfr_cmp_ui(op1, 1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "not") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_set_prec(res, nodePrec); + if (0 == mpfr_cmp_ui(op, 1)) + mpfr_set_ui(res, 0, MPFR_RNDN); + else + mpfr_set_ui(res, 1, MPFR_RNDN); + } else if (node->op == "TRUE") { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else if (node->op == "FALSE") { + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "PI") { + mpfr_const_pi(res, MPFR_RNDN); + } else if (node->op == "E") { + mpfr_const_euler(res, MPFR_RNDN); + } else if (node->op == "INFINITY") { + mpfr_set_inf(res, 1); + } else if (node->op == "NAN") { + mpfr_set_nan(res); + } else { + llvm::errs() << "MPFREvaluator: Unexpected operator '" << node->op << "'\n"; + llvm_unreachable("Unexpected operator encountered"); + } +} + +mpfr_t &MPFREvaluator::getResult(FPNode *node) { + assert(cache.count(node) > 0 && "MPFREvaluator: Unexpected unevaluated node"); + return cache.at(node).value; +} + +// Emulate computation using native floating-point types +void getFPValues(ArrayRef outputs, + const MapVector &inputValues, + SmallVectorImpl &results, PTCandidate *pt) { + assert(!outputs.empty()); + results.resize(outputs.size()); + + FPEvaluator evaluator(pt); + + for (const auto *output : outputs) { + evaluator.evaluateNode(output, inputValues); + } + + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = evaluator.getResult(outputs[i]); + } +} + +// If looking for ground truth, compute a "correct" answer with MPFR. +// For each sampled input configuration: +// 0. Ignore `FPNode.dtype`. +// 1. Compute the expression with MPFR at `prec` precision +// by calling `MPFRValueHelper`. When operand is a FPConst, use its +// lower bound. When operand is a FPLLValue, get its inputs from +// `inputs`. +// 2. Dynamically extend precisions +// until the first `groundTruthPrec` bits of significand don't change. +// Otherwise, compute the expression with MPFR at precisions specified within +// `FPNode`s or new precisions specified by `pt`. +void getMPFRValues(ArrayRef outputs, + const MapVector &inputValues, + SmallVectorImpl &results, bool groundTruth, + const unsigned groundTruthPrec, PTCandidate *pt) { + assert(!outputs.empty()); + results.resize(outputs.size()); + + if (!groundTruth) { + MPFREvaluator evaluator(0, pt); + // if (pt) { + // llvm::errs() << "getMPFRValues: PT candidate detected: " << pt->desc + // << "\n"; + // } else { + // llvm::errs() << "getMPFRValues: emulating original computation\n"; + // } + + for (const auto *output : outputs) { + evaluator.evaluateNode(output, inputValues, false); + } + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + return; + } + + unsigned curPrec = 64; + std::vector prevResExp(outputs.size(), 0); + std::vector prevResStr(outputs.size(), nullptr); + std::vector prevResSign(outputs.size(), 0); + std::vector converged(outputs.size(), false); + size_t numConverged = 0; + + while (true) { + MPFREvaluator evaluator(curPrec, nullptr); + + // llvm::errs() << "getMPFRValues: computing ground truth with precision " + // << curPrec << "\n"; + + for (const auto *output : outputs) { + evaluator.evaluateNode(output, inputValues, true); + } + + for (size_t i = 0; i < outputs.size(); ++i) { + if (converged[i]) + continue; + + mpfr_t &res = evaluator.getResult(outputs[i]); + int resSign = mpfr_sgn(res); + mpfr_exp_t resExp; + char *resStr = + mpfr_get_str(nullptr, &resExp, 2, groundTruthPrec, res, MPFR_RNDN); + + if (prevResStr[i] != nullptr && resSign == prevResSign[i] && + resExp == prevResExp[i] && strcmp(resStr, prevResStr[i]) == 0) { + converged[i] = true; + numConverged++; + mpfr_free_str(resStr); + mpfr_free_str(prevResStr[i]); + prevResStr[i] = nullptr; + continue; + } + + if (prevResStr[i]) { + mpfr_free_str(prevResStr[i]); + } + prevResStr[i] = resStr; + prevResExp[i] = resExp; + prevResSign[i] = resSign; + } + + if (numConverged == outputs.size()) { + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + break; + } + + curPrec *= 2; + + if (curPrec > FPOptMaxMPFRPrec) { + llvm::errs() << "getMPFRValues: MPFR precision limit reached for some " + "outputs, returning NaN\n"; + for (size_t i = 0; i < outputs.size(); ++i) { + if (!converged[i]) { + mpfr_free_str(prevResStr[i]); + results[i] = std::numeric_limits::quiet_NaN(); + } else { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + } + return; + } + } +} \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonEvaluators.h b/enzyme/Enzyme/Poseidon/PoseidonEvaluators.h new file mode 100644 index 000000000000..99258b76dc08 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonEvaluators.h @@ -0,0 +1,86 @@ +//=- PoseidonEvaluators.h - Expression evaluators for Poseidon ------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the evaluator classes for floating-point expressions +// in the Poseidon optimization pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_EVALUATORS_H +#define ENZYME_POSEIDON_EVALUATORS_H + +#include "llvm/ADT/MapVector.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CommandLine.h" +#include +#include + +#include "PoseidonPrecUtils.h" +#include "PoseidonTypes.h" + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptStrictMode; +extern llvm::cl::opt FPOptGeoMeanEps; +} + +class FPEvaluator { +private: + std::unordered_map cache; + std::unordered_map nodePrecisions; + +public: + FPEvaluator(PTCandidate *pt = nullptr); + + PrecisionChangeType getNodePrecision(const FPNode *node) const; + void evaluateNode(const FPNode *node, + const MapVector &inputValues); + double getResult(const FPNode *node) const; +}; + +class MPFREvaluator { +private: + struct CachedValue { + mpfr_t value; + unsigned prec; + + CachedValue(unsigned prec); + CachedValue(const CachedValue &) = delete; + CachedValue &operator=(const CachedValue &) = delete; + CachedValue(CachedValue &&other) noexcept; + CachedValue &operator=(CachedValue &&other) noexcept; + virtual ~CachedValue(); + }; + + std::unordered_map cache; + unsigned prec; + std::unordered_map nodeToNewPrec; + +public: + MPFREvaluator(unsigned prec, PTCandidate *pt = nullptr); + virtual ~MPFREvaluator() = default; + + unsigned getNodePrecision(const FPNode *node, bool groundTruth) const; + void evaluateNode(const FPNode *node, + const MapVector &inputValues, + bool groundTruth); + mpfr_t &getResult(FPNode *node); +}; + +void getFPValues(ArrayRef outputs, + const MapVector &inputValues, + SmallVectorImpl &results, PTCandidate *pt = nullptr); + +void getMPFRValues(ArrayRef outputs, + const MapVector &inputValues, + SmallVectorImpl &results, bool groundTruth, + const unsigned groundTruthPrec = 53, + PTCandidate *pt = nullptr); + +#endif // ENZYME_POSEIDON_EVALUATORS_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.cpp new file mode 100644 index 000000000000..a2516be3917d --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.cpp @@ -0,0 +1,895 @@ +//=- PoseidonHerbieUtils.cpp - Herbie integration utilities ---------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for integrating with the Herbie tool for +// floating-point expression optimization. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar/EarlyCSE.h" +#include "llvm/Transforms/Utils/Mem2Reg.h" + +#include "../Utils.h" +#include "Poseidon.h" +#include "PoseidonEvaluators.h" +#include "PoseidonHerbieUtils.h" +#include "PoseidonSolvers.h" +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +#include +#include +#include +#include +#include + +using namespace llvm; + +extern "C" { +cl::opt HerbieNumThreads("herbie-num-threads", cl::init(8), cl::Hidden, + cl::desc("Number of threads Herbie uses")); +cl::opt HerbieTimeout("herbie-timeout", cl::init(9999), cl::Hidden, + cl::desc("Herbie's timeout to use for each " + "candidate expressions.")); +cl::opt + HerbieNumPoints("herbie-num-pts", cl::init(1024), cl::Hidden, + cl::desc("Number of input points Herbie uses to evaluate " + "candidate expressions.")); +cl::opt HerbieNumIters( + "herbie-num-iters", cl::init(6), cl::Hidden, + cl::desc("Number of times Herbie attempts to improve accuracy.")); +cl::opt HerbieNumEnodes( + "herbie-num-enodes", cl::init(8000), cl::Hidden, + cl::desc("Number of equivalence graph nodes to use when doing algebraic " + "reasoning in Herbie.")); +cl::opt HerbieDisableNumerics( + "herbie-disable-numerics", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie rewrite rules that produce numerical shorthands " + "expm1, log1p, fma, and hypot")); +cl::opt HerbieDisableArithmetic( + "herbie-disable-arithmetic", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie rewrite rules on basic arithmetic fasts.")); +cl::opt HerbieDisableFractions( + "herbie-disable-fractions", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie rewrite rules on fraction arithmetic.")); +cl::opt + HerbieDisableTaylor("herbie-disable-taylor", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie's series expansion")); +cl::opt HerbieDisableSetupSimplify( + "herbie-disable-setup-simplify", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from pre-simplifying expressions")); +cl::opt HerbieDisableGenSimplify( + "herbie-disable-gen-simplify", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from simplifying expressions " + "during the main improvement loop")); +cl::opt HerbieDisableRegime( + "herbie-disable-regime", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from branching between expressions candidates")); +cl::opt HerbieDisableBranchExpr( + "herbie-disable-branch-expr", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from branching on expressions")); +cl::opt HerbieDisableAvgError( + "herbie-disable-avg-error", cl::init(false), cl::Hidden, + cl::desc("Make Herbie choose the candidates with the least maximum error")); +} + +std::shared_ptr parseHerbieExpr( + const std::string &expr, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + // if (FPOptPrint) + // llvm::errs() << "Parsing: " << expr << "\n"; + std::string trimmedExpr = expr; + trimmedExpr.erase(0, trimmedExpr.find_first_not_of(" ")); + trimmedExpr.erase(trimmedExpr.find_last_not_of(" ") + 1); + + // Arguments + if (trimmedExpr.front() != '(' && trimmedExpr.front() != '#') { + if (auto node = valueToNodeMap[symbolToValueMap[trimmedExpr]]) { + return node; + } + } + + // Constants + static const std::regex constantPattern( + "^#s\\(literal\\s+([-+]?\\d+(/\\d+)?|[-+]?inf\\.0)\\s+(\\w+)\\)$"); + static const std::regex plainConstantPattern( + R"(^([-+]?(\d+(\.\d+)?)(/\d+)?|[-+]?inf\.0))"); + + { + std::smatch matches; + if (std::regex_match(trimmedExpr, matches, constantPattern)) { + std::string value = matches[1].str(); + std::string dtype = matches[3].str(); + if (dtype == "binary64") { + dtype = "f64"; + } else if (dtype == "binary32") { + dtype = "f32"; + } else { + std::string msg = + "Herbie expr parser: Unexpected constant dtype: " + dtype; + llvm_unreachable(msg.c_str()); + } + // if (FPOptPrint) + // llvm::errs() << "Herbie expr parser: Found __const " << value + // << " with dtype " << dtype << "\n"; + return std::make_shared(value, dtype); + } else if (std::regex_match(trimmedExpr, matches, plainConstantPattern)) { + std::string value = matches[1].str(); + std::string dtype = "f64"; // Assume f64 by default + return std::make_shared(value, dtype); + } + } + + if (trimmedExpr.substr(0, 9) == "#s(approx") { + if (trimmedExpr.back() != ')') { + llvm_unreachable(("Malformed approx expression: " + trimmedExpr).c_str()); + } + std::string inner = trimmedExpr.substr(9, trimmedExpr.size() - 9 - 1); + inner.erase(0, inner.find_first_not_of(" ")); + inner.erase(inner.find_last_not_of(" ") + 1); + + int depth = 0; + size_t splitPos = std::string::npos; + for (size_t i = 0; i < inner.size(); ++i) { + if (inner[i] == '(') + depth++; + else if (inner[i] == ')') + depth--; + else if (inner[i] == ' ' && depth == 0) { + splitPos = i; + break; + } + } + if (splitPos == std::string::npos) { + llvm_unreachable(("Malformed approx expression: " + trimmedExpr).c_str()); + } + std::string resultPart = inner.substr(splitPos + 1); + resultPart.erase(0, resultPart.find_first_not_of(" ")); + resultPart.erase(resultPart.find_last_not_of(" ") + 1); + return parseHerbieExpr(resultPart, valueToNodeMap, symbolToValueMap); + } + + if (trimmedExpr.front() != '(' || trimmedExpr.back() != ')') { + llvm::errs() << "Unexpected subexpression: " << trimmedExpr << "\n"; + assert(0 && "Failed to parse Herbie expression"); + } + + trimmedExpr = trimmedExpr.substr(1, trimmedExpr.size() - 2); + + auto endOp = trimmedExpr.find(' '); + std::string fullOp = trimmedExpr.substr(0, endOp); + + size_t pos = fullOp.find('.'); + std::string dtype; + std::string op; + if (pos != std::string::npos) { + op = fullOp.substr(0, pos); + dtype = fullOp.substr(pos + 1); + assert(dtype == "f64" || dtype == "f32"); + // llvm::errs() << "Herbie expr parser: Found operator " << op + // << " with dtype " << dtype << "\n"; + } else { + op = fullOp; + // llvm::errs() << "Herbie expr parser: Found operator " << op << "\n"; + } + + auto node = std::make_shared(op, dtype); + + int depth = 0; + auto start = trimmedExpr.find_first_not_of(" ", endOp); + std::string::size_type curr; + for (curr = start; curr < trimmedExpr.size(); ++curr) { + if (trimmedExpr[curr] == '(') + depth++; + if (trimmedExpr[curr] == ')') + depth--; + if (depth == 0 && trimmedExpr[curr] == ' ') { + node->addOperand(parseHerbieExpr(trimmedExpr.substr(start, curr - start), + valueToNodeMap, symbolToValueMap)); + start = curr + 1; + } + } + if (start < curr) { + node->addOperand(parseHerbieExpr(trimmedExpr.substr(start, curr - start), + valueToNodeMap, symbolToValueMap)); + } + + return node; +} + +bool improveViaHerbie( + const std::vector &inputExprs, + std::vector &COs, Module *M, + const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + int subgraphIdx) { + std::string Program = HERBIE_BINARY; + llvm::errs() << "random seed: " << std::to_string(FPOptRandomSeed) << "\n"; + + SmallVector BaseArgs = { + Program, "report", + "--seed", std::to_string(FPOptRandomSeed), + "--timeout", std::to_string(HerbieTimeout), + "--threads", std::to_string(HerbieNumThreads), + "--num-points", std::to_string(HerbieNumPoints), + "--num-iters", std::to_string(HerbieNumIters), + "--num-enodes", std::to_string(HerbieNumEnodes)}; + + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:proofs"); + + if (HerbieDisableNumerics) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("rules:numerics"); + } + + if (HerbieDisableArithmetic) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("rules:arithmetic"); + } + + if (HerbieDisableFractions) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("rules:fractions"); + } + + if (HerbieDisableSetupSimplify) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("setup:simplify"); + } + + if (HerbieDisableGenSimplify) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:simplify"); + } + + if (HerbieDisableTaylor) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:taylor"); + } + + if (HerbieDisableRegime) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:regimes"); + } + + if (HerbieDisableBranchExpr) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:branch-expressions"); + } + + if (HerbieDisableAvgError) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:avg-error"); + } + + SmallVector> BaseArgsList; + BaseArgsList.push_back(BaseArgs); + + std::vector> seenExprs(COs.size()); + + bool success = false; + + auto processHerbieOutput = [&](const std::string &content, + bool skipEvaluation = false) -> bool { + Expected parsed = json::parse(content); + if (!parsed) { + llvm::errs() << "Failed to parse Herbie result!\n"; + return false; + } + + json::Object *obj = parsed->getAsObject(); + json::Array &tests = *obj->getArray("tests"); + + for (size_t testIndex = 0; testIndex < tests.size(); ++testIndex) { + auto &test = *tests[testIndex].getAsObject(); + + StringRef bestExpr = test.getString("output").value(); + if (bestExpr == "#f") { + continue; + } + + StringRef ID = test.getString("name").value(); + size_t index = std::stoul(ID.str()); + if (index >= COs.size()) { + llvm::errs() << "Invalid CO index: " << index << "\n"; + continue; + } + + CandidateOutput &CO = COs[index]; + auto &seenExprSet = seenExprs[index]; + + double bits = test.getNumber("bits").value(); + json::Array &costAccuracy = *test.getArray("cost-accuracy"); + + json::Array &initial = *costAccuracy[0].getAsArray(); + double initialCostVal = initial[0].getAsNumber().value(); + double initialAccuracy = 1.0 - initial[1].getAsNumber().value() / bits; + double initialCost = 1.0; + + CO.initialHerbieCost = initialCost; + CO.initialHerbieAccuracy = initialAccuracy; + + if (seenExprSet.count(bestExpr.str()) == 0) { + seenExprSet.insert(bestExpr.str()); + + json::Array &best = *costAccuracy[1].getAsArray(); + double bestCost = best[0].getAsNumber().value() / initialCostVal; + double bestAccuracy = 1.0 - best[1].getAsNumber().value() / bits; + + RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str()); + if (!skipEvaluation) { + bestCandidate.CompCost = getCompCost( + bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(CO.oldOutput)->getFastMathFlags()); + } + CO.candidates.push_back(bestCandidate); + } + + json::Array &alternatives = *costAccuracy[2].getAsArray(); + + // Handle alternatives + for (size_t j = 0; j < alternatives.size(); ++j) { + json::Array &entry = *alternatives[j].getAsArray(); + StringRef expr = entry[2].getAsString().value(); + + if (seenExprSet.count(expr.str()) != 0) { + continue; + } + seenExprSet.insert(expr.str()); + + double cost = entry[0].getAsNumber().value() / initialCostVal; + double accuracy = 1.0 - entry[1].getAsNumber().value() / bits; + + RewriteCandidate candidate(cost, accuracy, expr.str()); + if (!skipEvaluation) { + candidate.CompCost = + getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(CO.oldOutput)->getFastMathFlags()); + } + CO.candidates.push_back(candidate); + } + + if (!skipEvaluation) { + setUnifiedAccuracyCost(CO, valueToNodeMap, symbolToValueMap); + } + } + return true; + }; + + for (size_t baseArgsIndex = 0; baseArgsIndex < BaseArgsList.size(); + ++baseArgsIndex) { + const auto &BaseArgs = BaseArgsList[baseArgsIndex]; + std::string content; + + // Try to get cached Herbie output first + std::string cacheFilePath; + bool cached = false; + + if (!FPOptCachePath.empty()) { + cacheFilePath = FPOptCachePath + "/cachedHerbieOutput_" + + std::to_string(subgraphIdx) + "_" + + std::to_string(baseArgsIndex) + ".txt"; + std::ifstream cacheFile(cacheFilePath); + if (cacheFile) { + content.assign((std::istreambuf_iterator(cacheFile)), + std::istreambuf_iterator()); + cacheFile.close(); + llvm::errs() << "Using cached Herbie output from " << cacheFilePath + << "\n"; + cached = true; + } + } + + // If we have cached output, process it directly + if (cached) { + llvm::errs() << "Herbie output: " << content << "\n"; + std::string dpCacheFilePath = FPOptCachePath + "/table.json"; + bool skipEvaluation = FPOptSolverType == "dp" && + !FPOptCachePath.empty() && + llvm::sys::fs::exists(dpCacheFilePath); + if (processHerbieOutput(content, skipEvaluation)) { + success = true; + } + continue; + } + + // No cached result, need to run Herbie + SmallString<32> tmpin, tmpout; + + if (llvm::sys::fs::createUniqueFile("herbie_input_%%%%%%%%%%%%%%%%", tmpin, + llvm::sys::fs::perms::owner_all)) { + llvm::errs() << "Failed to create a unique input file.\n"; + continue; + } + + if (llvm::sys::fs::createUniqueDirectory("herbie_output_%%%%%%%%%%%%%%%%", + tmpout)) { + llvm::errs() << "Failed to create a unique output directory.\n"; + if (auto EC = llvm::sys::fs::remove(tmpin)) + llvm::errs() << "Warning: Failed to remove temporary input file: " + << EC.message() << "\n"; + continue; + } + + std::ofstream input(tmpin.c_str()); + if (!input) { + llvm::errs() << "Failed to open input file.\n"; + if (auto EC = llvm::sys::fs::remove(tmpin)) + llvm::errs() << "Warning: Failed to remove temporary input file: " + << EC.message() << "\n"; + if (auto EC = llvm::sys::fs::remove_directories(tmpout)) + llvm::errs() << "Warning: Failed to remove temporary output directory: " + << EC.message() << "\n"; + continue; + } + for (const auto &expr : inputExprs) { + input << expr << "\n"; + } + input.close(); + + SmallVector Args; + Args.reserve(BaseArgs.size()); + for (const auto &arg : BaseArgs) { + Args.emplace_back(arg); + } + + Args.push_back(tmpin); + Args.push_back(tmpout); + + std::string ErrMsg; + bool ExecutionFailed = false; + + if (FPOptPrint) { + llvm::errs() << "Executing Herbie with arguments: "; + for (const auto &arg : Args) { + llvm::errs() << arg << " "; + } + llvm::errs() << "\n"; + } + + llvm::sys::ExecuteAndWait(Program, Args, /*Env=*/{}, + /*Redirects=*/{}, + /*SecondsToWait=*/0, /*MemoryLimit=*/0, &ErrMsg, + &ExecutionFailed); + + std::remove(tmpin.c_str()); + if (ExecutionFailed) { + llvm::errs() << "Execution failed: " << ErrMsg << "\n"; + if (auto EC = llvm::sys::fs::remove_directories(tmpout)) + llvm::errs() << "Warning: Failed to remove temporary output directory: " + << EC.message() << "\n"; + continue; + } + + std::ifstream output((tmpout + "/results.json").str()); + if (!output) { + llvm::errs() << "Failed to open output file.\n"; + if (auto EC = llvm::sys::fs::remove_directories(tmpout)) + llvm::errs() << "Warning: Failed to remove temporary output directory: " + << EC.message() << "\n"; + continue; + } + content.assign((std::istreambuf_iterator(output)), + std::istreambuf_iterator()); + output.close(); + if (auto EC = llvm::sys::fs::remove_directories(tmpout)) + llvm::errs() << "Warning: Failed to remove temporary output directory: " + << EC.message() << "\n"; + + llvm::errs() << "Herbie output: " << content << "\n"; + + // Save output to cache if needed + if (!FPOptCachePath.empty()) { + if (auto EC = llvm::sys::fs::create_directories(FPOptCachePath, true)) + llvm::errs() << "Warning: Could not create cache directory: " + << EC.message() << "\n"; + std::ofstream cacheFile(cacheFilePath); + if (!cacheFile) { + llvm_unreachable("Failed to open cache file for writing"); + } else { + cacheFile << content; + cacheFile.close(); + llvm::errs() << "Saved Herbie output to cache file " << cacheFilePath + << "\n"; + } + } + + // Process the output + if (processHerbieOutput(content, false)) { + success = true; + } + } + + return success; +} + +std::string getHerbieOperator(const Instruction &I) { + switch (I.getOpcode()) { + case Instruction::FNeg: + return "neg"; + case Instruction::FAdd: + return "+"; + case Instruction::FSub: + return "-"; + case Instruction::FMul: + return "*"; + case Instruction::FDiv: + return "/"; + case Instruction::Call: { + const CallInst *CI = dyn_cast(&I); + assert(CI && CI->getCalledFunction() && + "getHerbieOperator: Call without a function"); + + StringRef funcName = CI->getCalledFunction()->getName(); + + // LLVM intrinsics + if (startsWith(funcName, "llvm.")) { + std::regex regex("llvm\\.(\\w+)\\.?.*"); + std::smatch matches; + std::string nameStr = funcName.str(); + if (std::regex_search(nameStr, matches, regex) && matches.size() > 1) { + std::string intrinsic = matches[1]; + // Special case mappings + if (intrinsic == "fmuladd") + return "fma"; + if (intrinsic == "maxnum") + return "fmax"; + if (intrinsic == "minnum") + return "fmin"; + if (intrinsic == "powi") + return "pow"; + return intrinsic; + } + assert(0 && "getHerbieOperator: Unknown LLVM intrinsic"); + } + // libm functions + else { + std::string name = funcName.str(); + if (!name.empty() && name.back() == 'f') { + name.pop_back(); + } + return name; + } + } + default: + assert(0 && "getHerbieOperator: Unknown operator"); + } +} + +std::string getPrecondition( + const SmallSet &args, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap) { + std::string preconditions; + + for (const auto &arg : args) { + const auto node = valueToNodeMap.at(symbolToValueMap.at(arg)); + double lower = node->getLowerBound(); + double upper = node->getUpperBound(); + + if (upper - lower < 1e-10 && !std::isinf(lower) && !std::isinf(upper)) { + double midpoint = (lower + upper) / 2.0; + double tolerance = std::max(1e-10, std::abs(midpoint) * 1e-6); + lower = midpoint - tolerance; + upper = midpoint + tolerance; + } + + std::ostringstream lowerStr, upperStr; + lowerStr << std::setprecision(std::numeric_limits::max_digits10) + << std::scientific << lower; + upperStr << std::setprecision(std::numeric_limits::max_digits10) + << std::scientific << upper; + + preconditions += + " (<=" + + (std::isinf(lower) ? (lower > 0 ? " INFINITY" : " (- INFINITY)") + : (" " + lowerStr.str())) + + " " + arg + + (std::isinf(upper) ? (upper > 0 ? " INFINITY" : " (- INFINITY)") + : (" " + upperStr.str())) + + ")"; + } + + return preconditions.empty() ? "TRUE" : "(and" + preconditions + ")"; +} + +void setUnifiedAccuracyCost( + CandidateOutput &CO, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + + SmallVector, 4> sampledPoints; + getSampledPoints(CO.subgraph->inputs.getArrayRef(), valueToNodeMap, + symbolToValueMap, sampledPoints); + + SmallVector goldVals; + goldVals.resize(FPOptNumSamples); + + double origCost = 0.0; + if (FPOptReductionEval == "geomean") { + double sumLog = 0.0; + unsigned count = 0; + for (const auto &pair : enumerate(sampledPoints)) { + std::shared_ptr node = valueToNodeMap[CO.oldOutput]; + SmallVector results; + getMPFRValues({node.get()}, pair.value(), results, true, 53); + double goldVal = results[0]; + goldVals[pair.index()] = goldVal; + + getFPValues({node.get()}, pair.value(), results); + double realVal = results[0]; + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + if (error == 0.0) { + if (FPOptGeoMeanEps == 0.0) + error = getOneULP(goldVal); + else + error += FPOptGeoMeanEps; + } else if (FPOptGeoMeanEps != 0.0) { + error += FPOptGeoMeanEps; + } + sumLog += std::log(error); + ++count; + } + } + assert(count != 0 && "No valid sample found for original expr"); + origCost = std::exp(sumLog / count); + } else if (FPOptReductionEval == "arithmean") { + double sum = 0.0; + unsigned count = 0; + for (const auto &pair : enumerate(sampledPoints)) { + std::shared_ptr node = valueToNodeMap[CO.oldOutput]; + SmallVector results; + getMPFRValues({node.get()}, pair.value(), results, true, 53); + double goldVal = results[0]; + goldVals[pair.index()] = goldVal; + + getFPValues({node.get()}, pair.value(), results); + double realVal = results[0]; + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + sum += error; + ++count; + } + } + assert(count != 0 && "No valid sample found for original expr"); + origCost = sum / count; + } else if (FPOptReductionEval == "maxabs") { + double maxErr = 0.0; + for (const auto &pair : enumerate(sampledPoints)) { + std::shared_ptr node = valueToNodeMap[CO.oldOutput]; + SmallVector results; + getMPFRValues({node.get()}, pair.value(), results, true, 53); + double goldVal = results[0]; + goldVals[pair.index()] = goldVal; + + getFPValues({node.get()}, pair.value(), results); + double realVal = results[0]; + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) + maxErr = std::max(maxErr, error); + } + origCost = maxErr; + } else { + llvm_unreachable("Unknown fpopt-reduction strategy"); + } + CO.initialAccCost = origCost * std::fabs(CO.grad); + if (std::isnan(CO.initialAccCost)) { + llvm::errs() << "Warning: NaN in initialAccCost computation:\n"; + llvm::errs() << " origCost = " << origCost << "\n"; + llvm::errs() << " CO.grad = " << CO.grad << "\n"; + llvm::errs() << " fabs(CO.grad) = " << std::fabs(CO.grad) << "\n"; + } + + SmallVector newCandidates; + for (auto &candidate : CO.candidates) { + bool discardCandidate = false; + double candCost = 0.0; + + std::shared_ptr parsedNode = + parseHerbieExpr(candidate.expr, valueToNodeMap, symbolToValueMap); + + if (FPOptReductionEval == "geomean") { + double sumLog = 0.0; + unsigned count = 0; + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues({parsedNode.get()}, pair.value(), results); + double realVal = results[0]; + double goldVal = goldVals[pair.index()]; + + if (FPOptStrictMode && (!std::isnan(goldVal)) && std::isnan(realVal)) { + discardCandidate = true; + break; + } + + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + if (error == 0.0) { + if (FPOptGeoMeanEps == 0.0) + error = getOneULP(goldVal); + else + error += FPOptGeoMeanEps; + } else if (FPOptGeoMeanEps != 0.0) { + error += FPOptGeoMeanEps; + } + sumLog += std::log(error); + ++count; + } + } + if (!discardCandidate) { + if (count == 0) { + discardCandidate = true; + } else { + candCost = std::exp(sumLog / count); + } + } + } else if (FPOptReductionEval == "arithmean") { + double sum = 0.0; + unsigned count = 0; + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues({parsedNode.get()}, pair.value(), results); + double realVal = results[0]; + double goldVal = goldVals[pair.index()]; + + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(realVal)) { + discardCandidate = true; + break; + } + + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + sum += error; + ++count; + } + } + if (!discardCandidate) { + if (count == 0) { + discardCandidate = true; + } else { + candCost = sum / count; + } + } + } else if (FPOptReductionEval == "maxabs") { + double maxErr = 0.0; + bool hasValid = false; + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues({parsedNode.get()}, pair.value(), results); + double realVal = results[0]; + double goldVal = goldVals[pair.index()]; + + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(realVal)) { + discardCandidate = true; + break; + } + + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + hasValid = true; + maxErr = std::max(maxErr, error); + } + } + if (!discardCandidate) { + if (!hasValid) { + discardCandidate = true; + } else { + candCost = maxErr; + } + } + } else { + llvm_unreachable("Unknown fpopt-reduction strategy"); + } + + if (!discardCandidate) { + candidate.accuracyCost = candCost * std::fabs(CO.grad); + assert(!std::isnan(candidate.accuracyCost)); + newCandidates.push_back(std::move(candidate)); + } + } + CO.candidates = std::move(newCandidates); +} + +InstructionCost getCompCost( + const std::string &expr, Module *M, const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + const FastMathFlags &FMF) { + // llvm::errs() << "Evaluating cost of " << expr << "\n"; + SmallSet argStrSet; + getUniqueArgs(expr, argStrSet); + + SetVector args; + SmallVector argTypes; + SmallVector argNames; + for (const auto &argStr : argStrSet) { + Value *argValue = symbolToValueMap[argStr]; + args.insert(argValue); + argTypes.push_back(argValue->getType()); + argNames.push_back(argStr); + } + + auto parsedNode = parseHerbieExpr(expr, valueToNodeMap, symbolToValueMap); + + Type *ReturnType = nullptr; + for (Type *ArgTy : argTypes) { + if (ArgTy->isFloatingPointTy()) { + ReturnType = ArgTy; + break; + } + if (ArgTy->isVectorTy()) { + if (auto *VT = dyn_cast(ArgTy)) { + if (VT->getElementType()->isFloatingPointTy()) { + ReturnType = ArgTy; + break; + } + } + } + } + if (!ReturnType) { + if (parsedNode->dtype == "f32") + ReturnType = Type::getFloatTy(M->getContext()); + else if (parsedNode->dtype == "f16") + ReturnType = Type::getHalfTy(M->getContext()); + else + ReturnType = Type::getDoubleTy(M->getContext()); + } + + FunctionType *FT = FunctionType::get(ReturnType, argTypes, false); + Function *tempFunction = + Function::Create(FT, Function::InternalLinkage, "tempFunc", M); + + ValueToValueMapTy VMap; + Function::arg_iterator AI = tempFunction->arg_begin(); + for (const auto &argStr : argNames) { + VMap[symbolToValueMap[argStr]] = &*AI; + ++AI; + } + + BasicBlock *entry = + BasicBlock::Create(M->getContext(), "entry", tempFunction); + + IRBuilder<> builder(entry); + + builder.setFastMathFlags(FMF); + Value *RetVal = parsedNode->getLLValue(builder, &VMap); + assert(RetVal && "Parsed node did not produce a value"); + assert((RetVal->getType() == ReturnType) && + "Return value type mismatch with temp function return type"); + builder.CreateRet(RetVal); + + // llvm::errs() << "Temp function before optimizations:\n"; + // tempFunction->print(llvm::errs()); + + runPoseidonFunctionSimplify(*tempFunction, OptimizationLevel::O3); + + // llvm::errs() << "Temp function after optimizations:\n"; + // tempFunction->print(llvm::errs()); + + InstructionCost cost = getCompCost(tempFunction, TTI); + + tempFunction->eraseFromParent(); + return cost; +} \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.h b/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.h new file mode 100644 index 000000000000..5a31244ecdfd --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.h @@ -0,0 +1,77 @@ +//=- PoseidonHerbieUtils.h - Herbie integration utilities -----------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares utilities for integrating with the Herbie tool for +// floating-point expression optimization. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_HERBIE_UTILS_H +#define ENZYME_POSEIDON_HERBIE_UTILS_H + +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CommandLine.h" + +#include +#include +#include + +#include "PoseidonTypes.h" + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt HerbieNumThreads; +extern llvm::cl::opt HerbieTimeout; +extern llvm::cl::opt HerbieNumPoints; +extern llvm::cl::opt HerbieNumIters; +extern llvm::cl::opt HerbieDisableNumerics; +extern llvm::cl::opt HerbieDisableArithmetic; +extern llvm::cl::opt HerbieDisableFractions; +extern llvm::cl::opt HerbieDisableTaylor; +extern llvm::cl::opt HerbieDisableSetupSimplify; +extern llvm::cl::opt HerbieDisableGenSimplify; +extern llvm::cl::opt HerbieDisableRegime; +extern llvm::cl::opt HerbieDisableBranchExpr; +extern llvm::cl::opt HerbieDisableAvgError; +} + +std::shared_ptr parseHerbieExpr( + const std::string &expr, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +bool improveViaHerbie( + const std::vector &inputExprs, + std::vector &COs, Module *M, + const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + int subgraphIdx); + +std::string getHerbieOperator(const Instruction &I); + +std::string getPrecondition( + const SmallSet &args, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap); + +void setUnifiedAccuracyCost( + CandidateOutput &CO, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +InstructionCost getCompCost( + const std::string &expr, Module *M, const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + const FastMathFlags &FMF); + +#endif // ENZYME_POSEIDON_HERBIE_UTILS_H diff --git a/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.cpp new file mode 100644 index 000000000000..4d3c96744c2d --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.cpp @@ -0,0 +1,711 @@ +//=- PoseidonPrecUtils.cpp - Precision change utilities for Poseidon ------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for handling precision changes in the Poseidon +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Transforms/Scalar/EarlyCSE.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Mem2Reg.h" + +#include +#include +#include + +#include "../Utils.h" +#include "Poseidon.h" +#include "PoseidonEvaluators.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +using namespace llvm; + +extern "C" { +cl::opt FPOptShowPTDetails( + "fpopt-show-pt-details", cl::init(false), cl::Hidden, + cl::desc("Print details of precision tuning candidates along with the DP " + "table (highly verbose for large applications)")); +cl::opt + FPOptMaxMPFRPrec("fpopt-max-mpfr-prec", cl::init(1024), cl::Hidden, + cl::desc("Max precision for MPFR gold value computation")); +} + +unsigned getMPFRPrec(PrecisionChangeType type) { + switch (type) { + case PrecisionChangeType::BF16: + return 8; + case PrecisionChangeType::FP16: + return 11; + case PrecisionChangeType::FP32: + return 24; + case PrecisionChangeType::FP64: + return 53; + case PrecisionChangeType::FP80: + return 64; + case PrecisionChangeType::FP128: + return 113; + default: + llvm_unreachable("Unsupported FP precision"); + } +} + +Type *getLLVMFPType(PrecisionChangeType type, LLVMContext &context) { + switch (type) { + case PrecisionChangeType::BF16: + return Type::getBFloatTy(context); + case PrecisionChangeType::FP16: + return Type::getHalfTy(context); + case PrecisionChangeType::FP32: + return Type::getFloatTy(context); + case PrecisionChangeType::FP64: + return Type::getDoubleTy(context); + case PrecisionChangeType::FP80: + return Type::getX86_FP80Ty(context); + case PrecisionChangeType::FP128: + return Type::getFP128Ty(context); + default: + llvm_unreachable("Unsupported FP precision"); + } +} + +PrecisionChangeType getPrecisionChangeType(Type *type) { + if (type->isHalfTy()) { + return PrecisionChangeType::BF16; + } else if (type->isHalfTy()) { + return PrecisionChangeType::FP16; + } else if (type->isFloatTy()) { + return PrecisionChangeType::FP32; + } else if (type->isDoubleTy()) { + return PrecisionChangeType::FP64; + } else if (type->isX86_FP80Ty()) { + return PrecisionChangeType::FP80; + } else if (type->isFP128Ty()) { + return PrecisionChangeType::FP128; + } else { + llvm_unreachable("Unsupported FP precision"); + } +} + +StringRef getPrecisionChangeTypeString(PrecisionChangeType type) { + switch (type) { + case PrecisionChangeType::BF16: + return "BF16"; + case PrecisionChangeType::FP16: + return "FP16"; + case PrecisionChangeType::FP32: + return "FP32"; + case PrecisionChangeType::FP64: + return "FP64"; + case PrecisionChangeType::FP80: + return "FP80"; + case PrecisionChangeType::FP128: + return "FP128"; + default: + return "Unknown PT type"; + } +} + +void changePrecision(Instruction *I, PrecisionChange &change, + MapVector &oldToNew) { + if (!Poseidonable(*I)) { + llvm_unreachable("Trying to tune an instruction is not Poseidonable"); + } + + IRBuilder<> Builder(I); + Builder.setFastMathFlags(I->getFastMathFlags()); + Type *newType = getLLVMFPType(change.newType, I->getContext()); + Value *newI = nullptr; + + if (isa(I) || isa(I)) { + SmallVector newOps; + for (auto &operand : I->operands()) { + Value *newOp = nullptr; + if (oldToNew.count(operand)) { + newOp = oldToNew[operand]; + } else if (operand->getType()->isIntegerTy()) { + newOp = operand; + } else { + IRBuilder<> OpBuilder(I); + OpBuilder.setFastMathFlags(I->getFastMathFlags()); + if (isa(operand)) { + newOp = + OpBuilder.CreateFPCast(operand, newType, "fpopt.const.fpcast"); + } else if (isa(operand) || isa(operand)) { + newOp = OpBuilder.CreateFPCast(operand, newType, "fpopt.fpcast"); + } else { + llvm_unreachable("Unsupported operand type"); + } + } + newOps.push_back(newOp); + } + newI = Builder.CreateNAryOp(I->getOpcode(), newOps); + } else if (auto *CI = dyn_cast(I)) { + SmallVector newArgs; + for (auto &arg : CI->args()) { + Value *newArg = nullptr; + if (oldToNew.count(arg)) { + newArg = oldToNew[arg]; + } else if (arg->getType()->isIntegerTy()) { + newArg = arg; + } else { + IRBuilder<> ArgBuilder(I); + ArgBuilder.setFastMathFlags(I->getFastMathFlags()); + if (isa(arg)) { + newArg = ArgBuilder.CreateFPCast(arg, newType, "fpopt.const.fpcast"); + } else if (isa(arg) || isa(arg)) { + newArg = ArgBuilder.CreateFPCast(arg, newType, "fpopt.fpcast"); + } else { + llvm_unreachable("Unsupported argument type"); + } + } + newArgs.push_back(newArg); + } + auto *calledFunc = CI->getCalledFunction(); + if (calledFunc && calledFunc->isIntrinsic()) { + Intrinsic::ID intrinsicID = calledFunc->getIntrinsicID(); + if (intrinsicID != Intrinsic::not_intrinsic) { + // Special cases for intrinsics with mixed types + if (intrinsicID == Intrinsic::powi) { + // powi + SmallVector overloadedTypes; + overloadedTypes.push_back(newType); + overloadedTypes.push_back(CI->getArgOperand(1)->getType()); + Function *newFunc = getIntrinsicDeclaration( + CI->getModule(), intrinsicID, overloadedTypes); + newI = Builder.CreateCall(newFunc, newArgs); + } else { + Function *newFunc = + getIntrinsicDeclaration(CI->getModule(), intrinsicID, newType); + newI = Builder.CreateCall(newFunc, newArgs); + } + } else { + llvm::errs() << "PT: Unknown intrinsic: " << *CI << "\n"; + llvm_unreachable("changePrecision: Unknown intrinsic call to change"); + } + } else { + StringRef funcName = calledFunc->getName(); + std::string newFuncName = getLibmFunctionForPrecision(funcName, newType); + + if (!newFuncName.empty()) { + Module *M = CI->getModule(); + SmallVector newArgTypes(newArgs.size(), newType); + + FunctionCallee newFuncCallee = M->getOrInsertFunction( + newFuncName, FunctionType::get(newType, newArgTypes, false)); + + if (Function *newFunc = dyn_cast(newFuncCallee.getCallee())) { + newI = Builder.CreateCall(newFunc, newArgs); + } else { + llvm::errs() << "PT: Failed to get " + << getPrecisionChangeTypeString(change.newType) + << " libm function for: " << *CI << "\n"; + llvm_unreachable("changePrecision: Failed to get libm function"); + } + } else { + llvm::errs() << "PT: Unknown function call: " << *CI << "\n"; + llvm_unreachable("changePrecision: Unknown function call to change"); + } + } + + } else { + llvm::errs() << "Unexpectedly Poseidonable instruction: " << *I << "\n"; + llvm_unreachable("Unexpectedly Poseidonable instruction"); + } + + oldToNew[I] = newI; +} + +// If `VMap` is passed, map `llvm::Value`s in `subgraph` to their cloned +// values and change outputs in VMap to new casted outputs. +void PTCandidate::apply(Subgraph &subgraph, ValueToValueMapTy *VMap) { + SetVector operations; + ValueToValueMapTy clonedToOriginal; // Maps cloned outputs to old outputs + if (VMap) { + for (auto *I : subgraph.operations) { + assert(VMap->count(I)); + if (Value *Mapped = VMap->lookup(I)) { + if (auto *MappedI = dyn_cast(Mapped)) { + operations.insert(MappedI); + clonedToOriginal[MappedI] = I; + } + } + // llvm::errs() << "Mapping back: " << *VMap->lookup(I) << " (in " + // << cast(VMap->lookup(I)) + // ->getParent() + // ->getParent() + // ->getName() + // << ") --> " << *I << " (in " + // << I->getParent()->getParent()->getName() << ")\n"; + } + } else { + operations = subgraph.operations; + } + + for (auto &change : changes) { + SmallPtrSet seen; + SmallVector todo; + MapVector oldToNew; + + SetVector instsToChange; + for (auto node : change.nodes) { + if (!node || !node->value) { + continue; + } + assert(isa(node->value)); + auto *I = cast(node->value); + if (VMap) { + assert(VMap->count(I)); + if (Value *Mapped = VMap->lookup(I)) { + if (auto *MappedI = dyn_cast(Mapped)) { + I = MappedI; + } else { + continue; + } + } else { + continue; + } + } + if (!operations.contains(I)) { + // Already erased by `CO.apply()`. + continue; + } + instsToChange.insert(I); + } + + SmallVector instsToChangeSorted; + topoSort(instsToChange, instsToChangeSorted); + + for (auto *I : instsToChangeSorted) { + changePrecision(I, change, oldToNew); + } + + // Restore the precisions of the last level of instructions to be changed. + // Clean up old instructions. + for (auto &[oldV, newV] : oldToNew) { + if (!isa(oldV)) { + continue; + } + + if (!instsToChange.contains(cast(oldV))) { + continue; + } + + SmallPtrSet users; + for (auto *user : oldV->users()) { + assert(isa(user) && + "PT: Unexpected non-instruction user of a changed instruction"); + if (!instsToChange.contains(cast(user))) { + users.insert(cast(user)); + } + } + + Value *casted = nullptr; + if (!users.empty()) { + IRBuilder<> builder(cast(oldV)->getParent(), + ++BasicBlock::iterator(cast(oldV))); + casted = builder.CreateFPCast( + newV, getLLVMFPType(change.oldType, builder.getContext())); + + if (VMap) { + assert(VMap->count(clonedToOriginal[oldV])); + (*VMap)[clonedToOriginal[oldV]] = casted; + } + } + + for (auto *user : users) { + user->replaceUsesOfWith(oldV, casted); + } + + // Assumes no external uses of the old value since all corresponding new + // values are already restored to original precision and used to replace + // uses of their old value. This is also advantageous to the solvers. + for (auto *user : oldV->users()) { + assert(instsToChange.contains(cast(user)) && + "PT: Unexpected external user of a changed instruction"); + } + + if (!oldV->use_empty()) { + oldV->replaceAllUsesWith(UndefValue::get(oldV->getType())); + } + + cast(oldV)->eraseFromParent(); + + // The change is being materialized to the original subgraph + if (!VMap) + subgraph.operations.remove(cast(oldV)); + } + } +} + +void setUnifiedAccuracyCost( + CandidateSubgraph &CS, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + + SmallVector, 4> sampledPoints; + getSampledPoints(CS.subgraph->inputs.getArrayRef(), valueToNodeMap, + symbolToValueMap, sampledPoints); + + MapVector> goldVals; + for (auto *output : CS.subgraph->outputs) { + auto *node = valueToNodeMap[output].get(); + goldVals[node].resize(FPOptNumSamples); + CS.perOutputInitialAccCost[node] = 0.; + } + + SmallVector outputs; + for (auto *output : CS.subgraph->outputs) + outputs.push_back(valueToNodeMap[output].get()); + + if (FPOptReductionEval == "geomean") { + struct RunningAcc { + double sumLog = 0.0; + unsigned count = 0; + }; + std::unordered_map runAcc; + for (auto *node : outputs) + runAcc[node] = RunningAcc(); + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getMPFRValues(outputs, pair.value(), results, true, 53); + for (const auto &[node, result] : zip(outputs, results)) + goldVals[node][pair.index()] = result; + + getFPValues(outputs, pair.value(), results); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) { + if (error == 0.0) { + if (FPOptGeoMeanEps == 0.0) + error = getOneULP(goldVal); + else + error += FPOptGeoMeanEps; + } else if (FPOptGeoMeanEps != 0.0) { + error += FPOptGeoMeanEps; + } + runAcc[node].sumLog += std::log(error); + ++runAcc[node].count; + } + } + } + CS.initialAccCost = 0.0; + for (auto *node : outputs) { + RunningAcc &ra = runAcc[node]; + assert(ra.count != 0 && "No valid sample found for original subgraph"); + double red = std::exp(ra.sumLog / ra.count); + CS.perOutputInitialAccCost[node] = red * std::fabs(node->grad); + CS.initialAccCost += CS.perOutputInitialAccCost[node]; + } + } else if (FPOptReductionEval == "arithmean") { + struct RunningAccArith { + double sum = 0.0; + unsigned count = 0; + }; + std::unordered_map runAcc; + for (auto *node : outputs) + runAcc[node] = RunningAccArith(); + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getMPFRValues(outputs, pair.value(), results, true, 53); + for (const auto &[node, result] : zip(outputs, results)) + goldVals[node][pair.index()] = result; + + getFPValues(outputs, pair.value(), results); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) { + runAcc[node].sum += error; + ++runAcc[node].count; + } + } + } + CS.initialAccCost = 0.0; + for (auto *node : outputs) { + auto &ra = runAcc[node]; + assert(ra.count != 0 && "No valid sample found for original subgraph"); + double red = ra.sum / ra.count; + CS.perOutputInitialAccCost[node] = red * std::fabs(node->grad); + CS.initialAccCost += CS.perOutputInitialAccCost[node]; + } + } else if (FPOptReductionEval == "maxabs") { + std::unordered_map runAcc; + for (auto *node : outputs) + runAcc[node] = 0.0; + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getMPFRValues(outputs, pair.value(), results, true, 53); + for (const auto &[node, result] : zip(outputs, results)) + goldVals[node][pair.index()] = result; + + getFPValues(outputs, pair.value(), results); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) + runAcc[node] = std::max(runAcc[node], error); + } + } + CS.initialAccCost = 0.0; + for (auto *node : outputs) { + double red = runAcc[node]; + CS.perOutputInitialAccCost[node] = red * std::fabs(node->grad); + CS.initialAccCost += CS.perOutputInitialAccCost[node]; + } + } else { + llvm_unreachable("Unknown fpopt-reduction strategy"); + } + assert(!std::isnan(CS.initialAccCost)); + + SmallVector newCandidates; + for (auto &candidate : CS.candidates) { + bool discardCandidate = false; + if (FPOptReductionEval == "geomean") { + struct RunningAcc { + double sumLog = 0.0; + unsigned count = 0; + }; + std::unordered_map candAcc; + for (auto *node : outputs) + candAcc[node] = RunningAcc(); + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues(outputs, pair.value(), results, &candidate); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(result)) { + discardCandidate = true; + break; + } + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) { + if (error == 0.0) { + if (FPOptGeoMeanEps == 0.0) + error = getOneULP(goldVal); + else + error += FPOptGeoMeanEps; + } else if (FPOptGeoMeanEps != 0.0) { + error += FPOptGeoMeanEps; + } + candAcc[node].sumLog += std::log(error); + ++candAcc[node].count; + } + } + if (discardCandidate) + break; + } + if (!discardCandidate) { + candidate.accuracyCost = 0.0; + for (auto *node : outputs) { + RunningAcc &ra = candAcc[node]; + assert(ra.count != 0 && + "No valid sample found for candidate subgraph"); + double red = std::exp(ra.sumLog / ra.count); + candidate.perOutputAccCost[node] = red * std::fabs(node->grad); + candidate.accuracyCost += candidate.perOutputAccCost[node]; + } + assert(!std::isnan(candidate.accuracyCost)); + newCandidates.push_back(std::move(candidate)); + } + } else if (FPOptReductionEval == "arithmean") { + struct RunningAccArith { + double sum = 0.0; + unsigned count = 0; + }; + std::unordered_map candAcc; + for (auto *node : outputs) + candAcc[node] = RunningAccArith(); + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues(outputs, pair.value(), results, &candidate); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(result)) { + discardCandidate = true; + break; + } + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) { + candAcc[node].sum += error; + ++candAcc[node].count; + } + } + if (discardCandidate) + break; + } + if (!discardCandidate) { + candidate.accuracyCost = 0.0; + for (auto *node : outputs) { + auto &ra = candAcc[node]; + assert(ra.count != 0 && + "No valid sample found for candidate subgraph"); + double red = ra.sum / ra.count; + candidate.perOutputAccCost[node] = red * std::fabs(node->grad); + candidate.accuracyCost += candidate.perOutputAccCost[node]; + } + assert(!std::isnan(candidate.accuracyCost)); + newCandidates.push_back(std::move(candidate)); + } + } else if (FPOptReductionEval == "maxabs") { + std::unordered_map candAcc; + for (auto *node : outputs) + candAcc[node] = 0.0; + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues(outputs, pair.value(), results, &candidate); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(result)) { + discardCandidate = true; + break; + } + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) + candAcc[node] = std::max(candAcc[node], error); + } + if (discardCandidate) + break; + } + if (!discardCandidate) { + candidate.accuracyCost = 0.0; + for (auto *node : outputs) { + double red = candAcc[node]; + candidate.perOutputAccCost[node] = red * std::fabs(node->grad); + candidate.accuracyCost += candidate.perOutputAccCost[node]; + } + assert(!std::isnan(candidate.accuracyCost)); + newCandidates.push_back(std::move(candidate)); + } + } else { + llvm_unreachable("Unknown fpopt-reduction strategy"); + } + } + CS.candidates = std::move(newCandidates); +} + +InstructionCost getCompCost(Subgraph &subgraph, const TargetTransformInfo &TTI, + PTCandidate &pt) { + assert(!subgraph.outputs.empty()); + + InstructionCost cost = 0; + + Function *F = cast(subgraph.outputs[0])->getFunction(); + + ValueToValueMapTy VMap; + Function *FClone = CloneFunction(F, VMap); + FClone->setName(F->getName() + "_clone"); + + pt.apply(subgraph, &VMap); + + // llvm::errs() << "\n========================================\n"; + // llvm::errs() << "DEBUG PT Cost Measurement: " << pt.desc << "\n"; + // llvm::errs() << "========================================\n"; + // llvm::errs() << "Before optimization:\n"; + // FClone->print(llvm::errs()); + // llvm::errs() << "========================================\n"; + + SmallVector trackedInputs; + for (auto &input : subgraph.inputs) { + if (VMap.count(input)) + trackedInputs.emplace_back(VMap[input]); + } + SmallVector trackedOutputs; + for (auto &output : subgraph.outputs) { + if (VMap.count(output)) + trackedOutputs.emplace_back(VMap[output]); + } + + runPoseidonFunctionSimplify(*FClone, OptimizationLevel::O3); + + // llvm::errs() << "Cloned function AFTER precision changes and + // optimization:\n"; FClone->print(llvm::errs()); llvm::errs() << + // "========================================\n"; + + SmallPtrSet clonedInputs; + for (auto &VH : trackedInputs) { + if (!VH.pointsToAliveValue()) + continue; + Value *V = VH; + if (V) + clonedInputs.insert(V); + } + + SmallPtrSet clonedOutputs; + for (auto &VH : trackedOutputs) { + if (!VH.pointsToAliveValue()) + continue; + Value *V = VH; + if (V) + clonedOutputs.insert(V); + } + + SmallPtrSet seen; + SmallVector todo; + + todo.insert(todo.end(), clonedOutputs.begin(), clonedOutputs.end()); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + if (!seen.insert(cur).second) + continue; + + if (clonedInputs.contains(cur)) + continue; + + if (auto *I = dyn_cast(cur)) { + auto instCost = getInstructionCompCost(I, TTI); + + // if (I->getType()->isFPOrFPVectorTy() || + // (I->getOpcode() == Instruction::FCmp)) { + // Type *Ty = I->getType(); + // if (I->getOpcode() == Instruction::FCmp) + // Ty = I->getOperand(0)->getType(); + + // std::string precStr = "unknown"; + // if (Ty->isFloatTy()) precStr = "FP32"; + // else if (Ty->isDoubleTy()) precStr = "FP64"; + + // llvm::errs() << " [" << precStr << "] Cost=" << instCost + // << " : " << *I << "\n"; + // } + + cost += instCost; + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &operand : operands) { + todo.push_back(operand); + } + } + } + // llvm::errs() << "Total PT cost for " << pt.desc << ": " << cost << "\n"; + // llvm::errs() << "========================================\n\n"; + + FClone->eraseFromParent(); + + return cost; +} \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.h b/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.h new file mode 100644 index 000000000000..909e0b2512c4 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.h @@ -0,0 +1,85 @@ +//=- PoseidonPrecUtils.h - Precision change utilities for Poseidon --------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares utilities for handling precision changes in the Poseidon +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_PREC_UTILS_H +#define ENZYME_POSEIDON_PREC_UTILS_H + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#include +#include +#include + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptShowPTDetails; +extern llvm::cl::opt FPOptMaxMPFRPrec; +} + +class FPNode; +class FPLLValue; +struct Subgraph; +class CandidateSubgraph; + +enum class PrecisionChangeType { BF16, FP16, FP32, FP64, FP80, FP128 }; +unsigned getMPFRPrec(PrecisionChangeType type); +Type *getLLVMFPType(PrecisionChangeType type, LLVMContext &context); +PrecisionChangeType getPrecisionChangeType(Type *type); +StringRef getPrecisionChangeTypeString(PrecisionChangeType type); + +struct PrecisionChange { + SetVector nodes; + PrecisionChangeType oldType; + PrecisionChangeType newType; + + explicit PrecisionChange(SetVector &nodes, + PrecisionChangeType oldType, + PrecisionChangeType newType) + : nodes(nodes), oldType(oldType), newType(newType) {} +}; + +struct PTCandidate { + SmallVector changes; + double accuracyCost = std::numeric_limits::quiet_NaN(); + InstructionCost CompCost = std::numeric_limits::max(); + std::string desc; + std::unordered_map perOutputAccCost; + std::unordered_map> errors; + + explicit PTCandidate(SmallVector changes, + const std::string &desc) + : changes(std::move(changes)), desc(desc) {} + + void apply(Subgraph &subgraph, ValueToValueMapTy *VMap = nullptr); +}; + +void changePrecision(Instruction *I, PrecisionChange &change, + MapVector &oldToNew); + +InstructionCost getCompCost(Subgraph &subgraph, const TargetTransformInfo &TTI, + PTCandidate &pt); + +void setUnifiedAccuracyCost( + CandidateSubgraph &CS, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +#endif // ENZYME_POSEIDON_PREC_UTILS_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp new file mode 100644 index 000000000000..b15de0c2f9f2 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp @@ -0,0 +1,158 @@ +//=- PoseidonProfUtils.cpp - Profiling utilities for Poseidon +//------------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements profiling-related utilities for the Poseidon +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "PoseidonProfUtils.h" +#include "PoseidonUtils.h" + +using namespace llvm; + +extern "C" { +cl::opt FPOptLooseCoverage( + "fpopt-loose-coverage", cl::init(false), cl::Hidden, + cl::desc("Allow unexecuted FP instructions in subgraph indentification")); +cl::opt + FPOptWidenRange("fpopt-widen-range", cl::init(1), cl::Hidden, + cl::desc("Ablation study only: widen the range of input " + "hypercube by this factor")); +} + +void parseProfileFile(const std::string &profilePath, + std::unordered_map &profileMap) { + profileMap.clear(); + std::ifstream file(profilePath); + if (!file.is_open()) { + llvm::errs() << "Warning: Could not open profile file: " << profilePath + << "\n"; + return; + } + + std::string line; + std::regex indexPattern(R"(^(\d+)$)"); + + std::regex minResPattern(R"(^\s*MinRes\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex maxResPattern(R"(^\s*MaxRes\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumValuePattern(R"(^\s*SumValue\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumSensPattern(R"(^\s*SumSens\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumGradPattern(R"(^\s*SumGrad\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex execPattern(R"(^\s*Exec\s*=\s*(\d+))"); + std::regex numOperandsPattern(R"(^\s*NumOperands\s*=\s*(\d+))"); + std::regex operandPattern( + R"(^\s*Operand\[(\d+)\]\s*=\s*\[([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan),\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan)\])"); + + while (std::getline(file, line)) { + if (!line.empty() && line.back() == '\r') { + line.pop_back(); + } + + std::smatch match; + if (std::regex_match(line, match, indexPattern)) { + size_t idx = std::stoull(match[1]); + ProfileInfo info; + + std::string minResLine, maxResLine, sumValueLine, sumSensLine, + sumGradLine, execLine, numOperandsLine; + + if (std::getline(file, minResLine) && std::getline(file, maxResLine) && + std::getline(file, sumValueLine) && std::getline(file, sumSensLine) && + std::getline(file, sumGradLine) && std::getline(file, execLine) && + std::getline(file, numOperandsLine)) { + + auto stripCR = [](std::string &s) { + if (!s.empty() && s.back() == '\r') + s.pop_back(); + }; + stripCR(minResLine); + stripCR(maxResLine); + stripCR(sumValueLine); + stripCR(sumSensLine); + stripCR(sumGradLine); + stripCR(execLine); + stripCR(numOperandsLine); + + std::smatch mMinRes, mMaxRes, mSumValue, mSumSens, mSumGrad, mExec, + mNumOperands; + if (std::regex_search(minResLine, mMinRes, minResPattern) && + std::regex_search(maxResLine, mMaxRes, maxResPattern) && + std::regex_search(sumValueLine, mSumValue, sumValuePattern) && + std::regex_search(sumSensLine, mSumSens, sumSensPattern) && + std::regex_search(sumGradLine, mSumGrad, sumGradPattern) && + std::regex_search(execLine, mExec, execPattern) && + std::regex_search(numOperandsLine, mNumOperands, + numOperandsPattern)) { + + info.minRes = stringToDouble(mMinRes[1]); + info.maxRes = stringToDouble(mMaxRes[1]); + info.sumValue = stringToDouble(mSumValue[1]); + info.sumSens = stringToDouble(mSumSens[1]); + info.sumGrad = stringToDouble(mSumGrad[1]); + info.exec = static_cast(std::stoul(mExec[1])); + unsigned numOperands = + static_cast(std::stoul(mNumOperands[1])); + + info.minOperands.resize(numOperands, 0.0); + info.maxOperands.resize(numOperands, 0.0); + + for (unsigned i = 0; i < numOperands; ++i) { + if (std::getline(file, line)) { + if (!line.empty() && line.back() == '\r') + line.pop_back(); + + std::smatch operandMatch; + if (std::regex_search(line, operandMatch, operandPattern)) { + unsigned opIdx = + static_cast(std::stoul(operandMatch[1])); + double minVal = stringToDouble(operandMatch[2]); + double maxVal = stringToDouble(operandMatch[3]); + + if (opIdx < numOperands) { + info.minOperands[opIdx] = minVal; + info.maxOperands[opIdx] = maxVal; + } + } + } + } + + if (FPOptWidenRange != 1.0) { + double center = (info.minRes + info.maxRes) / 2.0; + double half_range = (info.maxRes - info.minRes) / 2.0; + double new_half_range = half_range * FPOptWidenRange; + info.minRes = center - new_half_range; + info.maxRes = center + new_half_range; + + for (size_t i = 0; i < info.minOperands.size(); ++i) { + double op_center = + (info.minOperands[i] + info.maxOperands[i]) / 2.0; + double op_half_range = + (info.maxOperands[i] - info.minOperands[i]) / 2.0; + double op_new_half_range = op_half_range * FPOptWidenRange; + info.minOperands[i] = op_center - op_new_half_range; + info.maxOperands[i] = op_center + op_new_half_range; + } + } + + profileMap[idx] = info; + } else { + llvm::errs() << "Warning: Failed to parse profile fields for index " + << idx << "\n"; + } + } else { + llvm::errs() << "Warning: Incomplete profile entry for index " << idx + << "\n"; + } + } + } +} diff --git a/enzyme/Enzyme/Poseidon/PoseidonProfUtils.h b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.h new file mode 100644 index 000000000000..6c0f45c12463 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.h @@ -0,0 +1,51 @@ +//=- PoseidonProfUtils.h - Profiling utilities for Poseidon +//--------------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares profiling-related utilities for the Poseidon optimization +// pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_PROF_UTILS_H +#define ENZYME_POSEIDON_PROF_UTILS_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CommandLine.h" + +#include +#include + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptLooseCoverage; +extern llvm::cl::opt FPOptWidenRange; +} + +struct ProfileInfo { + double minRes; + double maxRes; + double sumValue; // Sum of values (not abs) + double sumSens; // Sum of sensitivity scores = |grad * value| + double sumGrad; // Sum of gradients (not abs) + unsigned exec; + + SmallVector minOperands; + SmallVector maxOperands; + + ProfileInfo() + : minRes(std::numeric_limits::max()), + maxRes(std::numeric_limits::lowest()), sumValue(0.0), + sumSens(0.0), sumGrad(0.0), exec(0) {} +}; + +void parseProfileFile(const std::string &profilePath, + std::unordered_map &profileMap); + +#endif // ENZYME_POSEIDON_PROF_UTILS_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp b/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp new file mode 100644 index 000000000000..fedf24857a0e --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp @@ -0,0 +1,889 @@ +//=- PoseidonSolvers.cpp - Solver utilities for Poseidon ------------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements solver-related utilities for the Poseidon optimization +// pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_ostream.h" + +#include "Poseidon.h" +#include "PoseidonSolvers.h" +#include "PoseidonTypes.h" + +#include "PoseidonHerbieUtils.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonUtils.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/Transforms/Utils/Cloning.h" + +#include + +using namespace llvm; + +extern "C" { +cl::opt FPOptSolverType("fpopt-solver-type", cl::init("dp"), + cl::Hidden, + cl::desc("Which solver to use; " + "either 'dp' or 'greedy'")); +cl::opt FPOptComputationCostBudget( + "fpopt-comp-cost-budget", cl::init(0L), cl::Hidden, + cl::desc("The maximum computation cost budget for the solver")); +cl::opt FPOptShowTable( + "fpopt-show-table", cl::init(false), cl::Hidden, + cl::desc( + "Print the full DP table (highly verbose for large applications)")); +cl::list FPOptShowTableCosts( + "fpopt-show-table-costs", cl::ZeroOrMore, cl::CommaSeparated, cl::Hidden, + cl::desc( + "Comma-separated list of computation costs for which to print DP table " + "entries. If provided, only specified computation costs are " + "printed. ")); +cl::opt FPOptEarlyPrune( + "fpopt-early-prune", cl::init(true), cl::Hidden, + cl::desc("Prune dominated candidates in expression transformation phases")); +cl::opt FPOptCostDominanceThreshold( + "fpopt-cost-dom-thres", cl::init(0.0), cl::Hidden, + cl::desc("The threshold for cost dominance in DP solver")); +cl::opt FPOptAccuracyDominanceThreshold( + "fpopt-acc-dom-thres", cl::init(0.0), cl::Hidden, + cl::desc("The threshold for accuracy dominance in DP solver")); +cl::opt FPOptRefineDPTable( + "fpopt-refine-dp", cl::init(false), cl::Hidden, + cl::desc("After initial DP build, materialize each solution in a cloned\n" + "function, run O3, and recompute cost deltas. Only applies when\n" + "generating a new cache, not when loading from cache.")); +} + +#if LLVM_VERSION_MAJOR >= 21 +#define GET_INSTRUCTION_COST(cost) (cost.getValue()) +#else +#define GET_INSTRUCTION_COST(cost) (cost.getValue().value()) +#endif + +// Given the cost budget `FPOptComputationCostBudget`, we want to minimize the +// accuracy cost of the rewritten expressions. +bool accuracyGreedySolver( + SmallVector &COs, + SmallVector &CSs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + bool changed = false; + llvm::errs() << "Starting accuracy greedy solver with computation budget: " + << FPOptComputationCostBudget << "\n"; + InstructionCost totalComputationCost = 0; + + SmallVector aoIndices; + for (size_t i = 0; i < COs.size(); ++i) { + aoIndices.push_back(i); + } + std::mt19937 g(FPOptRandomSeed); + std::shuffle(aoIndices.begin(), aoIndices.end(), g); + + for (size_t idx : aoIndices) { + auto &CO = COs[idx]; + int bestCandidateIndex = -1; + double bestAccuracyCost = std::numeric_limits::infinity(); + InstructionCost bestCandidateComputationCost; + + for (const auto &candidate : enumerate(CO.candidates)) { + size_t i = candidate.index(); + auto candCompCost = CO.getCompCostDelta(i); + auto candAccCost = CO.getAccCostDelta(i); + // llvm::errs() << "CO Candidate " << i << " for " << CO.expr + // << " has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << "\n"; + + if (totalComputationCost + candCompCost <= FPOptComputationCostBudget) { + if (candAccCost < bestAccuracyCost) { + // llvm::errs() << "CO Candidate " << i << " selected!\n"; + bestCandidateIndex = i; + bestAccuracyCost = candAccCost; + bestCandidateComputationCost = candCompCost; + } + } + } + + if (bestCandidateIndex != -1) { + CO.apply(bestCandidateIndex, valueToNodeMap, symbolToValueMap); + changed = true; + totalComputationCost += bestCandidateComputationCost; + if (FPOptPrint) { + llvm::errs() << "Greedy solver selected candidate " + << bestCandidateIndex << " for " << CO.expr + << " with accuracy cost: " << bestAccuracyCost + << " and computation cost: " + << bestCandidateComputationCost << "\n"; + } + } + } + + SmallVector accIndices; + for (size_t i = 0; i < CSs.size(); ++i) { + accIndices.push_back(i); + } + std::shuffle(accIndices.begin(), accIndices.end(), g); + + for (size_t idx : accIndices) { + auto &CS = CSs[idx]; + int bestCandidateIndex = -1; + double bestAccuracyCost = std::numeric_limits::infinity(); + InstructionCost bestCandidateComputationCost; + + for (const auto &candidate : enumerate(CS.candidates)) { + size_t i = candidate.index(); + auto candCompCost = CS.getCompCostDelta(i); + auto candAccCost = CS.getAccCostDelta(i); + // llvm::errs() << "CS Candidate " << i << " (" << candidate.value().desc + // << ") has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << "\n"; + + if (totalComputationCost + candCompCost <= FPOptComputationCostBudget) { + if (candAccCost < bestAccuracyCost) { + // llvm::errs() << "CS Candidate " << i << " selected!\n"; + bestCandidateIndex = i; + bestAccuracyCost = candAccCost; + bestCandidateComputationCost = candCompCost; + } + } + } + + if (bestCandidateIndex != -1) { + CS.apply(bestCandidateIndex); + changed = true; + totalComputationCost += bestCandidateComputationCost; + if (FPOptPrint) { + llvm::errs() << "Greedy solver selected candidate " + << bestCandidateIndex << " for " + << CS.candidates[bestCandidateIndex].desc + << " with accuracy cost: " << bestAccuracyCost + << " and computation cost: " + << bestCandidateComputationCost << "\n"; + } + } + } + + llvm::errs() << "Greedy solver finished with total computation cost: " + << totalComputationCost + << "; total allowance: " << FPOptComputationCostBudget << "\n"; + + return changed; +} + +bool accuracyDPSolver( + Function &F, const TargetTransformInfo &TTI, + SmallVector &COs, + SmallVector &CSs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + double errorTol) { + bool changed = false; + llvm::errs() << "Starting accuracy DP solver with computation budget: " + << FPOptComputationCostBudget << "\n"; + if (errorTol > 0.0) { + llvm::errs() << "Absolute error tolerance: " << errorTol << "\n"; + } + + using CostMap = std::map; + using SolutionMap = std::map>; + + CostMap costToAccuracyMap; + SolutionMap costToSolutionMap; + CostMap newCostToAccuracyMap; + SolutionMap newCostToSolutionMap; + CostMap prunedCostToAccuracyMap; + SolutionMap prunedCostToSolutionMap; + + std::string cacheFilePath = FPOptCachePath + "/table.json"; + + if (llvm::sys::fs::exists(cacheFilePath)) { + llvm::errs() << "Cache file found. Loading DP tables from cache.\n"; + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFile(cacheFilePath); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Error reading cache file: " << ec.message() << "\n"; + return changed; + } + llvm::StringRef buffer = fileOrErr.get()->getBuffer(); + llvm::Expected jsonOrErr = llvm::json::parse(buffer); + if (!jsonOrErr) { + llvm::errs() << "Error parsing JSON from cache file: " + << llvm::toString(jsonOrErr.takeError()) << "\n"; + return changed; + } + + llvm::json::Object *jsonObj = jsonOrErr->getAsObject(); + if (!jsonObj) { + llvm::errs() << "Invalid JSON format in cache file.\n"; + return changed; + } + + if (llvm::json::Object *costAccMap = + jsonObj->getObject("costToAccuracyMap")) { + for (auto &pair : *costAccMap) { + InstructionCost compCost(std::stoll(pair.first.str())); + double accCost = pair.second.getAsNumber().value(); + costToAccuracyMap[compCost] = accCost; + } + } else { + llvm_unreachable("Invalid costToAccuracyMap in cache file."); + } + + if (llvm::json::Object *costSolMap = + jsonObj->getObject("costToSolutionMap")) { + for (auto &pair : *costSolMap) { + InstructionCost compCost(std::stoll(pair.first.str())); + SmallVector solutionSteps; + + llvm::json::Array *stepsArray = pair.second.getAsArray(); + if (!stepsArray) { + llvm::errs() << "Invalid steps array in cache file.\n"; + return changed; + } + + for (llvm::json::Value &stepVal : *stepsArray) { + llvm::json::Object *stepObj = stepVal.getAsObject(); + if (!stepObj) { + llvm_unreachable("Invalid step object in cache file."); + } + + StringRef itemType = stepObj->getString("itemType").value(); + size_t candidateIndex = stepObj->getInteger("candidateIndex").value(); + size_t itemIndex = stepObj->getInteger("itemIndex").value(); + + if (itemType == "CO") { + if (itemIndex >= COs.size()) { + llvm_unreachable("Invalid CandidateOutput index in cache file."); + } + solutionSteps.emplace_back(&COs[itemIndex], candidateIndex); + } else if (itemType == "CS") { + if (itemIndex >= CSs.size()) { + llvm_unreachable( + "Invalid CandidateSubgraph index in cache file."); + } + solutionSteps.emplace_back(&CSs[itemIndex], candidateIndex); + } else { + llvm_unreachable("Invalid itemType in cache file."); + } + } + + costToSolutionMap[compCost] = solutionSteps; + } + } else { + llvm::errs() << "costToSolutionMap not found in cache file.\n"; + return changed; + } + + llvm::errs() << "Loaded DP tables from cache.\n"; + + } else { + llvm::errs() << "Cache file not found. Proceeding to solve DP.\n"; + + costToAccuracyMap[0] = 0; + costToSolutionMap[0] = {}; + + std::unordered_map aoPtrToIndex; + for (size_t i = 0; i < COs.size(); ++i) { + aoPtrToIndex[&COs[i]] = i; + } + std::unordered_map accPtrToIndex; + for (size_t i = 0; i < CSs.size(); ++i) { + accPtrToIndex[&CSs[i]] = i; + } + + int COCounter = 0; + + for (auto &CO : COs) { + // It is possible to apply zero candidate for an CO. + // When no candidate is applied, the resulting accuracy cost + // and solution steps remain the same. + newCostToAccuracyMap = costToAccuracyMap; + newCostToSolutionMap = costToSolutionMap; + + for (const auto &pair : costToAccuracyMap) { + InstructionCost currCompCost = pair.first; + double currAccCost = pair.second; + + for (const auto &candidate : enumerate(CO.candidates)) { + size_t i = candidate.index(); + auto candCompCost = CO.getCompCostDelta(i); + auto candAccCost = CO.getAccCostDelta(i); + + // Don't apply a candidate that strictly makes things worse + if (candCompCost >= 0 && candAccCost >= 0.) { + continue; + } + + InstructionCost newCompCost = currCompCost + candCompCost; + double newAccCost = currAccCost + candAccCost; + + // if (FPOptPrint) + // llvm::errs() << "CO candidate " << i + // << " has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + + if (newCostToAccuracyMap.find(newCompCost) == + newCostToAccuracyMap.end() || + newCostToAccuracyMap[newCompCost] > newAccCost) { + newCostToAccuracyMap[newCompCost] = newAccCost; + newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; + newCostToSolutionMap[newCompCost].emplace_back(&CO, i); + // if (FPOptPrint) + // llvm::errs() << "Updating accuracy map (CO candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; + } + } + } + + // TODO: Do not prune CO parts of the DP table since COs influence CSs + if (!FPOptEarlyPrune) { + costToAccuracyMap = newCostToAccuracyMap; + costToSolutionMap = newCostToSolutionMap; + + llvm::errs() << "##### Finished processing " << ++COCounter << " of " + << COs.size() << " COs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + continue; + } + + for (const auto &l : newCostToAccuracyMap) { + InstructionCost currCompCost = l.first; + double currAccCost = l.second; + + bool dominated = false; + for (const auto &r : newCostToAccuracyMap) { + InstructionCost otherCompCost = r.first; + double otherAccCost = r.second; + + if (currCompCost - otherCompCost > + std::fabs(FPOptCostDominanceThreshold * + GET_INSTRUCTION_COST(otherCompCost)) && + currAccCost - otherAccCost >= + std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { + // if (FPOptPrint) + // llvm::errs() << "CO candidate with computation cost: " + // << currCompCost + // << " and accuracy cost: " << currAccCost + // << " is dominated by candidate with computation + // cost:" + // << otherCompCost + // << " and accuracy cost: " << otherAccCost << "\n"; + dominated = true; + break; + } + } + + if (!dominated) { + prunedCostToAccuracyMap[currCompCost] = currAccCost; + prunedCostToSolutionMap[currCompCost] = + newCostToSolutionMap[currCompCost]; + } + } + + costToAccuracyMap = prunedCostToAccuracyMap; + costToSolutionMap = prunedCostToSolutionMap; + prunedCostToAccuracyMap.clear(); + prunedCostToSolutionMap.clear(); + + llvm::errs() << "##### Finished processing " << ++COCounter << " of " + << COs.size() << " COs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + } + + int CSCounter = 0; + + for (auto &CS : CSs) { + // It is possible to apply zero candidate for an CS. + // When no candidate is applied, the resulting accuracy cost + // and solution steps remain the same. + newCostToAccuracyMap = costToAccuracyMap; + newCostToSolutionMap = costToSolutionMap; + + for (const auto &pair : costToAccuracyMap) { + InstructionCost currCompCost = pair.first; + double currAccCost = pair.second; + + for (const auto &candidate : enumerate(CS.candidates)) { + size_t i = candidate.index(); + auto candCompCost = + CS.getAdjustedCompCostDelta(i, costToSolutionMap[currCompCost]); + auto candAccCost = + CS.getAdjustedAccCostDelta(i, costToSolutionMap[currCompCost], + valueToNodeMap, symbolToValueMap); + + // Don't ever try to apply a strictly useless candidate + if (candCompCost >= 0 && candAccCost >= 0.) { + continue; + } + + InstructionCost newCompCost = currCompCost + candCompCost; + double newAccCost = currAccCost + candAccCost; + + // if (FPOptPrint) + // llvm::errs() << "CS candidate " << i << " (" + // << candidate.value().desc + // << ") has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + + if (newCostToAccuracyMap.find(newCompCost) == + newCostToAccuracyMap.end() || + newCostToAccuracyMap[newCompCost] > newAccCost) { + newCostToAccuracyMap[newCompCost] = newAccCost; + newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; + newCostToSolutionMap[newCompCost].emplace_back(&CS, i); + // if (FPOptPrint) { + // llvm::errs() << "CS candidate " << i << " (" + // << candidate.value().desc + // << ") added; has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + // llvm::errs() << "Updating accuracy map (CS candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; + // } + } + } + } + + for (const auto &l : newCostToAccuracyMap) { + InstructionCost currCompCost = l.first; + double currAccCost = l.second; + + bool dominated = false; + for (const auto &r : newCostToAccuracyMap) { + InstructionCost otherCompCost = r.first; + double otherAccCost = r.second; + + if (currCompCost - otherCompCost > + std::fabs(FPOptCostDominanceThreshold * + GET_INSTRUCTION_COST(otherCompCost)) && + currAccCost - otherAccCost >= + std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { + // if (FPOptPrint) + // llvm::errs() << "CS candidate with computation cost: " + // << currCompCost + // << " and accuracy cost: " << currAccCost + // << " is dominated by candidate with computation + // cost:" + // << otherCompCost + // << " and accuracy cost: " << otherAccCost << "\n"; + dominated = true; + break; + } + } + + if (!dominated) { + prunedCostToAccuracyMap[currCompCost] = currAccCost; + prunedCostToSolutionMap[currCompCost] = + newCostToSolutionMap[currCompCost]; + } + } + + costToAccuracyMap = prunedCostToAccuracyMap; + costToSolutionMap = prunedCostToSolutionMap; + prunedCostToAccuracyMap.clear(); + prunedCostToSolutionMap.clear(); + + llvm::errs() << "##### Finished processing " << ++CSCounter << " of " + << CSs.size() << " CSs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + } + + json::Object jsonObj; + + json::Object costAccMap; + for (const auto &pair : costToAccuracyMap) { + costAccMap[std::to_string(GET_INSTRUCTION_COST(pair.first))] = + pair.second; + } + jsonObj["costToAccuracyMap"] = std::move(costAccMap); + + json::Object costSolMap; + for (const auto &pair : costToSolutionMap) { + json::Array stepsArray; + for (const auto &step : pair.second) { + json::Object stepObj; + stepObj["candidateIndex"] = static_cast(step.candidateIndex); + + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + stepObj["itemType"] = "CO"; + size_t index = aoPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } else if constexpr (std::is_same_v) { + stepObj["itemType"] = "CS"; + size_t index = accPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } + }, + step.item); + stepsArray.push_back(std::move(stepObj)); + } + costSolMap[std::to_string(GET_INSTRUCTION_COST(pair.first))] = + std::move(stepsArray); + } + jsonObj["costToSolutionMap"] = std::move(costSolMap); + + if (!FPOptRefineDPTable) { + std::error_code EC; + llvm::raw_fd_ostream cacheFile(cacheFilePath, EC, llvm::sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing cache file: " << EC.message() << "\n"; + } else { + cacheFile << llvm::formatv("{0:2}", + llvm::json::Value(std::move(jsonObj))) + << "\n"; + cacheFile.close(); + llvm::errs() << "DP tables cached to file.\n"; + } + } else if (!COs.empty() || !CSs.empty()) { + ValueToValueMapTy BaseVMap; + Function *BaseClone = CloneFunction(&F, BaseVMap); + runPoseidonFunctionSimplify(*BaseClone, OptimizationLevel::O3); + InstructionCost BaseCost = getCompCost(BaseClone, TTI); + BaseClone->eraseFromParent(); + + using CostMap = std::map; + using SolutionMap = std::map>; + CostMap refinedCostToAccuracyMap; + SolutionMap refinedCostToSolutionMap; + + for (const auto &pair : costToSolutionMap) { + const SmallVector &steps = pair.second; + + ValueToValueMapTy VMap; + Function *FClone = CloneFunction(&F, VMap); + + for (const auto &step : steps) { + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + auto &CO = *item; + Instruction *oldI = cast(CO.oldOutput); + Instruction *clonedOldI = cast(VMap[oldI]); + IRBuilder<> builder(clonedOldI->getParent(), + ++BasicBlock::iterator(clonedOldI)); + builder.setFastMathFlags(clonedOldI->getFastMathFlags()); + auto parsedNode = + parseHerbieExpr(CO.candidates[step.candidateIndex].expr, + valueToNodeMap, symbolToValueMap); + Value *newVal = parsedNode->getLLValue(builder, &VMap); + clonedOldI->replaceAllUsesWith(newVal); + } else if constexpr (std::is_same_v) { + auto &CS = *item; + PTCandidate &pt = CS.candidates[step.candidateIndex]; + pt.apply(*CS.subgraph, &VMap); + } else { + llvm_unreachable("accuracyDPSolver refine: unexpected step"); + } + }, + step.item); + } + + runPoseidonFunctionSimplify(*FClone, OptimizationLevel::O3); + InstructionCost NewTotal = getCompCost(FClone, TTI); + InstructionCost NewDelta = NewTotal - BaseCost; + FClone->eraseFromParent(); + + double accCost = costToAccuracyMap[pair.first]; + auto it = refinedCostToAccuracyMap.find(NewDelta); + if (it == refinedCostToAccuracyMap.end() || it->second > accCost) { + refinedCostToAccuracyMap[NewDelta] = accCost; + refinedCostToSolutionMap[NewDelta] = steps; + } + } + + // One-pass domination pruning + std::map refinedPrunedAcc; + std::map> refinedPrunedSol; + for (const auto &l : refinedCostToAccuracyMap) { + InstructionCost currCompCost = l.first; + double currAccCost = l.second; + bool dominated = false; + for (const auto &r : refinedCostToAccuracyMap) { + InstructionCost otherCompCost = r.first; + double otherAccCost = r.second; + if (currCompCost - otherCompCost > + std::fabs(FPOptCostDominanceThreshold * + GET_INSTRUCTION_COST(otherCompCost)) && + currAccCost - otherAccCost >= + std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { + dominated = true; + break; + } + } + if (!dominated) { + refinedPrunedAcc[currCompCost] = currAccCost; + refinedPrunedSol[currCompCost] = + refinedCostToSolutionMap[currCompCost]; + } + } + + costToAccuracyMap = std::move(refinedPrunedAcc); + costToSolutionMap = std::move(refinedPrunedSol); + + json::Object jsonObj; + + json::Object costAccMap; + for (const auto &p : costToAccuracyMap) { + costAccMap[std::to_string(GET_INSTRUCTION_COST(p.first))] = p.second; + } + jsonObj["costToAccuracyMap"] = std::move(costAccMap); + + json::Object costSolMap; + std::unordered_map aoPtrToIndex; + for (size_t i = 0; i < COs.size(); ++i) + aoPtrToIndex[&COs[i]] = i; + std::unordered_map accPtrToIndex; + for (size_t i = 0; i < CSs.size(); ++i) + accPtrToIndex[&CSs[i]] = i; + + for (const auto &p : costToSolutionMap) { + json::Array stepsArray; + for (const auto &step : p.second) { + json::Object stepObj; + stepObj["candidateIndex"] = static_cast(step.candidateIndex); + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + stepObj["itemType"] = "CO"; + size_t index = aoPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } else if constexpr (std::is_same_v) { + stepObj["itemType"] = "CS"; + size_t index = accPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } + }, + step.item); + stepsArray.push_back(std::move(stepObj)); + } + costSolMap[std::to_string(GET_INSTRUCTION_COST(p.first))] = + std::move(stepsArray); + } + jsonObj["costToSolutionMap"] = std::move(costSolMap); + + std::error_code EC; + llvm::raw_fd_ostream cacheFile(cacheFilePath, EC, llvm::sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing refined cache file: " << EC.message() + << "\n"; + } else { + cacheFile << llvm::formatv("{0:2}", + llvm::json::Value(std::move(jsonObj))) + << "\n"; + cacheFile.close(); + llvm::errs() << "Refined DP tables cached to file.\n"; + } + } + } + + if (FPOptPrint) { + if (FPOptShowTable) { + llvm::errs() << "\n*** DP Table ***\n"; + for (const auto &pair : costToAccuracyMap) { + if (!FPOptShowTableCosts.empty()) { + bool shouldPrint = false; + for (auto selectedCost : FPOptShowTableCosts) + if (pair.first == selectedCost) { + shouldPrint = true; + break; + } + if (!shouldPrint) + continue; + } + + llvm::errs() << "Computation cost: " << pair.first + << ", Accuracy cost: " << pair.second << "\n"; + llvm::errs() << "\tSolution steps: \n"; + for (const auto &step : costToSolutionMap[pair.first]) { + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + llvm::errs() + << "\t\t" << item->expr << " --(" << step.candidateIndex + << ")-> " << item->candidates[step.candidateIndex].expr + << "\n"; + } else if constexpr (std::is_same_v) { + llvm::errs() << "\t\tCS: " + << item->candidates[step.candidateIndex].desc + << " (#" << step.candidateIndex << ")\n"; + if (FPOptShowPTDetails) { + auto &candidate = item->candidates[step.candidateIndex]; + for (const auto &change : candidate.changes) { + llvm::errs() + << "\t\t\tChanging from " + << getPrecisionChangeTypeString(change.oldType) + << " to " + << getPrecisionChangeTypeString(change.newType) + << ":\n"; + for (auto *val : change.nodes) { + llvm::errs() << "\t\t\t\t" << *val->value << "\n"; + } + } + } + } else { + llvm_unreachable( + "accuracyDPSolver: Unexpected type of solution step"); + } + }, + step.item); + } + } + llvm::errs() << "*** End of DP Table ***\n\n"; + } + } + + std::string budgetsFile = FPOptCachePath + "/budgets.txt"; + if (!llvm::sys::fs::exists(budgetsFile)) { + std::string budgetsStr; + for (const auto &pair : costToAccuracyMap) { + budgetsStr += std::to_string(GET_INSTRUCTION_COST(pair.first)) + ","; + } + + if (!budgetsStr.empty()) + budgetsStr.pop_back(); + + std::error_code EC; + llvm::raw_fd_ostream Out(budgetsFile, EC, llvm::sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error opening " << budgetsFile << ": " << EC.message() + << "\n"; + } else { + Out << budgetsStr; + } + } + + llvm::errs() << "Critical computation cost range: [" + << costToAccuracyMap.begin()->first << ", " + << costToAccuracyMap.rbegin()->first << "]\n"; + + llvm::errs() << "DP table contains " << costToAccuracyMap.size() + << " entries.\n"; + + double totalCandidateCompositions = 1.0; + for (const auto &CO : COs) { + // +1 for the "do nothing" possibility + totalCandidateCompositions *= CO.candidates.size() + 1; + } + for (const auto &CS : CSs) { + totalCandidateCompositions *= CS.candidates.size() + 1; + } + llvm::errs() << "Total candidate compositions: " << totalCandidateCompositions + << "\n"; + + if (costToSolutionMap.find(0) != costToSolutionMap.end()) { + if (costToSolutionMap[0].empty()) { + llvm::errs() << "WARNING: No-op solution (utilized cost budget = 0) is " + "considered Pareto-optimal.\n"; + } + } + + double minAccCost = std::numeric_limits::infinity(); + InstructionCost bestCompCost = 0; + + if (errorTol > 0.0) { + InstructionCost minCompCost = std::numeric_limits::max(); + bool foundSolution = false; + + for (const auto &pair : costToAccuracyMap) { + InstructionCost compCost = pair.first; + double accCost = pair.second; + + if (accCost <= errorTol) { + if (compCost < minCompCost) { + minCompCost = compCost; + minAccCost = accCost; + bestCompCost = compCost; + foundSolution = true; + } + } + } + + if (!foundSolution) { + llvm::errs() << "No solution found that meets accuracy tolerance " + << errorTol << "!\n"; + llvm::errs() << "Best achievable accuracy in DP table: " + << costToAccuracyMap.begin()->second << "\n"; + return changed; + } + + llvm::errs() << "Found solution meeting accuracy tolerance " << errorTol + << "\n"; + llvm::errs() << "Accuracy cost achieved: " << minAccCost << "\n"; + llvm::errs() << "Computation cost required: " << bestCompCost << "\n"; + } else { + for (const auto &pair : costToAccuracyMap) { + InstructionCost compCost = pair.first; + double accCost = pair.second; + + if (compCost <= FPOptComputationCostBudget && accCost < minAccCost) { + minAccCost = accCost; + bestCompCost = compCost; + } + } + + if (minAccCost == std::numeric_limits::infinity()) { + llvm::errs() << "No solution found within the computation cost budget!\n"; + return changed; + } + + llvm::errs() << "Minimum accuracy cost within budget: " << minAccCost + << "\n"; + llvm::errs() << "Computation cost budget used: " << bestCompCost << "\n"; + } + + assert(costToSolutionMap.find(bestCompCost) != costToSolutionMap.end() && + "FPOpt DP solver: expected a solution!"); + + llvm::errs() << "\n!!! DP solver: Applying solution ... !!!\n"; + for (const auto &solution : costToSolutionMap[bestCompCost]) { + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + llvm::errs() << "Applying solution for " << item->expr << " --(" + << solution.candidateIndex << ")-> " + << item->candidates[solution.candidateIndex].expr + << "\n"; + item->apply(solution.candidateIndex, valueToNodeMap, + symbolToValueMap); + } else if constexpr (std::is_same_v) { + llvm::errs() << "Applying solution for CS: " + << item->candidates[solution.candidateIndex].desc + << " (#" << solution.candidateIndex << ")\n"; + item->apply(solution.candidateIndex); + } else { + llvm_unreachable( + "accuracyDPSolver: Unexpected type of solution step"); + } + }, + solution.item); + changed = true; + } + llvm::errs() << "!!! DP Solver: Solution applied !!!\n\n"; + + return changed; +} \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonSolvers.h b/enzyme/Enzyme/Poseidon/PoseidonSolvers.h new file mode 100644 index 000000000000..8fb3445d3227 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonSolvers.h @@ -0,0 +1,54 @@ +//=- PoseidonSolvers.h - Solver utilities for Poseidon --------------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares solver-related utilities for the Poseidon optimization +// pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_SOLVERS_H +#define ENZYME_POSEIDON_SOLVERS_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CommandLine.h" + +#include +#include + +#include "PoseidonTypes.h" + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptSolverType; +extern llvm::cl::opt FPOptComputationCostBudget; +extern llvm::cl::opt FPOptShowTable; +extern llvm::cl::list FPOptShowTableCosts; +extern llvm::cl::opt FPOptEarlyPrune; +extern llvm::cl::opt FPOptCostDominanceThreshold; +extern llvm::cl::opt FPOptAccuracyDominanceThreshold; +} + +bool accuracyGreedySolver( + SmallVector &COs, + SmallVector &CSs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +bool accuracyDPSolver( + Function &F, const TargetTransformInfo &TTI, + SmallVector &COs, + SmallVector &CSs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + double errorTol = 0.0); + +#endif // ENZYME_POSEIDON_SOLVERS_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp b/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp new file mode 100644 index 000000000000..6900cd1fef4f --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp @@ -0,0 +1,959 @@ +//=- PoseidonNodes.cpp - AST node implementations for Poseidon ------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST node classes for representing floating-point +// expressions in the Poseidon optimization pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +#include +#include +#include + +#include "../Utils.h" +#include "Poseidon.h" +#include "PoseidonHerbieUtils.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +FPNode::NodeType FPNode::getType() const { return ntype; } + +void FPNode::addOperand(std::shared_ptr operand) { + operands.push_back(operand); +} + +bool FPNode::hasSymbol() const { + std::string msg = "Unexpected invocation of `hasSymbol` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +std::string FPNode::toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth) { + std::string msg = "Unexpected invocation of `toFullExpression` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +unsigned FPNode::getMPFRPrec() const { + if (dtype == "f16") + return 11; + if (dtype == "f32") + return 24; + if (dtype == "f64") + return 53; + std::string msg = + "getMPFRPrec: operator " + op + " has unknown dtype " + dtype; + llvm_unreachable(msg.c_str()); +} + +void FPNode::updateBounds(double lower, double upper) { + std::string msg = "Unexpected invocation of `updateBounds` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +double FPNode::getLowerBound() const { + std::string msg = "Unexpected invocation of `getLowerBound` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +double FPNode::getUpperBound() const { + std::string msg = "Unexpected invocation of `getUpperBound` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +Value *FPNode::getLLValue(IRBuilder<> &builder, const ValueToValueMapTy *VMap) { + Module *M = builder.GetInsertBlock()->getModule(); + if (op == "if") { + Value *condValue = operands[0]->getLLValue(builder, VMap); + Value *trueValue = operands[1]->getLLValue(builder, VMap); + Value *falseValue = operands[2]->getLLValue(builder, VMap); + + return builder.CreateSelect(condValue, trueValue, falseValue, + "herbie.select"); + } + + SmallVector operandValues; + for (auto operand : operands) { + Value *val = operand->getLLValue(builder, VMap); + assert(val && "Operand produced a null value!"); + operandValues.push_back(val); + } + + static const std::unordered_map< + std::string, std::function &, Module *, + const SmallVectorImpl &)>> + opMap = { + {"neg", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &ops) + -> Value * { return b.CreateFNeg(ops[0], "herbie.neg"); }}, + {"+", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFAdd(ops[0], ops[1], "herbie.add"); + }}, + {"-", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFSub(ops[0], ops[1], "herbie.sub"); + }}, + {"*", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFMul(ops[0], ops[1], "herbie.mul"); + }}, + {"/", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFDiv(ops[0], ops[1], "herbie.div"); + }}, + {"fmin", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateBinaryIntrinsic(Intrinsic::minnum, ops[0], ops[1], + nullptr, "herbie.fmin"); + }}, + {"fmax", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateBinaryIntrinsic(Intrinsic::maxnum, ops[0], ops[1], + nullptr, "herbie.fmax"); + }}, + {"sin", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::sin, ops[0], nullptr, + "herbie.sin"); + }}, + {"cos", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::cos, ops[0], nullptr, + "herbie.cos"); + }}, + {"tan", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::tan, ops[0], nullptr, + "herbie.tan"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "tan" : "tanf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee tanFunc = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(tanFunc, {ops[0]}, "herbie.tan"); +#endif + }}, + {"exp", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::exp, ops[0], nullptr, + "herbie.exp"); + }}, + {"expm1", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "expm1" : "expm1f"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.expm1"); + }}, + {"log", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::log, ops[0], nullptr, + "herbie.log"); + }}, + {"log1p", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "log1p" : "log1pf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.log1p"); + }}, + {"sqrt", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::sqrt, ops[0], nullptr, + "herbie.sqrt"); + }}, + {"cbrt", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "cbrt" : "cbrtf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.cbrt"); + }}, + {"pow", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + // Use powi when possible + if (auto *CF = dyn_cast(ops[1])) { + double value = CF->getValueAPF().convertToDouble(); + if (value == std::floor(value) && value >= INT_MIN && + value <= INT_MAX) { + int exp = static_cast(value); + SmallVector overloadedTypes = { + ops[0]->getType(), Type::getInt32Ty(M->getContext())}; + Function *powiFunc = getIntrinsicDeclaration( + M, Intrinsic::powi, overloadedTypes); + Value *exponent = + ConstantInt::get(Type::getInt32Ty(M->getContext()), exp); + return b.CreateCall(powiFunc, {ops[0], exponent}, + "herbie.powi"); + } + } + + return b.CreateBinaryIntrinsic(Intrinsic::pow, ops[0], ops[1], + nullptr, "herbie.pow"); + }}, + {"fma", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateIntrinsic(Intrinsic::fma, {ops[0]->getType()}, + {ops[0], ops[1], ops[2]}, nullptr, + "herbie.fma"); + }}, + {"fabs", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::fabs, ops[0], nullptr, + "herbie.fabs"); + }}, + {"hypot", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "hypot" : "hypotf"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.hypot"); + }}, + {"asin", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::asin, ops[0], nullptr, + "herbie.asin"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "asin" : "asinf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.asin"); +#endif + }}, + {"acos", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::acos, ops[0], nullptr, + "herbie.acos"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "acos" : "acosf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.acos"); +#endif + }}, + {"atan", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::atan, ops[0], nullptr, + "herbie.atan"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "atan" : "atanf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.atan"); +#endif + }}, + {"atan2", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateBinaryIntrinsic(Intrinsic::atan2, ops[0], ops[1], + nullptr, "herbie.atan2"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "atan2" : "atan2f"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.atan2"); +#endif + }}, + {"sinh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::sinh, ops[0], nullptr, + "herbie.sinh"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "sinh" : "sinhf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.sinh"); +#endif + }}, + {"cosh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::cosh, ops[0], nullptr, + "herbie.cosh"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "cosh" : "coshf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.cosh"); +#endif + }}, + {"tanh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::tanh, ops[0], nullptr, + "herbie.tanh"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "tanh" : "tanhf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.tanh"); +#endif + }}, + {"copysign", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateBinaryIntrinsic(Intrinsic::copysign, ops[0], ops[1], + nullptr, "herbie.copysign"); + }}, + {"rem", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFRem(ops[0], ops[1], "herbie.rem"); + }}, + {"ceil", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::ceil, ops[0], nullptr, + "herbie.ceil"); + }}, + {"floor", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::floor, ops[0], nullptr, + "herbie.floor"); + }}, + {"exp2", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::exp2, ops[0], nullptr, + "herbie.exp2"); + }}, + {"log10", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::log10, ops[0], nullptr, + "herbie.log10"); + }}, + {"log2", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::log2, ops[0], nullptr, + "herbie.log2"); + }}, + {"rint", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::rint, ops[0], nullptr, + "herbie.rint"); + }}, + {"round", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::round, ops[0], nullptr, + "herbie.round"); + }}, + {"trunc", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::trunc, ops[0], nullptr, + "herbie.trunc"); + }}, + {"fdim", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "fdim" : "fdimf"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.fdim"); + }}, + {"fmod", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "fmod" : "fmodf"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.fmod"); + }}, + {"remainder", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = + Ty->isDoubleTy() ? "remainder" : "remainderf"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.remainder"); + }}, + {"erf", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "erf" : "erff"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.erf"); + }}, + {"lgamma", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "lgamma" : "lgammaf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.lgamma"); + }}, + {"tgamma", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "tgamma" : "tgammaf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.tgamma"); + }}, + {"asinh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "asinh" : "asinhf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.asinh"); + }}, + {"acosh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "acosh" : "acoshf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.acosh"); + }}, + {"atanh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "atanh" : "atanhf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.atanh"); + }}, + {"==", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOEQ(ops[0], ops[1], "herbie.eq"); + }}, + {"!=", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpONE(ops[0], ops[1], "herbie.ne"); + }}, + {"<", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOLT(ops[0], ops[1], "herbie.lt"); + }}, + {">", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOGT(ops[0], ops[1], "herbie.gt"); + }}, + {"<=", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOLE(ops[0], ops[1], "herbie.le"); + }}, + {">=", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOGE(ops[0], ops[1], "herbie.ge"); + }}, + {"and", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateAnd(ops[0], ops[1], "herbie.and"); + }}, + {"or", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &ops) + -> Value * { return b.CreateOr(ops[0], ops[1], "herbie.or"); }}, + {"not", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &ops) + -> Value * { return b.CreateNot(ops[0], "herbie.not"); }}, + {"TRUE", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantInt::getTrue(b.getContext()); }}, + {"FALSE", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantInt::getFalse(b.getContext()); }}, + {"PI", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantFP::get(b.getDoubleTy(), M_PI); }}, + {"E", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantFP::get(b.getDoubleTy(), M_E); }}, + {"INFINITY", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &) -> Value * { + return ConstantFP::getInfinity(b.getDoubleTy(), false); + }}, + {"NaN", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantFP::getNaN(b.getDoubleTy()); }}, + }; + + auto it = opMap.find(op); + if (it != opMap.end()) + return it->second(builder, M, operandValues); + else { + std::string msg = "FPNode getLLValue: Unexpected operator " + op; + llvm_unreachable(msg.c_str()); + } +} + +bool FPLLValue::hasSymbol() const { return !symbol.empty(); } + +std::string FPLLValue::toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth) { + // Check if this value is an input to the current subgraph + if (subgraphInputs.contains(value)) { + assert(hasSymbol() && "FPLLValue has no symbol!"); + return symbol; + } else { + assert(!operands.empty() && "FPNode has no operands!"); + + if (depth > FPOptMaxExprDepth) { + std::string msg = "Expression depth exceeded maximum allowed depth of " + + std::to_string(FPOptMaxExprDepth) + " for " + op + + "; consider disabling loop unrolling"; + + llvm_unreachable(msg.c_str()); + } + + std::string expr = "(" + (op == "neg" ? "-" : op); + for (auto operand : operands) { + expr += " " + operand->toFullExpression(valueToNodeMap, subgraphInputs, + depth + 1); + } + expr += ")"; + return expr; + } +} + +void FPLLValue::updateBounds(double lower, double upper) { + lb = std::min(lb, lower); + ub = std::max(ub, upper); + if (FPOptPrint) + llvm::errs() << "Updated bounds for " << *value << ": [" << lb << ", " << ub + << "]\n"; +} + +double FPLLValue::getLowerBound() const { return lb; } +double FPLLValue::getUpperBound() const { return ub; } + +Value *FPLLValue::getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap) { + if (VMap) { + assert(VMap->count(value) && "FPLLValue not found in passed-in VMap!"); + return VMap->lookup(value); + } + return value; +} + +bool FPLLValue::classof(const FPNode *N) { + return N->getType() == NodeType::LLValue; +} + +std::string FPConst::toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth) { + return strValue; +} + +bool FPConst::hasSymbol() const { + std::string msg = "Unexpected invocation of `hasSymbol` on an FPConst"; + llvm_unreachable(msg.c_str()); +} + +void FPConst::updateBounds(double lower, double upper) { return; } + +double FPConst::getLowerBound() const { + if (strValue == "+inf.0") { + return std::numeric_limits::infinity(); + } else if (strValue == "-inf.0") { + return -std::numeric_limits::infinity(); + } + + double constantValue; + size_t div = strValue.find('/'); + + if (div != std::string::npos) { + std::string numerator = strValue.substr(0, div); + std::string denominator = strValue.substr(div + 1); + double num = stringToDouble(numerator); + double denom = stringToDouble(denominator); + + constantValue = num / denom; + } else { + constantValue = stringToDouble(strValue); + } + + return constantValue; +} + +double FPConst::getUpperBound() const { return getLowerBound(); } + +Value *FPConst::getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap) { + Type *Ty; + if (dtype == "f64") { + Ty = builder.getDoubleTy(); + } else if (dtype == "f32") { + Ty = builder.getFloatTy(); + } else { + std::string msg = "FPConst getValue: Unexpected dtype: " + dtype; + llvm_unreachable(msg.c_str()); + } + if (strValue == "+inf.0") { + return ConstantFP::getInfinity(Ty, false); + } else if (strValue == "-inf.0") { + return ConstantFP::getInfinity(Ty, true); + } + + double constantValue; + size_t div = strValue.find('/'); + + if (div != std::string::npos) { + std::string numerator = strValue.substr(0, div); + std::string denominator = strValue.substr(div + 1); + double num = stringToDouble(numerator); + double denom = stringToDouble(denominator); + + constantValue = num / denom; + } else { + constantValue = stringToDouble(strValue); + } + + // if (FPOptPrint) + // llvm::errs() << "Returning " << strValue << " as " << dtype + // << " constant: " << constantValue << "\n"; + return ConstantFP::get(Ty, constantValue); +} + +bool FPConst::classof(const FPNode *N) { + return N->getType() == NodeType::Const; +} + +void CandidateOutput::apply( + size_t candidateIndex, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + // 4) parse the output string solution from herbieland + // 5) convert into a solution in llvm vals/instructions + + // if (FPOptPrint) + // llvm::errs() << "Parsing Herbie output: " << herbieOutput << "\n"; + auto parsedNode = parseHerbieExpr(candidates[candidateIndex].expr, + valueToNodeMap, symbolToValueMap); + // if (FPOptPrint) + // llvm::errs() << "Parsed Herbie output: " + // << parsedNode->toFullExpression(valueToNodeMap) << "\n"; + + IRBuilder<> builder(cast(oldOutput)->getParent(), + ++BasicBlock::iterator(cast(oldOutput))); + builder.setFastMathFlags(cast(oldOutput)->getFastMathFlags()); + + // auto *F = cast(oldOutput)->getParent()->getParent(); + // llvm::errs() << "Before: " << *F << "\n"; + Value *newOutput = parsedNode->getLLValue(builder); + assert(newOutput && "Failed to get value from parsed node"); + + oldOutput->replaceAllUsesWith(newOutput); + symbolToValueMap[valueToNodeMap[oldOutput]->symbol] = newOutput; + valueToNodeMap[newOutput] = std::make_shared( + newOutput, "__no", valueToNodeMap[oldOutput]->dtype); + + for (auto *I : erasableInsts) { + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + subgraph->operations.remove(I); // Avoid a second removal + cast(valueToNodeMap[I].get())->value = nullptr; + } + + // llvm::errs() << "After: " << *F << "\n"; + + subgraph->outputs_rewritten++; +} + +// Lower is better +InstructionCost CandidateOutput::getCompCostDelta(size_t candidateIndex) { + InstructionCost erasableCost = 0; + + for (auto *I : erasableInsts) { + erasableCost += getInstructionCompCost(I, *TTI); + } + + return (candidates[candidateIndex].CompCost - erasableCost) * executions; +} + +void CandidateOutput::findErasableInstructions() { + SmallPtrSet visited; + SmallPtrSet exprInsts; + collectExprInsts(oldOutput, subgraph->inputs, exprInsts, visited); + visited.clear(); + + SetVector instsToProcess(exprInsts.begin(), exprInsts.end()); + + SmallVector instsToProcessSorted; + reverseTopoSort(instsToProcess, instsToProcessSorted); + + // `oldOutput` is trivially erasable + erasableInsts.clear(); + erasableInsts.insert(cast(oldOutput)); + + for (auto *I : instsToProcessSorted) { + if (erasableInsts.contains(I)) + continue; + + bool usedOutside = false; + for (auto user : I->users()) { + if (auto *userI = dyn_cast(user)) { + if (erasableInsts.contains(userI)) { + continue; + } + } + // If the user is not an intruction or the user instruction is not an + // erasable instruction, then the current instruction is not erasable + // llvm::errs() << "Can't erase " << *I << " because of " << *user << + // "\n"; + usedOutside = true; + break; + } + + if (!usedOutside) { + erasableInsts.insert(I); + } + } + + // llvm::errs() << "Erasable instructions:\n"; + // for (auto *I : erasableInsts) { + // llvm::errs() << *I << "\n"; + // } + // llvm::errs() << "End of erasable instructions\n"; +} + +bool CandidateSubgraph::CacheKey::operator==(const CacheKey &other) const { + return candidateIndex == other.candidateIndex && + CandidateOutputs == other.CandidateOutputs; +} + +std::size_t +CandidateSubgraph::CacheKeyHash::operator()(const CacheKey &key) const { + std::size_t seed = std::hash{}(key.candidateIndex); + for (const auto *ao : key.CandidateOutputs) { + seed ^= std::hash{}(ao) + 0x9e3779b9 + + (seed << 6) + (seed >> 2); + } + return seed; +} + +void CandidateSubgraph::apply(size_t candidateIndex) { + if (candidateIndex >= candidates.size()) { + llvm_unreachable("Invalid candidate index"); + } + + // Traverse all the instructions to be changed precisions in a + // topological order with respect to operand dependencies. Insert FP casts + // between llvm::Value inputs and first level of instructions to be changed. + // Restore precisions of the last level of instructions to be changed. + candidates[candidateIndex].apply(*subgraph); +} + +// Lower is better +InstructionCost CandidateSubgraph::getCompCostDelta(size_t candidateIndex) { + // TODO: adjust this based on erasured instructions + return (candidates[candidateIndex].CompCost - initialCompCost) * executions; +} + +// Lower is better +double CandidateSubgraph::getAccCostDelta(size_t candidateIndex) { + return candidates[candidateIndex].accuracyCost - initialAccCost; +} + +// Lower is better +double CandidateOutput::getAccCostDelta(size_t candidateIndex) { + return candidates[candidateIndex].accuracyCost - initialAccCost; +} + +InstructionCost CandidateSubgraph::getAdjustedCompCostDelta( + size_t candidateIndex, const SmallVectorImpl &steps) { + CandidateOutputSet CandidateOutputs; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + if ((*ptr)->subgraph == subgraph) { + CandidateOutputs.insert(*ptr); + } + } + } + + CacheKey key{candidateIndex, CandidateOutputs}; + + auto cacheIt = compCostDeltaCache.find(key); + if (cacheIt != compCostDeltaCache.end()) { + return cacheIt->second; + } + + Subgraph newSubgraph = *this->subgraph; + + for (auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + const auto &CO = **ptr; + if (CO.subgraph == subgraph) { + // Eliminate erasadable instructions from the adjusted CS + newSubgraph.operations.remove_if( + [&CO](Instruction *I) { return CO.erasableInsts.contains(I); }); + newSubgraph.outputs.remove(cast(CO.oldOutput)); + } + } + } + + // If all outputs are rewritten, then the adjusted CS is empty + if (newSubgraph.outputs.empty()) { + compCostDeltaCache[key] = 0; + return 0; + } + + InstructionCost initialCompCost = + getCompCost({newSubgraph.outputs.begin(), newSubgraph.outputs.end()}, + newSubgraph.inputs, TTI); + + InstructionCost candidateCompCost = + getCompCost(newSubgraph, TTI, candidates[candidateIndex]); + + InstructionCost adjustedCostDelta = + (candidateCompCost - initialCompCost) * executions; + // llvm::errs() << "Initial cost: " << initialCompCost << "\n"; + // llvm::errs() << "Candidate cost: " << candidateCompCost << "\n"; + // llvm::errs() << "Num executions: " << executions << "\n"; + // llvm::errs() << "Adjusted cost delta: " << adjustedCostDelta << "\n\n"; + + compCostDeltaCache[key] = adjustedCostDelta; + return adjustedCostDelta; +} + +double CandidateSubgraph::getAdjustedAccCostDelta( + size_t candidateIndex, SmallVectorImpl &steps, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + CandidateOutputSet CandidateOutputs; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + if ((*ptr)->subgraph == subgraph) { + CandidateOutputs.insert(*ptr); + } + } + } + + CacheKey key{candidateIndex, CandidateOutputs}; + + auto cacheIt = accCostDeltaCache.find(key); + if (cacheIt != accCostDeltaCache.end()) { + return cacheIt->second; + } + + double totalCandidateAccCost = 0.0; + double totalInitialAccCost = 0.0; + + // Collect erased output nodes + SmallPtrSet stepNodes; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + const auto &CO = **ptr; + if (CO.subgraph == subgraph) { + auto it = valueToNodeMap.find(CO.oldOutput); + assert(it != valueToNodeMap.end() && it->second); + stepNodes.insert(it->second.get()); + } + } + } + + // Iterate over all output nodes and sum costs for nodes not erased + for (auto &[node, cost] : perOutputInitialAccCost) { + if (!stepNodes.count(node)) { + totalInitialAccCost += cost; + } + } + + for (auto &[node, cost] : candidates[candidateIndex].perOutputAccCost) { + if (!stepNodes.count(node)) { + totalCandidateAccCost += cost; + } + } + + double adjustedAccCostDelta = totalCandidateAccCost - totalInitialAccCost; + + accCostDeltaCache[key] = adjustedAccCostDelta; + return adjustedAccCostDelta; +} diff --git a/enzyme/Enzyme/Poseidon/PoseidonTypes.h b/enzyme/Enzyme/Poseidon/PoseidonTypes.h new file mode 100644 index 000000000000..def99cb727d5 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonTypes.h @@ -0,0 +1,252 @@ +//=- PoseidonTypes.h - AST node declarations for Poseidon -----------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the AST node classes for representing floating-point +// expressions in the Poseidon optimization pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_TYPES_H +#define ENZYME_POSEIDON_TYPES_H + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#include +#include +#include +#include +#include +#include + +#include "PoseidonUtils.h" + +#include "PoseidonPrecUtils.h" + +using namespace llvm; + +class FPNode { +public: + enum class NodeType { Node, LLValue, Const }; + +private: + const NodeType ntype; + +public: + std::string op; + std::string dtype; + std::string symbol; + SmallVector, 2> operands; + double sens = + std::numeric_limits::quiet_NaN(); // Sensitivity score, sum of + // |grad * value| + double grad = + std::numeric_limits::quiet_NaN(); // Sum of gradients (not abs) + unsigned executions = 0; + + explicit FPNode(const std::string &op, const std::string &dtype) + : ntype(NodeType::Node), op(op), dtype(dtype) {} + explicit FPNode(NodeType ntype, const std::string &op, + const std::string &dtype) + : ntype(ntype), op(op), dtype(dtype) {} + virtual ~FPNode() = default; + + NodeType getType() const; + void addOperand(std::shared_ptr operand); + virtual bool hasSymbol() const; + virtual std::string toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth = 0); + unsigned getMPFRPrec() const; + virtual void updateBounds(double lower, double upper); + virtual double getLowerBound() const; + virtual double getUpperBound() const; + virtual Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr); +}; + +class FPLLValue : public FPNode { +private: + double lb = std::numeric_limits::infinity(); + double ub = -std::numeric_limits::infinity(); + +public: + Value *value; + + explicit FPLLValue(Value *value, const std::string &op, + const std::string &dtype) + : FPNode(NodeType::LLValue, op, dtype), value(value) {} + + bool hasSymbol() const override; + std::string toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth = 0) override; + void updateBounds(double lower, double upper) override; + double getLowerBound() const override; + double getUpperBound() const override; + Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr) override; + + static bool classof(const FPNode *N); +}; + +class FPConst : public FPNode { +private: + std::string strValue; + +public: + explicit FPConst(const std::string &strValue, const std::string &dtype) + : FPNode(NodeType::Const, "__const", dtype), strValue(strValue) {} + + std::string toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth = 0) override; + bool hasSymbol() const override; + void updateBounds(double lower, double upper) override; + double getLowerBound() const override; + double getUpperBound() const override; + Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr) override; + + static bool classof(const FPNode *N); +}; + +struct Subgraph { + SetVector inputs; + SetVector outputs; + SetVector operations; + size_t outputs_rewritten = 0; + + Subgraph() = default; + explicit Subgraph(SetVector inputs, SetVector outputs, + SetVector operations) + : inputs(inputs), outputs(outputs), operations(operations) {} +}; + +struct SolutionStep; + +struct RewriteCandidate { + InstructionCost CompCost = std::numeric_limits::max(); + double herbieCost = std::numeric_limits::quiet_NaN(); + double herbieAccuracy = std::numeric_limits::quiet_NaN(); + double accuracyCost = std::numeric_limits::quiet_NaN(); + std::string expr; + + RewriteCandidate(double cost, double accuracy, std::string expression) + : herbieCost(cost), herbieAccuracy(accuracy), expr(expression) {} +}; + +class CandidateOutput { +public: + Subgraph *subgraph; + Value *oldOutput; + std::string expr; + double grad = std::numeric_limits::quiet_NaN(); + unsigned executions = 0; + const TargetTransformInfo *TTI = nullptr; + double initialAccCost = std::numeric_limits::quiet_NaN(); + InstructionCost initialCompCost = + std::numeric_limits::quiet_NaN(); + double initialHerbieCost = std::numeric_limits::quiet_NaN(); + double initialHerbieAccuracy = std::numeric_limits::quiet_NaN(); + SmallVector candidates; + SmallPtrSet erasableInsts; + + explicit CandidateOutput(Subgraph &subgraph, Value *oldOutput, + std::string expr, double grad, unsigned executions, + const TargetTransformInfo &TTI) + : subgraph(&subgraph), oldOutput(oldOutput), expr(expr), grad(grad), + executions(executions), TTI(&TTI) { + initialCompCost = getCompCost({oldOutput}, subgraph.inputs, TTI); + findErasableInstructions(); + } + + void + apply(size_t candidateIndex, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + InstructionCost getCompCostDelta(size_t candidateIndex); + double getAccCostDelta(size_t candidateIndex); + +private: + void findErasableInstructions(); +}; + +class CandidateSubgraph { +public: + Subgraph *subgraph; + const TargetTransformInfo &TTI; + double initialAccCost = std::numeric_limits::quiet_NaN(); + InstructionCost initialCompCost = + std::numeric_limits::quiet_NaN(); + unsigned executions = 0; + std::unordered_map perOutputInitialAccCost; + SmallVector candidates; + + using CandidateOutputSet = std::set; + struct CacheKey { + size_t candidateIndex; + CandidateOutputSet CandidateOutputs; + bool operator==(const CacheKey &other) const; + }; + + struct CacheKeyHash { + std::size_t operator()(const CacheKey &key) const; + }; + + std::unordered_map + compCostDeltaCache; + std::unordered_map accCostDeltaCache; + + explicit CandidateSubgraph(Subgraph &subgraph, const TargetTransformInfo &TTI) + : subgraph(&subgraph), TTI(TTI) { + initialCompCost = + getCompCost({subgraph.outputs.begin(), subgraph.outputs.end()}, + subgraph.inputs, TTI); + } + + void apply(size_t candidateIndex); + InstructionCost getCompCostDelta(size_t candidateIndex); + double getAccCostDelta(size_t candidateIndex); + InstructionCost + getAdjustedCompCostDelta(size_t candidateIndex, + const SmallVectorImpl &steps); + double getAdjustedAccCostDelta( + size_t candidateIndex, SmallVectorImpl &steps, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); +}; + +struct SolutionStep { + std::variant item; + size_t candidateIndex; + + SolutionStep(CandidateOutput *ao_, size_t idx) + : item(ao_), candidateIndex(idx) {} + SolutionStep(CandidateSubgraph *acc_, size_t idx) + : item(acc_), candidateIndex(idx) {} +}; + +void getSampledPoints( + ArrayRef inputs, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints); + +void getSampledPoints( + const std::string &expr, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints); + +#endif // ENZYME_POSEIDON_TYPES_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp new file mode 100644 index 000000000000..2b675985e080 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp @@ -0,0 +1,1022 @@ +//=- PoseidonUtils.cpp - Utility functions for Poseidon optimization pass --=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utility functions for the Poseidon floating-point +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include "llvm/Analysis/TargetTransformInfo.h" + +#include "llvm/IR/Function.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Verifier.h" + +#include "llvm/Passes/PassBuilder.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/raw_ostream.h" + +#include "llvm/Pass.h" + +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +using namespace llvm; + +extern "C" { +extern cl::opt FPOptPrint; +cl::opt + FPOptCostModelPath("fpopt-cost-model-path", cl::init(""), cl::Hidden, + cl::desc("Use a custom cost model in the FPOpt pass")); +cl::opt + FPOptNumSamples("fpopt-num-samples", cl::init(1024), cl::Hidden, + cl::desc("Number of sampled points for input hypercube")); +cl::opt + FPOptRandomSeed("fpopt-random-seed", cl::init(239778888), cl::Hidden, + cl::desc("The random seed used in the FPOpt pass")); +} + +void runPoseidonFunctionSimplify(Function &F, OptimizationLevel Level) { + + PassBuilder PB; + LoopAnalysisManager LAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + ModuleAnalysisManager MAM; + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + if (verifyFunction(F, &llvm::errs())) { + llvm_unreachable("Poseidon intermediate function failed verification"); + } + + FunctionPassManager FPM = + PB.buildFunctionSimplificationPipeline(Level, ThinOrFullLTOPhase::None); + (void)FPM.run(F, FAM); +} + +static const std::unordered_set LibmFuncs = { + "sin", "cos", "tan", "asin", "acos", "atan", "atan2", + "sinh", "cosh", "tanh", "asinh", "acosh", "atanh", "exp", + "log", "sqrt", "cbrt", "pow", "powi", "fabs", "fma", + "hypot", "expm1", "log1p", "ceil", "floor", "erf", "exp2", + "lgamma", "log10", "log2", "rint", "round", "tgamma", "trunc", + "copysign", "fdim", "fmod", "remainder"}; + +double getOneULP(double value) { + assert(!std::isnan(value) && !std::isinf(value)); + + double next = std::nextafter(value, std::numeric_limits::infinity()); + double ulp = std::fabs(next - value); + + return ulp; +} + +std::string getLibmFunctionForPrecision(StringRef funcName, Type *newType) { + std::string baseName = funcName.str(); + if (baseName.back() == 'f' || baseName.back() == 'l') { + baseName.pop_back(); + } + + if (LibmFuncs.count(baseName)) { + if (newType->isFloatTy()) { + return baseName + "f"; + } else if (newType->isDoubleTy()) { + return baseName; + } else if (newType->isFP128Ty() || newType->isX86_FP80Ty()) { + return baseName + "l"; + } + } + + return ""; +} + +double stringToDouble(const std::string &str) { + char *end; + errno = 0; + double result = std::strtod(str.c_str(), &end); + + if (errno == ERANGE) { + if (result == HUGE_VAL) { + result = std::numeric_limits::infinity(); + } else if (result == -HUGE_VAL) { + result = -std::numeric_limits::infinity(); + } + } + + return result; // Denormalized values are fine +} + +void topoSort(const SetVector &insts, + SmallVectorImpl &instsSorted) { + SmallPtrSet visited; + SmallPtrSet onStack; + + std::function dfsVisit = [&](Instruction *I) { + if (visited.count(I)) + return; + visited.insert(I); + onStack.insert(I); + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &op : operands) { + if (isa(op)) { + Instruction *oI = cast(op); + if (insts.contains(oI)) { + if (onStack.count(oI)) { + llvm_unreachable( + "topoSort: Cycle detected in instruction dependencies!"); + } + dfsVisit(oI); + } + } + } + + onStack.erase(I); + instsSorted.push_back(I); + }; + + for (auto *I : insts) { + if (!visited.count(I)) { + dfsVisit(I); + } + } +} + +void reverseTopoSort(const SetVector &insts, + SmallVectorImpl &instsSorted) { + topoSort(insts, instsSorted); + std::reverse(instsSorted.begin(), instsSorted.end()); +} + +void getUniqueArgs(const std::string &expr, SmallSet &args) { + std::regex argPattern("v\\d+"); + + std::sregex_iterator begin(expr.begin(), expr.end(), argPattern); + std::sregex_iterator end; + + while (begin != end) { + args.insert(begin->str()); + ++begin; + } +} + +const std::map, InstructionCost> & +getCostModel() { + static std::map, InstructionCost> + CostModel; + static bool Loaded = false; + if (!Loaded) { + std::ifstream CostFile(FPOptCostModelPath); + if (!CostFile.is_open()) { + std::string msg = + "Cost model file could not be opened: " + FPOptCostModelPath; + llvm_unreachable(msg.c_str()); + } + std::string Line; + while (std::getline(CostFile, Line)) { + std::istringstream SS(Line); + std::string OpcodeStr, PrecisionStr, CostStr; + if (!std::getline(SS, OpcodeStr, ',')) { + llvm_unreachable( + ("Unexpected line in custom cost model: " + Line).c_str()); + } + if (!std::getline(SS, PrecisionStr, ',')) { + llvm_unreachable( + ("Unexpected line in custom cost model: " + Line).c_str()); + } + if (!std::getline(SS, CostStr)) { + llvm_unreachable( + ("Unexpected line in custom cost model: " + Line).c_str()); + } + CostModel[{OpcodeStr, PrecisionStr}] = std::stoi(CostStr); + } + Loaded = true; + } + return CostModel; +} + +InstructionCost queryCostModel(const std::string &OpcodeName, + const std::string &PrecisionName) { + const auto &CostModel = getCostModel(); + auto Key = std::make_pair(OpcodeName, PrecisionName); + auto It = CostModel.find(Key); + if (It != CostModel.end()) + return It->second; + + std::string msg = "Custom cost model: entry not found for " + OpcodeName + + " @ " + PrecisionName; + llvm::errs() << msg << "\n"; + llvm_unreachable(msg.c_str()); +} + +InstructionCost getInstructionCompCost(const Instruction *I, + const TargetTransformInfo &TTI) { + if (!I->getType()->isFPOrFPVectorTy()) + return 0; + + if (!FPOptCostModelPath.empty()) { + std::string OpcodeName; + switch (I->getOpcode()) { + case Instruction::FNeg: + OpcodeName = "fneg"; + break; + case Instruction::FAdd: + OpcodeName = "fadd"; + break; + case Instruction::FSub: + OpcodeName = "fsub"; + break; + case Instruction::FMul: + OpcodeName = "fmul"; + break; + case Instruction::FDiv: + OpcodeName = "fdiv"; + break; + case Instruction::FCmp: + OpcodeName = "fcmp"; + break; + case Instruction::FPExt: + OpcodeName = "fpext"; + break; + case Instruction::FPTrunc: + OpcodeName = "fptrunc"; + break; + case Instruction::PHI: + case Instruction::Select: + case Instruction::Load: + return 0; + case Instruction::Call: { + auto *Call = cast(I); + if (auto *CalledFunc = Call->getCalledFunction()) { + if (CalledFunc->isIntrinsic()) { + switch (CalledFunc->getIntrinsicID()) { + case Intrinsic::sin: + OpcodeName = "sin"; + break; + case Intrinsic::cos: + OpcodeName = "cos"; + break; +#if LLVM_VERSION_MAJOR > 16 + case Intrinsic::tan: + OpcodeName = "tan"; + break; + case Intrinsic::asin: + OpcodeName = "asin"; + break; + case Intrinsic::acos: + OpcodeName = "acos"; + break; + case Intrinsic::atan: + OpcodeName = "atan"; + break; + case Intrinsic::atan2: + OpcodeName = "atan2"; + break; + case Intrinsic::sinh: + OpcodeName = "sinh"; + break; + case Intrinsic::cosh: + OpcodeName = "cosh"; + break; + case Intrinsic::tanh: + OpcodeName = "tanh"; +#endif + case Intrinsic::exp: + OpcodeName = "exp"; + break; + case Intrinsic::log: + OpcodeName = "log"; + break; + case Intrinsic::sqrt: + OpcodeName = "sqrt"; + break; + case Intrinsic::fabs: + OpcodeName = "fabs"; + break; + case Intrinsic::fma: + OpcodeName = "fma"; + break; + case Intrinsic::pow: + OpcodeName = "pow"; + break; + case Intrinsic::powi: + OpcodeName = "powi"; + break; + case Intrinsic::fmuladd: + OpcodeName = "fmuladd"; + break; + case Intrinsic::maxnum: + OpcodeName = "maxnum"; + break; + case Intrinsic::minnum: + OpcodeName = "minnum"; + break; + case Intrinsic::ceil: + OpcodeName = "ceil"; + break; + case Intrinsic::floor: + OpcodeName = "floor"; + break; + case Intrinsic::exp2: + OpcodeName = "exp2"; + break; + case Intrinsic::log10: + OpcodeName = "log10"; + break; + case Intrinsic::log2: + OpcodeName = "log2"; + break; + case Intrinsic::rint: + OpcodeName = "rint"; + break; + case Intrinsic::round: + OpcodeName = "round"; + break; + case Intrinsic::trunc: + OpcodeName = "trunc"; + break; + case Intrinsic::copysign: + OpcodeName = "copysign"; + break; + default: { + std::string msg = "Custom cost model: unsupported intrinsic " + + CalledFunc->getName().str(); + llvm_unreachable(msg.c_str()); + } + } + } else { + std::string FuncName = CalledFunc->getName().str(); + if (!FuncName.empty() && + (FuncName.back() == 'f' || FuncName.back() == 'l')) + FuncName.pop_back(); + + if (LibmFuncs.count(FuncName)) + OpcodeName = FuncName; + else { + std::string msg = + "Custom cost model: unknown function call " + FuncName; + llvm_unreachable(msg.c_str()); + } + } + } else { + llvm_unreachable("Custom cost model: unknown function call"); + } + break; + } + default: { + llvm::errs() << "Problematic instruction: " << *I << "\n"; + std::string msg = "Custom cost model: unexpected opcode " + + std::string(I->getOpcodeName()); + llvm_unreachable(msg.c_str()); + } + } + + std::string PrecisionName; + Type *Ty = I->getType(); + if (I->getOpcode() == Instruction::FCmp) + Ty = I->getOperand(0)->getType(); + + if (Ty->isBFloatTy()) + PrecisionName = "bf16"; + else if (Ty->isHalfTy()) + PrecisionName = "half"; + else if (Ty->isFloatTy()) + PrecisionName = "float"; + else if (Ty->isDoubleTy()) + PrecisionName = "double"; + else if (Ty->isX86_FP80Ty()) + PrecisionName = "fp80"; + else if (Ty->isFP128Ty()) + PrecisionName = "fp128"; + else { + std::string msg = "Custom cost model: unsupported precision type!"; + llvm_unreachable(msg.c_str()); + } + + // For FPExt/FPTrunc, update the opcode name to include conversion info. + if (I->getOpcode() == Instruction::FPExt || + I->getOpcode() == Instruction::FPTrunc) { + Type *SrcTy = I->getOperand(0)->getType(); + std::string SrcPrecisionName; + if (SrcTy->isBFloatTy()) + SrcPrecisionName = "bf16"; + else if (SrcTy->isHalfTy()) + SrcPrecisionName = "half"; + else if (SrcTy->isFloatTy()) + SrcPrecisionName = "float"; + else if (SrcTy->isDoubleTy()) + SrcPrecisionName = "double"; + else if (SrcTy->isX86_FP80Ty()) + SrcPrecisionName = "fp80"; + else if (SrcTy->isFP128Ty()) + SrcPrecisionName = "fp128"; + else { + std::string msg = "Custom cost model: unsupported precision type!"; + llvm_unreachable(msg.c_str()); + } + + OpcodeName += "_" + SrcPrecisionName + "_to_" + PrecisionName; + PrecisionName = SrcPrecisionName; + } + + return queryCostModel(OpcodeName, PrecisionName); + } else { + llvm::errs() << "WARNING: Custom cost model not found, using TTI cost!\n"; + return TTI.getInstructionCost(I, TargetTransformInfo::TCK_RecipThroughput); + } +} + +const std::unordered_set &getPTFuncs() { + static const std::unordered_set PTFuncs = []() { + if (FPOptCostModelPath.empty()) + return std::unordered_set{}; + std::unordered_set funcs; + for (const auto &func : LibmFuncs) { + InstructionCost costFP32 = queryCostModel(func, "float"); + InstructionCost costFP64 = queryCostModel(func, "double"); + InstructionCost costFPTrunc = + queryCostModel("fptrunc_double_to_float", "double"); + InstructionCost costFPExt = + queryCostModel("fpext_float_to_double", "float"); + InstructionCost costFPCast = costFPTrunc + costFPExt; + if (costFP32 + costFPCast < costFP64) + funcs.insert(func); + } + return funcs; + }(); + return PTFuncs; +} + +InstructionCost computeMaxCost( + BasicBlock *BB, std::unordered_map &MaxCost, + std::unordered_set &Visited, const TargetTransformInfo &TTI) { + if (MaxCost.find(BB) != MaxCost.end()) + return MaxCost[BB]; + + if (!Visited.insert(BB).second) + return 0; + + InstructionCost BBCost = 0; + for (const Instruction &I : *BB) { + if (I.isTerminator()) + continue; + + auto instCost = getInstructionCompCost(&I, TTI); + + // if (FPOptPrint) + // llvm::errs() << "Cost of " << I << " is: " << instCost << "\n"; + + BBCost += instCost; + } + + InstructionCost succCost = 0; + + if (!succ_empty(BB)) { + InstructionCost maxSuccCost = 0; + for (BasicBlock *Succ : successors(BB)) { + InstructionCost succBBCost = computeMaxCost(Succ, MaxCost, Visited, TTI); + if (succBBCost > maxSuccCost) + maxSuccCost = succBBCost; + } + // llvm::errs() << "Max succ cost: " << maxSuccCost << "\n"; + succCost = maxSuccCost; + } + + InstructionCost totalCost = BBCost + succCost; + // llvm::errs() << "BB " << BB->getName() << " cost: " << totalCost << "\n"; + MaxCost[BB] = totalCost; + Visited.erase(BB); + return totalCost; +} + +void getSampledPoints( + ArrayRef inputs, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints) { + std::default_random_engine gen; + gen.seed(FPOptRandomSeed); + std::uniform_real_distribution<> dis; + + MapVector> hypercube; + for (const auto input : inputs) { + const auto node = valueToNodeMap.at(input); + + double lower = node->getLowerBound(); + double upper = node->getUpperBound(); + + hypercube.insert({input, {lower, upper}}); + } + + // llvm::errs() << "Hypercube:\n"; + // for (const auto &entry : hypercube) { + // Value *val = entry.first; + // double lower = entry.second[0]; + // double upper = entry.second[1]; + // llvm::errs() << valueToNodeMap.at(val)->symbol << ": [" << lower << ", " + // << upper << "]\n"; + // } + + // Sample `FPOptNumSamples` points from the hypercube. Store it in + // `sampledPoints`. + sampledPoints.clear(); + sampledPoints.resize(FPOptNumSamples); + for (size_t i = 0; i < FPOptNumSamples; ++i) { + MapVector point; + for (const auto &entry : hypercube) { + Value *val = entry.first; + double lower = entry.second[0]; + double upper = entry.second[1]; + double sample = dis(gen, decltype(dis)::param_type{lower, upper}); + point.insert({val, sample}); + } + sampledPoints[i] = point; + // llvm::errs() << "Sample " << i << ":\n"; + // for (const auto &entry : point) { + // llvm::errs() << valueToNodeMap.at(entry.first)->symbol << ": " + // << entry.second << "\n"; + // } + } +} + +void getSampledPoints( + const std::string &expr, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints) { + SmallSet argStrSet; + getUniqueArgs(expr, argStrSet); + + SmallVector inputs; + for (const auto &argStr : argStrSet) { + inputs.push_back(symbolToValueMap.at(argStr)); + } + + getSampledPoints(inputs, valueToNodeMap, symbolToValueMap, sampledPoints); +} + +InstructionCost getCompCost(Function *F, const TargetTransformInfo &TTI) { + std::unordered_map MaxCost; + std::unordered_set Visited; + + BasicBlock *EntryBB = &F->getEntryBlock(); + InstructionCost TotalCost = computeMaxCost(EntryBB, MaxCost, Visited, TTI); + // llvm::errs() << "Total cost: " << TotalCost << "\n"; + return TotalCost; +} + +// Sum up the cost of `output` and its FP operands recursively up to `inputs` +// (exclusive). +InstructionCost getCompCost(const SmallVector &outputs, + const SetVector &inputs, + const TargetTransformInfo &TTI) { + assert(!outputs.empty()); + SmallPtrSet seen; + SmallVector todo; + InstructionCost cost = 0; + + todo.insert(todo.end(), outputs.begin(), outputs.end()); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + if (!seen.insert(cur).second) + continue; + + if (inputs.contains(cur)) + continue; + + if (auto *I = dyn_cast(cur)) { + // TODO: unfair to ignore branches when calculating cost + auto instCost = getInstructionCompCost(I, TTI); + + // if (FPOptPrint) + // llvm::errs() << "Cost of " << *I << " is: " << instCost << "\n"; + + // Only add the cost of the instruction if it is not an input + cost += instCost; + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &operand : operands) { + todo.push_back(operand); + } + } + } + + return cost; +} + +void collectExprInsts(Value *V, const SetVector &inputs, + SmallPtrSetImpl &exprInsts, + SmallPtrSetImpl &visited) { + if (!V || inputs.contains(V) || visited.contains(V)) { + return; + } + + visited.insert(V); + + if (auto *I = dyn_cast(V)) { + exprInsts.insert(I); + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + + for (auto &op : operands) { + collectExprInsts(op, inputs, exprInsts, visited); + } + } +} + +bool isExpansionBottleneck(Instruction *I, const Subgraph &subgraph) { + // Allow splitting at outputs if there are multiple outputs + if (subgraph.outputs.contains(I) && subgraph.outputs.size() <= 1) { + return false; + } + + // Criteria 1: Number of internal uses + unsigned internalUses = 0; + for (auto *U : I->users()) { + if (auto *UI = dyn_cast(U)) { + if (subgraph.operations.contains(UI)) { + ++internalUses; + } + } + } + + if (internalUses < FPOptMinUsesForSplit) { + return false; + } + + // Criteria 2: Number of upstream internal operations (complexity) + SetVector relevantTree; + SmallVector worklist; + worklist.push_back(I); + relevantTree.insert(I); + + while (!worklist.empty()) { + Instruction *current = worklist.pop_back_val(); + auto operands = isa(current) ? cast(current)->args() + : current->operands(); + for (auto &op : operands) { + if (subgraph.inputs.contains(op)) { + continue; + } + if (auto *OpI = dyn_cast(op)) { + if (subgraph.operations.contains(OpI) && !relevantTree.contains(OpI)) { + worklist.push_back(OpI); + relevantTree.insert(OpI); + } + } + } + } + + SetVector movedOps; + SetVector keptSubtreeRoots; + + for (auto *op : relevantTree) { + if (op == I) { + movedOps.insert(op); + continue; + } + + bool hasExternalUse = false; + for (auto *U : op->users()) { + if (auto *UI = dyn_cast(U)) { + if (!relevantTree.contains(UI)) { + hasExternalUse = true; + break; + } + } + } + + if (hasExternalUse) { + keptSubtreeRoots.insert(op); + } else { + movedOps.insert(op); + } + } + + for (auto *keptRoot : keptSubtreeRoots) { + SetVector keptUpstream; + SmallVector keptWorklist; + keptWorklist.push_back(keptRoot); + keptUpstream.insert(keptRoot); + + while (!keptWorklist.empty()) { + Instruction *current = keptWorklist.pop_back_val(); + auto operands = isa(current) ? cast(current)->args() + : current->operands(); + for (auto &op : operands) { + if (subgraph.inputs.contains(op)) { + continue; + } + if (auto *OpI = dyn_cast(op)) { + if (relevantTree.contains(OpI) && !keptUpstream.contains(OpI)) { + keptWorklist.push_back(OpI); + keptUpstream.insert(OpI); + movedOps.remove(OpI); + } + } + } + } + } + + bool isBottleneck = movedOps.size() >= FPOptMinOpsForSplit; + if (FPOptPrint && isBottleneck) { + llvm::errs() << "Bottleneck: " << *I << "\n"; + llvm::errs() << "Num of operations that would be moved: " << movedOps.size() + << " (>=" << FPOptMinOpsForSplit << ")\n"; + llvm::errs() << "Num of internal uses: " << internalUses + << " (>=" << FPOptMinUsesForSplit << ")\n"; + llvm::errs() << "Operations that would be moved:\n"; + for (auto *op : movedOps) { + llvm::errs() << "\t" << *op << "\n"; + } + } + return isBottleneck; +} + +SetVector +findReachedInputs(const SetVector &operations) { + SetVector reachedInputs; + + for (auto *I : operations) { + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &op : operands) { + if (auto *OpI = dyn_cast(op)) { + if (operations.contains(OpI)) { + continue; + } + } + + reachedInputs.insert(op); + } + } + + return reachedInputs; +} + +void splitSubgraphAtBottleneck(Subgraph ¤tSubgraph, + Instruction *bottleneck, Subgraph &newSubgraph, + Subgraph &remainingSubgraph) { + + // 1. Get all upstream ops from bottleneck (relevant tree) + SetVector relevantTree; + SmallVector worklist; + worklist.push_back(bottleneck); + relevantTree.insert(bottleneck); + + while (!worklist.empty()) { + auto current = worklist.pop_back_val(); + auto operands = isa(current) ? cast(current)->args() + : current->operands(); + for (auto &op : operands) { + if (currentSubgraph.inputs.contains(op)) { + continue; + } + if (auto *OpI = dyn_cast(op)) { + if (currentSubgraph.operations.contains(OpI) && + !relevantTree.contains(OpI)) { + worklist.push_back(OpI); + relevantTree.insert(OpI); + } + } + } + } + + // 2. Move ops that have external uses to the new subgraph + SetVector movedOps; + SetVector keptSubtreeRoots; + + for (auto op : relevantTree) { + // The bottleneck is always moved to the new subgraph + if (op == bottleneck) { + movedOps.insert(op); + continue; + } + + bool hasExternalUse = false; + for (auto U : op->users()) { + if (auto UI = dyn_cast(U)) { + if (!relevantTree.contains(UI)) { + hasExternalUse = true; + break; + } + } + } + + if (hasExternalUse) { + keptSubtreeRoots.insert(op); + } else { + movedOps.insert(op); + } + } + + // 3. Remove upstream ops of kept subtree roots from movedOps + for (auto keptRoot : keptSubtreeRoots) { + SetVector keptUpstream; + SmallVector keptWorklist; + keptWorklist.push_back(keptRoot); + keptUpstream.insert(keptRoot); + + while (!keptWorklist.empty()) { + Instruction *current = keptWorklist.pop_back_val(); + auto operands = isa(current) ? cast(current)->args() + : current->operands(); + for (auto &op : operands) { + if (currentSubgraph.inputs.contains(op)) { + continue; + } + if (auto *OpI = dyn_cast(op)) { + if (relevantTree.contains(OpI) && !keptUpstream.contains(OpI)) { + keptWorklist.push_back(OpI); + keptUpstream.insert(OpI); + movedOps.remove(OpI); + } + } + } + } + } + + // 4. Build new subgraph. + // `operations`: dependencies of `movedOps` + // `inputs`: reached inputs of `operations` + // `outputs`: the bottleneck instruction + newSubgraph.operations = movedOps; + newSubgraph.outputs.insert(bottleneck); + newSubgraph.inputs = findReachedInputs(newSubgraph.operations); + + // 5. Build remaining subgraph. + // `operations`: unmoved ops of `currentSubgraph.operations` + // `inputs`: reached inputs of `operations` + // `outputs`: the outputs of the original subgraph (excluding bottleneck if it + // was an output) + for (auto I : currentSubgraph.operations) { + if (!movedOps.contains(I)) { + remainingSubgraph.operations.insert(I); + } + } + + remainingSubgraph.outputs = currentSubgraph.outputs; + if (currentSubgraph.outputs.contains(bottleneck)) { + remainingSubgraph.outputs.remove(bottleneck); + } + + remainingSubgraph.inputs = findReachedInputs(remainingSubgraph.operations); + assert(remainingSubgraph.inputs.contains(bottleneck)); +} + +void splitSubgraphs(SmallVectorImpl &subgraphs) { + + SmallVector resultSubgraphs; + SmallVector workQueue; + + for (const auto &subgraph : subgraphs) { + workQueue.push_back(subgraph); + } + + while (!workQueue.empty()) { + Subgraph currentSubgraph = workQueue.pop_back_val(); + + if (currentSubgraph.operations.size() <= 2) { + resultSubgraphs.push_back(currentSubgraph); + continue; + } + + // Sort all operations in topological order (inputs to outputs) + SmallVector sortedOps; + topoSort(currentSubgraph.operations, sortedOps); + + bool madeSplit = false; + for (auto *I : sortedOps) { + assert(currentSubgraph.operations.contains(I)); + + if (isExpansionBottleneck(I, currentSubgraph)) { + if (FPOptPrint) { + llvm::errs() << "Bottleneck: " << *I << "\n"; + llvm::errs() << "currentSubgraph.inputs (" + << currentSubgraph.inputs.size() << "): "; + for (auto *input : currentSubgraph.inputs) { + llvm::errs() << "\t" << *input << "\n"; + } + llvm::errs() << "\n"; + llvm::errs() << "currentSubgraph.operations (" + << currentSubgraph.operations.size() << "): "; + for (auto *op : currentSubgraph.operations) { + llvm::errs() << "\t" << *op << "\n"; + } + llvm::errs() << "\n"; + llvm::errs() << "currentSubgraph.outputs (" + << currentSubgraph.outputs.size() << "): "; + for (auto *output : currentSubgraph.outputs) { + llvm::errs() << "\t" << *output << "\n"; + } + llvm::errs() << "\n"; + } + + Subgraph newSubgraph, remainingSubgraph; + splitSubgraphAtBottleneck(currentSubgraph, I, newSubgraph, + remainingSubgraph); + + if (FPOptPrint) { + llvm::errs() << "=== Splitting subgraph at bottleneck: " << *I + << " ===\n"; + + llvm::errs() << " New subgraph:\n"; + llvm::errs() << " Inputs (" << newSubgraph.inputs.size() << "):\n"; + for (auto *input : newSubgraph.inputs) { + llvm::errs() << " " << *input << "\n"; + } + llvm::errs() << " Operations (" << newSubgraph.operations.size() + << "):\n"; + for (auto *op : newSubgraph.operations) { + llvm::errs() << " " << *op << "\n"; + } + llvm::errs() << " Outputs (" << newSubgraph.outputs.size() + << "):\n"; + for (auto *output : newSubgraph.outputs) { + llvm::errs() << " " << *output << "\n"; + } + + llvm::errs() << " Remaining subgraph:\n"; + + llvm::errs() << " Inputs (" << remainingSubgraph.inputs.size() + << "):\n"; + for (auto *input : remainingSubgraph.inputs) { + llvm::errs() << " " << *input << "\n"; + } + llvm::errs() << " Operations (" + << remainingSubgraph.operations.size() << "):\n"; + for (auto *op : remainingSubgraph.operations) { + llvm::errs() << " " << *op << "\n"; + } + llvm::errs() << " Outputs (" << remainingSubgraph.outputs.size() + << "):\n"; + for (auto *output : remainingSubgraph.outputs) { + llvm::errs() << " " << *output << "\n"; + } + } + + resultSubgraphs.push_back(newSubgraph); + currentSubgraph = remainingSubgraph; + madeSplit = true; + } + } + + if (madeSplit) { + workQueue.push_back(currentSubgraph); + } else { + resultSubgraphs.push_back(currentSubgraph); + } + } + + if (FPOptPrint) { + llvm::errs() << "=== Subgraph splitting complete ===\n"; + llvm::errs() << " Original subgraphs: " << subgraphs.size() << "\n"; + llvm::errs() << " Final subgraphs after splitting: " + << resultSubgraphs.size() << "\n"; + } + + subgraphs = std::move(resultSubgraphs); +} diff --git a/enzyme/Enzyme/Poseidon/PoseidonUtils.h b/enzyme/Enzyme/Poseidon/PoseidonUtils.h new file mode 100644 index 000000000000..4eae75e0f3ab --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonUtils.h @@ -0,0 +1,83 @@ +//=- PoseidonUtils.h - Utility functions for Poseidon optimization pass ----=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares utility functions for the Poseidon floating-point +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_UTILS_H +#define ENZYME_POSEIDON_UTILS_H + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Type.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Support/InstructionCost.h" + +#include +#include +#include +#include + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptCostModelPath; +extern llvm::cl::opt FPOptNumSamples; +extern llvm::cl::opt FPOptRandomSeed; +extern llvm::cl::opt FPOptMinUsesForSplit; +extern llvm::cl::opt FPOptMinOpsForSplit; +} + +struct Subgraph; +class FPNode; + +// Utility function declarations +double getOneULP(double value); +std::string getLibmFunctionForPrecision(StringRef funcName, Type *newType); +double stringToDouble(const std::string &str); +void topoSort(const SetVector &insts, + SmallVectorImpl &instsSorted); +void reverseTopoSort(const SetVector &insts, + SmallVectorImpl &instsSorted); + +void getUniqueArgs(const std::string &expr, SmallSet &args); + +const std::map, InstructionCost> & +getCostModel(); +InstructionCost queryCostModel(const std::string &OpcodeName, + const std::string &TypeName); +InstructionCost getInstructionCompCost(const Instruction *I, + const TargetTransformInfo &TTI); + +const std::unordered_set &getPTFuncs(); + +InstructionCost computeMaxCost( + BasicBlock *BB, std::unordered_map &MaxCost, + std::unordered_set &Visited, const TargetTransformInfo &TTI); + +InstructionCost getCompCost(Function *F, const TargetTransformInfo &TTI); + +InstructionCost getCompCost(const SmallVector &outputs, + const SetVector &inputs, + const TargetTransformInfo &TTI); + +void collectExprInsts(Value *V, const SetVector &inputs, + SmallPtrSetImpl &exprInsts, + SmallPtrSetImpl &visited); + +void splitSubgraphs(SmallVectorImpl &subgraphs); + +void runPoseidonFunctionSimplify(Function &F, OptimizationLevel Level); + +#endif // ENZYME_POSEIDON_UTILS_H diff --git a/enzyme/Enzyme/Runtimes/FPProfiler/FPProfiler.cpp b/enzyme/Enzyme/Runtimes/FPProfiler/FPProfiler.cpp new file mode 100644 index 000000000000..563cb71494ab --- /dev/null +++ b/enzyme/Enzyme/Runtimes/FPProfiler/FPProfiler.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ProfileInfo { +public: + double minRes = std::numeric_limits::max(); + double maxRes = std::numeric_limits::lowest(); + std::vector minOperands; + std::vector maxOperands; + double sumValue = 0.0; + double sumSens = 0.0; + double sumGrad = 0.0; + unsigned exec = 0; + + void updateValue(double value, const double *operands, size_t numOperands) { + ++exec; + + if (minOperands.empty()) { + minOperands.resize(numOperands, std::numeric_limits::max()); + maxOperands.resize(numOperands, std::numeric_limits::lowest()); + } + for (size_t i = 0; i < numOperands; ++i) { + if (!std::isnan(operands[i])) { + minOperands[i] = std::min(minOperands[i], operands[i]); + maxOperands[i] = std::max(maxOperands[i], operands[i]); + } + } + + if (!std::isnan(value)) { + minRes = std::min(minRes, value); + maxRes = std::max(maxRes, value); + sumValue += value; + } + } + + void updateGradient(double value, double grad) { + if (!std::isnan(grad) && !std::isnan(value) && !std::isinf(grad) && + !std::isinf(value)) { + sumGrad += grad; + sumSens += std::fabs(grad * value); + } + } +}; + +class FPProfiler { +private: + std::string functionName; + std::unordered_map profileInfo; + static std::string dir; + +public: + FPProfiler(const std::string &funcName) : functionName(funcName) {} + + static void setOutputDir(const std::string &dir_) { dir = dir_; } + + std::string getOutputPath() const { + return dir + "/" + functionName + ".fpprofile"; + } + + void updateValue(size_t idx, double res, size_t numOperands, + const double *operands) { + auto it = profileInfo.try_emplace(idx).first; + it->second.updateValue(res, operands, numOperands); + } + + void updateGradient(size_t idx, double value, double grad) { + auto it = profileInfo.try_emplace(idx).first; + it->second.updateGradient(value, grad); + } + + void write() const { + std::string outputPath = getOutputPath(); + + struct stat st = {0}; + if (stat(dir.c_str(), &st) == -1) { + mkdir(dir.c_str(), 0755); + } + + std::ofstream out(outputPath); + if (!out.is_open()) { + std::cerr << "Warning: Could not open profile file: " << outputPath + << std::endl; + return; + } + + out << std::scientific + << std::setprecision(std::numeric_limits::max_digits10); + + for (const auto &pair : profileInfo) { + const auto i = pair.first; + const auto &info = pair.second; + out << i << "\n"; + out << "\tMinRes = " << info.minRes << "\n"; + out << "\tMaxRes = " << info.maxRes << "\n"; + out << "\tSumValue = " << info.sumValue << "\n"; + out << "\tSumSens = " << info.sumSens << "\n"; + out << "\tSumGrad = " << info.sumGrad << "\n"; + out << "\tExec = " << info.exec << "\n"; + out << "\tNumOperands = " << info.minOperands.size() << "\n"; + for (size_t i = 0; i < info.minOperands.size(); ++i) { + out << "\tOperand[" << i << "] = [" << info.minOperands[i] << ", " + << info.maxOperands[i] << "]\n"; + } + } + + out << "\n"; + out.close(); + } +}; + +std::string FPProfiler::dir = "./fpprofile"; + +static std::unordered_map> + profilerRegistry; +static std::mutex registryMutex; + +static void writeAllProfilesAtExit() { + std::lock_guard lock(registryMutex); + for (auto &pair : profilerRegistry) { + pair.second->write(); + } + profilerRegistry.clear(); +} + +static int RegisterFPProfileRuntime() { + const char *envPath = getenv("ENZYME_FPPROFILE_DIR"); + if (envPath) { + FPProfiler::setOutputDir(envPath); + } else { + FPProfiler::setOutputDir("./fpprofile"); + } + + std::atexit(writeAllProfilesAtExit); + + return 0; +} + +extern "C" int ENZYME_FPPROFILE_RUNTIME_VAR = RegisterFPProfileRuntime(); + +extern "C" { + +void ProfilerWrite() { + std::lock_guard lock(registryMutex); + for (auto &pair : profilerRegistry) { + pair.second->write(); + } +} + +void enzymeLogGrad(const char *funcName, size_t idx, double value, + double grad) { + if (!funcName) + return; + + std::lock_guard lock(registryMutex); + + auto it = profilerRegistry.find(funcName); + if (it == profilerRegistry.end()) { + profilerRegistry[funcName] = std::make_unique(funcName); + it = profilerRegistry.find(funcName); + } + + it->second->updateGradient(idx, value, grad); +} + +void enzymeLogValue(const char *funcName, size_t idx, double res, + size_t numOperands, double *operands) { + if (!funcName) + return; + + std::lock_guard lock(registryMutex); + + auto it = profilerRegistry.find(funcName); + if (it == profilerRegistry.end()) { + profilerRegistry[funcName] = std::make_unique(funcName); + it = profilerRegistry.find(funcName); + } + + it->second->updateValue(idx, res, numOperands, operands); +} + +} // extern "C" \ No newline at end of file diff --git a/enzyme/test/Enzyme/Poseidon/CMakeLists.txt b/enzyme/test/Enzyme/Poseidon/CMakeLists.txt new file mode 100644 index 000000000000..6a3191ef451e --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-poseidon "Running Poseidon regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-poseidon PROPERTIES FOLDER "Tests") + +#add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} + # DEPENDS ${ENZYME_TEST_DEPS} +#) diff --git a/enzyme/test/Enzyme/Poseidon/add_noopt.ll b/enzyme/test/Enzyme/Poseidon/add_noopt.ll new file mode 100644 index 000000000000..222f1ca5841c --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/add_noopt.ll @@ -0,0 +1,24 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -S | FileCheck %s -dump-input=always + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + ret double %0 +} + +define double @test_profile(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fp_optimize(double (double, double)* nonnull @tester, double %x, double %y, metadata !"enzyme_err_tol", double 1.0e-6) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fp_optimize(double (double, double)*, ...) + +; CHECK: define double @test_profile(double %x, double %y) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[fadd:.+]] = call double @tester(double %x, double %y) +; CHECK-NEXT: ret double %[[fadd]] +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Poseidon/add_prof.ll b/enzyme/test/Enzyme/Poseidon/add_prof.ll new file mode 100644 index 000000000000..15520ba2d034 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/add_prof.ll @@ -0,0 +1,38 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s -dump-input=always + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + ret double %0 +} + +define double @test_profile(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fp_optimize(double (double, double)* nonnull @tester, double %x, double %y, metadata !"enzyme_err_tol", double 1.0e-6) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fp_optimize(double (double, double)*, ...) + +; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 +; CHECK: @fpprofiled_preprocess_tester = private unnamed_addr constant [18 x i8] c"preprocess_tester\00", align 1 + +; CHECK: define internal { double, double } @instrtester(double %x, double %y, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[fadd:.+]] = fadd fast double %x, %y, !enzyme_active !0, !enzyme_fpprofile_idx !1 +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_tester, i64 0, double %[[fadd]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_tester, i64 0, double %[[fadd]], double %[[differet]]) +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double } undef, double %[[differet]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double } %[[ins1]], double %[[differet]], 1 +; CHECK-NEXT: ret { double, double } %[[ins2]] +; CHECK-NEXT: } + +; CHECK: !0 = !{} +; CHECK: !1 = !{i64 0} \ No newline at end of file diff --git a/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll b/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll new file mode 100644 index 000000000000..e68216164d8a --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll @@ -0,0 +1,187 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -fpprofile-generate -S | FileCheck %s + +; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 +; CHECK: @fpprofiled_preprocess_test_maxnum_zero = private unnamed_addr constant [28 x i8] c"preprocess_test_maxnum_zero\00", align 1 +; CHECK: @fpprofiled_preprocess_test_maxnum_zero_reversed = private unnamed_addr constant [37 x i8] c"preprocess_test_maxnum_zero_reversed\00", align 1 +; CHECK: @fpprofiled_preprocess_test_maxnum_general = private unnamed_addr constant [31 x i8] c"preprocess_test_maxnum_general\00", align 1 +; CHECK: @fpprofiled_preprocess_test_minnum_general = private unnamed_addr constant [31 x i8] c"preprocess_test_minnum_general\00", align 1 +; CHECK: @fpprofiled_preprocess_test_combined = private unnamed_addr constant [25 x i8] c"preprocess_test_combined\00", align 1 + +define double @test_maxnum_zero(double %x) { +entry: + %cmp = fcmp ogt double %x, 0.0 + %result = select i1 %cmp, double %x, double 0.0 + ret double %result +} + +; CHECK: define internal { double } @instrtest_maxnum_zero(double %x, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[maxnum:.+]] = call double @llvm.maxnum.f64(double %x, double 0.000000e+00), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double 0.000000e+00, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_maxnum_zero, i64 0, double %[[maxnum]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_maxnum_zero, i64 0, double %[[maxnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp:.+]] = fcmp fast olt double %x, 0.000000e+00 +; CHECK-NEXT: %[[sel:.+]] = select fast i1 %[[cmp]], double 0.000000e+00, double %[[differet]] +; CHECK-NEXT: %[[ins:.+]] = insertvalue { double } undef, double %[[sel]], 0 +; CHECK-NEXT: ret { double } %[[ins]] +; CHECK-NEXT: } + +define double @test_maxnum_zero_reversed(double %x) { +entry: + %cmp = fcmp olt double %x, 0.0 + %result = select i1 %cmp, double 0.0, double %x + ret double %result +} + +; CHECK: define internal { double } @instrtest_maxnum_zero_reversed(double %x, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[maxnum:.+]] = call double @llvm.maxnum.f64(double %x, double 0.000000e+00), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double 0.000000e+00, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_maxnum_zero_reversed, i64 0, double %[[maxnum]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_maxnum_zero_reversed, i64 0, double %[[maxnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp:.+]] = fcmp fast olt double %x, 0.000000e+00 +; CHECK-NEXT: %[[sel:.+]] = select fast i1 %[[cmp]], double 0.000000e+00, double %[[differet]] +; CHECK-NEXT: %[[ins:.+]] = insertvalue { double } undef, double %[[sel]], 0 +; CHECK-NEXT: ret { double } %[[ins]] +; CHECK-NEXT: } + +define double @test_maxnum_general(double %x, double %y) { +entry: + %cmp = fcmp ogt double %x, %y + %result = select i1 %cmp, double %x, double %y + ret double %result +} + +; CHECK: define internal { double, double } @instrtest_maxnum_general(double %x, double %y, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[maxnum:.+]] = call double @llvm.maxnum.f64(double %x, double %y), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_maxnum_general, i64 0, double %[[maxnum]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_maxnum_general, i64 0, double %[[maxnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp:.+]] = fcmp fast olt double %x, %y +; CHECK-NEXT: %[[sel1:.+]] = select fast i1 %[[cmp]], double 0.000000e+00, double %[[differet]] +; CHECK-NEXT: %[[cmp2:.+]] = fcmp fast olt double %x, %y +; CHECK-NEXT: %[[sel2:.+]] = select fast i1 %[[cmp2]], double %[[differet]], double 0.000000e+00 +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double } undef, double %[[sel1]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double } %[[ins1]], double %[[sel2]], 1 +; CHECK-NEXT: ret { double, double } %[[ins2]] +; CHECK-NEXT: } + +define double @test_minnum_general(double %x, double %y) { +entry: + %cmp = fcmp olt double %x, %y + %result = select i1 %cmp, double %x, double %y + ret double %result +} + +; CHECK: define internal { double, double } @instrtest_minnum_general(double %x, double %y, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[minnum:.+]] = call double @llvm.minnum.f64(double %x, double %y), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_minnum_general, i64 0, double %[[minnum]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_minnum_general, i64 0, double %[[minnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp:.+]] = fcmp fast olt double %x, %y +; CHECK-NEXT: %[[sel1:.+]] = select fast i1 %[[cmp]], double %[[differet]], double 0.000000e+00 +; CHECK-NEXT: %[[cmp2:.+]] = fcmp fast olt double %x, %y +; CHECK-NEXT: %[[sel2:.+]] = select fast i1 %[[cmp2]], double 0.000000e+00 +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double } undef, double %[[sel1]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double } %[[ins1]], double %[[sel2]], 1 +; CHECK-NEXT: ret { double, double } %[[ins2]] +; CHECK-NEXT: } + +define double @test_combined(double %x, double %y, double %z) { +entry: + %cmp1 = fcmp ogt double %x, 0.0 + %max_x = select i1 %cmp1, double %x, double 0.0 + %cmp2 = fcmp olt double %max_x, %y + %min_xy = select i1 %cmp2, double %max_x, double %y + %result = fadd fast double %min_xy, %z + ret double %result +} + +; CHECK: define internal { double, double, double } @instrtest_combined(double %x, double %y, double %z, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca0:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[alloca1:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[alloca2:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[maxnum:.+]] = call double @llvm.maxnum.f64(double %x, double 0.000000e+00), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca2]], align 8 +; CHECK-NEXT: %[[gep2:.+]] = getelementptr [2 x double], ptr %[[alloca2]], i32 0, i32 1 +; CHECK-NEXT: store double 0.000000e+00, ptr %[[gep2]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_combined, i64 0, double %[[maxnum]], i32 2, ptr %[[alloca2]]) +; CHECK-NEXT: %[[minnum:.+]] = call double @llvm.minnum.f64(double %[[maxnum]], double %y), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %[[maxnum]], ptr %[[alloca1]], align 8 +; CHECK-NEXT: %[[gep1:.+]] = getelementptr [2 x double], ptr %[[alloca1]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep1]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_combined, i64 1, double %[[minnum]], i32 2, ptr %[[alloca1]]) +; CHECK-NEXT: %result = fadd fast double %[[minnum]], %z, !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %[[minnum]], ptr %[[alloca0]], align 8 +; CHECK-NEXT: %[[gep0:.+]] = getelementptr [2 x double], ptr %[[alloca0]], i32 0, i32 1 +; CHECK-NEXT: store double %z, ptr %[[gep0]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_combined, i64 2, double %result, i32 2, ptr %[[alloca0]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_combined, i64 2, double %result, double %[[differet]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_combined, i64 1, double %[[minnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp1:.+]] = fcmp fast olt double %[[maxnum]], %y +; CHECK-NEXT: %[[sel1:.+]] = select fast i1 %[[cmp1]], double %[[differet]], double 0.000000e+00 +; CHECK-NEXT: %[[cmp2:.+]] = fcmp fast olt double %[[maxnum]], %y +; CHECK-NEXT: %[[sel2:.+]] = select fast i1 %[[cmp2]], double 0.000000e+00, double %[[differet]] +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_combined, i64 0, double %[[maxnum]], double %[[sel1]]) +; CHECK-NEXT: %[[cmp3:.+]] = fcmp fast olt double %x, 0.000000e+00 +; CHECK-NEXT: %[[sel3:.+]] = select fast i1 %[[cmp3]], double 0.000000e+00, double %[[sel1]] +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double, double } undef, double %[[sel3]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double, double } %[[ins1]], double %[[sel2]], 1 +; CHECK-NEXT: %[[ins3:.+]] = insertvalue { double, double, double } %[[ins2]], double %[[differet]], 2 +; CHECK-NEXT: ret { double, double, double } %[[ins3]] +; CHECK-NEXT: } + +define double @test_profile_maxnum_zero(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fp_optimize(double (double)* nonnull @test_maxnum_zero, double %x) + ret double %0 +} + +define double @test_profile_maxnum_zero_reversed(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fp_optimize(double (double)* nonnull @test_maxnum_zero_reversed, double %x) + ret double %0 +} + +define double @test_profile_maxnum(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fp_optimize(double (double, double)* nonnull @test_maxnum_general, double %x, double %y) + ret double %0 +} + + +define double @test_profile_minnum(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fp_optimize(double (double, double)* nonnull @test_minnum_general, double %x, double %y) + ret double %0 +} + +define double @test_profile_combined(double %x, double %y, double %z) { +entry: + %0 = tail call double (double (double, double, double)*, ...) @__enzyme_fp_optimize(double (double, double, double)* nonnull @test_combined, double %x, double %y, double %z) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fp_optimize(...) + +; CHECK: !0 = !{} +; CHECK: !1 = !{i64 0} +; CHECK: !2 = !{i64 1} +; CHECK: !3 = !{i64 2} \ No newline at end of file diff --git a/enzyme/test/Enzyme/Poseidon/fma_prof.ll b/enzyme/test/Enzyme/Poseidon/fma_prof.ll new file mode 100644 index 000000000000..0b0df87d3456 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/fma_prof.ll @@ -0,0 +1,44 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y, double %z) { +entry: + %0 = fmul fast double %x, %y + %1 = fadd fast double %0, %z + ret double %1 +} + +define double @test_profile(double %x, double %y, double %z) { +entry: + %0 = tail call double (double (double, double, double)*, ...) @__enzyme_fp_optimize(double (double, double, double)* nonnull @tester, double %x, double %y, double %z) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fp_optimize(double (double, double, double)*, ...) + +; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 +; CHECK: @fpprofiled_preprocess_tester = private unnamed_addr constant [18 x i8] c"preprocess_tester\00", align 1 + +; CHECK: define internal { double, double, double } @instrtester(double %x, double %y, double %z, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [3 x double], align 8 +; CHECK-NEXT: %[[fmuladd:.+]] = call fast double @llvm.fmuladd.f64(double %x, double %y, double %z), !enzyme_active !0, !enzyme_fpprofile_idx !1 +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep1:.+]] = getelementptr [3 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep1]], align 8 +; CHECK-NEXT: %[[gep2:.+]] = getelementptr [3 x double], ptr %[[alloca]], i32 0, i32 2 +; CHECK-NEXT: store double %z, ptr %[[gep2]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_tester, i64 0, double %[[fmuladd]], i32 3, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_tester, i64 0, double %[[fmuladd]], double %[[differet]]) +; CHECK-NEXT: %[[grad1:.+]] = fmul fast double %[[differet]], %y +; CHECK-NEXT: %[[grad2:.+]] = fmul fast double %[[differet]], %x +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double, double } undef, double %[[grad1]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double, double } %[[ins1]], double %[[grad2]], 1 +; CHECK-NEXT: %[[ins3:.+]] = insertvalue { double, double, double } %[[ins2]], double %[[differet]], 2 +; CHECK-NEXT: ret { double, double, double } %[[ins3]] +; CHECK-NEXT: } + +; CHECK: !0 = !{} +; CHECK: !1 = !{i64 0} \ No newline at end of file diff --git a/enzyme/test/Enzyme/Poseidon/structret_prof.ll b/enzyme/test/Enzyme/Poseidon/structret_prof.ll new file mode 100644 index 000000000000..8f75a789021e --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/structret_prof.ll @@ -0,0 +1,44 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s + +; Adapted from enzyme/test/Enzyme/ReverseMode/gradient-struct-ret.ll + +%struct.Gradients = type { double, double } + +; Function Attrs: noinline nounwind readnone uwtable +define dso_local double @muldd(double %x, double %y) { +entry: + %mul = fmul fast double %x, %y + ret double %mul +} + +define dso_local %struct.Gradients @test_profile(double %x, double %y) { +entry: + %call = call %struct.Gradients (i8*, ...) @__enzyme_fp_optimize(i8* bitcast (double (double, double)* @muldd to i8*), double %x, double %y) + ret %struct.Gradients %call +} + +; Function Attrs: nounwind +declare %struct.Gradients @__enzyme_fp_optimize(i8*, ...) + +; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 +; CHECK: @fpprofiled_preprocess_muldd = private unnamed_addr constant [17 x i8] c"preprocess_muldd\00", align 1 + +; CHECK: define internal { double, double } @instrmuldd(double %x, double %y, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[mul:.+]] = fmul fast double %x, %y, !enzyme_active !0, !enzyme_fpprofile_idx !1 +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_muldd, i64 0, double %[[mul]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_muldd, i64 0, double %[[mul]], double %[[differet]]) +; CHECK-NEXT: %[[grad1:.+]] = fmul fast double %[[differet]], %y +; CHECK-NEXT: %[[grad2:.+]] = fmul fast double %[[differet]], %x +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double } undef, double %[[grad1]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double } %[[ins1]], double %[[grad2]], 1 +; CHECK-NEXT: ret { double, double } %[[ins2]] +; CHECK-NEXT: } + +; CHECK: !0 = !{} +; CHECK: !1 = !{i64 0} \ No newline at end of file diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 0cc5e6f28f38..aed8f3a9946e 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -117,6 +117,7 @@ config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < config.substitutions.append(('%newLoadClangEnzyme', newPM)) config.substitutions.append(('%hasMPFR', has_mpfr)) +config.substitutions.append(('%FPProfileLib', '@ENZYME_BINARY_DIR@/Enzyme/libEnzymeFPProfile.a')) # Let the main config do the real work. cfgfile = "@ENZYME_SOURCE_DIR@/test/lit.cfg.py" From 112540723fc7ca979ac2a656ed097c6398dcbd00 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Mar 2026 18:41:21 -0500 Subject: [PATCH 02/18] tests + tblgen fix up --- enzyme/test/Enzyme/CMakeLists.txt | 4 + enzyme/test/Enzyme/ForwardError/add.ll | 2 +- enzyme/test/Enzyme/ForwardError/cos.ll | 2 +- enzyme/test/Enzyme/ForwardError/div.ll | 2 +- enzyme/test/Integration/CMakeLists.txt | 3 + .../test/Integration/Poseidon/CMakeLists.txt | 8 + enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 258 +++++++++++------- 7 files changed, 171 insertions(+), 108 deletions(-) create mode 100644 enzyme/test/Integration/Poseidon/CMakeLists.txt diff --git a/enzyme/test/Enzyme/CMakeLists.txt b/enzyme/test/Enzyme/CMakeLists.txt index 3617d3425d34..0865c61e474e 100644 --- a/enzyme/test/Enzyme/CMakeLists.txt +++ b/enzyme/test/Enzyme/CMakeLists.txt @@ -11,6 +11,10 @@ add_subdirectory(ProbProg) add_subdirectory(JLSimplify) add_subdirectory(SimpleGVN) +if(ENABLE_POSEIDON) + add_subdirectory(Poseidon) +endif() + # Run regression and unit tests add_lit_testsuite(check-enzyme "Running enzyme regression tests" ${CMAKE_CURRENT_BINARY_DIR} diff --git a/enzyme/test/Enzyme/ForwardError/add.ll b/enzyme/test/Enzyme/ForwardError/add.ll index 265759867f7a..d75b5646df65 100644 --- a/enzyme/test/Enzyme/ForwardError/add.ll +++ b/enzyme/test/Enzyme/ForwardError/add.ll @@ -18,7 +18,7 @@ entry: declare double @__enzyme_error_estimate(double (double, double)*, ...) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK: define internal double @fwderrtester(double %x, double %"x'", double %y, double %"y'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[i0:.+]] = fadd double %x, %y ; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x diff --git a/enzyme/test/Enzyme/ForwardError/cos.ll b/enzyme/test/Enzyme/ForwardError/cos.ll index 3d75b9115e2c..4036adb53916 100644 --- a/enzyme/test/Enzyme/ForwardError/cos.ll +++ b/enzyme/test/Enzyme/ForwardError/cos.ll @@ -24,7 +24,7 @@ declare double @llvm.sin.f64(double) declare double @__enzyme_error_estimate(double (double)*, ...) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK: define internal double @fwderrtester(double %x, double %"x'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[i0:.+]] = tail call fast double @llvm.cos.f64(double %x) ; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x diff --git a/enzyme/test/Enzyme/ForwardError/div.ll b/enzyme/test/Enzyme/ForwardError/div.ll index ea8d9a6ab39b..87af99012ae3 100644 --- a/enzyme/test/Enzyme/ForwardError/div.ll +++ b/enzyme/test/Enzyme/ForwardError/div.ll @@ -18,7 +18,7 @@ entry: declare double @__enzyme_error_estimate(double (double, double)*, ...) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK: define internal double @fwderrtester(double %x, double %"x'", double %y, double %"y'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[i0:.+]] = fdiv double %x, %y ; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x diff --git a/enzyme/test/Integration/CMakeLists.txt b/enzyme/test/Integration/CMakeLists.txt index f76a45cfda30..e2ddaf268a5d 100644 --- a/enzyme/test/Integration/CMakeLists.txt +++ b/enzyme/test/Integration/CMakeLists.txt @@ -6,6 +6,9 @@ add_subdirectory(ReverseMode) add_subdirectory(BatchMode) add_subdirectory(Sparse) add_subdirectory(Truncate) +if(ENABLE_POSEIDON) + add_subdirectory(Poseidon) +endif() # Run regression and unit tests add_lit_testsuite(check-enzyme-integration "Running enzyme integration tests" diff --git a/enzyme/test/Integration/Poseidon/CMakeLists.txt b/enzyme/test/Integration/Poseidon/CMakeLists.txt new file mode 100644 index 000000000000..6f5dcfa1dde7 --- /dev/null +++ b/enzyme/test/Integration/Poseidon/CMakeLists.txt @@ -0,0 +1,8 @@ +# Run regression and unit tests +add_lit_testsuite(check-poseidon-integration "Running Poseidon integration tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR} + ARGS -v +) + +set_target_properties(check-poseidon-integration PROPERTIES FOLDER "Tests") diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index a7ace51e2bed..2b3fdd3d7b5c 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -2177,6 +2177,103 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, else os << " gutils->eraseIfUnused(" << origName << ");\n"; + if (intrinsic != MLIRDerivatives) { + os << "#ifdef ENABLE_POSEIDON\n"; + os << " if (gutils->profiled && " + "Poseidonable(" + << origName << ")) {\n" + << " if (auto md = " << origName + << ".getMetadata(\"enzyme_fpprofile_idx\")) {\n" + << " size_t instIdx = " + "cast(cast(md->getOperand(0))->" + "getValue())->getZExtValue();\n" + << " Type *PtrTy = PointerType::getUnqual(" << origName + << ".getContext());\n" + << " Type *DoubleTy = Type::getDoubleTy(" << origName + << ".getContext());\n" + << " Type *SizeTy = Type::getInt64Ty(" << origName + << ".getContext());\n" + << " Type *Int32Ty = Type::getInt32Ty(" << origName + << ".getContext());\n" + << " FunctionType *LogValueFT = " + "FunctionType::get(Type::getVoidTy(" + << origName << ".getContext()),\n" + << " {PtrTy, SizeTy, DoubleTy, Int32Ty, PtrTy}, false);\n" + << " FunctionCallee logFunc = " << origName + << ".getModule()->getOrInsertFunction(\"enzymeLogValue\", " + "LogValueFT);\n" + << " IRBuilder<> BuilderZ(&" << origName << ");\n" + << " getForwardBuilder(BuilderZ);\n" + << " std::string funcName = gutils->oldFunc->getName().str();\n" + << " GlobalVariable *gv = " + "gutils->oldFunc->getParent()->getNamedGlobal(\"fpprofiled_\" + " + "funcName);\n" + << " if (!gv)\n" + << " gv = BuilderZ.CreateGlobalString(funcName, " + "\"fpprofiled_\" + funcName);\n" + << " Value *funcNamePtr = " + "BuilderZ.CreateInBoundsGEP(gv->getValueType(), gv, " + "{BuilderZ.getInt32(0), BuilderZ.getInt32(0)});\n" + << " Value *origValue = " + "BuilderZ.CreateFPExt(gutils->getNewFromOriginal(&" + << origName << "),\n" + << " Type::getDoubleTy(" << origName << ".getContext()));\n" + << " unsigned numOperands = isa(" << origName + << ") ?\n" + << " cast(" << origName + << ").arg_size() : " << origName << ".getNumOperands();\n" + << " Value *numOperandsValue = ConstantInt::get(\n" + << " Type::getInt32Ty(" << origName + << ".getContext()), numOperands);\n" + << " auto operands = isa(" << origName << ") ?\n" + << " cast(" << origName << ").args() : " << origName + << ".operands();\n" + << " ArrayType *operandArrayType = ArrayType::get(\n" + << " Type::getDoubleTy(" << origName + << ".getContext()), numOperands);\n" + << " Value *operandArrayValue = " + "IRBuilder<>(gutils->inversionAllocs).\n" + << " CreateAlloca(operandArrayType);\n" + << " for (auto operand : enumerate(operands)) {\n" + << " Value *origOp = " + "gutils->getNewFromOriginal(operand.value());\n" + << " Value *operandValue = nullptr;\n" + << " if (origOp->getType()->isFloatingPointTy()) {\n" + << " operandValue = BuilderZ.CreateFPExt(\n" + << " origOp, Type::getDoubleTy(" << origName + << ".getContext()));\n" + << " } else if (origOp->getType()->isIntegerTy()) {\n" + << " operandValue = BuilderZ.CreateSIToFP(\n" + << " origOp, Type::getDoubleTy(" << origName + << ".getContext()));\n" + << " } else {\n" + << " llvm_unreachable(\"Unsupported operand type\");\n" + << " }\n" + << " Value *ptr = BuilderZ.CreateGEP(\n" + << " operandArrayType, operandArrayValue,\n" + << " {ConstantInt::get(Type::getInt32Ty(" << origName + << ".getContext()), 0),\n" + << " ConstantInt::get(Type::getInt32Ty(" << origName + << ".getContext()), operand.index())});\n" + << " BuilderZ.CreateStore(operandValue, ptr);\n" + << " }\n" + << " Value *operandPtrValue = BuilderZ.CreateGEP(\n" + << " operandArrayType, operandArrayValue,\n" + << " {ConstantInt::get(Type::getInt32Ty(" << origName + << ".getContext()), 0),\n" + << " ConstantInt::get(Type::getInt32Ty(" << origName + << ".getContext()), 0)});\n" + << " CallInst *logCallInst = BuilderZ.CreateCall(\n" + << " logFunc, {funcNamePtr, ConstantInt::get(SizeTy, " + "instIdx), " + "origValue, numOperandsValue, operandPtrValue});\n" + << " logCallInst->setDebugLoc(gutils->getNewFromOriginal(" + << origName << ".getDebugLoc()));\n" + << " }\n" + << " }\n"; + os << "#endif\n"; + } + if (intrinsic == MLIRDerivatives) { os << " if (gutils->isConstantInstruction(op))\n"; os << " return success();\n"; @@ -2350,8 +2447,6 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } os << " }\n"; - // forward error TODO: `ForwardFromSummedReverse` behavior - // also for custom derivatives. if (intrinsic != MLIRDerivatives) { os << " case DerivativeMode::ForwardModeError: {\n"; os << " IRBuilder<> Builder2(&" << origName << ");\n"; @@ -2465,113 +2560,10 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } // Perform the max with 1 ulp - // error TODO os << " res = Builder2.CreateMaxNum(get1ULP(Builder2, " "gutils->getNewFromOriginal(&" << origName << ")), res);\n"; - os << " assert(res);\n"; - - // Insert logging function call (optional) - os << " Function *logFunc = " << origName - << ".getModule()->getFunction(\"enzymeLogError\");\n"; - os << " if (logFunc) {\n" - << " std::string moduleName = " << origName - << ".getModule()->getModuleIdentifier() ;\n" - << " std::string functionName = " << origName - << ".getFunction()->getName().str();\n" - << " std::string blockName = " << origName - << ".getParent()->getName().str();\n" - << " int funcIdx = -1, blockIdx = -1, instIdx = -1;\n" - << " auto funcIt = std::find_if(" << origName - << ".getModule()->begin(), " << origName - << ".getModule()->end(),\n" - " [&](const auto& func) { return &func == " - << origName - << ".getFunction(); });\n" - " if (funcIt != " - << origName - << ".getModule()->end()) {\n" - " funcIdx = " - "std::distance(" - << origName << ".getModule()->begin(), funcIt);\n" - << " }\n" - << " auto blockIt = std::find_if(" << origName - << ".getFunction()->begin(), " << origName - << ".getFunction()->end(),\n" - " [&](const auto& block) { return &block == " - << origName - << ".getParent(); });\n" - " if (blockIt != " - << origName - << ".getFunction()->end()) {\n" - " blockIdx = std::distance(" - << origName << ".getFunction()->begin(), blockIt);\n" - << " }\n" - << " auto instIt = std::find_if(" << origName - << ".getParent()->begin(), " << origName - << ".getParent()->end(),\n" - " [&](const auto& curr) { return &curr == &" - << origName - << "; });\n" - " if (instIt != " - << origName - << ".getParent()->end()) {\n" - " instIdx = std::distance(" - << origName << ".getParent()->begin(), instIt);\n" - << " }\n" - << " Value *origValue = " - "Builder2.CreateFPExt(gutils->getNewFromOriginal(&" - << origName << "), Type::getDoubleTy(" << origName - << ".getContext()));\n" - << " Value *errValue = Builder2.CreateFPExt(res, " - "Type::getDoubleTy(" - << origName << ".getContext()));\n" - << " std::string opcodeName = " << origName - << ".getOpcodeName();\n" - << " std::string calleeName = \"\";\n" - << " if (auto CI = dyn_cast(&" << origName - << ")) {\n" - << " if (Function *fn = CI->getCalledFunction()) {\n" - << " calleeName = fn->getName();\n" - << " } else {\n" - << " calleeName = \"\";\n" - << " }\n" - << " }\n" - << "#if LLVM_VERSION_MAJOR >= 17\n" - << " Value *moduleNameValue = " - "Builder2.CreateGlobalString(moduleName);\n" - << " Value *functionNameValue = " - "Builder2.CreateGlobalString(functionName + \" (\" +" - "std::to_string(funcIdx) + \")\");\n" - << " Value *blockNameValue = " - "Builder2.CreateGlobalString(blockName + \" (\" +" - "std::to_string(blockIdx) + \")\");\n" - << " Value *opcodeNameValue = " - "Builder2.CreateGlobalString(opcodeName + \" (\" " - "+std::to_string(instIdx) + \")\");\n" - << " Value *calleeNameValue = " - "Builder2.CreateGlobalString(calleeName);\n" - << "#else\n" - << " Value *moduleNameValue = " - "Builder2.CreateGlobalStringPtr(moduleName);\n" - << " Value *functionNameValue = " - "Builder2.CreateGlobalStringPtr(functionName + \" (\" +" - "std::to_string(funcIdx) + \")\");\n" - << " Value *blockNameValue = " - "Builder2.CreateGlobalStringPtr(blockName + \" (\" +" - "std::to_string(blockIdx) + \")\");\n" - << " Value *opcodeNameValue = " - "Builder2.CreateGlobalStringPtr(opcodeName + \" (\" " - "+std::to_string(instIdx) + \")\");\n" - << " Value *calleeNameValue = " - "Builder2.CreateGlobalStringPtr(calleeName);\n" - << "#endif\n" - << " Builder2.CreateCall(logFunc, {origValue, " - "errValue, opcodeNameValue, calleeNameValue, moduleNameValue, " - "functionNameValue, blockNameValue});\n" - << " }\n"; - os << " setDiffe(&" << origName << ", res, Builder2);\n"; os << " break;\n"; os << " }\n"; @@ -2583,6 +2575,62 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " IRBuilder<> Builder2(&" << origName << ");\n"; os << " getReverseBuilder(Builder2);\n"; os << " Value *dif = nullptr;\n"; + + // Set `dif` immediately for profiled instructions + os << "#ifdef ENABLE_POSEIDON\n" + << " if (gutils->profiled && " + "Poseidonable(" + << origName << ")) {\n" + << " if (auto md = " << origName + << ".getMetadata(\"enzyme_fpprofile_idx\")) {\n" + << " size_t instIdx = " + "cast(cast(md->getOperand(0))->" + "getValue())->getZExtValue();\n" + << " dif = diffe(&" << origName << ", Builder2);\n" + << " setDiffe(&" << origName + << ", Constant::getNullValue(gutils->getShadowType(" << origName + << ".getType())), Builder2);\n" + << " Type *PtrTy = PointerType::getUnqual(" << origName + << ".getContext());\n" + << " Type *SizeTy = Type::getInt64Ty(" << origName + << ".getContext());\n" + << " Type *DoubleTy = Type::getDoubleTy(" << origName + << ".getContext());\n" + << " FunctionType *LogGradFT = " + "FunctionType::get(Type::getVoidTy(" + << origName << ".getContext()),\n" + << " {PtrTy, SizeTy, DoubleTy, DoubleTy}, false);\n" + << " FunctionCallee logFunc = " << origName + << ".getModule()->getOrInsertFunction(\"enzymeLogGrad\", LogGradFT);\n" + << " std::string funcName = " + "gutils->oldFunc->getName().str();\n" + << " GlobalVariable *gv = " + "gutils->oldFunc->getParent()->getNamedGlobal(\"fpprofiled_\" + " + "funcName);\n" + << " if (!gv)\n" + << " gv = Builder2.CreateGlobalString(funcName, " + "\"fpprofiled_\" + funcName);\n" + << " Value *funcNamePtr = " + "Builder2.CreateInBoundsGEP(gv->getValueType(), gv, " + "{Builder2.getInt32(0), Builder2.getInt32(0)});\n" + << " Value *primalInst = " + "gutils->lookupM(gutils->getNewFromOriginal(&" + << origName << "), Builder2);\n" + << " Value *primalDouble = " + "Builder2.CreateFPExt(primalInst, " + "Type::getDoubleTy(" + << origName << ".getContext()));\n" + << " Value *gradDouble = Builder2.CreateFPExt(dif, " + "Type::getDoubleTy(" + << origName << ".getContext()));\n" + << " CallInst *logCallInst = Builder2.CreateCall(logFunc, " + "{funcNamePtr, ConstantInt::get(SizeTy, instIdx), primalDouble, " + "gradDouble});\n" + << " logCallInst->setDebugLoc(gutils->getNewFromOriginal(" + << origName << ".getDebugLoc()));\n" + << " }\n" + << " }\n" + << "#endif\n"; } else { os << "};\n"; emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); From 634e2faec65e6c44c00e17fdb08a13a84a972e40 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Mar 2026 19:45:49 -0500 Subject: [PATCH 03/18] minor fix + more tests --- enzyme/Enzyme/Poseidon/PoseidonUtils.cpp | 1 + .../preprocess_tester.fpprofile | 29 +++++++++++++++++ .../fma_opt_cache/cachedHerbieOutput_0_0.txt | 1 + .../preprocess_tester.fpprofile | 20 ++++++++++++ .../Poseidon/expm1div_opt_herbie_live.ll | 28 +++++++++++++++++ enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll | 30 ++++++++++++++++++ .../test/Enzyme/Poseidon/fma_opt_noherbie.ll | 30 ++++++++++++++++++ enzyme/test/Integration/Poseidon/prof_add.c | 31 +++++++++++++++++++ 8 files changed, 170 insertions(+) create mode 100644 enzyme/test/Enzyme/Poseidon/Inputs/expm1div_profiles/preprocess_tester.fpprofile create mode 100644 enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_cache/cachedHerbieOutput_0_0.txt create mode 100644 enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_profiles/preprocess_tester.fpprofile create mode 100644 enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll create mode 100644 enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll create mode 100644 enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll create mode 100644 enzyme/test/Integration/Poseidon/prof_add.c diff --git a/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp index 2b675985e080..343b6bcc61fb 100644 --- a/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp +++ b/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp @@ -321,6 +321,7 @@ InstructionCost getInstructionCompCost(const Instruction *I, break; case Intrinsic::tanh: OpcodeName = "tanh"; + break; #endif case Intrinsic::exp: OpcodeName = "exp"; diff --git a/enzyme/test/Enzyme/Poseidon/Inputs/expm1div_profiles/preprocess_tester.fpprofile b/enzyme/test/Enzyme/Poseidon/Inputs/expm1div_profiles/preprocess_tester.fpprofile new file mode 100644 index 000000000000..da68eeac6e4c --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/Inputs/expm1div_profiles/preprocess_tester.fpprofile @@ -0,0 +1,29 @@ +0 + MinRes = 0.3679 + MaxRes = 7.389 + SumValue = 200.0 + SumSens = 80.0 + SumGrad = 40.0 + Exec = 100 + NumOperands = 1 + Operand[0] = [-1.0, 2.0] +1 + MinRes = -0.6321 + MaxRes = 6.389 + SumValue = 150.0 + SumSens = 60.0 + SumGrad = 30.0 + Exec = 100 + NumOperands = 2 + Operand[0] = [0.3679, 7.389] + Operand[1] = [-1.0, -1.0] +2 + MinRes = 0.5 + MaxRes = 3.195 + SumValue = 100.0 + SumSens = 50.0 + SumGrad = 25.0 + Exec = 100 + NumOperands = 2 + Operand[0] = [-0.6321, 6.389] + Operand[1] = [-1.0, 2.0] diff --git a/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_cache/cachedHerbieOutput_0_0.txt b/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_cache/cachedHerbieOutput_0_0.txt new file mode 100644 index 000000000000..1a281c4f0297 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_cache/cachedHerbieOutput_0_0.txt @@ -0,0 +1 @@ +{"branch":"release","commit":"2.2","date":1747796669,"flags":[],"iterations":6,"merged-cost-accuracy":[[1.0,0.5],[0.8,0.1],[[0.9,0.3,"(+.f64 (*.f64 v0 v2) (*.f64 v1 v2))"]]],"seed":239778888,"tests":[{"name":"0","input":"(FPCore (v0 v1 v2) (* (+ v0 v1) v2))","output":"(fma.f64 v0 v2 (*.f64 v1 v2))","bits":64,"cost-accuracy":[[320.0,0.5],[480.0,0.1],[[400.0,0.3,"(+.f64 (*.f64 v0 v2) (*.f64 v1 v2))"]]]}]} diff --git a/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_profiles/preprocess_tester.fpprofile b/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_profiles/preprocess_tester.fpprofile new file mode 100644 index 000000000000..bea69a1bb708 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_profiles/preprocess_tester.fpprofile @@ -0,0 +1,20 @@ +0 + MinRes = 3.0 + MaxRes = 15.0 + SumValue = 100.0 + SumSens = 50.0 + SumGrad = 25.0 + Exec = 100 + NumOperands = 2 + Operand[0] = [1.0, 5.0] + Operand[1] = [2.0, 10.0] +1 + MinRes = 6.0 + MaxRes = 150.0 + SumValue = 500.0 + SumSens = 100.0 + SumGrad = 50.0 + Exec = 100 + NumOperands = 2 + Operand[0] = [3.0, 15.0] + Operand[1] = [0.5, 10.0] diff --git a/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll b/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll new file mode 100644 index 000000000000..438ac59ac9a6 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/expm1div_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -fpopt-print -S 2>&1 | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/expm1div_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -fpopt-print -S 2>&1 | FileCheck %s + +; Live Herbie on (exp(x) - 1) / x: verifies expm1 candidate is produced. + +declare double @llvm.exp.f64(double) + +define double @tester(double %x) { +entry: + %exp = call fast double @llvm.exp.f64(double %x) + %sub = fsub fast double %exp, 1.0 + %div = fdiv fast double %sub, %x + ret double %div +} + +define double @test_opt(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fp_optimize(double (double)* nonnull @tester, double %x, metadata !"enzyme_err_tol", double 0.5) + ret double %0 +} + +declare double @__enzyme_fp_optimize(double (double)*, ...) + +; CHECK: Candidates: +; CHECK: expm1 +; CHECK: Finished +; CHECK: define double @preprocess_tester(double %x) +; CHECK: ret double diff --git a/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll b/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll new file mode 100644 index 000000000000..669c5b8a5f4f --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll @@ -0,0 +1,30 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path=%S/Inputs/fma_opt_cache -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path=%S/Inputs/fma_opt_cache -S | FileCheck %s + +; Cached Herbie rewrite: (x + y) * z --> fma(x, z, y * z) + +define double @tester(double %x, double %y, double %z) { +entry: + %add = fadd fast double %x, %y + %mul = fmul fast double %add, %z + ret double %mul +} + +define double @test_opt(double %x, double %y, double %z) { +entry: + %0 = tail call double (double (double, double, double)*, ...) @__enzyme_fp_optimize(double (double, double, double)* nonnull @tester, double %x, double %y, double %z, metadata !"enzyme_err_tol", double 0.5) + ret double %0 +} + +declare double @__enzyme_fp_optimize(double (double, double, double)*, ...) + +; CHECK: define double @test_opt(double %x, double %y, double %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[call:.+]] = call double @preprocess_tester(double %x, double %y, double %z) +; CHECK-NEXT: ret double %[[call]] + +; CHECK: define double @preprocess_tester(double %x, double %y, double %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[mul:.+]] = fmul fast double %z, %y +; CHECK-NEXT: %[[fma:.+]] = tail call fast double @llvm.fma.f64(double %x, double %z, double %[[mul]]) +; CHECK-NEXT: ret double %[[fma]] diff --git a/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll b/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll new file mode 100644 index 000000000000..cf5a979448a0 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll @@ -0,0 +1,30 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=false -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=false -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -S | FileCheck %s + +; Opt phase with no Herbie/PT: __enzyme_fp_optimize lowers to preprocess_tester call. + +define double @tester(double %x, double %y, double %z) { +entry: + %add = fadd fast double %x, %y + %mul = fmul fast double %add, %z + ret double %mul +} + +define double @test_opt(double %x, double %y, double %z) { +entry: + %0 = tail call double (double (double, double, double)*, ...) @__enzyme_fp_optimize(double (double, double, double)* nonnull @tester, double %x, double %y, double %z, metadata !"enzyme_err_tol", double 0.5) + ret double %0 +} + +declare double @__enzyme_fp_optimize(double (double, double, double)*, ...) + +; CHECK: define double @test_opt(double %x, double %y, double %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[call:.+]] = call double @preprocess_tester(double %x, double %y, double %z) +; CHECK-NEXT: ret double %[[call]] + +; CHECK: define double @preprocess_tester(double %x, double %y, double %z) +; CHECK-NEXT: entry: +; CHECK: fadd fast double +; CHECK: fmul fast double +; CHECK: ret double diff --git a/enzyme/test/Integration/Poseidon/prof_add.c b/enzyme/test/Integration/Poseidon/prof_add.c new file mode 100644 index 000000000000..6875edbc2456 --- /dev/null +++ b/enzyme/test/Integration/Poseidon/prof_add.c @@ -0,0 +1,31 @@ +// RUN: %clang -O0 %s -S -emit-llvm -o %t.ll +// RUN: %opt %t.ll %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S -o %t.opt.ll +// RUN: %clang -O0 %t.opt.ll -c -o %t.o +// RUN: %clang++ %t.o %FPProfileLib -lstdc++ -lm -o %t.exe +// RUN: rm -rf %t.profiles && ENZYME_FPPROFILE_DIR=%t.profiles %t.exe +// RUN: cat %t.profiles/preprocess_tester.fpprofile | FileCheck %s + +#include + +extern double __enzyme_fp_optimize(void *, ...); + +double tester(double x, double y) { + return x + y; +} + +int main() { + double res = __enzyme_fp_optimize((void *)tester, 3.0, 4.0); + printf("result = %f\n", res); + + res = __enzyme_fp_optimize((void *)tester, 1.0, 2.0); + printf("result = %f\n", res); + + return 0; +} + +// CHECK: MinRes = 3.{{[0-9e+]+}} +// CHECK: MaxRes = 7.{{[0-9e+]+}} +// CHECK: Exec = 2 +// CHECK: NumOperands = 2 +// CHECK: Operand[0] = [1.{{[0-9e+]+}}, 3.{{[0-9e+]+}}] +// CHECK: Operand[1] = [2.{{[0-9e+]+}}, 4.{{[0-9e+]+}}] From 813b0b6f7ce265c8e29e4ca18b4dd45c335e1b70 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Mar 2026 19:51:14 -0500 Subject: [PATCH 04/18] format --- enzyme/Enzyme/DifferentialUseAnalysis.h | 2 +- enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index 93b1be2c5e92..6df795208249 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -481,7 +481,7 @@ static inline bool is_value_needed_in_reverse( struct Node { llvm::Value *V; bool outgoing; - Node(llvm::Value *V, bool outgoing) : V(V), outgoing(outgoing){}; + Node(llvm::Value *V, bool outgoing) : V(V), outgoing(outgoing) {}; bool operator<(const Node N) const { if (V < N.V) return true; diff --git a/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp index b15de0c2f9f2..975c730bfe38 100644 --- a/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp +++ b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp @@ -43,11 +43,16 @@ void parseProfileFile(const std::string &profilePath, std::string line; std::regex indexPattern(R"(^(\d+)$)"); - std::regex minResPattern(R"(^\s*MinRes\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); - std::regex maxResPattern(R"(^\s*MaxRes\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); - std::regex sumValuePattern(R"(^\s*SumValue\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); - std::regex sumSensPattern(R"(^\s*SumSens\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); - std::regex sumGradPattern(R"(^\s*SumGrad\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex minResPattern( + R"(^\s*MinRes\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex maxResPattern( + R"(^\s*MaxRes\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumValuePattern( + R"(^\s*SumValue\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumSensPattern( + R"(^\s*SumSens\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumGradPattern( + R"(^\s*SumGrad\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); std::regex execPattern(R"(^\s*Exec\s*=\s*(\d+))"); std::regex numOperandsPattern(R"(^\s*NumOperands\s*=\s*(\d+))"); std::regex operandPattern( From f8ba429cdffe6f7016407d05c10c49cb53fd5fb6 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Mar 2026 20:16:26 -0500 Subject: [PATCH 05/18] poseidon test guards --- enzyme/test/Enzyme/Poseidon/add_noopt.ll | 1 + enzyme/test/Enzyme/Poseidon/add_prof.ll | 1 + enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll | 1 + enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll | 1 + enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll | 1 + enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll | 1 + enzyme/test/Enzyme/Poseidon/fma_prof.ll | 1 + enzyme/test/Enzyme/Poseidon/structret_prof.ll | 1 + enzyme/test/Integration/Poseidon/prof_add.c | 1 + enzyme/test/lit.site.cfg.py.in | 5 ++++- 10 files changed, 13 insertions(+), 1 deletion(-) diff --git a/enzyme/test/Enzyme/Poseidon/add_noopt.ll b/enzyme/test/Enzyme/Poseidon/add_noopt.ll index 222f1ca5841c..54d222723338 100644 --- a/enzyme/test/Enzyme/Poseidon/add_noopt.ll +++ b/enzyme/test/Enzyme/Poseidon/add_noopt.ll @@ -1,5 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -S | FileCheck %s -dump-input=always +; REQUIRES: poseidon ; Function Attrs: noinline nounwind readnone uwtable define double @tester(double %x, double %y) { diff --git a/enzyme/test/Enzyme/Poseidon/add_prof.ll b/enzyme/test/Enzyme/Poseidon/add_prof.ll index 15520ba2d034..bd0fde39c746 100644 --- a/enzyme/test/Enzyme/Poseidon/add_prof.ll +++ b/enzyme/test/Enzyme/Poseidon/add_prof.ll @@ -1,5 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s -dump-input=always +; REQUIRES: poseidon ; Function Attrs: noinline nounwind readnone uwtable define double @tester(double %x, double %y) { diff --git a/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll b/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll index 438ac59ac9a6..4730f786c0ee 100644 --- a/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll +++ b/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll @@ -1,5 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/expm1div_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -fpopt-print -S 2>&1 | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/expm1div_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -fpopt-print -S 2>&1 | FileCheck %s +; REQUIRES: poseidon ; Live Herbie on (exp(x) - 1) / x: verifies expm1 candidate is produced. diff --git a/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll b/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll index e68216164d8a..80611080e5a2 100644 --- a/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll +++ b/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll @@ -1,5 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -fpprofile-generate -S | FileCheck %s +; REQUIRES: poseidon ; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 ; CHECK: @fpprofiled_preprocess_test_maxnum_zero = private unnamed_addr constant [28 x i8] c"preprocess_test_maxnum_zero\00", align 1 diff --git a/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll b/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll index 669c5b8a5f4f..2a17078ac389 100644 --- a/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll +++ b/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll @@ -1,5 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path=%S/Inputs/fma_opt_cache -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path=%S/Inputs/fma_opt_cache -S | FileCheck %s +; REQUIRES: poseidon ; Cached Herbie rewrite: (x + y) * z --> fma(x, z, y * z) diff --git a/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll b/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll index cf5a979448a0..d783ee9f48e0 100644 --- a/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll +++ b/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll @@ -1,5 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=false -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=false -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -S | FileCheck %s +; REQUIRES: poseidon ; Opt phase with no Herbie/PT: __enzyme_fp_optimize lowers to preprocess_tester call. diff --git a/enzyme/test/Enzyme/Poseidon/fma_prof.ll b/enzyme/test/Enzyme/Poseidon/fma_prof.ll index 0b0df87d3456..b3ead4bae6d1 100644 --- a/enzyme/test/Enzyme/Poseidon/fma_prof.ll +++ b/enzyme/test/Enzyme/Poseidon/fma_prof.ll @@ -1,5 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s +; REQUIRES: poseidon ; Function Attrs: noinline nounwind readnone uwtable define double @tester(double %x, double %y, double %z) { diff --git a/enzyme/test/Enzyme/Poseidon/structret_prof.ll b/enzyme/test/Enzyme/Poseidon/structret_prof.ll index 8f75a789021e..15b1dfbfe2ce 100644 --- a/enzyme/test/Enzyme/Poseidon/structret_prof.ll +++ b/enzyme/test/Enzyme/Poseidon/structret_prof.ll @@ -1,5 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s +; REQUIRES: poseidon ; Adapted from enzyme/test/Enzyme/ReverseMode/gradient-struct-ret.ll diff --git a/enzyme/test/Integration/Poseidon/prof_add.c b/enzyme/test/Integration/Poseidon/prof_add.c index 6875edbc2456..242201509c30 100644 --- a/enzyme/test/Integration/Poseidon/prof_add.c +++ b/enzyme/test/Integration/Poseidon/prof_add.c @@ -4,6 +4,7 @@ // RUN: %clang++ %t.o %FPProfileLib -lstdc++ -lm -o %t.exe // RUN: rm -rf %t.profiles && ENZYME_FPPROFILE_DIR=%t.profiles %t.exe // RUN: cat %t.profiles/preprocess_tester.fpprofile | FileCheck %s +// REQUIRES: poseidon #include diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index aed8f3a9946e..678a454b6090 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -117,7 +117,10 @@ config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < config.substitutions.append(('%newLoadClangEnzyme', newPM)) config.substitutions.append(('%hasMPFR', has_mpfr)) -config.substitutions.append(('%FPProfileLib', '@ENZYME_BINARY_DIR@/Enzyme/libEnzymeFPProfile.a')) + +if "@ENABLE_POSEIDON@" == "1": + config.available_features.add('poseidon') + config.substitutions.append(('%FPProfileLib', '@ENZYME_BINARY_DIR@/Enzyme/libEnzymeFPProfile.a')) # Let the main config do the real work. cfgfile = "@ENZYME_SOURCE_DIR@/test/lit.cfg.py" From ad09ce65b7f18140c09c7a233462d25e8bd50624 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Mar 2026 20:16:54 -0500 Subject: [PATCH 06/18] ci fix up --- .github/workflows/poseidon.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/poseidon.yml b/.github/workflows/poseidon.yml index d54434dbb0eb..f131eb833628 100644 --- a/.github/workflows/poseidon.yml +++ b/.github/workflows/poseidon.yml @@ -41,6 +41,6 @@ jobs: - name: make check-poseidon working-directory: build run: make -j `nproc` check-poseidon - - name: make check-enzyme-integration + - name: make check-poseidon-integration working-directory: build - run: make -j3 check-enzyme-integration \ No newline at end of file + run: make -j3 check-poseidon-integration \ No newline at end of file From ed596030b8820b1bbc180a4735097357d622a8fa Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Mar 2026 00:20:46 -0500 Subject: [PATCH 07/18] remove legacy log behavior --- enzyme/test/Integration/ForwardError/binops.c | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/enzyme/test/Integration/ForwardError/binops.c b/enzyme/test/Integration/ForwardError/binops.c index 2770060575ce..8c814b4c4bcd 100644 --- a/enzyme/test/Integration/ForwardError/binops.c +++ b/enzyme/test/Integration/ForwardError/binops.c @@ -11,24 +11,13 @@ double fabs(double); extern double __enzyme_error_estimate(void *, ...); -int errorLogCount = 0; - -void enzymeLogError(double res, double err, const char *opcodeName, - const char *calleeName, const char *moduleName, - const char *functionName, const char *blockName) { - ++errorLogCount; - printf("Res = %e, Error = %e, Op = %s, Callee = %s, Module = %s, Function = " - "%s, BasicBlock = %s\n", - res, err, opcodeName, calleeName, moduleName, functionName, blockName); -} - // An example from https://dl.acm.org/doi/10.1145/3371128 double fun(double x) { double v1 = cos(x); double v2 = 1 - v1; double v3 = x * x; double v4 = v2 / v3; - double v5 = sin(v4); // Inactive -- logger is not invoked. + double v5 = sin(v4); printf("v1 = %.18e, v2 = %.18e, v3 = %.18e, v4 = %.18e, v5 = %.18e\n", v1, v2, v3, v4, v5); @@ -42,5 +31,4 @@ int main() { printf("res = %.18e, abs error = %.18e, rel error = %.18e\n", res, error, fabs(error / res)); APPROX_EQ(error, 2.2222222222e-2, 1e-4); - TEST_EQ(errorLogCount, 4); } From 8c8e6210e5f1b76b4f8d3195f1eca7a9f11d1b07 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Mar 2026 00:45:30 -0500 Subject: [PATCH 08/18] fmt --- enzyme/Enzyme/DifferentialUseAnalysis.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index 6df795208249..93b1be2c5e92 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -481,7 +481,7 @@ static inline bool is_value_needed_in_reverse( struct Node { llvm::Value *V; bool outgoing; - Node(llvm::Value *V, bool outgoing) : V(V), outgoing(outgoing) {}; + Node(llvm::Value *V, bool outgoing) : V(V), outgoing(outgoing){}; bool operator<(const Node N) const { if (V < N.V) return true; From 6a4eb7c5437885b4c88069a2fd720e3eb9582709 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Mar 2026 00:59:33 -0500 Subject: [PATCH 09/18] ci fix --- .github/workflows/poseidon.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/poseidon.yml b/.github/workflows/poseidon.yml index f131eb833628..f2fccd604ae6 100644 --- a/.github/workflows/poseidon.yml +++ b/.github/workflows/poseidon.yml @@ -19,7 +19,7 @@ jobs: build: ["Release", "Debug"] os: [ubuntu-22.04] - timeout-minutes: 30 + timeout-minutes: 60 steps: - name: Install dependencies @@ -27,8 +27,12 @@ jobs: wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true sudo apt-get update - sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libzstd-dev libmpfr-dev racket + sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libzstd-dev libmpfr-dev sudo python3 -m pip install --upgrade pip lit + - name: Install Racket + uses: Bogdanp/setup-racket@v1.12 + with: + version: '8.15' - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build From 9d06cd04496516c8492de960794d63ba835aed5e Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Mar 2026 01:21:49 -0500 Subject: [PATCH 10/18] fix ci --- .github/workflows/poseidon.yml | 40 +++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/.github/workflows/poseidon.yml b/.github/workflows/poseidon.yml index f2fccd604ae6..bcbdb805cba4 100644 --- a/.github/workflows/poseidon.yml +++ b/.github/workflows/poseidon.yml @@ -1,50 +1,56 @@ name: Poseidon CI -on: +on: push: branches: - main pull_request: merge_group: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: build-linux: name: Poseidon CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} runs-on: ${{ matrix.os }} - + strategy: fail-fast: false matrix: llvm: ["20"] build: ["Release", "Debug"] os: [ubuntu-22.04] - + timeout-minutes: 60 - + steps: - name: Install dependencies run: | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true sudo apt-get update - sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libzstd-dev libmpfr-dev + sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libclang-${{ matrix.llvm }}-dev libzstd-dev libmpfr-dev sudo python3 -m pip install --upgrade pip lit + - name: Install Racket uses: Bogdanp/setup-racket@v1.12 with: version: '8.15' + - uses: actions/checkout@v4 - - name: mkdir - run: rm -rf build && mkdir build - - name: cmake with Poseidon - working-directory: build - run: cmake ../enzyme -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DENABLE_POSEIDON=1 -DLLVM_EXTERNAL_LIT=`which lit` -DLLVM_DIR=/usr/lib/llvm-${{ matrix.llvm }}/lib/cmake/llvm - - name: make - working-directory: build - run: make -j `nproc` - - name: make check-poseidon + + - name: Build Enzyme with Poseidon + run: | + mkdir build && cd build + cmake ../enzyme -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DENABLE_POSEIDON=1 -DLLVM_EXTERNAL_LIT=$(which lit) -DLLVM_DIR=/usr/lib/llvm-${{ matrix.llvm }}/lib/cmake/llvm + make -j $(nproc) + + - name: check-poseidon working-directory: build - run: make -j `nproc` check-poseidon - - name: make check-poseidon-integration + run: make -j $(nproc) check-poseidon + + - name: check-poseidon-integration working-directory: build - run: make -j3 check-poseidon-integration \ No newline at end of file + run: make -j3 check-poseidon-integration From f71e20a04bfdede600baa600c70236cb2b96a40b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Mar 2026 01:39:51 -0500 Subject: [PATCH 11/18] fix ci --- .github/workflows/poseidon.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/poseidon.yml b/.github/workflows/poseidon.yml index bcbdb805cba4..65809eecf9e9 100644 --- a/.github/workflows/poseidon.yml +++ b/.github/workflows/poseidon.yml @@ -31,7 +31,7 @@ jobs: wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true sudo apt-get update - sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libclang-${{ matrix.llvm }}-dev libzstd-dev libmpfr-dev + sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libzstd-dev libmpfr-dev sudo python3 -m pip install --upgrade pip lit - name: Install Racket From cc695e694b5028ce892396954a314117581fe781 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Mar 2026 15:36:28 -0500 Subject: [PATCH 12/18] save readme draft --- enzyme/Enzyme/Poseidon/README.md | 286 +++++++++++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 enzyme/Enzyme/Poseidon/README.md diff --git a/enzyme/Enzyme/Poseidon/README.md b/enzyme/Enzyme/Poseidon/README.md new file mode 100644 index 000000000000..40f1d5d5795c --- /dev/null +++ b/enzyme/Enzyme/Poseidon/README.md @@ -0,0 +1,286 @@ +# Poseidon + +Poseidon is a modular and extensible framework that fully automates floating-point optimizations for real-world applications within a production compiler. It operates as an LLVM pass inside [Enzyme](https://enzyme.mit.edu/) and uses a PGO-like two-phase compilation to automatically extract numerical context (value ranges, sensitivities) from small surrogate profiling runs. It then synthesizes algebraic rewrites via [Herbie](https://herbie.uwplse.org/), generates precision tuning candidates, and uses a dynamic programming solver to find Pareto-optimal combinations that trade off computation cost against numerical accuracy. + +Unlike prior tools that require DSL inputs or manual code edits, Poseidon works directly on compiled LLVM IR and interoperates with standard compiler analyses and optimizations (mem2reg, inlining, loop unrolling, SimplifyCFG), enabling it to extract larger FP subgraphs than would be available in source code or binaries. + +For details, please read our paper [Thinking Fast and Correct: Automated Rewriting of Numerical Code through Compiler Augmentation](https://ece.is/assets/pdf/poseidon-cgo26.pdf) (CGO 2026). + +If you use Poseidon in an academic setting, please cite: + +```bibtex +@inproceedings{poseidon, + author={Qian, Siyuan Brant and Sathia, Vimarsh and Ivanov, Ivan R. and H\"{u}ckelheim, Jan and Hovland, Paul and Moses, William S.}, + booktitle={2026 IEEE/ACM International Symposium on Code Generation and Optimization (CGO)}, + title={Thinking Fast and Correct: Automated Rewriting of Numerical Code through Compiler Augmentation}, + year={2026}, + pages={548-562}, + doi={10.1109/CGO68049.2026.11395228} +} +``` + +## Build + +See the [artifact repository](https://github.com/PRONTOLab/Poseidon) for full build instructions including Docker. In short: + +```bash +cd Enzyme && mkdir build && cd build +cmake -G Ninja ../enzyme/ \ + -DLLVM_DIR=/path/to/llvm/build/lib/cmake/llvm \ + -DLLVM_EXTERNAL_LIT=$(which lit) \ + -DCMAKE_BUILD_TYPE=Release \ + -DENABLE_POSEIDON=ON +ninja +``` + +Both phases must use identical compiler flags (e.g., `-O3`, `-ffast-math`, `-march=native`, etc.) to ensure profile indices match between compilations. + +## Two-Phase Pipeline + +### User Code + +Annotate the function to optimize with `__enzyme_fp_optimize`: + +```cpp +template +return_type __enzyme_fp_optimize(void *, T...); +int enzyme_dup; + +__attribute__((noinline)) +void my_kernel(double *out, const double *in); + +void run() { + double out[N], out_grad[N], in[M], in_grad[M]; + // Gradients seed the sensitivity analysis (set to 1.0 for uniform weighting) + std::fill(out_grad, out_grad + N, 1.0); + __enzyme_fp_optimize((void *)my_kernel, + enzyme_dup, out, out_grad, + enzyme_dup, in, in_grad); +} +``` + +### Phase 1: Profiling + +```bash +clang++ mycode.cpp $CXXFLAGS \ + -fpass-plugin=ClangEnzyme-XX.so -Xclang -load -Xclang ClangEnzyme-XX.so \ + -mllvm --fpprofile-generate \ + -L$ENZYME_BUILD/Enzyme -lEnzymeFPProfile -lm -o mycode-prof + +ENZYME_FPPROFILE_DIR=./fpprofile ./mycode-prof +``` + +The instrumented binary records per-instruction value ranges, gradient sensitivities, and execution counts into a `.fpprofile` directory. A small surrogate input (e.g., 1000 samples) is often sufficient. + +### Phase 2: Optimization + +```bash +clang++ mycode.cpp $CXXFLAGS \ + -fpass-plugin=ClangEnzyme-XX.so -Xclang -load -Xclang ClangEnzyme-XX.so \ + -mllvm --fpprofile-use=./fpprofile \ + -mllvm --fpopt-enable-herbie=1 \ + -mllvm --fpopt-enable-solver \ + -mllvm --fpopt-enable-pt \ + -mllvm --fpopt-comp-cost-budget=0 \ + -mllvm --fpopt-cache-path=./cache \ + -mllvm --fpopt-cost-model-path=cm.csv \ + -mllvm --fpopt-strict-mode \ + -lmpfr -lm -o mycode-opt +``` + +The first run invokes Herbie and the DP solver; results are cached in `--fpopt-cache-path`. Subsequent runs reuse the cache. The `--fpopt-comp-cost-budget` selects a point on the Pareto curve (0 = no-op baseline). + +## Reporting + +Add `--fpopt-report-path=` to emit structured optimization reports: + +```bash +clang++ mycode.cpp $CXXFLAGS ... \ + -mllvm --fpopt-report-path=./report \ + -mllvm --fpopt-raptor-dir=/path/to/RAPTOR/build # optional, for validate.py +``` + +### Report Contents + +| File | Description | +|------|-------------| +| `.json` | Full Pareto table: per-step source locations, symbolic expressions (original and rewritten), affected LLVM IR, Herbie accuracy bits, cost/accuracy deltas, gradient and execution count | +| `.txt` | Human-readable version of the same (suitable for feeding to an LLM for explanation) | +| `validate_config.json` | Baked-in configuration: all Pareto budgets, profile path, cache path, RAPTOR dir | +| `validate.py` | Self-contained validation script | + +Source locations (file, line, column) are populated when the source is compiled with `-g`. Without `-g`, the symbolic expressions and affected IR are still available. + +### Understanding Rewrites + +The text report lists each Pareto point with the rewrites applied. For example: + +``` +--- Pareto Point #5: Cost=-6863264, Accuracy=3.964900e-01 --- + [Rewrite] (* (sqrt (fma v6 v6 (fma v5 v5 (fma v4 v4 0)))) 0.25) + --> (*.f64 (sqrt.f64 ...) #s(literal 1/4 binary64)) + Herbie accuracy: 0.935 -> 0.999 bits + Source: dquat.cpp:87:18 + Source: dquat.cpp:56:11 + Affected IR: + %mul = fmul fast double %sqrt, 2.500000e-01 +``` + +Each rewrite shows: +- **Original and rewritten symbolic expressions** in Herbie's S-expression format +- **Source locations** (when compiled with `-g`) pointing to the C/C++ lines affected +- **Affected LLVM IR instructions** that are replaced or erased +- **Herbie accuracy improvement** in bits + +For complex rewrites, consider feeding the `.txt` report along with the source code to an LLM and asking it to explain what each rewrite does and why it improves accuracy. The symbolic expressions are self-contained and map directly to the source locations listed. + +## Validation with validate.py + +The `validate.py` script (emitted alongside the report) automates accuracy and performance measurement across Pareto-optimal variants: + +```bash +python3 report/validate.py mycode.cpp \ + --enzyme-plugin /path/to/ClangEnzyme-XX.so \ + --cxx /path/to/clang++ \ + --extra-flags "-O3 -ffast-math -march=native -fno-exceptions -lmpfr" \ + --extra-run-args "--num-tests 100 --seed 42" \ + --gold-path gold_mpfr.txt \ + --num-samples 10 \ + --num-runs 5 +``` + +### What it does + +1. **Loads a gold reference** from `--gold-path` (MPFR ground truth; see [RAPTOR section](#generating-mpfr-gold-references-with-raptor) below) +2. **Compiles the original** (unoptimized) binary, measures its runtime and accuracy against gold +3. **Uniformly samples** N budgets from the full Pareto table (87 points for dquat with `-ffast-math`) +4. **For each budget:** recompiles with Poseidon at that budget using the cached Herbie results + DP table, runs the resulting binary, captures output +5. **Computes** geomean and max relative error against gold for each variant +6. **Measures** median runtime over multiple runs +7. **Prints a summary table** showing the original program as baseline, then each Pareto variant with estimated vs. validated accuracy and speedup + +### Example output + +``` +======================================================================== + Budget Est.AccCost GeomErr MaxErr Runtime Speedup +------------------------------------------------------------------------ + ORIGINAL -- 6.3760e-11 1.1794e+03 0.000015 1.00x +------------------------------------------------------------------------ + -7064000 3.5985e+01 6.3990e-03 3.4521e+02 0.000013 1.10x + -6863264 3.9649e-01 5.9403e-09 1.1794e+03 0.000014 1.08x + 1096856 -1.1591e-16 6.8253e-11 1.1794e+03 0.000013 1.13x +======================================================================== +``` + +Here `ORIGINAL` is the unoptimized double-precision program. Negative budgets allow the solver to trade accuracy for speed; positive budgets allow extra computation to improve accuracy. The `MaxErr: 1179` in the original comes from catastrophic cancellation (`1 - cos(theta)` near zero), which Herbie's rewrites at positive budgets fix. + +## Generating MPFR Gold References with RAPTOR + +For meaningful accuracy validation, the gold reference should be computed at high precision (MPFR-2048 bits) rather than using the original double-precision output. [RAPTOR](https://github.com/RIKEN-RCCS/RAPTOR) provides this capability by running every floating-point operation through MPFR with garbage collection to avoid OOM on large applications. + +### Building RAPTOR + +```bash +git clone https://github.com/RIKEN-RCCS/RAPTOR +cd RAPTOR && mkdir build && cd build +cmake .. -DLLVM_DIR=/path/to/llvm -DCMAKE_BUILD_TYPE=Release +make -j +``` + +RAPTOR requires LLVM >= 15 and MPFR. For LLVM >= 23, two patches are needed in `pass/Raptor.cpp`: +- `#include "llvm/Passes/PassPlugin.h"` -> `#include "llvm/Plugins/PassPlugin.h"` +- Remove the third argument from `createFunctionToLoopPassAdaptor` + +### Source preparation + +Add a `#ifdef POSEIDON_GOLD` guard that wraps the target function call with RAPTOR's MPFR truncation: + +```cpp +#ifdef POSEIDON_GOLD +template +__attribute__((nothrow)) fty *__raptor_truncate_mem_func(fty *, int, int, int, int); +__attribute__((nothrow)) extern double __raptor_truncate_mem_value(...); +__attribute__((nothrow)) extern double __raptor_expand_mem_value(...); +extern "C" double raptor_fprt_gc_mark_seen(double); +extern "C" void raptor_fprt_gc_doit(); +// RAPTOR API: (func_ptr, from_bits, type_selector, exponent, mantissa) +// type_selector: 0=IEEE, 1=MPFR +#define RAPTOR_FROM 64 +#define RAPTOR_TYPE 1 +#define RAPTOR_TO_E 64 +#define RAPTOR_TO_M 2048 +#define TRUNC_SELF(X) X = __raptor_truncate_mem_value(X, RAPTOR_FROM, RAPTOR_TYPE, RAPTOR_TO_E, RAPTOR_TO_M) +#define EXPAND_SELF(X) X = __raptor_expand_mem_value(X, RAPTOR_FROM, RAPTOR_TYPE, RAPTOR_TO_E, RAPTOR_TO_M) +#else +int enzyme_dup; +template +return_type __enzyme_fp_optimize(void *, T...); +#endif +``` + +In the main loop: + +```cpp +#ifdef POSEIDON_GOLD + for (int i = 0; i < num_inputs; i++) TRUNC_SELF(inputs[i]); + __raptor_truncate_mem_func(my_kernel, RAPTOR_FROM, RAPTOR_TYPE, + RAPTOR_TO_E, RAPTOR_TO_M)(outputs, inputs); + for (int i = 0; i < num_outputs; i++) EXPAND_SELF(outputs[i]); + for (int i = 0; i < num_inputs; i++) EXPAND_SELF(inputs[i]); + raptor_fprt_gc_doit(); +#else + __enzyme_fp_optimize((void *)my_kernel, + enzyme_dup, outputs, outputs_grad, enzyme_dup, inputs, inputs_grad); +#endif +``` + +### Compiling and running + +```bash +clang++ mycode.cpp -O0 -fno-exceptions -DPOSEIDON_GOLD \ + -fpass-plugin=/path/to/RAPTOR/build/pass/ClangRaptor-XX.so \ + -Xclang -load -Xclang /path/to/RAPTOR/build/pass/ClangRaptor-XX.so \ + -L/path/to/RAPTOR/build/runtime -lRaptor-RT-XX \ + -lmpfr -lm -o gold.exe + +./gold.exe --output-path gold_mpfr.txt +``` + +Then pass `--gold-path gold_mpfr.txt` to `validate.py`. + +## Command-Line Reference + +### Profiling +| Flag | Default | Description | +|------|---------|-------------| +| `--fpprofile-generate` | false | Instrument for FP profiling | +| `--fpprofile-use=` | | Profile directory for optimization | + +### Herbie +| Flag | Default | Description | +|------|---------|-------------| +| `--fpopt-enable-herbie` | true | Use Herbie for algebraic rewrites | +| `--fpopt-cache-path` | "cache" | Cache directory for Herbie results and DP table | +| `--herbie-num-threads` | 8 | Herbie worker threads | +| `--herbie-timeout` | 120 | Per-expression Herbie timeout (seconds) | + +### Solver +| Flag | Default | Description | +|------|---------|-------------| +| `--fpopt-enable-solver` | true | Enable DP solver | +| `--fpopt-enable-pt` | true | Enable precision tuning candidates | +| `--fpopt-comp-cost-budget` | 0 | Computation cost budget (0 = no-op) | +| `--fpopt-cost-model-path` | | Hardware-specific cost model CSV | +| `--fpopt-strict-mode` | false | Discard candidates that produce NaN/inf | +| `--fpopt-num-samples` | 1024 | Number of MPFR evaluation samples | +| `--fpopt-early-prune` | true | Prune dominated candidates during DP | +| `--fpopt-loose-coverage` | false | Allow unexecuted instructions (suppress coverage errors) | + +### Reporting +| Flag | Default | Description | +|------|---------|-------------| +| `--fpopt-report-path` | | Output directory for JSON/text reports + validate.py | +| `--fpopt-raptor-dir` | | RAPTOR build directory (baked into validate.py config) | +| `--fpopt-print` | false | Print debug info to stderr | +| `--fpopt-show-table` | false | Print full DP table to stderr | From 7292ba12ed2f68285c27c342a4a94a293c475ca0 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 17:06:41 -0500 Subject: [PATCH 13/18] generate user friendly reports --- enzyme/Enzyme/Poseidon/Poseidon.cpp | 79 ++- enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp | 553 +++++++++++++++++++++ enzyme/Enzyme/Poseidon/PoseidonSolvers.h | 1 + enzyme/Enzyme/Poseidon/PoseidonTypes.cpp | 7 +- enzyme/Enzyme/Poseidon/README.md | 3 +- 5 files changed, 637 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/Poseidon/Poseidon.cpp b/enzyme/Enzyme/Poseidon/Poseidon.cpp index 69be4a93c74a..2def5da8bf82 100644 --- a/enzyme/Enzyme/Poseidon/Poseidon.cpp +++ b/enzyme/Enzyme/Poseidon/Poseidon.cpp @@ -1166,7 +1166,84 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI, double errorTol) { } } - if (!FPOptEnableSolver) { + if (!FPOptApplyRewrites.empty()) { + // User-selected rewrites: parse IDs and apply the specified candidates. + // IDs are R{coIdx}_{candIdx} for rewrites, PT{csIdx}_{candIdx} for PT. + SmallVector ids; + StringRef(FPOptApplyRewrites).split(ids, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + + // Track which CO/CS has been applied to enforce at-most-one constraint + SmallDenseSet appliedCOs, appliedCSs; + + for (auto id : ids) { + id = id.trim(); + if (id.starts_with("R")) { + // Parse R{coIdx}_{candIdx} + auto rest = id.drop_front(1); + auto [coStr, candStr] = rest.split('_'); + size_t coIdx, candIdx; + if (coStr.getAsInteger(10, coIdx) || candStr.getAsInteger(10, candIdx)) { + llvm::errs() << "FPOpt: Invalid rewrite ID '" << id << "'\n"; + continue; + } + if (coIdx >= COs.size()) { + llvm::errs() << "FPOpt: CO index " << coIdx << " out of range (" + << COs.size() << " COs)\n"; + continue; + } + if (candIdx >= COs[coIdx].candidates.size()) { + llvm::errs() << "FPOpt: Candidate index " << candIdx + << " out of range for CO " << coIdx << " (" + << COs[coIdx].candidates.size() << " candidates)\n"; + continue; + } + if (!appliedCOs.insert(coIdx).second) { + llvm::errs() << "FPOpt: CO " << coIdx + << " already has a rewrite applied. Skipping " << id + << "\n"; + continue; + } + llvm::errs() << "FPOpt: Applying " << id << ": " << COs[coIdx].expr + << " -> " << COs[coIdx].candidates[candIdx].expr << "\n"; + COs[coIdx].apply(candIdx, valueToNodeMap, symbolToValueMap); + changed = true; + + } else if (id.starts_with("PT")) { + auto rest = id.drop_front(2); + auto [csStr, candStr] = rest.split('_'); + size_t csIdx, candIdx; + if (csStr.getAsInteger(10, csIdx) || candStr.getAsInteger(10, candIdx)) { + llvm::errs() << "FPOpt: Invalid PT ID '" << id << "'\n"; + continue; + } + if (csIdx >= CSs.size()) { + llvm::errs() << "FPOpt: CS index " << csIdx << " out of range (" + << CSs.size() << " CSs)\n"; + continue; + } + if (candIdx >= CSs[csIdx].candidates.size()) { + llvm::errs() << "FPOpt: Candidate index " << candIdx + << " out of range for CS " << csIdx << " (" + << CSs[csIdx].candidates.size() << " candidates)\n"; + continue; + } + if (!appliedCSs.insert(csIdx).second) { + llvm::errs() << "FPOpt: CS " << csIdx + << " already has a PT applied. Skipping " << id << "\n"; + continue; + } + llvm::errs() << "FPOpt: Applying " << id << ": " + << CSs[csIdx].candidates[candIdx].desc << "\n"; + CSs[csIdx].apply(candIdx); + changed = true; + + } else { + llvm::errs() << "FPOpt: Unknown ID prefix in '" << id + << "' (expected R or PT)\n"; + } + } + } else if (!FPOptEnableSolver) { if (FPOptEnableHerbie) { for (auto &CO : COs) { CO.apply(0, valueToNodeMap, symbolToValueMap); diff --git a/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp b/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp index fedf24857a0e..79d7cef7d484 100644 --- a/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp +++ b/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp @@ -63,6 +63,19 @@ cl::opt FPOptRefineDPTable( cl::desc("After initial DP build, materialize each solution in a cloned\n" "function, run O3, and recompute cost deltas. Only applies when\n" "generating a new cache, not when loading from cache.")); +cl::opt FPOptReportPath( + "fpopt-report-path", cl::init(""), cl::Hidden, + cl::desc("Directory to write Poseidon optimization reports.\n" + "Emits .json (Pareto table with source locations),\n" + ".txt (human-readable), _rewrites.json\n" + "(curated per-rewrite analysis with IDs), and\n" + "validate_config.json + validate.py (validation script).")); +cl::opt FPOptApplyRewrites( + "fpopt-apply-rewrites", cl::init(""), cl::Hidden, + cl::desc("Comma-separated rewrite IDs to apply (e.g.\n" + "R0_1,R3_0,PT1_2). IDs are from the _rewrites.json\n" + "report. Bypasses the DP solver. At most one candidate\n" + "per expression (R) or subgraph (PT).")); } #if LLVM_VERSION_MAJOR >= 21 @@ -71,6 +84,540 @@ cl::opt FPOptRefineDPTable( #define GET_INSTRUCTION_COST(cost) (cost.getValue().value()) #endif +static json::Value jsonFloat(double v) { + if (std::isfinite(v)) + return json::Value(v); + return json::Value(nullptr); +} + +static json::Object getSourceLocationJSON(Value *V) { + json::Object loc; + if (auto *I = dyn_cast(V)) { + if (const auto &DL = I->getDebugLoc()) { + loc["file"] = DL->getFilename().str(); + loc["line"] = static_cast(DL.getLine()); + loc["col"] = static_cast(DL.getCol()); + } + } + return loc; +} + +static std::string getSourceLocationStr(Value *V) { + if (auto *I = dyn_cast(V)) { + if (const auto &DL = I->getDebugLoc()) { + return (DL->getFilename() + ":" + Twine(DL.getLine()) + ":" + + Twine(DL.getCol())) + .str(); + } + } + return ""; +} + +static std::string getValueStr(Value *V) { + std::string str; + raw_string_ostream OS(str); + V->print(OS); + return str; +} + +// Collect unique source locations from a set of LLVM instructions. +static json::Array getSourceLocationsJSON(ArrayRef insts) { + json::Array locs; + SmallSet seen; + for (auto *I : insts) { + auto locStr = getSourceLocationStr(I); + if (!locStr.empty() && seen.insert(locStr).second) { + locs.push_back(getSourceLocationJSON(I)); + } + } + return locs; +} + +static json::Object +buildStepJSON(const SolutionStep &step, + const std::map &costToAccuracyMap) { + json::Object stepObj; + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + stepObj["type"] = "rewrite"; + stepObj["original_expr"] = item->expr; + auto &cand = item->candidates[step.candidateIndex]; + stepObj["rewritten_expr"] = cand.expr; + stepObj["herbie_cost"] = jsonFloat(cand.herbieCost); + stepObj["herbie_accuracy"] = jsonFloat(cand.herbieAccuracy); + stepObj["initial_herbie_cost"] = jsonFloat(item->initialHerbieCost); + stepObj["initial_herbie_accuracy"] = + jsonFloat(item->initialHerbieAccuracy); + stepObj["computation_cost_delta"] = + GET_INSTRUCTION_COST(item->getCompCostDelta(step.candidateIndex)); + stepObj["accuracy_cost_delta"] = + jsonFloat(item->getAccCostDelta(step.candidateIndex)); + stepObj["gradient"] = jsonFloat(item->grad); + stepObj["executions"] = static_cast(item->executions); + + // Source locations from the output instruction and erasable insts + SmallVector insts; + if (auto *I = dyn_cast(item->oldOutput)) + insts.push_back(I); + for (auto *I : item->erasableInsts) + insts.push_back(I); + stepObj["source_locations"] = getSourceLocationsJSON(insts); + + // Affected LLVM IR instructions + json::Array affectedIR; + if (auto *I = dyn_cast(item->oldOutput)) + affectedIR.push_back(getValueStr(I)); + for (auto *I : item->erasableInsts) { + if (I != item->oldOutput) + affectedIR.push_back(getValueStr(I)); + } + stepObj["affected_instructions"] = std::move(affectedIR); + } else if constexpr (std::is_same_v) { + auto &cand = item->candidates[step.candidateIndex]; + stepObj["type"] = "precision_change"; + stepObj["description"] = cand.desc; + stepObj["candidate_index"] = + static_cast(step.candidateIndex); + + json::Array changes; + for (const auto &change : cand.changes) { + json::Object changeObj; + changeObj["from"] = getPrecisionChangeTypeString(change.oldType); + changeObj["to"] = getPrecisionChangeTypeString(change.newType); + changeObj["num_operations"] = + static_cast(change.nodes.size()); + + SmallVector insts; + json::Array nodeIR; + for (auto *node : change.nodes) { + if (auto *I = dyn_cast(node->value)) { + insts.push_back(I); + nodeIR.push_back(getValueStr(I)); + } + } + changeObj["source_locations"] = getSourceLocationsJSON(insts); + changeObj["affected_instructions"] = std::move(nodeIR); + changes.push_back(std::move(changeObj)); + } + stepObj["changes"] = std::move(changes); + } + }, + step.item); + return stepObj; +} + +static void writeTextReportStep(raw_ostream &OS, const SolutionStep &step, + unsigned indent) { + std::string pad(indent, ' '); + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + auto &cand = item->candidates[step.candidateIndex]; + OS << pad << "[Rewrite] " << item->expr << " --> " << cand.expr + << "\n"; + if (!std::isnan(item->initialHerbieAccuracy) && + !std::isnan(cand.herbieAccuracy)) + OS << pad << " Herbie accuracy: " << item->initialHerbieAccuracy + << " -> " << cand.herbieAccuracy << " bits\n"; + OS << pad << " Gradient: " << item->grad + << ", Executions: " << item->executions << "\n"; + + // Source locations + SmallSet seen; + auto printLoc = [&](Instruction *I) { + auto loc = getSourceLocationStr(I); + if (!loc.empty() && seen.insert(loc).second) + OS << pad << " Source: " << loc << "\n"; + }; + if (auto *I = dyn_cast(item->oldOutput)) + printLoc(I); + for (auto *I : item->erasableInsts) + printLoc(I); + + // Affected IR + OS << pad << " Affected IR:\n"; + if (auto *I = dyn_cast(item->oldOutput)) + OS << pad << " " << *I << "\n"; + for (auto *I : item->erasableInsts) { + if (I != item->oldOutput) + OS << pad << " " << *I << "\n"; + } + } else if constexpr (std::is_same_v) { + auto &cand = item->candidates[step.candidateIndex]; + OS << pad << "[Precision] " << cand.desc << " (#" + << step.candidateIndex << ")\n"; + for (const auto &change : cand.changes) { + OS << pad << " " << getPrecisionChangeTypeString(change.oldType) + << " -> " << getPrecisionChangeTypeString(change.newType) + << " for " << change.nodes.size() << " operations\n"; + SmallSet seen; + for (auto *node : change.nodes) { + if (auto *I = dyn_cast(node->value)) { + auto loc = getSourceLocationStr(I); + if (!loc.empty() && seen.insert(loc).second) + OS << pad << " Source: " << loc << "\n"; + } + } + } + } + }, + step.item); +} + +static void +emitPoseidonReport(StringRef funcName, + const std::map &costToAccuracyMap, + const std::map> + &costToSolutionMap, + SmallVector &COs, + SmallVector &CSs) { + if (FPOptReportPath.empty()) + return; + + std::error_code EC; + if (!llvm::sys::fs::exists(FPOptReportPath)) { + EC = llvm::sys::fs::create_directories(FPOptReportPath); + if (EC) { + llvm::errs() << "Error creating report directory: " << EC.message() + << "\n"; + return; + } + } + + // Build JSON report + json::Object report; + report["function"] = funcName.str(); + report["num_pareto_points"] = static_cast(costToAccuracyMap.size()); + + if (!costToAccuracyMap.empty()) { + report["cost_range_min"] = + GET_INSTRUCTION_COST(costToAccuracyMap.begin()->first); + report["cost_range_max"] = + GET_INSTRUCTION_COST(costToAccuracyMap.rbegin()->first); + } + + // Candidate summary + json::Array coSummary; + for (const auto &CO : COs) { + json::Object co; + co["original_expr"] = CO.expr; + co["num_candidates"] = static_cast(CO.candidates.size()); + co["gradient"] = jsonFloat(CO.grad); + co["executions"] = static_cast(CO.executions); + co["initial_accuracy_cost"] = jsonFloat(CO.initialAccCost); + co["initial_computation_cost"] = GET_INSTRUCTION_COST(CO.initialCompCost); + if (auto *I = dyn_cast(CO.oldOutput)) { + auto loc = getSourceLocationJSON(I); + if (!loc.empty()) + co["source_location"] = std::move(loc); + } + coSummary.push_back(std::move(co)); + } + report["candidate_outputs"] = std::move(coSummary); + + json::Array csSummary; + for (const auto &CS : CSs) { + json::Object cs; + cs["num_candidates"] = static_cast(CS.candidates.size()); + cs["initial_accuracy_cost"] = jsonFloat(CS.initialAccCost); + cs["initial_computation_cost"] = GET_INSTRUCTION_COST(CS.initialCompCost); + csSummary.push_back(std::move(cs)); + } + report["candidate_subgraphs"] = std::move(csSummary); + + // Pareto table + json::Array paretoPoints; + for (const auto &pair : costToAccuracyMap) { + json::Object point; + point["computation_cost"] = GET_INSTRUCTION_COST(pair.first); + point["accuracy_cost"] = pair.second; + + auto it = costToSolutionMap.find(pair.first); + if (it != costToSolutionMap.end()) { + json::Array steps; + for (const auto &step : it->second) { + steps.push_back(buildStepJSON(step, costToAccuracyMap)); + } + point["steps"] = std::move(steps); + } + paretoPoints.push_back(std::move(point)); + } + report["pareto_points"] = std::move(paretoPoints); + + // Write JSON + std::string jsonFile = + (Twine(FPOptReportPath) + "/" + funcName + ".json").str(); + raw_fd_ostream jsonOut(jsonFile, EC, sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing JSON report: " << EC.message() << "\n"; + } else { + jsonOut << formatv("{0:2}", json::Value(std::move(report))) << "\n"; + llvm::errs() << "Poseidon JSON report written to " << jsonFile << "\n"; + } + + // Write human-readable text report + std::string textFile = + (Twine(FPOptReportPath) + "/" + funcName + ".txt").str(); + raw_fd_ostream textOut(textFile, EC, sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing text report: " << EC.message() << "\n"; + return; + } + + textOut << "=== Poseidon Report: " << funcName << " ===\n"; + textOut << "Pareto table: " << costToAccuracyMap.size() << " points"; + if (!costToAccuracyMap.empty()) { + textOut << ", cost range [" + << GET_INSTRUCTION_COST(costToAccuracyMap.begin()->first) << ", " + << GET_INSTRUCTION_COST(costToAccuracyMap.rbegin()->first) << "]"; + } + textOut << "\n"; + textOut << "Candidate outputs: " << COs.size() + << ", Candidate subgraphs: " << CSs.size() << "\n\n"; + + // Per-CO summary + for (size_t i = 0; i < COs.size(); ++i) { + textOut << "Expression #" << i << ": " << COs[i].expr << "\n"; + textOut << " Gradient: " << COs[i].grad + << ", Executions: " << COs[i].executions << "\n"; + if (auto *I = dyn_cast(COs[i].oldOutput)) { + auto loc = getSourceLocationStr(I); + if (!loc.empty()) + textOut << " Source: " << loc << "\n"; + } + textOut << " Candidates: " << COs[i].candidates.size() << "\n"; + for (size_t j = 0; j < COs[i].candidates.size(); ++j) { + auto &cand = COs[i].candidates[j]; + textOut << " [" << j << "] " << cand.expr; + if (!std::isnan(cand.herbieAccuracy)) + textOut << " (accuracy: " << cand.herbieAccuracy << " bits)"; + textOut << "\n"; + } + } + textOut << "\n"; + + // Pareto points + unsigned pointIdx = 0; + for (const auto &pair : costToAccuracyMap) { + textOut << "--- Pareto Point #" << pointIdx++ + << ": Cost=" << GET_INSTRUCTION_COST(pair.first) + << ", Accuracy=" << pair.second << " ---\n"; + auto it = costToSolutionMap.find(pair.first); + if (it != costToSolutionMap.end() && !it->second.empty()) { + for (const auto &step : it->second) { + writeTextReportStep(textOut, step, 2); + } + } else { + textOut << " (no changes)\n"; + } + textOut << "\n"; + } + + llvm::errs() << "Poseidon text report written to " << textFile << "\n"; + + // Emit validation config + script + { + std::string configFile = + (Twine(FPOptReportPath) + "/validate_config.json").str(); + { + json::Object cfg; + cfg["function"] = funcName.str(); + cfg["profile_path"] = FPProfileUse.getValue(); + cfg["cache_path"] = FPOptCachePath.getValue(); + json::Array budgetArr; + for (const auto &pair : costToAccuracyMap) + budgetArr.push_back(GET_INSTRUCTION_COST(pair.first)); + cfg["budgets"] = std::move(budgetArr); + json::Array accArr; + for (const auto &pair : costToAccuracyMap) + accArr.push_back(pair.second); + cfg["estimated_accuracy_costs"] = std::move(accArr); + + raw_fd_ostream cfgOut(configFile, EC, sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing validate config: " << EC.message() + << "\n"; + } else { + cfgOut << formatv("{0:2}", json::Value(std::move(cfg))) << "\n"; + } + } + } + + // --- Curated per-rewrite analysis --- + // Enumerate every individual rewrite candidate, categorize by its marginal + // impact, and rank tradeoffs by efficiency (benefit per unit cost). + { + json::Array rewrites; + for (size_t coIdx = 0; coIdx < COs.size(); ++coIdx) { + auto &CO = COs[coIdx]; + for (size_t ci = 0; ci < CO.candidates.size(); ++ci) { + auto &cand = CO.candidates[ci]; + double accDelta = CO.getAccCostDelta(ci); + auto compDelta = CO.getCompCostDelta(ci); + int64_t compDeltaVal = GET_INSTRUCTION_COST(compDelta); + + // Skip candidates with no usable data + if (std::isnan(accDelta)) + continue; + + // Categorize + // comp < 0 means faster, acc < 0 means more accurate + std::string category; + double efficiency = 0.0; + if (compDeltaVal <= 0 && accDelta <= 0) { + category = "free_win"; + // Rank by combined magnitude of improvement + efficiency = std::abs(accDelta) + std::abs((double)compDeltaVal); + } else if (compDeltaVal <= 0 && accDelta > 0) { + category = "speed_for_accuracy"; + // Efficiency: speed gained per unit of accuracy lost + efficiency = (accDelta > 1e-30) + ? std::abs((double)compDeltaVal) / accDelta + : std::abs((double)compDeltaVal); + } else if (compDeltaVal > 0 && accDelta <= 0) { + category = "accuracy_for_speed"; + // Efficiency: accuracy gained per unit of speed lost + efficiency = (compDeltaVal > 0) + ? std::abs(accDelta) / (double)compDeltaVal + : std::abs(accDelta); + } else { + continue; // Both worse — skip + } + + json::Object rw; + std::string id = "R" + std::to_string(coIdx) + "_" + std::to_string(ci); + rw["id"] = id; + rw["original_expr"] = CO.expr; + rw["rewritten_expr"] = cand.expr; + rw["category"] = category; + rw["efficiency"] = jsonFloat(efficiency); + rw["computation_cost_delta"] = compDeltaVal; + rw["accuracy_cost_delta"] = jsonFloat(accDelta); + rw["gradient"] = jsonFloat(CO.grad); + rw["executions"] = static_cast(CO.executions); + rw["herbie_accuracy"] = jsonFloat(cand.herbieAccuracy); + rw["initial_herbie_accuracy"] = jsonFloat(CO.initialHerbieAccuracy); + + // Source location + if (auto *I = dyn_cast(CO.oldOutput)) { + auto loc = getSourceLocationJSON(I); + if (!loc.empty()) + rw["source_location"] = std::move(loc); + } + + rewrites.push_back(std::move(rw)); + } + } + + // Also include precision tuning candidates + for (size_t csIdx = 0; csIdx < CSs.size(); ++csIdx) { + auto &CS = CSs[csIdx]; + for (size_t ci = 0; ci < CS.candidates.size(); ++ci) { + auto &pt = CS.candidates[ci]; + double accDelta = CS.getAccCostDelta(ci); + auto compDelta = CS.getCompCostDelta(ci); + int64_t compDeltaVal = GET_INSTRUCTION_COST(compDelta); + + if (std::isnan(accDelta)) + continue; + + std::string category; + double efficiency = 0.0; + if (compDeltaVal <= 0 && accDelta <= 0) { + category = "free_win"; + efficiency = std::abs(accDelta) + std::abs((double)compDeltaVal); + } else if (compDeltaVal <= 0 && accDelta > 0) { + category = "speed_for_accuracy"; + efficiency = (accDelta > 1e-30) + ? std::abs((double)compDeltaVal) / accDelta + : std::abs((double)compDeltaVal); + } else if (compDeltaVal > 0 && accDelta <= 0) { + category = "accuracy_for_speed"; + efficiency = (compDeltaVal > 0) + ? std::abs(accDelta) / (double)compDeltaVal + : std::abs(accDelta); + } else { + continue; + } + + json::Object rw; + std::string id = + "PT" + std::to_string(csIdx) + "_" + std::to_string(ci); + rw["id"] = id; + rw["type"] = "precision_change"; + rw["description"] = pt.desc; + rw["category"] = category; + rw["efficiency"] = jsonFloat(efficiency); + rw["computation_cost_delta"] = compDeltaVal; + rw["accuracy_cost_delta"] = jsonFloat(accDelta); + + // Collect affected instructions + deduplicated source locations + SmallVector ptInsts; + for (const auto &change : pt.changes) { + for (auto *node : change.nodes) { + if (auto *I = dyn_cast(node->value)) + ptInsts.push_back(I); + } + } + rw["source_locations"] = getSourceLocationsJSON(ptInsts); + json::Array affectedIR; + for (auto *I : ptInsts) + affectedIR.push_back(getValueStr(I)); + rw["affected_instructions"] = std::move(affectedIR); + + rewrites.push_back(std::move(rw)); + } + } + + // Sort: free_wins first, then by efficiency descending + auto rewriteVec = + SmallVector(rewrites.begin(), rewrites.end()); + llvm::sort(rewriteVec, [](const json::Value &a, const json::Value &b) { + auto *ao = a.getAsObject(); + auto *bo = b.getAsObject(); + StringRef aCat = ao->getString("category").value_or(""); + StringRef bCat = bo->getString("category").value_or(""); + auto catRank = [](StringRef c) -> int { + if (c == "free_win") + return 0; + if (c == "accuracy_for_speed") + return 1; + if (c == "speed_for_accuracy") + return 2; + return 3; + }; + int ar = catRank(aCat), br = catRank(bCat); + if (ar != br) + return ar < br; + double ae = ao->getNumber("efficiency").value_or(0); + double be = bo->getNumber("efficiency").value_or(0); + return ae > be; // higher efficiency first + }); + + json::Array sortedRewrites; + for (auto &v : rewriteVec) + sortedRewrites.push_back(std::move(v)); + + std::string rewritesFile = + (Twine(FPOptReportPath) + "/" + funcName + "_rewrites.json").str(); + raw_fd_ostream rwOut(rewritesFile, EC, sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing rewrites report: " << EC.message() << "\n"; + } else { + json::Object root; + root["function"] = funcName.str(); + root["total_rewrites"] = static_cast(sortedRewrites.size()); + root["rewrites"] = std::move(sortedRewrites); + rwOut << formatv("{0:2}", json::Value(std::move(root))) << "\n"; + llvm::errs() << "Poseidon curated rewrites written to " << rewritesFile + << "\n"; + } + } +} + // Given the cost budget `FPOptComputationCostBudget`, we want to minimize the // accuracy cost of the rewritten expressions. bool accuracyGreedySolver( @@ -205,9 +752,11 @@ bool accuracyDPSolver( SolutionMap prunedCostToSolutionMap; std::string cacheFilePath = FPOptCachePath + "/table.json"; + bool loadedFromCache = false; if (llvm::sys::fs::exists(cacheFilePath)) { llvm::errs() << "Cache file found. Loading DP tables from cache.\n"; + loadedFromCache = true; llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFile(cacheFilePath); @@ -776,6 +1325,10 @@ bool accuracyDPSolver( } } + if (!loadedFromCache) + emitPoseidonReport(F.getName(), costToAccuracyMap, costToSolutionMap, COs, + CSs); + llvm::errs() << "Critical computation cost range: [" << costToAccuracyMap.begin()->first << ", " << costToAccuracyMap.rbegin()->first << "]\n"; diff --git a/enzyme/Enzyme/Poseidon/PoseidonSolvers.h b/enzyme/Enzyme/Poseidon/PoseidonSolvers.h index 8fb3445d3227..846c89d13b37 100644 --- a/enzyme/Enzyme/Poseidon/PoseidonSolvers.h +++ b/enzyme/Enzyme/Poseidon/PoseidonSolvers.h @@ -35,6 +35,7 @@ extern llvm::cl::list FPOptShowTableCosts; extern llvm::cl::opt FPOptEarlyPrune; extern llvm::cl::opt FPOptCostDominanceThreshold; extern llvm::cl::opt FPOptAccuracyDominanceThreshold; +extern llvm::cl::opt FPOptApplyRewrites; } bool accuracyGreedySolver( diff --git a/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp b/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp index 6900cd1fef4f..93604b9e9d39 100644 --- a/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp +++ b/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp @@ -720,9 +720,10 @@ void CandidateOutput::apply( // llvm::errs() << "Parsed Herbie output: " // << parsedNode->toFullExpression(valueToNodeMap) << "\n"; - IRBuilder<> builder(cast(oldOutput)->getParent(), - ++BasicBlock::iterator(cast(oldOutput))); - builder.setFastMathFlags(cast(oldOutput)->getFastMathFlags()); + auto *oldInst = cast(oldOutput); + + IRBuilder<> builder(oldInst->getParent(), ++BasicBlock::iterator(oldInst)); + builder.setFastMathFlags(oldInst->getFastMathFlags()); // auto *F = cast(oldOutput)->getParent()->getParent(); // llvm::errs() << "Before: " << *F << "\n"; diff --git a/enzyme/Enzyme/Poseidon/README.md b/enzyme/Enzyme/Poseidon/README.md index 40f1d5d5795c..474b4f6a3007 100644 --- a/enzyme/Enzyme/Poseidon/README.md +++ b/enzyme/Enzyme/Poseidon/README.md @@ -97,7 +97,6 @@ Add `--fpopt-report-path=` to emit structured optimization reports: ```bash clang++ mycode.cpp $CXXFLAGS ... \ -mllvm --fpopt-report-path=./report \ - -mllvm --fpopt-raptor-dir=/path/to/RAPTOR/build # optional, for validate.py ``` ### Report Contents @@ -281,6 +280,6 @@ Then pass `--gold-path gold_mpfr.txt` to `validate.py`. | Flag | Default | Description | |------|---------|-------------| | `--fpopt-report-path` | | Output directory for JSON/text reports + validate.py | -| `--fpopt-raptor-dir` | | RAPTOR build directory (baked into validate.py config) | +| `--fpopt-apply-rewrites` | | Comma-separated rewrite IDs from `_rewrites.json` (bypasses DP solver) | | `--fpopt-print` | false | Print debug info to stderr | | `--fpopt-show-table` | false | Print full DP table to stderr | From 815c3df7116f9b236a7c7e448fd7b1c25baa8092 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 17:06:49 -0500 Subject: [PATCH 14/18] bug fix --- enzyme/Enzyme/Enzyme.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 192a3401369c..098a03d0ec41 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -2067,7 +2067,7 @@ class EnzymeBase { "to create a profile or -fpprofile-use= to use an " "existing profile."); } else { - F = Logic.PPC.preprocessForClone(F, mode); + F = Logic.PPC.preprocessForClone(F, mode, /*profiled=*/true); setPoseidonMetadata(*F); SmallString<128> profilePath(FPProfileUse); From 76b1f33c2c36ded9b63538ec53be45be10e13b5b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 18:06:51 -0500 Subject: [PATCH 15/18] save doc --- enzyme/Enzyme/Poseidon/README.md | 157 +++++++--- enzyme/Enzyme/Poseidon/scripts/validate.py | 340 +++++++++++++++++++++ 2 files changed, 461 insertions(+), 36 deletions(-) create mode 100644 enzyme/Enzyme/Poseidon/scripts/validate.py diff --git a/enzyme/Enzyme/Poseidon/README.md b/enzyme/Enzyme/Poseidon/README.md index 474b4f6a3007..4f742dfb1363 100644 --- a/enzyme/Enzyme/Poseidon/README.md +++ b/enzyme/Enzyme/Poseidon/README.md @@ -1,12 +1,10 @@ # Poseidon -Poseidon is a modular and extensible framework that fully automates floating-point optimizations for real-world applications within a production compiler. It operates as an LLVM pass inside [Enzyme](https://enzyme.mit.edu/) and uses a PGO-like two-phase compilation to automatically extract numerical context (value ranges, sensitivities) from small surrogate profiling runs. It then synthesizes algebraic rewrites via [Herbie](https://herbie.uwplse.org/), generates precision tuning candidates, and uses a dynamic programming solver to find Pareto-optimal combinations that trade off computation cost against numerical accuracy. - -Unlike prior tools that require DSL inputs or manual code edits, Poseidon works directly on compiled LLVM IR and interoperates with standard compiler analyses and optimizations (mem2reg, inlining, loop unrolling, SimplifyCFG), enabling it to extract larger FP subgraphs than would be available in source code or binaries. +Poseidon is a modular and extensible framework that fully automates advanced floating-point rewriting techniques for real-world applications within a production compiler. It operates as a PGO-like two-phase compiler that automatically extract numerical context (e.g., value ranges, sensitivities) from small surrogate profiling runs. It synthesizes algebraic rewrites via [Herbie](https://herbie.uwplse.org/), generates precision tuning candidates, and uses a dynamic programming solver to find a Pareto frontier of optimized programs. For details, please read our paper [Thinking Fast and Correct: Automated Rewriting of Numerical Code through Compiler Augmentation](https://ece.is/assets/pdf/poseidon-cgo26.pdf) (CGO 2026). -If you use Poseidon in an academic setting, please cite: +If you use Poseidon in an academic setting, please kindly cite: ```bibtex @inproceedings{poseidon, @@ -21,23 +19,71 @@ If you use Poseidon in an academic setting, please cite: ## Build -See the [artifact repository](https://github.com/PRONTOLab/Poseidon) for full build instructions including Docker. In short: +We recommend building from source and generating a hardware-specific cost model for best results. The cost model estimates per-operation latencies on your machine and is used by the DP solver to estimate computation costs. + +### Prerequisites + +```bash +sudo apt install build-essential cmake ninja-build libmpfr-dev +pip install lit numpy matplotlib tqdm +``` + +Additionally, install [Racket](https://racket-lang.org/) and [Rust](https://www.rust-lang.org/tools/install). + +### Build LLVM + +```bash +cd llvm-project +mkdir build && cd build +cmake -G Ninja \ + -DLLVM_ENABLE_PROJECTS="clang" \ + -DLLVM_ENABLE_LLD=ON \ + -DLLVM_TARGETS_TO_BUILD="X86" \ + -DCMAKE_BUILD_TYPE=Release \ + ../llvm +ninja +cd ../.. +``` + +### Build Enzyme with Poseidon Enabled ```bash -cd Enzyme && mkdir build && cd build +cd Enzyme +mkdir build && cd build cmake -G Ninja ../enzyme/ \ - -DLLVM_DIR=/path/to/llvm/build/lib/cmake/llvm \ + -DLLVM_DIR=<...>/llvm-project/build/lib/cmake/llvm \ -DLLVM_EXTERNAL_LIT=$(which lit) \ -DCMAKE_BUILD_TYPE=Release \ - -DENABLE_POSEIDON=ON + -DENABLE_POSEIDON=ON \ + -DCMAKE_C_COMPILER=<...>/llvm-project/build/bin/clang \ + -DCMAKE_CXX_COMPILER=<...>/llvm-project/build/bin/clang++ ninja +cd ../.. ``` -Both phases must use identical compiler flags (e.g., `-O3`, `-ffast-math`, `-march=native`, etc.) to ensure profile indices match between compilations. +Replace `<...>` with the appropriate path prefix. + +A preconfigured Docker image is also available; see the [CGO artifact repository](https://github.com/PRONTOLab/Poseidon) for details. + +### Generate the Cost Model + +The cost model (`cm.csv`) is hardware-specific. To generate it for your machine: + +```bash +cd cost-model +python3 microbm.py +cp results.csv cm.csv +``` + +Then pass `--fpopt-cost-model-path=<...>/cost-model/cm.csv` during the optimization phase. Without a cost model, Poseidon falls back to LLVM's `TargetTransformInfo` estimates. ## Two-Phase Pipeline -### User Code +See the [dquat benchmark](https://github.com/PRONTOLab/Poseidon/tree/main/dquat) for an end-to-end example. + +:warning: Please Note: Both phases we introduce below must use identical compiler flags (e.g., `-O3`, `-ffast-math`, `-march=native`, etc.) to ensure profile indices match between compilations. + +### User Code Modification Annotate the function to optimize with `__enzyme_fp_optimize`: @@ -70,7 +116,7 @@ clang++ mycode.cpp $CXXFLAGS \ ENZYME_FPPROFILE_DIR=./fpprofile ./mycode-prof ``` -The instrumented binary records per-instruction value ranges, gradient sensitivities, and execution counts into a `.fpprofile` directory. A small surrogate input (e.g., 1000 samples) is often sufficient. +The instrumented binary records per-instruction value ranges, gradient sensitivities, and execution counts into a `.fpprofile` directory. A small surrogate input (e.g., 100 samples) is often sufficient. ### Phase 2: Optimization @@ -90,29 +136,38 @@ clang++ mycode.cpp $CXXFLAGS \ The first run invokes Herbie and the DP solver; results are cached in `--fpopt-cache-path`. Subsequent runs reuse the cache. The `--fpopt-comp-cost-budget` selects a point on the Pareto curve (0 = no-op baseline). +## How to Apply Rewrites? + +Poseidon offers two ways to apply numerical rewrites: + +1. **Select an optimized program from the Pareto frontier.** Pass `--fpopt-comp-cost-budget=N` to pick a point computed by the DP solver. Each budget value corresponds to a different combination of rewrites and precision changes that the solver determined to be Pareto-optimal. The full set of available budgets is listed in `validate_config.json` (when using `--fpopt-report-path`) and `cache/budgets.txt`. + +2. **Obtain a custom optimized program.** Generate a [report](#reporting), review the individual rewrites in `_rewrites.json`, and pass the IDs of the ones you want via `--fpopt-apply-rewrites=R3_0,PT1_0,...`. This bypasses the DP solver and gives fine-grained control over exactly which rewrites are applied. See [Applying User-Selected Rewrites](#applying-user-selected-rewrites) for details. + ## Reporting -Add `--fpopt-report-path=` to emit structured optimization reports: +By default, Poseidon silently applies the best solution within the given budget to the compiled binary. To understand *what* was changed and *why*, add `--fpopt-report-path=` (with `-g` for source locations): ```bash -clang++ mycode.cpp $CXXFLAGS ... \ - -mllvm --fpopt-report-path=./report \ +clang++ mycode.cpp $CXXFLAGS -g ... \ + -mllvm --fpopt-report-path=./report ``` ### Report Contents | File | Description | |------|-------------| -| `.json` | Full Pareto table: per-step source locations, symbolic expressions (original and rewritten), affected LLVM IR, Herbie accuracy bits, cost/accuracy deltas, gradient and execution count | -| `.txt` | Human-readable version of the same (suitable for feeding to an LLM for explanation) | -| `validate_config.json` | Baked-in configuration: all Pareto budgets, profile path, cache path, RAPTOR dir | -| `validate.py` | Self-contained validation script | +| `.json` | Full Pareto table with details | +| `.txt` | Plain-text version of the Pareto table | +| `_rewrites.json` | Detailed per-rewrite information: all rewrites categorized and ranked | + +Source locations (file, line, column) are populated when compiled with `-g`. Without `-g`, the symbolic expressions and affected IR are still available. -Source locations (file, line, column) are populated when the source is compiled with `-g`. Without `-g`, the symbolic expressions and affected IR are still available. +:bulb: For complex programs/rewrites, it may be beneficial to feed the `.txt` report and the program source code to an LLM and asking it to explain what each rewrite does and why it improves numerical accuracy. -### Understanding Rewrites +### Understanding the Pareto Report -The text report lists each Pareto point with the rewrites applied. For example: +The `.json` and `.txt` reports describe the full Pareto frontier. Each Pareto point represents one **optimized program** — a combination of rewrites and precision changes selected by the DP solver at a given computation cost budget. For example: ``` --- Pareto Point #5: Cost=-6863264, Accuracy=3.964900e-01 --- @@ -125,19 +180,53 @@ The text report lists each Pareto point with the rewrites applied. For example: %mul = fmul fast double %sqrt, 2.500000e-01 ``` -Each rewrite shows: -- **Original and rewritten symbolic expressions** in Herbie's S-expression format -- **Source locations** (when compiled with `-g`) pointing to the C/C++ lines affected -- **Affected LLVM IR instructions** that are replaced or erased -- **Herbie accuracy improvement** in bits +### Per-Rewrite Analysis -For complex rewrites, consider feeding the `.txt` report along with the source code to an LLM and asking it to explain what each rewrite does and why it improves accuracy. The symbolic expressions are self-contained and map directly to the source locations listed. +The `_rewrites.json` file lists every **individual** rewrite candidate, categorized by its estimated impact: -## Validation with validate.py +- **`free_win`** — improves both accuracy and speed. Always beneficial. +- **`accuracy_for_speed`** — improves accuracy at the cost of extra computation. +- **`speed_for_accuracy`** — improves speed at the cost of some accuracy. + +Each entry includes an `efficiency` score that ranks tradeoffs: higher means more benefit per unit of cost. Entries are sorted by category (free wins first) then by efficiency descending. + +Each rewrite has a stable `id` (e.g., `R3_0`, `PT1_2`) that can be used with `--fpopt-apply-rewrites` (see below). + +Example: +```json +{ + "id": "R6_3", + "category": "speed_for_accuracy", + "efficiency": 4.602e+11, + "computation_cost_delta": -22632, + "accuracy_cost_delta": 4.918e-08, + "original_expr": "(* (/ 1 (sqrt ...)) v5)", + "rewritten_expr": "#s(approx ...)", + "source_location": {"file": "dquat.cpp", "line": 116, "col": 18} +} +``` -The `validate.py` script (emitted alongside the report) automates accuracy and performance measurement across Pareto-optimal variants: +### Applying User-Selected Rewrites + +A user can pick specific rewrites from `_rewrites.json` by their IDs (this bypasses the DP solver): + +```bash +clang++ mycode.cpp $CXXFLAGS ... \ + -mllvm --fpopt-apply-rewrites=R5_5,R6_3,R7_3,PT1_0 +``` + +With a few constraints: +- At most one rewrite per expression (e.g., `R5_0` and `R5_1` conflict) +- At most one precision change per subgraph (e.g., `PT1_0` and `PT1_2` conflict) + +Duplicates are detected and skipped with a warning. + +## Optimized Program Validation + +To validate the actual accuracy and runtime of optimized programs, a reference validation script is provided at `Poseidon/scripts/validate.py`. Copy it to your report directory alongside `validate_config.json`, then run: ```bash +cp /Enzyme/Poseidon/scripts/validate.py ./report/ python3 report/validate.py mycode.cpp \ --enzyme-plugin /path/to/ClangEnzyme-XX.so \ --cxx /path/to/clang++ \ @@ -150,7 +239,7 @@ python3 report/validate.py mycode.cpp \ ### What it does -1. **Loads a gold reference** from `--gold-path` (MPFR ground truth; see [RAPTOR section](#generating-mpfr-gold-references-with-raptor) below) +1. **Loads a gold accuracy reference** from `--gold-path` (MPFR ground truth; see [RAPTOR section](#generating-mpfr-gold-references-with-raptor) below) 2. **Compiles the original** (unoptimized) binary, measures its runtime and accuracy against gold 3. **Uniformly samples** N budgets from the full Pareto table (87 points for dquat with `-ffast-math`) 4. **For each budget:** recompiles with Poseidon at that budget using the cached Herbie results + DP table, runs the resulting binary, captures output @@ -174,9 +263,9 @@ python3 report/validate.py mycode.cpp \ Here `ORIGINAL` is the unoptimized double-precision program. Negative budgets allow the solver to trade accuracy for speed; positive budgets allow extra computation to improve accuracy. The `MaxErr: 1179` in the original comes from catastrophic cancellation (`1 - cos(theta)` near zero), which Herbie's rewrites at positive budgets fix. -## Generating MPFR Gold References with RAPTOR +## (Optional) Generating MPFR References with RAPTOR -For meaningful accuracy validation, the gold reference should be computed at high precision (MPFR-2048 bits) rather than using the original double-precision output. [RAPTOR](https://github.com/RIKEN-RCCS/RAPTOR) provides this capability by running every floating-point operation through MPFR with garbage collection to avoid OOM on large applications. +For accuracy validation we recommend computing reference results at high floating-point precision (e.g., MPFR-2048 bits). [RAPTOR](https://github.com/RIKEN-RCCS/RAPTOR) provides this capability by running floating-point operations through MPFR. For the latest build and usage instructions of RAPTOR, see the [RAPTOR repository](https://github.com/RIKEN-RCCS/RAPTOR). ### Building RAPTOR @@ -187,10 +276,6 @@ cmake .. -DLLVM_DIR=/path/to/llvm -DCMAKE_BUILD_TYPE=Release make -j ``` -RAPTOR requires LLVM >= 15 and MPFR. For LLVM >= 23, two patches are needed in `pass/Raptor.cpp`: -- `#include "llvm/Passes/PassPlugin.h"` -> `#include "llvm/Plugins/PassPlugin.h"` -- Remove the third argument from `createFunctionToLoopPassAdaptor` - ### Source preparation Add a `#ifdef POSEIDON_GOLD` guard that wraps the target function call with RAPTOR's MPFR truncation: diff --git a/enzyme/Enzyme/Poseidon/scripts/validate.py b/enzyme/Enzyme/Poseidon/scripts/validate.py new file mode 100644 index 000000000000..4da43218adab --- /dev/null +++ b/enzyme/Enzyme/Poseidon/scripts/validate.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +""" +Poseidon validation reference script. + +Compiles a uniformly sampled subset of Pareto-optimal variants with +Enzyme/Poseidon, runs each, and reports geomean relative error against +a gold reference and runtime. + +This is a REFERENCE SCRIPT meant to be copied and adapted for your +application. Search for "CUSTOMIZE" to find the places that likely +need modification for a new benchmark. + +Usage: + python validate.py --enzyme-plugin \ + [--cxx ] [--extra-flags "..."] \ + [--num-samples 10] [--num-runs 5] \ + [--gold-path gold.txt] +""" + +import argparse, glob, json, math, os, re, subprocess, sys, time + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +CONFIG = json.load(open(os.path.join(SCRIPT_DIR, "validate_config.json"))) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_artifact(pattern): + hits = glob.glob(pattern) + return hits[0] if hits else None + + +def run(cmd, desc="", timeout=600, capture=False): + """Run a command and optionally capture stdout.""" + label = " ".join(cmd) if len(cmd) < 8 else f"{cmd[0]} ... ({desc})" + print(f" $ {label}") + try: + r = subprocess.run(cmd, capture_output=capture, text=True, timeout=timeout) + if r.returncode != 0: + print(f" FAILED ({desc}): exit {r.returncode}") + if capture and r.stderr: + print(r.stderr[:500]) + return None + return r.stdout if capture else "" + except subprocess.TimeoutExpired: + print(f" TIMEOUT ({desc})") + return None + + +# ---- CUSTOMIZE: output parsing ------------------------------------------- +# Replace this with a parser suited to your application's output format. +# The default extracts every floating-point number from stdout. + +def parse_output(text): + """Extract comparable numeric values from program output. + + Returns a list of floats in the order they appear. Adapt this for + your application -- e.g. skip header lines, parse specific columns, + or read a binary file instead. + """ + return [float(m) for m in + re.findall(r"[+-]?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?", text)] + + +# ---- CUSTOMIZE: error metric --------------------------------------------- +# The default is geomean relative error. Replace with ULP, L2 norm, or +# whatever is appropriate for your domain. + +def compute_error(gold_vals, test_vals): + """Compute geomean and max relative error between two value lists.""" + n = min(len(gold_vals), len(test_vals)) + if n == 0: + return float("nan"), float("nan") + log_sum = 0.0 + count = 0 + max_err = 0.0 + for i in range(n): + g, t = gold_vals[i], test_vals[i] + if g == 0 and t == 0: + continue + denom = abs(g) if g != 0 else abs(t) + rel = abs(g - t) / denom + max_err = max(max_err, rel) + if rel > 0: + log_sum += math.log(rel) + count += 1 + if count == 0: + return 0.0, 0.0 + return math.exp(log_sum / count), max_err + + +# ---- CUSTOMIZE: runtime measurement -------------------------------------- +# The default uses wall-clock time. If your program reports its own +# timing (e.g. "Elapsed: 1.23s"), parse it here for more stable results. + +def measure_runtime(exe, num_runs, extra_run_args=None): + """Measure median wall-clock runtime of an executable.""" + times = [] + args = [exe] + (extra_run_args or []) + for _ in range(num_runs): + t0 = time.perf_counter() + r = subprocess.run(args, capture_output=True, text=True, timeout=300) + t1 = time.perf_counter() + if r.returncode == 0: + times.append(t1 - t0) + if not times: + return float("nan") + times.sort() + return times[len(times) // 2] + + +# ---- CUSTOMIZE: how to run the program and capture output ----------------- +# The default runs the executable with extra_run_args and captures stdout. +# If your program writes to a file, reads from stdin, or needs environment +# variables, modify this function. + +def run_and_capture(exe, extra_run_args): + """Run executable and return its stdout as a string, or None on failure.""" + return run([exe] + extra_run_args, f"run {os.path.basename(exe)}", capture=True) + + +def uniform_sample(lst, n): + """Pick n uniformly spaced items from lst (always include first and last).""" + if n >= len(lst): + return list(range(len(lst))) + if n <= 0: + return [] + if n == 1: + return [0] + step = (len(lst) - 1) / (n - 1) + return sorted(set(int(round(i * step)) for i in range(n))) + + +# --------------------------------------------------------------------------- +# Compilation +# --------------------------------------------------------------------------- + +def compile_gold(source, cxx, raptor_dir, extra_flags, gold_flags, out_dir): + """Compile with RAPTOR for MPFR gold reference. + + The source must contain RAPTOR truncation calls (e.g. guarded by + -DPOSEIDON_GOLD). See the Poseidon README for the full pattern. + """ + plugin = find_artifact(os.path.join(raptor_dir, "pass", "ClangRaptor-*.so")) + rt = find_artifact(os.path.join(raptor_dir, "runtime", "libRaptor-RT-*.a")) + if not plugin or not rt: + print(f"Error: RAPTOR artifacts not found in {raptor_dir}") + return None + rt_dir = os.path.dirname(rt) + rt_name = os.path.basename(rt).replace("lib", "", 1).replace(".a", "") + + exe = os.path.join(out_dir, "gold.exe") + cmd = [cxx, source] + extra_flags.split() + gold_flags.split() + [ + f"-fpass-plugin={plugin}", + "-Xclang", "-load", "-Xclang", plugin, + f"-L{rt_dir}", f"-l{rt_name}", + "-lmpfr", "-lm", + "-o", exe, + ] + if run(cmd, "compile gold") is None: + return None + return exe + + +def compile_variant(source, cxx, enzyme_plugin, budget, config, extra_flags, out_dir): + """Compile a Poseidon-optimized variant at a given budget.""" + exe = os.path.join(out_dir, f"opt_{budget}.exe") + cmd = [cxx, source] + extra_flags.split() + [ + f"-fpass-plugin={enzyme_plugin}", + "-Xclang", "-load", "-Xclang", enzyme_plugin, + "-mllvm", f"--fpprofile-use={config['profile_path']}", + "-mllvm", "--fpopt-enable-solver", + "-mllvm", "--fpopt-enable-herbie=1", + "-mllvm", "--fpopt-enable-pt", + "-mllvm", f"--fpopt-comp-cost-budget={budget}", + "-mllvm", f"--fpopt-cache-path={config['cache_path']}", + "-lmpfr", "-lm", + "-o", exe, + ] + if run(cmd, f"compile budget={budget}") is None: + return None + return exe + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + p = argparse.ArgumentParser( + description="Poseidon accuracy/runtime validation (reference script)") + p.add_argument("source", help="C/C++ source file") + p.add_argument("--enzyme-plugin", required=True, + help="Path to ClangEnzyme-XX.so") + p.add_argument("--cxx", default="clang++", + help="C++ compiler (default: clang++)") + p.add_argument("--extra-flags", default="-O3 -ffast-math -lm", + help="Compile flags (must match profiling flags)") + p.add_argument("--extra-run-args", default="", + help="Args passed to each executable at runtime") + p.add_argument("--num-samples", type=int, default=10, + help="Number of Pareto points to validate") + p.add_argument("--num-runs", type=int, default=5, + help="Runtime measurement repetitions per variant") + p.add_argument("--gold-path", default="", + help="Pre-computed gold reference (skip RAPTOR compilation)") + p.add_argument("--raptor-dir", default="", + help="RAPTOR build directory (for gold compilation)") + p.add_argument("--gold-source", default="", + help="Source file for gold compilation (if different)") + p.add_argument("--gold-flags", default="", + help="Extra flags for gold compilation (e.g. -DPOSEIDON_GOLD)") + args = p.parse_args() + + budgets = CONFIG["budgets"] + est_acc = CONFIG.get("estimated_accuracy_costs", []) + run_args = args.extra_run_args.split() if args.extra_run_args else [] + + out_dir = os.path.join(SCRIPT_DIR, "validation_output") + os.makedirs(out_dir, exist_ok=True) + + # --- Gold reference --- + if args.gold_path: + print(f"=== Using pre-computed gold reference: {args.gold_path} ===") + gold_text = open(args.gold_path).read() + else: + print("=== Step 1: Compiling MPFR gold reference with RAPTOR ===") + if not args.raptor_dir: + print("Error: --raptor-dir required for gold compilation.") + print("Either provide --gold-path or --raptor-dir.") + sys.exit(1) + gold_src = args.gold_source if args.gold_source else args.source + gold_exe = compile_gold(gold_src, args.cxx, args.raptor_dir, + args.extra_flags, args.gold_flags, out_dir) + if not gold_exe: + print("Gold compilation failed.") + sys.exit(1) + print("Running gold binary...") + gold_text = run_and_capture(gold_exe, run_args) + if gold_text is None: + print("Error: gold run failed.") + sys.exit(1) + + gold_vals = parse_output(gold_text) + print(f"Gold reference: {len(gold_vals)} values extracted.\n") + + # --- Original (unoptimized) baseline --- + print("=== Step 2: Measuring baseline (original) ===") + orig_exe = os.path.join(out_dir, "original.exe") + orig_cmd = [args.cxx, args.source] + args.extra_flags.split() + [ + f"-fpass-plugin={args.enzyme_plugin}", + "-Xclang", "-load", "-Xclang", args.enzyme_plugin, + "-lmpfr", "-lm", "-o", orig_exe, + ] + orig_geo_err = float("nan") + orig_max_err = float("nan") + orig_runtime = float("nan") + if run(orig_cmd, "compile original") is not None: + orig_runtime = measure_runtime(orig_exe, args.num_runs, run_args) + orig_text = run_and_capture(orig_exe, run_args) + if orig_text is not None: + orig_vals = parse_output(orig_text) + orig_geo_err, orig_max_err = compute_error(gold_vals, orig_vals) + print(f"Baseline runtime: {orig_runtime:.6f}s, " + f"geomean error: {orig_geo_err:.6e}, max error: {orig_max_err:.6e}\n") + else: + print("Warning: could not compile original.\n") + + # --- Sample and compile Pareto variants --- + sample_indices = uniform_sample(budgets, args.num_samples) + sampled_budgets = [budgets[i] for i in sample_indices] + sampled_est = [est_acc[i] if i < len(est_acc) else float("nan") + for i in sample_indices] + + print(f"=== Step 3: Compiling {len(sampled_budgets)} Pareto variants ===") + + results = [] + for budget, est in zip(sampled_budgets, sampled_est): + print(f"\n--- Budget: {budget} (estimated acc cost: {est:.6e}) ---") + exe = compile_variant(args.source, args.cxx, args.enzyme_plugin, + budget, CONFIG, args.extra_flags, out_dir) + if not exe: + results.append({"budget": budget, "error": "compile_failed"}) + continue + + test_text = run_and_capture(exe, run_args) + if test_text is None: + results.append({"budget": budget, "error": "run_failed"}) + continue + + test_vals = parse_output(test_text) + geo_err, max_err = compute_error(gold_vals, test_vals) + rt = measure_runtime(exe, args.num_runs, run_args) + speedup = orig_runtime / rt if rt > 0 and not math.isnan(orig_runtime) else float("nan") + + results.append({ + "budget": budget, + "estimated_accuracy_cost": est, + "geomean_relative_error": geo_err, + "max_relative_error": max_err, + "runtime": rt, + "speedup": speedup, + }) + print(f" geomean error: {geo_err:.6e}, max error: {max_err:.6e}, " + f"runtime: {rt:.6f}s, speedup: {speedup:.2f}x") + + # --- Summary --- + print("\n" + "=" * 72) + print(f"{'Budget':>10} {'Est.AccCost':>12} {'GeomErr':>12} {'MaxErr':>12} " + f"{'Runtime':>10} {'Speedup':>8}") + print("-" * 72) + if not math.isnan(orig_runtime): + print(f"{'ORIGINAL':>10} {'--':>12} " + f"{orig_geo_err:>12.4e} {orig_max_err:>12.4e} " + f"{orig_runtime:>10.6f} {'1.00x':>8}") + print("-" * 72) + for r in results: + if "error" in r and isinstance(r.get("error"), str): + print(f"{r['budget']:>10} {'':>12} {r['error']:>12}") + continue + print(f"{r['budget']:>10} {r.get('estimated_accuracy_cost',0):>12.4e} " + f"{r['geomean_relative_error']:>12.4e} {r['max_relative_error']:>12.4e} " + f"{r['runtime']:>10.6f} {r['speedup']:>7.2f}x") + print("=" * 72) + + # Save results + results_path = os.path.join(SCRIPT_DIR, f"{CONFIG['function']}_validation.json") + with open(results_path, "w") as f: + json.dump({"config": CONFIG, + "baseline": {"runtime": orig_runtime, + "geomean_relative_error": orig_geo_err, + "max_relative_error": orig_max_err}, + "results": results}, f, indent=2, default=str) + print(f"\nResults saved to {results_path}") + + +if __name__ == "__main__": + main() From bc3a70fe5c9ac356314bf0d5c776c3fb0a4af275 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 18:10:55 -0500 Subject: [PATCH 16/18] toc --- enzyme/Enzyme/Poseidon/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/enzyme/Enzyme/Poseidon/README.md b/enzyme/Enzyme/Poseidon/README.md index 4f742dfb1363..ce2014dd0e1e 100644 --- a/enzyme/Enzyme/Poseidon/README.md +++ b/enzyme/Enzyme/Poseidon/README.md @@ -17,6 +17,19 @@ If you use Poseidon in an academic setting, please kindly cite: } ``` +## Table of Contents + +- [Build](#build) +- [Two-Phase Pipeline](#two-phase-pipeline) +- [How to Apply Rewrites?](#how-to-apply-rewrites) +- [Reporting](#reporting) + - [Per-Rewrite Analysis](#per-rewrite-analysis) + - [Applying User-Selected Rewrites](#applying-user-selected-rewrites) +- [Optimized Program Validation](#optimized-program-validation) +- [Generating MPFR References with RAPTOR](#optional-generating-mpfr-references-with-raptor) +- [Command-Line Reference](#command-line-reference) + + ## Build We recommend building from source and generating a hardware-specific cost model for best results. The cost model estimates per-operation latencies on your machine and is used by the DP solver to estimate computation costs. From 2b791b84212a9c6b62d685d3d88fc796dc724e73 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 18:12:14 -0500 Subject: [PATCH 17/18] save --- enzyme/Enzyme/Poseidon/Poseidon.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/Poseidon/Poseidon.cpp b/enzyme/Enzyme/Poseidon/Poseidon.cpp index 2def5da8bf82..d5dd8201723b 100644 --- a/enzyme/Enzyme/Poseidon/Poseidon.cpp +++ b/enzyme/Enzyme/Poseidon/Poseidon.cpp @@ -1170,8 +1170,9 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI, double errorTol) { // User-selected rewrites: parse IDs and apply the specified candidates. // IDs are R{coIdx}_{candIdx} for rewrites, PT{csIdx}_{candIdx} for PT. SmallVector ids; - StringRef(FPOptApplyRewrites).split(ids, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); + StringRef(FPOptApplyRewrites) + .split(ids, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); // Track which CO/CS has been applied to enforce at-most-one constraint SmallDenseSet appliedCOs, appliedCSs; @@ -1183,7 +1184,8 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI, double errorTol) { auto rest = id.drop_front(1); auto [coStr, candStr] = rest.split('_'); size_t coIdx, candIdx; - if (coStr.getAsInteger(10, coIdx) || candStr.getAsInteger(10, candIdx)) { + if (coStr.getAsInteger(10, coIdx) || + candStr.getAsInteger(10, candIdx)) { llvm::errs() << "FPOpt: Invalid rewrite ID '" << id << "'\n"; continue; } @@ -1213,7 +1215,8 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI, double errorTol) { auto rest = id.drop_front(2); auto [csStr, candStr] = rest.split('_'); size_t csIdx, candIdx; - if (csStr.getAsInteger(10, csIdx) || candStr.getAsInteger(10, candIdx)) { + if (csStr.getAsInteger(10, csIdx) || + candStr.getAsInteger(10, candIdx)) { llvm::errs() << "FPOpt: Invalid PT ID '" << id << "'\n"; continue; } From ab99402797f04b2eb52ef1a9c4f8db1232ffc313 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 18:24:39 -0500 Subject: [PATCH 18/18] save scripts --- enzyme/Enzyme/Poseidon/README.md | 7 +- enzyme/Enzyme/Poseidon/scripts/cm_example.csv | 98 +++ enzyme/Enzyme/Poseidon/scripts/microbm.py | 610 ++++++++++++++++++ 3 files changed, 711 insertions(+), 4 deletions(-) create mode 100644 enzyme/Enzyme/Poseidon/scripts/cm_example.csv create mode 100644 enzyme/Enzyme/Poseidon/scripts/microbm.py diff --git a/enzyme/Enzyme/Poseidon/README.md b/enzyme/Enzyme/Poseidon/README.md index ce2014dd0e1e..6b017f8da2be 100644 --- a/enzyme/Enzyme/Poseidon/README.md +++ b/enzyme/Enzyme/Poseidon/README.md @@ -80,15 +80,14 @@ A preconfigured Docker image is also available; see the [CGO artifact repository ### Generate the Cost Model -The cost model (`cm.csv`) is hardware-specific. To generate it for your machine: +The cost model (`cm.csv`) is hardware-specific. A benchmarking script is provided at `Poseidon/scripts/microbm.py`. To generate the cost model for your machine: ```bash -cd cost-model -python3 microbm.py +python3 /Enzyme/Poseidon/scripts/microbm.py cp results.csv cm.csv ``` -Then pass `--fpopt-cost-model-path=<...>/cost-model/cm.csv` during the optimization phase. Without a cost model, Poseidon falls back to LLVM's `TargetTransformInfo` estimates. +Then pass `--fpopt-cost-model-path=cm.csv` during the optimization phase. An example cost model (generated on AMD Ryzen Threadripper PRO 7995WX) is provided at `Poseidon/scripts/cm_example.csv` for reference. Without a cost model, Poseidon falls back to LLVM's `TargetTransformInfo` estimates. ## Two-Phase Pipeline diff --git a/enzyme/Enzyme/Poseidon/scripts/cm_example.csv b/enzyme/Enzyme/Poseidon/scripts/cm_example.csv new file mode 100644 index 000000000000..5f1c42c6a1d4 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/scripts/cm_example.csv @@ -0,0 +1,98 @@ +fneg,float,75 +fadd,float,77 +fsub,float,78 +fmul,float,77 +fdiv,float,77 +fcmp,float,62 +fpext_float_to_double,float,78 +fmuladd,float,78 +sin,float,421 +cos,float,418 +tan,float,877 +exp,float,458 +log,float,300 +sqrt,float,124 +expm1,float,499 +log1p,float,478 +cbrt,float,2777 +pow,float,545 +fabs,float,76 +fma,float,75 +maxnum,float,75 +minnum,float,76 +ceil,float,77 +floor,float,75 +exp2,float,458 +log10,float,493 +log2,float,301 +rint,float,298 +round,float,304 +trunc,float,76 +copysign,float,76 +fdim,float,300 +fmod,float,298 +asin,float,331 +acos,float,296 +atan,float,780 +atan2,float,2730 +sinh,float,481 +cosh,float,328 +tanh,float,344 +asinh,float,341 +acosh,float,340 +atanh,float,377 +hypot,float,1938 +erf,float,327 +lgamma,float,1678 +tgamma,float,3468 +remainder,float,3026 +powi,float,83 +fneg,double,82 +fadd,double,83 +fsub,double,83 +fmul,double,82 +fdiv,double,127 +fcmp,double,65 +fptrunc_double_to_float,double,83 +fmuladd,double,83 +sin,double,3715 +cos,double,3769 +tan,double,4499 +exp,double,748 +log,double,350 +sqrt,double,236 +expm1,double,327 +log1p,double,460 +cbrt,double,1047 +pow,double,1102 +fabs,double,83 +fma,double,83 +maxnum,double,83 +minnum,double,83 +ceil,double,83 +floor,double,83 +exp2,double,442 +log10,double,604 +log2,double,327 +rint,double,326 +round,double,328 +trunc,double,83 +copysign,double,83 +fdim,double,329 +fmod,double,328 +asin,double,576 +acos,double,566 +atan,double,822 +atan2,double,1178 +sinh,double,627 +cosh,double,546 +tanh,double,327 +asinh,double,493 +acosh,double,385 +atanh,double,330 +hypot,double,526 +erf,double,327 +lgamma,double,1981 +tgamma,double,2017 +remainder,double,1064 +powi,double,82 diff --git a/enzyme/Enzyme/Poseidon/scripts/microbm.py b/enzyme/Enzyme/Poseidon/scripts/microbm.py new file mode 100644 index 000000000000..d5847b7808d5 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/scripts/microbm.py @@ -0,0 +1,610 @@ +import time +import csv +import struct +import random +import numpy as np +import llvmlite.binding as llvm +import ctypes + + +random.seed(42) + +llvm.initialize_native_target() +llvm.initialize_native_asmprinter() + +FAST_MATH_FLAG = "fast" + +unrolled = 128 +iterations = 100000000 +AMPLIFIER = 10 + +precision_to_llvm_type = { + "double": "double", + "float": "float", + "half": "half", + "fp80": "x86_fp80", + "fp128": "fp128", + "bf16": "bfloat", +} + +precision_to_intrinsic_suffix = { + "double": "f64", + "float": "f32", + "half": "f16", + "fp80": "f80", + "fp128": "f128", + "bf16": "bf16", +} + +precision_ranks = {"bf16": 0, "half": 1, "float": 2, "double": 3, "fp80": 4, "fp128": 5} +precisions_ordered = ["bf16", "half", "float", "double", "fp80", "fp128"] +precisions = ["float", "double"] + + +def get_zero_literal(precision): + if precision in ("double", "float", "half"): + return "0.0" + elif precision == "bf16": + return "0xR0000" + elif precision == "fp80": + return "0xK00000000000000000000" + elif precision == "fp128": + return "0xL00000000000000000000000000000000" + return "0.0" + + +def float64_to_fp80_bytes(value: np.float64) -> bytes: + packed = struct.pack(">d", value) + (bits,) = struct.unpack(">Q", packed) + sign = (bits >> 63) & 0x1 + exponent = (bits >> 52) & 0x7FF + mantissa = bits & 0xFFFFFFFFFFFFF + + if exponent == 0: + if mantissa == 0: + fp80_exponent = 0 + fp80_mantissa = 0 + else: + shift = 0 + while (mantissa & (1 << 52)) == 0: + mantissa <<= 1 + shift += 1 + exponent = 1 - shift + exponent_bias_64 = 1023 + exponent_bias_80 = 16383 + fp80_exponent = exponent - exponent_bias_64 + exponent_bias_80 + fp80_mantissa = mantissa << (63 - 52) + elif exponent == 0x7FF: + fp80_exponent = 0x7FFF + if mantissa == 0: + fp80_mantissa = 0x8000000000000000 + else: + fp80_mantissa = 0xC000000000000000 | (mantissa << (63 - 52)) + else: + exponent_bias_64 = 1023 + exponent_bias_80 = 16383 + fp80_exponent = exponent - exponent_bias_64 + exponent_bias_80 + fp80_mantissa = (0x8000000000000000) | (mantissa << (63 - 52)) + + exponent_sign = (sign << 15) | fp80_exponent + fp80_bits = (exponent_sign << 64) | fp80_mantissa + fp80_bytes = fp80_bits.to_bytes(10, byteorder="big") + return fp80_bytes + + +def float64_to_fp128_bytes(value: np.float64) -> bytes: + packed = struct.pack(">d", value) + (bits,) = struct.unpack(">Q", packed) + sign = (bits >> 63) & 0x1 + exponent = (bits >> 52) & 0x7FF + mantissa = bits & 0xFFFFFFFFFFFFF + + if exponent == 0: + fp128_exponent = 0 + elif exponent == 0x7FF: + fp128_exponent = 0x7FFF + else: + exponent_bias_64 = 1023 + exponent_bias_128 = 16383 + fp128_exponent = exponent - exponent_bias_64 + exponent_bias_128 + + fp128_mantissa = mantissa << 60 + fp128_bits = (sign << 127) | (fp128_exponent << 112) | fp128_mantissa + fp128_bytes = fp128_bits.to_bytes(16, byteorder="big") + return fp128_bytes + + +def float_to_llvm_hex(f, precision): + if precision == "double": + f_cast = np.float64(f) + packed = struct.pack(">d", f_cast) + [i] = struct.unpack(">Q", packed) + return f"0x{i:016X}" + elif precision == "float": + f_cast = np.float32(f) + packed = struct.pack(">d", f_cast) + [i] = struct.unpack(">Q", packed) + return f"0x{i:016X}" + elif precision == "half": + f_cast = np.float16(f) + packed = f_cast.tobytes() + [i] = struct.unpack(">H", packed) + return f"0xH{i:04X}" + elif precision == "bf16": + f_cast = np.float32(f) + [bits] = struct.unpack(">I", struct.pack(">f", f_cast)) + bf16_bits = bits >> 16 + return f"0xR{bf16_bits:04X}" + elif precision == "fp80": + f_cast = np.float64(f) + fp80_bytes = float64_to_fp80_bytes(f_cast) + return f"0xK{fp80_bytes.hex().upper()}" + elif precision == "fp128": + f_cast = np.float64(f) + fp128_bytes = float64_to_fp128_bytes(f_cast) + swapped = fp128_bytes[8:] + fp128_bytes[:8] + return f"0xL{swapped.hex().upper()}" + else: + return str(f) + + +def generate_random_fp(precision): + if precision == "double": + f = random.uniform(-1e10, 1e10) + dtype = np.float64 + elif precision == "float": + f = random.uniform(-1e5, 1e5) + dtype = np.float32 + elif precision == "half": + f = random.uniform(-1e3, 1e3) + dtype = np.float16 + elif precision == "bf16": + f = random.uniform(-1e3, 1e3) + dtype = np.float32 + elif precision == "fp80": + f = random.uniform(-1e10, 1e10) + dtype = np.float64 + elif precision == "fp128": + f = random.uniform(-1e10, 1e10) + dtype = np.float64 + else: + f = random.uniform(-1e3, 1e3) + dtype = np.float64 + return dtype(f).item() + + +OP_INFO = { + "fneg": {"llvm_instr": "fneg", "num_operands": 1, "kind": "arithmetic"}, + "fadd": {"llvm_instr": "fadd", "num_operands": 2, "kind": "arithmetic"}, + "fsub": {"llvm_instr": "fsub", "num_operands": 2, "kind": "arithmetic"}, + "fmul": {"llvm_instr": "fmul", "num_operands": 2, "kind": "arithmetic"}, + "fdiv": {"llvm_instr": "fdiv", "num_operands": 2, "kind": "arithmetic"}, + "fcmp": {"llvm_instr": "fcmp", "num_operands": 2, "kind": "compare"}, + "fptrunc": {"llvm_instr": "fptrunc", "num_operands": 1, "kind": "cast"}, + "fpext": {"llvm_instr": "fpext", "num_operands": 1, "kind": "cast"}, +} + +FUNC_INFO = { + "fmuladd": {"intrinsic": "llvm.fmuladd", "num_operands": 3}, + "sin": {"intrinsic": "llvm.sin", "num_operands": 1}, + "cos": {"intrinsic": "llvm.cos", "num_operands": 1}, + "tan": {"intrinsic": None, "num_operands": 1}, + "exp": {"intrinsic": "llvm.exp", "num_operands": 1}, + "log": {"intrinsic": "llvm.log", "num_operands": 1}, + "sqrt": {"intrinsic": "llvm.sqrt", "num_operands": 1}, + "expm1": {"intrinsic": None, "num_operands": 1}, + "log1p": {"intrinsic": None, "num_operands": 1}, + "cbrt": {"intrinsic": None, "num_operands": 1}, + "pow": {"intrinsic": "llvm.pow", "num_operands": 2}, + "fabs": {"intrinsic": "llvm.fabs", "num_operands": 1}, + "fma": {"intrinsic": "llvm.fma", "num_operands": 3}, + "maxnum": {"intrinsic": "llvm.maxnum", "num_operands": 2}, + "minnum": {"intrinsic": "llvm.minnum", "num_operands": 2}, + "ceil": {"intrinsic": "llvm.ceil", "num_operands": 1}, + "floor": {"intrinsic": "llvm.floor", "num_operands": 1}, + "exp2": {"intrinsic": "llvm.exp2", "num_operands": 1}, + "log10": {"intrinsic": "llvm.log10", "num_operands": 1}, + "log2": {"intrinsic": "llvm.log2", "num_operands": 1}, + "rint": {"intrinsic": "llvm.rint", "num_operands": 1}, + "round": {"intrinsic": "llvm.round", "num_operands": 1}, + "trunc": {"intrinsic": "llvm.trunc", "num_operands": 1}, + "copysign": {"intrinsic": "llvm.copysign", "num_operands": 2}, + "fdim": {"intrinsic": None, "num_operands": 2}, + "fmod": {"intrinsic": None, "num_operands": 2}, + "asin": {"intrinsic": None, "num_operands": 1}, + "acos": {"intrinsic": None, "num_operands": 1}, + "atan": {"intrinsic": None, "num_operands": 1}, + "atan2": {"intrinsic": None, "num_operands": 2}, + "sinh": {"intrinsic": None, "num_operands": 1}, + "cosh": {"intrinsic": None, "num_operands": 1}, + "tanh": {"intrinsic": None, "num_operands": 1}, + "asinh": {"intrinsic": None, "num_operands": 1}, + "acosh": {"intrinsic": None, "num_operands": 1}, + "atanh": {"intrinsic": None, "num_operands": 1}, + "hypot": {"intrinsic": None, "num_operands": 2}, + "erf": {"intrinsic": None, "num_operands": 1}, + "lgamma": {"intrinsic": None, "num_operands": 1}, + "tgamma": {"intrinsic": None, "num_operands": 1}, + "remainder": {"intrinsic": None, "num_operands": 2}, + "powi": {"intrinsic": "llvm.powi", "num_operands": 2}, +} + + +def generate_loop_code(llvm_type, iterations, body_instructions, final_acc_reg): + zero_literal = get_zero_literal(llvm_type) + code = f""" +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + %acc = alloca {llvm_type} + store i32 0, i32* %i + store {llvm_type} {zero_literal}, {llvm_type}* %acc + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: + %acc_val0 = load {llvm_type}, {llvm_type}* %acc +{body_instructions} + store {llvm_type} {final_acc_reg}, {llvm_type}* %acc + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + %final_acc = load {llvm_type}, {llvm_type}* %acc + call void @use({llvm_type} %final_acc) + ret i32 0 +}} + +define void @use({llvm_type} %val) {{ + ret void +}} +""" + return code + + +def generate_arithmetic_op_code(op_key, precision, iterations): + """Generate LLVM IR for a basic arithmetic operator (or fneg) based on OP_INFO.""" + op_info = OP_INFO[op_key] + llvm_type = precision_to_llvm_type[precision] + body_lines = "" + for idx in range(unrolled): + operands = [] + for _ in range(op_info["num_operands"]): + f_val = generate_random_fp(precision) + operands.append(float_to_llvm_hex(f_val, precision)) + + if op_info["num_operands"] == 1: + line = f" %result{idx} = {op_info['llvm_instr']} {FAST_MATH_FLAG} {llvm_type} {operands[0]}" + body_lines += line + "\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" + elif op_info["num_operands"] == 2: + line = f" %result{idx} = {op_info['llvm_instr']} {FAST_MATH_FLAG} {llvm_type} {operands[0]}, {operands[1]}" + body_lines += line + "\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" + + final_acc = f"%acc_val{unrolled}" + return generate_loop_code(llvm_type, iterations, body_lines, final_acc) + + +def generate_compare_op_code(precision, iterations): + """Generate LLVM IR for an fcmp (comparison) operation.""" + llvm_type = precision_to_llvm_type[precision] + + body_lines = "" + for idx in range(unrolled): + f_a = generate_random_fp(precision) + f_b = generate_random_fp(precision) + a_hex = float_to_llvm_hex(f_a, precision) + b_hex = float_to_llvm_hex(f_b, precision) + line = f" %cmp{idx} = fcmp {FAST_MATH_FLAG} olt {llvm_type} {a_hex}, {b_hex}" + body_lines += line + "\n" + body_lines += f" %cmp_int{idx} = zext i1 %cmp{idx} to i32\n" + + code = f""" +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + %acc = alloca i32 + store i32 0, i32* %i + store i32 0, i32* %acc + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: + %acc_val0 = load i32, i32* %acc +{body_lines} +""" + + for idx in range(unrolled): + code += f" %acc_val{idx+1} = add i32 %acc_val{idx}, %cmp_int{idx}\n" + + final_acc = f"%acc_val{unrolled}" + code += f""" + store i32 {final_acc}, i32* %acc + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + %final_acc = load i32, i32* %acc + call void @use_i32(i32 %final_acc) + ret i32 0 +}} + +define void @use_i32(i32 %val) {{ + ret void +}} +""" + return code + + +def generate_cast_op_code(op_key, src_precision, dst_precision, iterations): + """Generate LLVM IR for a cast operation (fptrunc or fpext).""" + op_info = OP_INFO[op_key] + src_type = precision_to_llvm_type[src_precision] + dst_type = precision_to_llvm_type[dst_precision] + zero_literal = get_zero_literal(dst_precision) + + body_lines = "" + for idx in range(unrolled): + f_val = generate_random_fp(src_precision) + hex_val = float_to_llvm_hex(f_val, src_precision) + line = f" %result{idx} = {op_info['llvm_instr']} {src_type} {hex_val} to {dst_type}" + body_lines += line + "\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {dst_type} %acc_val{idx}, %result{idx}\n" + + final_acc = f"%acc_val{unrolled}" + code = f""" +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + %acc = alloca {dst_type} + store i32 0, i32* %i + store {dst_type} {zero_literal}, {dst_type}* %acc + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: + %acc_val0 = load {dst_type}, {dst_type}* %acc +{body_lines} + store {dst_type} {final_acc}, {dst_type}* %acc + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + %final_acc = load {dst_type}, {dst_type}* %acc + call void @use({dst_type} %final_acc) + ret i32 0 +}} + +define void @use({dst_type} %val) {{ + ret void +}} +""" + return code + + +def generate_function_call_code(func_name, precision, iterations): + """Generate LLVM IR for a function call based on FUNC_INFO.""" + + func_info = FUNC_INFO[func_name] + llvm_type = precision_to_llvm_type[precision] + intrinsic_suffix = precision_to_intrinsic_suffix.get(precision, "") + + if func_info["intrinsic"]: + fn = f"{func_info['intrinsic']}.{intrinsic_suffix}" + else: + fn = func_name + + num_operands = func_info["num_operands"] + + body_lines = " %acc_val0 = load " + llvm_type + ", " + llvm_type + "* %acc\n" + + for idx in range(unrolled): + if func_name == "powi": + f_val = generate_random_fp(precision) + i_val = random.randint(-10, 10) + f_hex = float_to_llvm_hex(f_val, precision) + line = f" %result{idx} = call {FAST_MATH_FLAG} {llvm_type} @{fn}(" f"{llvm_type} {f_hex}, i32 {i_val})" + body_lines += line + "\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" + else: + operands = [] + for _ in range(num_operands): + f_val = generate_random_fp(precision) + operands.append(float_to_llvm_hex(f_val, precision)) + + if num_operands == 1: + call_str = f"call {FAST_MATH_FLAG} {llvm_type} @{fn}({llvm_type} {operands[0]})" + elif num_operands == 2: + call_str = ( + f"call {FAST_MATH_FLAG} {llvm_type} @{fn}({llvm_type} {operands[0]}, {llvm_type} {operands[1]})" + ) + elif num_operands == 3: + call_str = f"call {FAST_MATH_FLAG} {llvm_type} @{fn}({llvm_type} {operands[0]}, {llvm_type} {operands[1]}, {llvm_type} {operands[2]})" + else: + call_str = "" + + body_lines += f" %result{idx} = {call_str}\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" + + if func_name == "powi": + decl = f"declare {llvm_type} @{fn}({llvm_type}, i32)" + else: + arg_types = ", ".join([llvm_type] * num_operands) + decl = f"declare {llvm_type} @{fn}({arg_types})" + + code = f""" +{decl} +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + %acc = alloca {llvm_type} + store i32 0, i32* %i + store {llvm_type} {get_zero_literal(precision)}, {llvm_type}* %acc + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: +{body_lines} + store {llvm_type} %acc_val{unrolled}, {llvm_type}* %acc + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + %final_acc = load {llvm_type}, {llvm_type}* %acc + call void @use({llvm_type} %final_acc) + ret i32 0 +}} + +define void @use({llvm_type} %val) {{ + ret void +}} +""" + return code + + +def generate_baseline_code(iterations): + return f""" +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + store i32 0, i32* %i + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + ret i32 0 +}} +""" + + +def create_execution_engine(): + target = llvm.Target.from_default_triple() + target_machine = target.create_target_machine() + mod = llvm.parse_assembly("") + engine = llvm.create_mcjit_compiler(mod, target_machine) + return engine + + +def run_llvm_ir_jit(llvm_ir): + engine = create_execution_engine() + mod = llvm.parse_assembly(llvm_ir) + mod.verify() + engine.add_module(mod) + engine.finalize_object() + + func_ptr = engine.get_function_address("main") + cfunc = ctypes.CFUNCTYPE(ctypes.c_int)(func_ptr) + + start = time.perf_counter() + retval = cfunc() + end = time.perf_counter() + + return (end - start), retval + + +if __name__ == "__main__": + csv_file = "results.csv" + with open(csv_file, "w", newline="") as csvfile: + fieldnames = ["instruction", "precision", "cost"] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + llvm_code = generate_baseline_code(iterations) + baseline_time, _ = run_llvm_ir_jit(llvm_code) + + for precision in precisions: + for instr in OP_INFO: + op_kind = OP_INFO[instr]["kind"] + if op_kind == "cast": + src_precision = precision + src_rank = precision_ranks.get(src_precision) + if src_rank is None: + continue + + if instr == "fptrunc": + dst_precisions = [ + p for p in precisions_ordered if p in precisions and precision_ranks[p] < src_rank + ] + else: + dst_precisions = [ + p for p in precisions_ordered if p in precisions and precision_ranks[p] > src_rank + ] + + for dst_precision in dst_precisions: + if (src_precision, dst_precision) in [ + ("half", "bf16"), + ("bf16", "half"), + ]: + continue + code = generate_cast_op_code(instr, src_precision, dst_precision, iterations) + name = f"{instr}_{src_precision}_to_{dst_precision}" + elapsed, _ = run_llvm_ir_jit(code) + adjusted = (elapsed - baseline_time) * AMPLIFIER + writer.writerow( + { + "instruction": name, + "precision": src_precision, + "cost": int(adjusted), + } + ) + else: + if op_kind == "arithmetic": + code = generate_arithmetic_op_code(instr, precision, iterations) + elif op_kind == "compare": + code = generate_compare_op_code(precision, iterations) + else: + code = "" + if code.strip(): + elapsed, _ = run_llvm_ir_jit(code) + adjusted = (elapsed - baseline_time) * AMPLIFIER + writer.writerow( + { + "instruction": instr, + "precision": precision, + "cost": int(adjusted), + } + ) + + for func in FUNC_INFO: + code = generate_function_call_code(func, precision, iterations) + if code.strip(): + elapsed, _ = run_llvm_ir_jit(code) + adjusted = (elapsed - baseline_time) * AMPLIFIER + writer.writerow( + { + "instruction": func, + "precision": precision, + "cost": int(adjusted), + } + ) + + print(f"Results in '{csv_file}'. Baseline: {baseline_time:.6f}s")