Fix non-functional jitter in Warp.get_reference_grid#8953
Conversation
Warp.get_reference_grid built the grid from torch.arange (integer dtype), assigned self.ref_grid = grid.to(ddf) before the jitter block, then jittered the original integer grid in place. Three defects resulted: the jitter was applied to a dead local and never returned; torch.rand_like on the Long grid raised NotImplementedError, so jitter=True crashed outright; and fork_rng(enabled=seed) disabled RNG forking whenever seed defaulted to 0, leaking the seeded state into the global RNG. Cast the grid to ddf first, jitter that float tensor, assign it to self.ref_grid after jittering, and use fork_rng() so the seeded draw is isolated from the global RNG. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Asserts the jittered grid is floating point with non-integer values, the unjittered grid stays integer valued, and jitter is reproducible per seed. Fails before the fix with NotImplementedError on torch.rand_like. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
📝 WalkthroughWalkthrough
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/blocks/warp.py (1)
115-120: 🎯 Functional Correctness | 🟠 MajorCache ignores
jitter/seedparameters.The early return in
get_reference_grid(lines 115-120) returns the cachedself.ref_gridsolely based on shape. IfWarpinstance is reused with differentjitterorseedsettings, stale grids are returned.Internal call at line 152 passes
jitter=self.jitterbut ignoresseed. The cache key lacks these parameters, rendering them ineffective on subsequent calls with mismatched settings.if ( self.ref_grid is not None and self.ref_grid.shape[0] == ddf.shape[0] and self.ref_grid.shape[1:] == ddf.shape[2:] ): return self.ref_grid # type: ignore🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/blocks/warp.py` around lines 115 - 120, The cache in get_reference_grid only checks tensor shape, so it can return a stale self.ref_grid when Warp is reused with different jitter or seed values. Update the cache validation so it also accounts for the jitter/seed inputs used to build the grid, and make sure the call site in Warp that passes self.jitter also propagates seed consistently. If the stored grid was created with different settings, regenerate it instead of returning the cached value.
🧹 Nitpick comments (1)
tests/networks/blocks/warp/test_warp.py (1)
141-154: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueGood coverage. Float dtype, non-integer jitter, integer-aligned baseline, and seed determinism all checked.
Optional: add a Google-style docstring; path instructions ask for docstrings on all definitions.
As per path instructions: "Docstrings should be present for all definition".
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/blocks/warp/test_warp.py` around lines 141 - 154, The test coverage in test_jitter is solid, but the new path-level requirement says every definition should have a Google-style docstring. Add a concise docstring to the test_jitter method in the Warp test class, describing the jitter behavior it verifies and keeping it consistent with the existing test naming and structure.Source: Path instructions
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@monai/networks/blocks/warp.py`:
- Around line 115-120: The cache in get_reference_grid only checks tensor shape,
so it can return a stale self.ref_grid when Warp is reused with different jitter
or seed values. Update the cache validation so it also accounts for the
jitter/seed inputs used to build the grid, and make sure the call site in Warp
that passes self.jitter also propagates seed consistently. If the stored grid
was created with different settings, regenerate it instead of returning the
cached value.
---
Nitpick comments:
In `@tests/networks/blocks/warp/test_warp.py`:
- Around line 141-154: The test coverage in test_jitter is solid, but the new
path-level requirement says every definition should have a Google-style
docstring. Add a concise docstring to the test_jitter method in the Warp test
class, describing the jitter behavior it verifies and keeping it consistent with
the existing test naming and structure.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 5f4e987b-6ff9-4a26-b3f3-eafcff5b7b94
📒 Files selected for processing (2)
monai/networks/blocks/warp.pytests/networks/blocks/warp/test_warp.py
Description
Little messy description as trying to fix multiple things but the jist is:
Warp.get_reference_gridnever applied thejitterit advertises and crashed wheneverjitter=True.The grid is built from
torch.arange(integer dtype) andself.ref_gridwas assignedgrid.to(ddf)before the jitter block, sogrid += torch.rand_like(grid)mutated a local that was never returned, andtorch.rand_likeraisesNotImplementedErroron an integer tensor anyway. Separately,fork_rng(enabled=seed)disabled RNG forking whenseedtook its default of0, leaking the seeded state into the global RNG.The grid is now cast to
ddfbefore jittering, the jittered tensor is assigned toself.ref_grid, andfork_rng()isolates the seeded draw.The non-jitter path is unchanged. A regression test covers the float/non-integer jittered grid, the integer un-jittered grid, and per-seed reproducibility; it fails before this change with
NotImplementedError.Types of changes