Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,33 @@ def expected_fn(
def test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)

@needs_mps
@pytest.mark.parametrize("seed", range(3))
def test_backward_mps_consistency(self, seed):
# Regression test for over-accumulation in the MPS roi_align backward
# kernel. The kernel used a grid-stride loop dispatched over multiple
# threadgroups, so each pooled-output element's gradient was scattered
# once per threadgroup. The error is invisible for a handful of RoIs
# (output fits in one threadgroup) but grows with the RoI count, so use
# enough RoIs for the output to span several threadgroups.
torch.random.manual_seed(seed)
pool_size = 5
num_rois = 100
x = torch.rand(1, 2 * (pool_size**2), 40, 40, dtype=torch.float32)
rois = torch.empty(num_rois, 5)
rois[:, 0] = 0
xy = torch.rand(num_rois, 2) * 20
wh = torch.rand(num_rois, 2) * 20
rois[:, 1:3] = xy
rois[:, 3:5] = xy + wh + 1 # ensure x2 > x1 and y2 > y1

def grad_on(device):
xd = x.to(device).detach().requires_grad_(True)
self.fn(xd, rois.to(device), pool_size, pool_size, spatial_scale=1, sampling_ratio=2).sum().backward()
return xd.grad.cpu()

torch.testing.assert_close(grad_on("mps"), grad_on("cpu"), rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64)) # , ids=str)
Expand Down
16 changes: 9 additions & 7 deletions torchvision/csrc/ops/mps/mps_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,15 @@ kernel void roi_align_backward(
constant int64_t & c_stride [[buffer(13)]],
constant int64_t & h_stride [[buffer(14)]],
constant int64_t & w_stride [[buffer(15)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
uint index [[thread_position_in_grid]]){

MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// One thread per pooled-output element. Dispatched via dispatchThreads, so
// each element is processed exactly once; redundant passes would otherwise
// be silently summed by the atomic_add below (overlapping RoIs share pixels).
if (index >= static_cast<uint>(output_size)) {
return;
}
{
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
Expand Down Expand Up @@ -602,9 +606,7 @@ kernel void roi_align_backward<DTYPE, INT_DTYPE>( \
constant int64_t & c_stride [[buffer(13)]], \
constant int64_t & h_stride [[buffer(14)]], \
constant int64_t & w_stride [[buffer(15)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
uint index [[thread_position_in_grid]]);

template<typename T, typename integer_t>
kernel void roi_pool(
Expand Down
19 changes: 7 additions & 12 deletions torchvision/csrc/ops/mps/roi_align_kernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,6 @@
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);

const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
Expand Down Expand Up @@ -163,14 +159,13 @@
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];

// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}

MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
// One thread per pooled-output element. dispatchThreads guarantees each
// index is handled exactly once; the kernel scatters into grad_input with
// atomic_add for overlapping RoIs.
MTLSize threadsPerGrid = MTLSizeMake(output_size, 1, 1);
NSUInteger tgSize = std::min(static_cast<int64_t>(visionPSO.maxTotalThreadsPerThreadgroup), output_size);
MTLSize threadGroupSize = MTLSizeMake(std::max<NSUInteger>(tgSize, 1), 1, 1);
[computeEncoder dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadGroupSize];

getMPSProfiler().endProfileKernel(visionPSO);
}
Expand Down
Loading