Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,38 @@

#define DEBUG_TYPE "enzyme"

// Walk an LLVM type and, when every leaf is the same floating-point
// type, return that leaf. Returns nullptr otherwise (mixed types,
// non-FP, empty aggregate). Used by the looseTypeAnalysis fallback in
// visitExtractValueInst to recover type info for aggregates returned
// by opaque external calls (#2630, #2463). Lives at the consumer site
// (this header) rather than inside TypeAnalysis itself because
// looseTypeAnalysis is a consumer-policy override, not a dataflow rule
// — TypeAnalysis core stays a pure analysis.
static inline llvm::Type *uniformFPLeafType(llvm::Type *T) {
if (T->isFloatingPointTy())
return T;
if (auto VT = llvm::dyn_cast<llvm::VectorType>(T))
return uniformFPLeafType(VT->getElementType());
if (auto AT = llvm::dyn_cast<llvm::ArrayType>(T)) {
if (AT->getNumElements() == 0)
return nullptr;
return uniformFPLeafType(AT->getElementType());
}
if (auto ST = llvm::dyn_cast<llvm::StructType>(T)) {
if (ST->getNumElements() == 0)
return nullptr;
auto first = uniformFPLeafType(ST->getElementType(0));
if (!first)
return nullptr;
for (unsigned i = 1; i < ST->getNumElements(); ++i)
if (uniformFPLeafType(ST->getElementType(i)) != first)
return nullptr;
return first;
}
return nullptr;
}

// Helper instruction visitor that generates adjoints
class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
private:
Expand Down Expand Up @@ -1898,6 +1930,16 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
if (EVI.getType()->isFPOrFPVectorTy()) {
dt = ConcreteType(EVI.getType()->getScalarType());
found = true;
} else if (auto *LeafTy = uniformFPLeafType(EVI.getType())) {
// Aggregate-of-uniform-FP (e.g. [2 x float],
// [N x [M x float]]): seed from the leaf FP type.
// Pre-fix this branch only handled primitive FP/vector
// and fell through to EmitNoTypeError for aggregates
// (#2630, #2463). Per maintainer review (wsmoses on
// PR #2819), the right place for this fallback is here
// — at the consumer — not inside TypeAnalysis itself.
dt = ConcreteType(LeafTy);
found = true;
} else if (EVI.getType()->isIntOrIntVectorTy() ||
EVI.getType()->isPointerTy()) {
dt = BaseType::Integer;
Expand Down
49 changes: 49 additions & 0 deletions enzyme/test/TypeAnalysis/extractvalue_aggregate.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
; End-to-end check: with loose-types, the full enzyme pass succeeds (no
; "Cannot deduce type of extract" diagnostic) and emits a reverse-mode
; entry point. Pre-fix this run failed because the extracted [2 x float]
; aggregates had empty TypeTrees; the AdjointGenerator's loose-types
; fallback (AdjointGenerator.h:1894) only handled primitive FP/int and
; couldn't fill aggregate types. Fix extends that fallback to walk
; aggregate types via uniformFPLeafType (same file scope).
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-loose-types -enzyme-type-warning=0 -enzyme-preopt=false --enzyme-assume-unknown-nofree=1 -S | FileCheck %s --check-prefix=ENZYME

; Regression test for EnzymeAD/Enzyme#2630 (and #2463): with looseTypeAnalysis,
; the AdjointGenerator's loose-types fallback walks aggregate types so
; extractvalue results whose source aggregate comes from an opaque external
; call get seeded with the leaf FP type.

target triple = "x86_64-unknown-linux-gnu"

%struct.CV = type { [2 x float], [2 x [3 x float]] }

declare %struct.CV @pre_work(i64) #0
declare void @opaque_sink(float, float) #0
declare float @__enzyme_autodiff(...)

attributes #0 = { "enzyme_inactive" }

; Active reverse-mode entry. Mirrors #2630's pattern: an opaque struct
; return whose elements feed an active fmul. Without the fix the enzyme
; pass aborts with "Cannot deduce type of extract" on the [2 x float]
; aggregates extracted from %r.
define float @ad_compute(float %seed) {
entry:
%r = call %struct.CV @pre_work(i64 0)
%a = extractvalue %struct.CV %r, 0
%b = extractvalue %struct.CV %r, 1
%a0 = extractvalue [2 x float] %a, 0
%a1 = extractvalue [2 x float] %a, 1
%b00 = extractvalue [2 x [3 x float]] %b, 0, 0
call void @opaque_sink(float %a0, float %a1)
%m1 = fmul float %a0, %seed
%m2 = fmul float %m1, %b00
ret float %m2
}

define float @caller(float %seed) {
%d = call float (...) @__enzyme_autodiff(float (float)* @ad_compute, float %seed)
ret float %d
}

; ENZYME: define internal {{.*}} @diffead_compute
; ENZYME-NOT: Cannot deduce type of extract
Loading