diff --git a/test/test_ops.py b/test/test_ops.py index 8f8756fe675..0a412d6930b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -474,6 +474,33 @@ def expected_fn( def test_boxes_shape(self): self._helper_boxes_shape(ops.roi_align) + @needs_mps + @pytest.mark.parametrize("seed", range(3)) + def test_backward_mps_consistency(self, seed): + # Regression test for over-accumulation in the MPS roi_align backward + # kernel. The kernel used a grid-stride loop dispatched over multiple + # threadgroups, so each pooled-output element's gradient was scattered + # once per threadgroup. The error is invisible for a handful of RoIs + # (output fits in one threadgroup) but grows with the RoI count, so use + # enough RoIs for the output to span several threadgroups. + torch.random.manual_seed(seed) + pool_size = 5 + num_rois = 100 + x = torch.rand(1, 2 * (pool_size**2), 40, 40, dtype=torch.float32) + rois = torch.empty(num_rois, 5) + rois[:, 0] = 0 + xy = torch.rand(num_rois, 2) * 20 + wh = torch.rand(num_rois, 2) * 20 + rois[:, 1:3] = xy + rois[:, 3:5] = xy + wh + 1 # ensure x2 > x1 and y2 > y1 + + def grad_on(device): + xd = x.to(device).detach().requires_grad_(True) + self.fn(xd, rois.to(device), pool_size, pool_size, spatial_scale=1, sampling_ratio=2).sum().backward() + return xd.grad.cpu() + + torch.testing.assert_close(grad_on("mps"), grad_on("cpu"), rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64)) # , ids=str) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index 90c0aee9eb7..e06235fe284 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -486,11 +486,15 @@ kernel void roi_align_backward( constant int64_t & c_stride [[buffer(13)]], constant int64_t & h_stride [[buffer(14)]], constant int64_t & w_stride [[buffer(15)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ + uint index [[thread_position_in_grid]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // One thread per pooled-output element. Dispatched via dispatchThreads, so + // each element is processed exactly once; redundant passes would otherwise + // be silently summed by the atomic_add below (overlapping RoIs share pixels). + if (index >= static_cast(output_size)) { + return; + } + { // (n, c, ph, pw) is an element in the pooled output integer_t pw = index % pooled_width; integer_t ph = (index / pooled_width) % pooled_height; @@ -602,9 +606,7 @@ kernel void roi_align_backward( \ constant int64_t & c_stride [[buffer(13)]], \ constant int64_t & h_stride [[buffer(14)]], \ constant int64_t & w_stride [[buffer(15)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); + uint index [[thread_position_in_grid]]); template kernel void roi_pool( diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index 6cf07761060..918f69a5ffa 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -132,10 +132,6 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -163,14 +159,13 @@ [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + // One thread per pooled-output element. dispatchThreads guarantees each + // index is handled exactly once; the kernel scatters into grad_input with + // atomic_add for overlapping RoIs. + MTLSize threadsPerGrid = MTLSizeMake(output_size, 1, 1); + NSUInteger tgSize = std::min(static_cast(visionPSO.maxTotalThreadsPerThreadgroup), output_size); + MTLSize threadGroupSize = MTLSizeMake(std::max(tgSize, 1), 1, 1); + [computeEncoder dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadGroupSize]; getMPSProfiler().endProfileKernel(visionPSO); }