Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2821,6 +2821,34 @@ void TypeAnalyzer::visitShuffleVectorInst(ShuffleVectorInst &I) {
}
}

// Returns the floating-point leaf type if every scalar leaf of T is the same
// FP type; otherwise nullptr. Used by visitExtractValueInst to seed type info
// for extractvalue results whose source aggregate has no prior type info
// (e.g. an opaque external call return) — see EnzymeAD/Enzyme#2630, #2463.
static llvm::Type *uniformFPLeafType(llvm::Type *T) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't modify typeanalysis itself for a loosetypeanalysis change, but the user instructino, if no type was found

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the aggregate fallback out of TypeAnalysis

if (T->isFloatingPointTy())
return T;
if (auto VT = llvm::dyn_cast<llvm::VectorType>(T))
return uniformFPLeafType(VT->getElementType());
if (auto AT = llvm::dyn_cast<llvm::ArrayType>(T)) {
if (AT->getNumElements() == 0)
return nullptr;
return uniformFPLeafType(AT->getElementType());
}
if (auto ST = llvm::dyn_cast<llvm::StructType>(T)) {
if (ST->getNumElements() == 0)
return nullptr;
auto first = uniformFPLeafType(ST->getElementType(0));
if (!first)
return nullptr;
for (unsigned i = 1; i < ST->getNumElements(); ++i)
if (uniformFPLeafType(ST->getElementType(i)) != first)
return nullptr;
return first;
}
return nullptr;
}

void TypeAnalyzer::visitExtractValueInst(ExtractValueInst &I) {
auto &dl = fntypeinfo.Function->getParent()->getDataLayout();
SmallVector<Value *, 4> vec;
Expand All @@ -2839,6 +2867,16 @@ void TypeAnalyzer::visitExtractValueInst(ExtractValueInst &I) {
int off = (int)ai.getLimitedValue();
int size = dl.getTypeSizeInBits(I.getType()) / 8;

// Seed the result from the LLVM type when it is a uniform-FP aggregate
// (or a single FP). extractvalue is fully type-checked, so the LLVM type
// is authoritative about leaf float types — there is no type-punning
// possible. UP propagation then pushes the type back into the source
// aggregate, recovering type info even when the source originated from an
// opaque external call (#2630, #2463).
if (auto LeafTy = uniformFPLeafType(I.getType())) {
updateAnalysis(&I, TypeTree(ConcreteType(LeafTy)).Only(-1, &I), &I);
}

if (direction & DOWN)
updateAnalysis(&I,
getAnalysis(I.getOperand(0))
Expand Down
71 changes: 71 additions & 0 deletions enzyme/test/TypeAnalysis/extractvalue_aggregate.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=compute -o /dev/null | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=compute -S -o /dev/null | FileCheck %s

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test isn't necessarily helpful here because by default float input types will be assumed to be the correct float

@kimjune01 kimjune01 May 11, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that float types from the struct are already propagated. The end-to-end test (ENZYME check) exercises the full pipeline path through AdjointGenerator, which is where the original crash happened (#2630). The WARN check might be redundant if the seeding fires after the type is already known from struct propagation. Happy to drop the WARN check and keep just the end-to-end test if that's cleaner.

; Regression test for EnzymeAD/Enzyme#2630 (and #2463): TypeAnalysis must
; recover float type info for extractvalue results whose source aggregate
; comes from an opaque external call (no body to interprocedurally analyse).
; Pre-fix, %a/%b/%a0.../%b12 all had empty type info ("{}"), and the
; AdjointGenerator emitted "Cannot deduce type of extract" when reverse-mode
; AD tried to consume them. Post-fix, the LLVM type of each extractvalue
; result (uniform-FP aggregate or single FP) seeds Float@float into the
; result's TypeTree, which UP propagation then back-fills onto the source.

target triple = "x86_64-unknown-linux-gnu"

%struct.CV = type { [2 x float], [2 x [3 x float]] }

declare %struct.CV @pre_work(float)

define float @compute(float %seed) {
entry:
%r = call %struct.CV @pre_work(float %seed)
%a = extractvalue %struct.CV %r, 0
%b = extractvalue %struct.CV %r, 1
%a0 = extractvalue [2 x float] %a, 0
%a1 = extractvalue [2 x float] %a, 1
%b00 = extractvalue [2 x [3 x float]] %b, 0, 0
%b01 = extractvalue [2 x [3 x float]] %b, 0, 1
%b02 = extractvalue [2 x [3 x float]] %b, 0, 2
%b10 = extractvalue [2 x [3 x float]] %b, 1, 0
%b11 = extractvalue [2 x [3 x float]] %b, 1, 1
%b12 = extractvalue [2 x [3 x float]] %b, 1, 2
%m = fmul float %a0, %seed
ret float %m
}

; CHECK: compute - {[-1]:Float@float} |{[-1]:Float@float}:{}
; CHECK-NEXT: float %seed: {[-1]:Float@float}
; CHECK-NEXT: entry
; CHECK-NEXT: %r = call %struct.CV @pre_work(float %seed): {[0]:Float@float, [4]:Float@float, [8]:Float@float, [12]:Float@float, [16]:Float@float, [20]:Float@float, [24]:Float@float, [28]:Float@float}
; CHECK-NEXT: %a = extractvalue %struct.CV %r, 0: {[-1]:Float@float}
; CHECK-NEXT: %b = extractvalue %struct.CV %r, 1: {[-1]:Float@float}
; CHECK-NEXT: %a0 = extractvalue [2 x float] %a, 0: {[-1]:Float@float}
; CHECK-NEXT: %a1 = extractvalue [2 x float] %a, 1: {[-1]:Float@float}
; CHECK-NEXT: %b00 = extractvalue [2 x [3 x float]] %b, 0, 0: {[-1]:Float@float}
; CHECK-NEXT: %b01 = extractvalue [2 x [3 x float]] %b, 0, 1: {[-1]:Float@float}
; CHECK-NEXT: %b02 = extractvalue [2 x [3 x float]] %b, 0, 2: {[-1]:Float@float}
; CHECK-NEXT: %b10 = extractvalue [2 x [3 x float]] %b, 1, 0: {[-1]:Float@float}
; CHECK-NEXT: %b11 = extractvalue [2 x [3 x float]] %b, 1, 1: {[-1]:Float@float}
; CHECK-NEXT: %b12 = extractvalue [2 x [3 x float]] %b, 1, 2: {[-1]:Float@float}
; CHECK-NEXT: %m = fmul float %a0, %seed: {[-1]:Float@float}
; CHECK-NEXT: ret float %m: {}

; Negative test: mixed aggregate (float + i32) must NOT seed Float on the i32 field.
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=compute_mixed -o /dev/null | FileCheck %s --check-prefix=MIXED; fi
; RUN: %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=compute_mixed -S -o /dev/null | FileCheck %s --check-prefix=MIXED

%struct.Mixed = type { float, i32 }

declare %struct.Mixed @opaque_mixed(float)

define float @compute_mixed(float %seed) {
entry:
%r = call %struct.Mixed @opaque_mixed(float %seed)
%f = extractvalue %struct.Mixed %r, 0
%i = extractvalue %struct.Mixed %r, 1
ret float %f
}

; MIXED: compute_mixed
; MIXED: %f = extractvalue %struct.Mixed %r, 0: {[-1]:Float@float}
; MIXED-NOT: %i = extractvalue %struct.Mixed %r, 1: {[-1]:Float@float}