[MPS] Fix roi_align backward gradient over-accumulation#9510
Conversation
The backward kernel ran a grid-stride loop dispatched over multiple threadgroups but strided by only one threadgroup's width, so every threadgroup re-processed overlapping output elements and atomic-added each gradient repeatedly. This is harmless for the forward (idempotent writes, fixed for perf in #9100) but the backward over-accumulated input gradients by a factor that grows with the RoI count (~100x at 512 RoIs), making detection training diverge to NaN on MPS. Fix mirrors #9100: one thread per pooled-output element via dispatchThreads. Test Plan: python -m pytest test/test_ops.py -k "test_backward_mps_consistency" Authored with assistance from Claude.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9510
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 6ff0eeb with merge base 85a1dcc ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Thanks @malfet , this LGTM, I'll merge once the CI is kind enough to be green (macos jobs are failing on a new conda issue right now, I'm looking into it)
|
Hi @malfet! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
The backward kernel ran a grid-stride loop dispatched over multiple threadgroups but strided by only one threadgroup's width, so every threadgroup re-processed overlapping output elements and atomic-added each gradient repeatedly.
This is harmless for the forward (idempotent writes, fixed for perf in #9100) but the backward over-accumulated input gradients by a factor that grows with the RoI count (~100x at 512 RoIs), making detection training diverge to NaN on MPS.
Fix mirrors #9100: one thread per pooled-output element via dispatchThreads.
Test Plan:
python -m pytest test/test_ops.py -k "test_backward_mps_consistency"
Authored with assistance from Claude.