mps-conv3d 0.1.5__cp314-cp314-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
mps_conv3d/__init__.py ADDED
@@ -0,0 +1,259 @@
1
+ """
2
+ MPS Conv3D - 3D Convolution for Apple Silicon
3
+
4
+ Drop-in replacement for torch.nn.functional.conv3d on MPS.
5
+ Used in video models: Synchformer, I3D, SlowFast, C3D, etc.
6
+ """
7
+
8
+ import torch
9
+ from torch import nn, Tensor
10
+ from torch.autograd import Function
11
+ from typing import Optional, Tuple, Union
12
+ import torch.nn.functional as F
13
+
14
+ __version__ = "0.1.5"
15
+
16
+
17
+ def _load_library():
18
+ """Load the native extension."""
19
+ try:
20
+ from mps_conv3d import _C as _lib
21
+ return _lib
22
+ except ImportError:
23
+ import os
24
+ from torch.utils.cpp_extension import load
25
+
26
+ src_dir = os.path.join(os.path.dirname(__file__), "csrc")
27
+ _lib = load(
28
+ name="mps_conv3d",
29
+ sources=[os.path.join(src_dir, "conv3d_mps.mm")],
30
+ extra_cflags=["-std=c++17"],
31
+ extra_ldflags=["-framework", "Metal", "-framework", "Foundation"],
32
+ verbose=False,
33
+ )
34
+ return _lib
35
+
36
+
37
+ _lib_cache = None
38
+
39
+
40
+ def _get_lib():
41
+ global _lib_cache
42
+ if _lib_cache is None:
43
+ _lib_cache = _load_library()
44
+ return _lib_cache
45
+
46
+
47
+ class _Conv3DFunction(Function):
48
+ """Autograd function for Conv3D."""
49
+
50
+ @staticmethod
51
+ def forward(
52
+ ctx,
53
+ input: Tensor,
54
+ weight: Tensor,
55
+ bias: Optional[Tensor],
56
+ stride: Tuple[int, int, int],
57
+ padding: Tuple[int, int, int],
58
+ dilation: Tuple[int, int, int],
59
+ groups: int,
60
+ ) -> Tensor:
61
+ ctx.save_for_backward(input, weight, bias)
62
+ ctx.stride = stride
63
+ ctx.padding = padding
64
+ ctx.dilation = dilation
65
+ ctx.groups = groups
66
+
67
+ lib = _get_lib()
68
+ output = lib.conv3d_forward(
69
+ input, weight,
70
+ stride[0], stride[1], stride[2],
71
+ padding[0], padding[1], padding[2],
72
+ dilation[0], dilation[1], dilation[2],
73
+ groups
74
+ )
75
+
76
+ if bias is not None:
77
+ output = output + bias.view(1, -1, 1, 1, 1)
78
+
79
+ return output
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output: Tensor):
83
+ input, weight, bias = ctx.saved_tensors
84
+ lib = _get_lib()
85
+
86
+ grad_input = grad_weight = grad_bias = None
87
+
88
+ if ctx.needs_input_grad[0]:
89
+ grad_input = lib.conv3d_backward_input(
90
+ grad_output, weight, input.shape,
91
+ ctx.stride[0], ctx.stride[1], ctx.stride[2],
92
+ ctx.padding[0], ctx.padding[1], ctx.padding[2],
93
+ ctx.dilation[0], ctx.dilation[1], ctx.dilation[2],
94
+ ctx.groups
95
+ )
96
+
97
+ if ctx.needs_input_grad[1]:
98
+ grad_weight = lib.conv3d_backward_weight(
99
+ grad_output, input, weight.shape,
100
+ ctx.stride[0], ctx.stride[1], ctx.stride[2],
101
+ ctx.padding[0], ctx.padding[1], ctx.padding[2],
102
+ ctx.dilation[0], ctx.dilation[1], ctx.dilation[2],
103
+ ctx.groups
104
+ )
105
+
106
+ if bias is not None and ctx.needs_input_grad[2]:
107
+ grad_bias = grad_output.sum(dim=[0, 2, 3, 4])
108
+
109
+ return grad_input, grad_weight, grad_bias, None, None, None, None
110
+
111
+
112
+ def _normalize_tuple(value, n, name):
113
+ """Convert int or tuple to n-tuple."""
114
+ if isinstance(value, int):
115
+ return (value,) * n
116
+ if isinstance(value, (list, tuple)):
117
+ if len(value) == n:
118
+ return tuple(value)
119
+ elif len(value) == 1:
120
+ return (value[0],) * n
121
+ raise ValueError(f"{name} must be an int or {n}-tuple")
122
+
123
+
124
+ def conv3d(
125
+ input: Tensor,
126
+ weight: Tensor,
127
+ bias: Optional[Tensor] = None,
128
+ stride: Union[int, Tuple[int, int, int]] = 1,
129
+ padding: Union[int, Tuple[int, int, int]] = 0,
130
+ dilation: Union[int, Tuple[int, int, int]] = 1,
131
+ groups: int = 1,
132
+ ) -> Tensor:
133
+ """
134
+ 3D convolution on MPS.
135
+
136
+ Drop-in replacement for torch.nn.functional.conv3d.
137
+
138
+ Args:
139
+ input: Input tensor (N, C_in, D, H, W)
140
+ weight: Weight tensor (C_out, C_in/groups, kD, kH, kW)
141
+ bias: Optional bias tensor (C_out,)
142
+ stride: Stride of convolution
143
+ padding: Padding added to input
144
+ dilation: Dilation of kernel
145
+ groups: Number of groups
146
+
147
+ Returns:
148
+ Output tensor (N, C_out, D_out, H_out, W_out)
149
+ """
150
+ if input.device.type != "mps":
151
+ # Fallback to PyTorch for non-MPS
152
+ return F.conv3d(input, weight, bias, stride, padding, dilation, groups)
153
+
154
+ stride = _normalize_tuple(stride, 3, "stride")
155
+ padding = _normalize_tuple(padding, 3, "padding")
156
+ dilation = _normalize_tuple(dilation, 3, "dilation")
157
+
158
+ return _Conv3DFunction.apply(
159
+ input.contiguous(), weight.contiguous(), bias,
160
+ stride, padding, dilation, groups
161
+ )
162
+
163
+
164
+ class Conv3d(nn.Module):
165
+ """
166
+ 3D Convolution layer for MPS.
167
+
168
+ Drop-in replacement for torch.nn.Conv3d.
169
+
170
+ Args:
171
+ in_channels: Number of input channels
172
+ out_channels: Number of output channels
173
+ kernel_size: Size of the convolving kernel
174
+ stride: Stride of the convolution
175
+ padding: Padding added to input
176
+ dilation: Dilation of kernel elements
177
+ groups: Number of blocked connections
178
+ bias: If True, adds a learnable bias
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ in_channels: int,
184
+ out_channels: int,
185
+ kernel_size: Union[int, Tuple[int, int, int]],
186
+ stride: Union[int, Tuple[int, int, int]] = 1,
187
+ padding: Union[int, Tuple[int, int, int]] = 0,
188
+ dilation: Union[int, Tuple[int, int, int]] = 1,
189
+ groups: int = 1,
190
+ bias: bool = True,
191
+ ):
192
+ super().__init__()
193
+ kernel_size = _normalize_tuple(kernel_size, 3, "kernel_size")
194
+ self.stride = _normalize_tuple(stride, 3, "stride")
195
+ self.padding = _normalize_tuple(padding, 3, "padding")
196
+ self.dilation = _normalize_tuple(dilation, 3, "dilation")
197
+ self.groups = groups
198
+
199
+ self.weight = nn.Parameter(
200
+ torch.empty(out_channels, in_channels // groups, *kernel_size)
201
+ )
202
+ if bias:
203
+ self.bias = nn.Parameter(torch.empty(out_channels))
204
+ else:
205
+ self.register_parameter('bias', None)
206
+
207
+ self.reset_parameters()
208
+
209
+ def reset_parameters(self):
210
+ nn.init.kaiming_uniform_(self.weight, a=5**0.5)
211
+ if self.bias is not None:
212
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
213
+ bound = 1 / (fan_in ** 0.5)
214
+ nn.init.uniform_(self.bias, -bound, bound)
215
+
216
+ def forward(self, input: Tensor) -> Tensor:
217
+ return conv3d(
218
+ input, self.weight, self.bias,
219
+ self.stride, self.padding, self.dilation, self.groups
220
+ )
221
+
222
+
223
+ _original_conv3d = None
224
+
225
+
226
+ def patch_conv3d():
227
+ """
228
+ Monkey-patch torch.nn.functional.conv3d to use MPS implementation.
229
+
230
+ Call this at the start of your script to automatically use MPS conv3d
231
+ for all 3D convolution operations.
232
+ """
233
+ global _original_conv3d
234
+
235
+ if _original_conv3d is not None:
236
+ return # Already patched
237
+
238
+ _original_conv3d = F.conv3d
239
+
240
+ def patched_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
241
+ if input.device.type == 'mps':
242
+ return conv3d(input, weight, bias, stride, padding, dilation, groups)
243
+ return _original_conv3d(input, weight, bias, stride, padding, dilation, groups)
244
+
245
+ F.conv3d = patched_conv3d
246
+ print("MPS Conv3D: Patched F.conv3d")
247
+
248
+
249
+ def unpatch_conv3d():
250
+ """Restore original torch.nn.functional.conv3d."""
251
+ global _original_conv3d
252
+ if _original_conv3d is not None:
253
+ F.conv3d = _original_conv3d
254
+ _original_conv3d = None
255
+
256
+
257
+ def is_available() -> bool:
258
+ """Check if MPS is available."""
259
+ return torch.backends.mps.is_available()
@@ -0,0 +1,505 @@
1
+ // MPS Conv3D - Metal implementation of 3D convolution
2
+ // Strategy: im2col + PyTorch matmul (leverage Apple's optimized GEMM)
3
+
4
+ #include <torch/extension.h>
5
+ #include <ATen/mps/MPSStream.h>
6
+ #include <ATen/native/mps/OperationUtils.h>
7
+ #include <Metal/Metal.h>
8
+ #include <Foundation/Foundation.h>
9
+
10
+ static id<MTLDevice> g_device = nil;
11
+ static id<MTLLibrary> g_library = nil;
12
+ static id<MTLComputePipelineState> g_im2col3d_fp32 = nil;
13
+ static id<MTLComputePipelineState> g_im2col3d_fp16 = nil;
14
+ static id<MTLComputePipelineState> g_col2im3d_fp32 = nil;
15
+
16
+ static const char* METAL_SHADER = R"(
17
+ #include <metal_stdlib>
18
+ using namespace metal;
19
+
20
+ // im2col3d: Extract 3D patches into column matrix for GEMM
21
+ // Input: (N, C_in, D, H, W)
22
+ // Output: (N * D_out * H_out * W_out, C_in * kD * kH * kW)
23
+ // Each thread handles one output spatial position
24
+ kernel void im2col3d_fp32(
25
+ device const float* input [[buffer(0)]],
26
+ device float* col [[buffer(1)]],
27
+ constant int& batch [[buffer(2)]],
28
+ constant int& in_channels [[buffer(3)]],
29
+ constant int& in_depth [[buffer(4)]],
30
+ constant int& in_height [[buffer(5)]],
31
+ constant int& in_width [[buffer(6)]],
32
+ constant int& out_depth [[buffer(7)]],
33
+ constant int& out_height [[buffer(8)]],
34
+ constant int& out_width [[buffer(9)]],
35
+ constant int& kernel_d [[buffer(10)]],
36
+ constant int& kernel_h [[buffer(11)]],
37
+ constant int& kernel_w [[buffer(12)]],
38
+ constant int& stride_d [[buffer(13)]],
39
+ constant int& stride_h [[buffer(14)]],
40
+ constant int& stride_w [[buffer(15)]],
41
+ constant int& pad_d [[buffer(16)]],
42
+ constant int& pad_h [[buffer(17)]],
43
+ constant int& pad_w [[buffer(18)]],
44
+ constant int& dilation_d [[buffer(19)]],
45
+ constant int& dilation_h [[buffer(20)]],
46
+ constant int& dilation_w [[buffer(21)]],
47
+ uint gid [[thread_position_in_grid]]
48
+ ) {
49
+ // gid = linear index into (batch, out_d, out_h, out_w)
50
+ int out_spatial = out_depth * out_height * out_width;
51
+ int total = batch * out_spatial;
52
+ if (int(gid) >= total) return;
53
+
54
+ int b = gid / out_spatial;
55
+ int spatial_idx = gid % out_spatial;
56
+ int od = spatial_idx / (out_height * out_width);
57
+ int oh = (spatial_idx / out_width) % out_height;
58
+ int ow = spatial_idx % out_width;
59
+
60
+ int col_width = in_channels * kernel_d * kernel_h * kernel_w;
61
+ int col_row = gid; // Row in column matrix
62
+
63
+ // Input base for this batch
64
+ int input_batch_offset = b * in_channels * in_depth * in_height * in_width;
65
+
66
+ // Fill one row of column matrix
67
+ int col_idx = 0;
68
+ for (int ic = 0; ic < in_channels; ic++) {
69
+ for (int kd = 0; kd < kernel_d; kd++) {
70
+ int id = od * stride_d - pad_d + kd * dilation_d;
71
+ for (int kh = 0; kh < kernel_h; kh++) {
72
+ int ih = oh * stride_h - pad_h + kh * dilation_h;
73
+ for (int kw = 0; kw < kernel_w; kw++) {
74
+ int iw = ow * stride_w - pad_w + kw * dilation_w;
75
+
76
+ float val = 0.0f;
77
+ if (id >= 0 && id < in_depth && ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
78
+ int input_idx = input_batch_offset +
79
+ ic * in_depth * in_height * in_width +
80
+ id * in_height * in_width +
81
+ ih * in_width + iw;
82
+ val = input[input_idx];
83
+ }
84
+
85
+ col[col_row * col_width + col_idx] = val;
86
+ col_idx++;
87
+ }
88
+ }
89
+ }
90
+ }
91
+ }
92
+
93
+ kernel void im2col3d_fp16(
94
+ device const half* input [[buffer(0)]],
95
+ device half* col [[buffer(1)]],
96
+ constant int& batch [[buffer(2)]],
97
+ constant int& in_channels [[buffer(3)]],
98
+ constant int& in_depth [[buffer(4)]],
99
+ constant int& in_height [[buffer(5)]],
100
+ constant int& in_width [[buffer(6)]],
101
+ constant int& out_depth [[buffer(7)]],
102
+ constant int& out_height [[buffer(8)]],
103
+ constant int& out_width [[buffer(9)]],
104
+ constant int& kernel_d [[buffer(10)]],
105
+ constant int& kernel_h [[buffer(11)]],
106
+ constant int& kernel_w [[buffer(12)]],
107
+ constant int& stride_d [[buffer(13)]],
108
+ constant int& stride_h [[buffer(14)]],
109
+ constant int& stride_w [[buffer(15)]],
110
+ constant int& pad_d [[buffer(16)]],
111
+ constant int& pad_h [[buffer(17)]],
112
+ constant int& pad_w [[buffer(18)]],
113
+ constant int& dilation_d [[buffer(19)]],
114
+ constant int& dilation_h [[buffer(20)]],
115
+ constant int& dilation_w [[buffer(21)]],
116
+ uint gid [[thread_position_in_grid]]
117
+ ) {
118
+ int out_spatial = out_depth * out_height * out_width;
119
+ int total = batch * out_spatial;
120
+ if (int(gid) >= total) return;
121
+
122
+ int b = gid / out_spatial;
123
+ int spatial_idx = gid % out_spatial;
124
+ int od = spatial_idx / (out_height * out_width);
125
+ int oh = (spatial_idx / out_width) % out_height;
126
+ int ow = spatial_idx % out_width;
127
+
128
+ int col_width = in_channels * kernel_d * kernel_h * kernel_w;
129
+ int col_row = gid;
130
+
131
+ int input_batch_offset = b * in_channels * in_depth * in_height * in_width;
132
+
133
+ int col_idx = 0;
134
+ for (int ic = 0; ic < in_channels; ic++) {
135
+ for (int kd = 0; kd < kernel_d; kd++) {
136
+ int id = od * stride_d - pad_d + kd * dilation_d;
137
+ for (int kh = 0; kh < kernel_h; kh++) {
138
+ int ih = oh * stride_h - pad_h + kh * dilation_h;
139
+ for (int kw = 0; kw < kernel_w; kw++) {
140
+ int iw = ow * stride_w - pad_w + kw * dilation_w;
141
+
142
+ half val = 0.0h;
143
+ if (id >= 0 && id < in_depth && ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
144
+ int input_idx = input_batch_offset +
145
+ ic * in_depth * in_height * in_width +
146
+ id * in_height * in_width +
147
+ ih * in_width + iw;
148
+ val = input[input_idx];
149
+ }
150
+
151
+ col[col_row * col_width + col_idx] = val;
152
+ col_idx++;
153
+ }
154
+ }
155
+ }
156
+ }
157
+ }
158
+
159
+ // col2im3d: Scatter column matrix back to image (for backward)
160
+ // Atomic add since multiple output positions may write to same input
161
+ inline void atomic_add_float(device atomic_uint* addr, float value) {
162
+ uint expected = atomic_load_explicit(addr, memory_order_relaxed);
163
+ float current_val = as_type<float>(expected);
164
+ float new_val = current_val + value;
165
+ uint new_bits = as_type<uint>(new_val);
166
+
167
+ while (!atomic_compare_exchange_weak_explicit(
168
+ addr, &expected, new_bits,
169
+ memory_order_relaxed, memory_order_relaxed)) {
170
+ current_val = as_type<float>(expected);
171
+ new_val = current_val + value;
172
+ new_bits = as_type<uint>(new_val);
173
+ }
174
+ }
175
+
176
+ kernel void col2im3d_fp32(
177
+ device const float* col [[buffer(0)]],
178
+ device atomic_uint* output [[buffer(1)]],
179
+ constant int& batch [[buffer(2)]],
180
+ constant int& in_channels [[buffer(3)]],
181
+ constant int& in_depth [[buffer(4)]],
182
+ constant int& in_height [[buffer(5)]],
183
+ constant int& in_width [[buffer(6)]],
184
+ constant int& out_depth [[buffer(7)]],
185
+ constant int& out_height [[buffer(8)]],
186
+ constant int& out_width [[buffer(9)]],
187
+ constant int& kernel_d [[buffer(10)]],
188
+ constant int& kernel_h [[buffer(11)]],
189
+ constant int& kernel_w [[buffer(12)]],
190
+ constant int& stride_d [[buffer(13)]],
191
+ constant int& stride_h [[buffer(14)]],
192
+ constant int& stride_w [[buffer(15)]],
193
+ constant int& pad_d [[buffer(16)]],
194
+ constant int& pad_h [[buffer(17)]],
195
+ constant int& pad_w [[buffer(18)]],
196
+ constant int& dilation_d [[buffer(19)]],
197
+ constant int& dilation_h [[buffer(20)]],
198
+ constant int& dilation_w [[buffer(21)]],
199
+ uint gid [[thread_position_in_grid]]
200
+ ) {
201
+ int out_spatial = out_depth * out_height * out_width;
202
+ int total = batch * out_spatial;
203
+ if (int(gid) >= total) return;
204
+
205
+ int b = gid / out_spatial;
206
+ int spatial_idx = gid % out_spatial;
207
+ int od = spatial_idx / (out_height * out_width);
208
+ int oh = (spatial_idx / out_width) % out_height;
209
+ int ow = spatial_idx % out_width;
210
+
211
+ int col_width = in_channels * kernel_d * kernel_h * kernel_w;
212
+ int col_row = gid;
213
+
214
+ int output_batch_offset = b * in_channels * in_depth * in_height * in_width;
215
+
216
+ int col_idx = 0;
217
+ for (int ic = 0; ic < in_channels; ic++) {
218
+ for (int kd = 0; kd < kernel_d; kd++) {
219
+ int id = od * stride_d - pad_d + kd * dilation_d;
220
+ for (int kh = 0; kh < kernel_h; kh++) {
221
+ int ih = oh * stride_h - pad_h + kh * dilation_h;
222
+ for (int kw = 0; kw < kernel_w; kw++) {
223
+ int iw = ow * stride_w - pad_w + kw * dilation_w;
224
+
225
+ if (id >= 0 && id < in_depth && ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
226
+ int output_idx = output_batch_offset +
227
+ ic * in_depth * in_height * in_width +
228
+ id * in_height * in_width +
229
+ ih * in_width + iw;
230
+ float val = col[col_row * col_width + col_idx];
231
+ atomic_add_float(&output[output_idx], val);
232
+ }
233
+ col_idx++;
234
+ }
235
+ }
236
+ }
237
+ }
238
+ }
239
+ )";
240
+
241
+ static void ensure_initialized() {
242
+ if (g_device != nil) return;
243
+
244
+ g_device = MTLCreateSystemDefaultDevice();
245
+ NSError* error = nil;
246
+
247
+ NSString* source = [NSString stringWithUTF8String:METAL_SHADER];
248
+ g_library = [g_device newLibraryWithSource:source options:nil error:&error];
249
+
250
+ if (error) {
251
+ NSLog(@"Failed to compile Metal shader: %@", error);
252
+ throw std::runtime_error("Failed to compile Metal shader");
253
+ }
254
+
255
+ id<MTLFunction> im2col_fp32 = [g_library newFunctionWithName:@"im2col3d_fp32"];
256
+ id<MTLFunction> im2col_fp16 = [g_library newFunctionWithName:@"im2col3d_fp16"];
257
+ id<MTLFunction> col2im_fp32 = [g_library newFunctionWithName:@"col2im3d_fp32"];
258
+
259
+ g_im2col3d_fp32 = [g_device newComputePipelineStateWithFunction:im2col_fp32 error:&error];
260
+ g_im2col3d_fp16 = [g_device newComputePipelineStateWithFunction:im2col_fp16 error:&error];
261
+ g_col2im3d_fp32 = [g_device newComputePipelineStateWithFunction:col2im_fp32 error:&error];
262
+ }
263
+
264
+ // im2col3d: Metal kernel to extract patches
265
+ torch::Tensor im2col3d_mps(
266
+ const torch::Tensor& input,
267
+ int kernel_d, int kernel_h, int kernel_w,
268
+ int stride_d, int stride_h, int stride_w,
269
+ int pad_d, int pad_h, int pad_w,
270
+ int dilation_d, int dilation_h, int dilation_w
271
+ ) {
272
+ ensure_initialized();
273
+
274
+ int batch = input.size(0);
275
+ int in_channels = input.size(1);
276
+ int in_depth = input.size(2);
277
+ int in_height = input.size(3);
278
+ int in_width = input.size(4);
279
+
280
+ int out_depth = (in_depth + 2 * pad_d - dilation_d * (kernel_d - 1) - 1) / stride_d + 1;
281
+ int out_height = (in_height + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1;
282
+ int out_width = (in_width + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1;
283
+
284
+ int col_height = batch * out_depth * out_height * out_width;
285
+ int col_width = in_channels * kernel_d * kernel_h * kernel_w;
286
+
287
+ auto input_contig = input.contiguous();
288
+ auto col = torch::empty({col_height, col_width}, input_contig.options());
289
+
290
+ id<MTLBuffer> input_buf = at::native::mps::getMTLBufferStorage(input_contig);
291
+ id<MTLBuffer> col_buf = at::native::mps::getMTLBufferStorage(col);
292
+
293
+ @autoreleasepool {
294
+ auto stream = at::mps::getCurrentMPSStream();
295
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
296
+
297
+ id<MTLComputePipelineState> pso = (input_contig.scalar_type() == at::kHalf)
298
+ ? g_im2col3d_fp16 : g_im2col3d_fp32;
299
+
300
+ [encoder setComputePipelineState:pso];
301
+ [encoder setBuffer:input_buf offset:input_contig.storage_offset() * input_contig.element_size() atIndex:0];
302
+ [encoder setBuffer:col_buf offset:0 atIndex:1];
303
+
304
+ [encoder setBytes:&batch length:sizeof(int) atIndex:2];
305
+ [encoder setBytes:&in_channels length:sizeof(int) atIndex:3];
306
+ [encoder setBytes:&in_depth length:sizeof(int) atIndex:4];
307
+ [encoder setBytes:&in_height length:sizeof(int) atIndex:5];
308
+ [encoder setBytes:&in_width length:sizeof(int) atIndex:6];
309
+ [encoder setBytes:&out_depth length:sizeof(int) atIndex:7];
310
+ [encoder setBytes:&out_height length:sizeof(int) atIndex:8];
311
+ [encoder setBytes:&out_width length:sizeof(int) atIndex:9];
312
+ [encoder setBytes:&kernel_d length:sizeof(int) atIndex:10];
313
+ [encoder setBytes:&kernel_h length:sizeof(int) atIndex:11];
314
+ [encoder setBytes:&kernel_w length:sizeof(int) atIndex:12];
315
+ [encoder setBytes:&stride_d length:sizeof(int) atIndex:13];
316
+ [encoder setBytes:&stride_h length:sizeof(int) atIndex:14];
317
+ [encoder setBytes:&stride_w length:sizeof(int) atIndex:15];
318
+ [encoder setBytes:&pad_d length:sizeof(int) atIndex:16];
319
+ [encoder setBytes:&pad_h length:sizeof(int) atIndex:17];
320
+ [encoder setBytes:&pad_w length:sizeof(int) atIndex:18];
321
+ [encoder setBytes:&dilation_d length:sizeof(int) atIndex:19];
322
+ [encoder setBytes:&dilation_h length:sizeof(int) atIndex:20];
323
+ [encoder setBytes:&dilation_w length:sizeof(int) atIndex:21];
324
+
325
+ MTLSize gridSize = MTLSizeMake(col_height, 1, 1);
326
+ NSUInteger threadGroupSize = std::min((NSUInteger)256, pso.maxTotalThreadsPerThreadgroup);
327
+ MTLSize tgSize = MTLSizeMake(threadGroupSize, 1, 1);
328
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:tgSize];
329
+ }
330
+
331
+ return col;
332
+ }
333
+
334
+ // Forward: im2col + matmul
335
+ torch::Tensor conv3d_forward_mps(
336
+ const torch::Tensor& input,
337
+ const torch::Tensor& weight,
338
+ int stride_d, int stride_h, int stride_w,
339
+ int pad_d, int pad_h, int pad_w,
340
+ int dilation_d, int dilation_h, int dilation_w,
341
+ int groups
342
+ ) {
343
+ TORCH_CHECK(input.device().type() == torch::kMPS, "input must be on MPS");
344
+ TORCH_CHECK(weight.device().type() == torch::kMPS, "weight must be on MPS");
345
+ TORCH_CHECK(groups == 1, "groups > 1 not yet supported in im2col path");
346
+
347
+ int batch = input.size(0);
348
+ int out_channels = weight.size(0);
349
+ int in_channels = weight.size(1);
350
+ int kernel_d = weight.size(2);
351
+ int kernel_h = weight.size(3);
352
+ int kernel_w = weight.size(4);
353
+
354
+ int in_depth = input.size(2);
355
+ int in_height = input.size(3);
356
+ int in_width = input.size(4);
357
+
358
+ int out_depth = (in_depth + 2 * pad_d - dilation_d * (kernel_d - 1) - 1) / stride_d + 1;
359
+ int out_height = (in_height + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1;
360
+ int out_width = (in_width + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1;
361
+
362
+ // im2col: (B*D_out*H_out*W_out, C_in*kD*kH*kW)
363
+ auto col = im2col3d_mps(input, kernel_d, kernel_h, kernel_w,
364
+ stride_d, stride_h, stride_w,
365
+ pad_d, pad_h, pad_w,
366
+ dilation_d, dilation_h, dilation_w);
367
+
368
+ // Weight: (C_out, C_in*kD*kH*kW) -> transpose for matmul
369
+ auto weight_col = weight.view({out_channels, -1}); // (C_out, C_in*k*k*k)
370
+
371
+ // Matmul: (B*D*H*W, C_in*k*k*k) @ (C_in*k*k*k, C_out) = (B*D*H*W, C_out)
372
+ auto output_col = torch::mm(col, weight_col.t());
373
+
374
+ // Reshape: (B, D_out, H_out, W_out, C_out) -> (B, C_out, D_out, H_out, W_out)
375
+ auto output = output_col.view({batch, out_depth, out_height, out_width, out_channels});
376
+ output = output.permute({0, 4, 1, 2, 3}).contiguous();
377
+
378
+ return output;
379
+ }
380
+
381
+ // Backward input: matmul + col2im
382
+ torch::Tensor conv3d_backward_input_mps(
383
+ const torch::Tensor& grad_output,
384
+ const torch::Tensor& weight,
385
+ std::vector<int64_t> input_shape,
386
+ int stride_d, int stride_h, int stride_w,
387
+ int pad_d, int pad_h, int pad_w,
388
+ int dilation_d, int dilation_h, int dilation_w,
389
+ int groups
390
+ ) {
391
+ ensure_initialized();
392
+ TORCH_CHECK(groups == 1, "groups > 1 not yet supported");
393
+
394
+ int batch = input_shape[0];
395
+ int in_channels = input_shape[1];
396
+ int in_depth = input_shape[2];
397
+ int in_height = input_shape[3];
398
+ int in_width = input_shape[4];
399
+
400
+ int out_channels = weight.size(0);
401
+ int kernel_d = weight.size(2);
402
+ int kernel_h = weight.size(3);
403
+ int kernel_w = weight.size(4);
404
+
405
+ int out_depth = grad_output.size(2);
406
+ int out_height = grad_output.size(3);
407
+ int out_width = grad_output.size(4);
408
+
409
+ // grad_output: (B, C_out, D_out, H_out, W_out) -> (B*D*H*W, C_out)
410
+ auto grad_out_col = grad_output.permute({0, 2, 3, 4, 1}).contiguous();
411
+ grad_out_col = grad_out_col.view({-1, out_channels});
412
+
413
+ // Weight: (C_out, C_in*k*k*k)
414
+ auto weight_col = weight.view({out_channels, -1});
415
+
416
+ // grad_col = grad_out_col @ weight_col: (B*D*H*W, C_in*k*k*k)
417
+ auto grad_col = torch::mm(grad_out_col.to(at::kFloat), weight_col.to(at::kFloat));
418
+
419
+ // col2im: scatter back to input shape
420
+ auto grad_input = torch::zeros(input_shape, grad_col.options());
421
+
422
+ int col_height = batch * out_depth * out_height * out_width;
423
+
424
+ id<MTLBuffer> col_buf = at::native::mps::getMTLBufferStorage(grad_col);
425
+ id<MTLBuffer> output_buf = at::native::mps::getMTLBufferStorage(grad_input);
426
+
427
+ @autoreleasepool {
428
+ auto stream = at::mps::getCurrentMPSStream();
429
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
430
+
431
+ [encoder setComputePipelineState:g_col2im3d_fp32];
432
+ [encoder setBuffer:col_buf offset:0 atIndex:0];
433
+ [encoder setBuffer:output_buf offset:0 atIndex:1];
434
+
435
+ [encoder setBytes:&batch length:sizeof(int) atIndex:2];
436
+ [encoder setBytes:&in_channels length:sizeof(int) atIndex:3];
437
+ [encoder setBytes:&in_depth length:sizeof(int) atIndex:4];
438
+ [encoder setBytes:&in_height length:sizeof(int) atIndex:5];
439
+ [encoder setBytes:&in_width length:sizeof(int) atIndex:6];
440
+ [encoder setBytes:&out_depth length:sizeof(int) atIndex:7];
441
+ [encoder setBytes:&out_height length:sizeof(int) atIndex:8];
442
+ [encoder setBytes:&out_width length:sizeof(int) atIndex:9];
443
+ [encoder setBytes:&kernel_d length:sizeof(int) atIndex:10];
444
+ [encoder setBytes:&kernel_h length:sizeof(int) atIndex:11];
445
+ [encoder setBytes:&kernel_w length:sizeof(int) atIndex:12];
446
+ [encoder setBytes:&stride_d length:sizeof(int) atIndex:13];
447
+ [encoder setBytes:&stride_h length:sizeof(int) atIndex:14];
448
+ [encoder setBytes:&stride_w length:sizeof(int) atIndex:15];
449
+ [encoder setBytes:&pad_d length:sizeof(int) atIndex:16];
450
+ [encoder setBytes:&pad_h length:sizeof(int) atIndex:17];
451
+ [encoder setBytes:&pad_w length:sizeof(int) atIndex:18];
452
+ [encoder setBytes:&dilation_d length:sizeof(int) atIndex:19];
453
+ [encoder setBytes:&dilation_h length:sizeof(int) atIndex:20];
454
+ [encoder setBytes:&dilation_w length:sizeof(int) atIndex:21];
455
+
456
+ MTLSize gridSize = MTLSizeMake(col_height, 1, 1);
457
+ MTLSize tgSize = MTLSizeMake(256, 1, 1);
458
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:tgSize];
459
+ }
460
+
461
+ return grad_input.to(grad_output.scalar_type());
462
+ }
463
+
464
+ // Backward weight: im2col.T @ grad_output
465
+ torch::Tensor conv3d_backward_weight_mps(
466
+ const torch::Tensor& grad_output,
467
+ const torch::Tensor& input,
468
+ std::vector<int64_t> weight_shape,
469
+ int stride_d, int stride_h, int stride_w,
470
+ int pad_d, int pad_h, int pad_w,
471
+ int dilation_d, int dilation_h, int dilation_w,
472
+ int groups
473
+ ) {
474
+ TORCH_CHECK(groups == 1, "groups > 1 not yet supported");
475
+
476
+ int out_channels = weight_shape[0];
477
+ int kernel_d = weight_shape[2];
478
+ int kernel_h = weight_shape[3];
479
+ int kernel_w = weight_shape[4];
480
+
481
+ // im2col on input
482
+ auto col = im2col3d_mps(input.to(at::kFloat), kernel_d, kernel_h, kernel_w,
483
+ stride_d, stride_h, stride_w,
484
+ pad_d, pad_h, pad_w,
485
+ dilation_d, dilation_h, dilation_w);
486
+
487
+ // grad_output: (B, C_out, D_out, H_out, W_out) -> (B*D*H*W, C_out)
488
+ auto grad_out_col = grad_output.permute({0, 2, 3, 4, 1}).contiguous();
489
+ grad_out_col = grad_out_col.view({-1, out_channels}).to(at::kFloat);
490
+
491
+ // grad_weight = grad_out_col.T @ col: (C_out, C_in*k*k*k)
492
+ auto grad_weight_col = torch::mm(grad_out_col.t(), col);
493
+
494
+ // Reshape to weight shape
495
+ auto grad_weight = grad_weight_col.view(weight_shape);
496
+
497
+ return grad_weight.to(grad_output.scalar_type());
498
+ }
499
+
500
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
501
+ m.def("conv3d_forward", &conv3d_forward_mps, "Conv3D forward (MPS) - im2col + GEMM");
502
+ m.def("conv3d_backward_input", &conv3d_backward_input_mps, "Conv3D backward input (MPS)");
503
+ m.def("conv3d_backward_weight", &conv3d_backward_weight_mps, "Conv3D backward weight (MPS)");
504
+ m.def("im2col3d", &im2col3d_mps, "im2col3d (MPS)");
505
+ }
@@ -0,0 +1,148 @@
1
+ Metadata-Version: 2.4
2
+ Name: mps-conv3d
3
+ Version: 0.1.5
4
+ Summary: 3D Convolution for Apple Silicon (MPS)
5
+ Author: mpsops
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/mpsops/mps-conv3d
8
+ Project-URL: Repository, https://github.com/mpsops/mps-conv3d
9
+ Project-URL: Issues, https://github.com/mpsops/mps-conv3d/issues
10
+ Keywords: pytorch,mps,apple-silicon,conv3d,video,3d-convolution
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Operating System :: MacOS :: MacOS X
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ Requires-Dist: torch>=2.0.0
23
+ Dynamic: requires-python
24
+
25
+ # MPS Conv3D
26
+
27
+ 3D Convolution for Apple Silicon (M1/M2/M3/M4).
28
+
29
+ **Drop-in replacement** for `torch.nn.functional.conv3d` on MPS.
30
+
31
+ ## Why?
32
+
33
+ 3D convolutions are essential for video models:
34
+ - **Synchformer**: Audio-visual synchronization
35
+ - **I3D**: Video classification
36
+ - **SlowFast**: Action recognition
37
+ - **C3D**: Video feature extraction
38
+ - **MMAudio**: Audio generation from video
39
+
40
+ But PyTorch's MPS backend doesn't support 3D convolutions:
41
+ ```
42
+ NotImplementedError: aten::slow_conv3d_forward is not implemented for MPS
43
+ ```
44
+
45
+ This package provides a native Metal implementation.
46
+
47
+ ## Installation
48
+
49
+ ```bash
50
+ pip install mps-conv3d
51
+ ```
52
+
53
+ Or from source:
54
+
55
+ ```bash
56
+ git clone https://github.com/mpsops/mps-conv3d
57
+ cd mps-conv3d
58
+ pip install -e .
59
+ ```
60
+
61
+ ## Quick Start
62
+
63
+ ### Patch All Conv3D Operations (Recommended)
64
+
65
+ ```python
66
+ from mps_conv3d import patch_conv3d
67
+
68
+ # Patch at the start of your script
69
+ patch_conv3d()
70
+
71
+ # Now all conv3d operations use MPS!
72
+ import torch
73
+ import torch.nn.functional as F
74
+
75
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
76
+ w = torch.randn(64, 3, 3, 7, 7, device='mps')
77
+ out = F.conv3d(x, w, padding=(1, 3, 3)) # Uses MPS!
78
+ ```
79
+
80
+ ### Direct Usage
81
+
82
+ ```python
83
+ import torch
84
+ from mps_conv3d import conv3d
85
+
86
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
87
+ w = torch.randn(64, 3, 3, 7, 7, device='mps')
88
+
89
+ out = conv3d(x, w, stride=1, padding=(1, 3, 3))
90
+ ```
91
+
92
+ ### Conv3d Module
93
+
94
+ ```python
95
+ from mps_conv3d import Conv3d
96
+
97
+ conv = Conv3d(
98
+ in_channels=3,
99
+ out_channels=64,
100
+ kernel_size=(3, 7, 7),
101
+ stride=(1, 2, 2),
102
+ padding=(1, 3, 3)
103
+ ).to('mps')
104
+
105
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
106
+ out = conv(x)
107
+ ```
108
+
109
+ ## API Reference
110
+
111
+ ### `conv3d(input, weight, bias, stride, padding, dilation, groups)`
112
+
113
+ Same signature as `torch.nn.functional.conv3d`.
114
+
115
+ | Parameter | Type | Description |
116
+ |-----------|------|-------------|
117
+ | `input` | Tensor | Input tensor (N, C_in, D, H, W) |
118
+ | `weight` | Tensor | Weight tensor (C_out, C_in/groups, kD, kH, kW) |
119
+ | `bias` | Tensor | Optional bias (C_out,) |
120
+ | `stride` | int/tuple | Stride of convolution |
121
+ | `padding` | int/tuple | Padding added to input |
122
+ | `dilation` | int/tuple | Dilation of kernel |
123
+ | `groups` | int | Number of groups |
124
+
125
+ ### `patch_conv3d()`
126
+
127
+ Monkey-patches `torch.nn.functional.conv3d` to use MPS implementation for MPS tensors.
128
+
129
+ ### `unpatch_conv3d()`
130
+
131
+ Restores original `torch.nn.functional.conv3d`.
132
+
133
+ ## Compatibility
134
+
135
+ - **PyTorch**: 2.0+
136
+ - **macOS**: 12.0+ (Monterey)
137
+ - **Hardware**: Apple Silicon (M1/M2/M3/M4)
138
+
139
+ ## Features
140
+
141
+ - Full forward and backward pass (training supported)
142
+ - fp32 and fp16 supported
143
+ - Groups and dilation supported
144
+ - Drop-in compatible with PyTorch API
145
+
146
+ ## License
147
+
148
+ MIT
@@ -0,0 +1,7 @@
1
+ mps_conv3d/_C.cpython-314-darwin.so,sha256=ZG2iBtp1Z12ChIEiW46iFB7CFH48XUBVxciWjuyqJ7A,267336
2
+ mps_conv3d/__init__.py,sha256=UxjYD3ud-FU2ogQDKtgrXiPtSeiOkWbxsrZG73k0JWQ,7729
3
+ mps_conv3d/csrc/conv3d_mps.mm,sha256=kSYI07eLvWxwO4LLBWG1-A0UvCdxt-7dMGS0hF8kD30,20998
4
+ mps_conv3d-0.1.5.dist-info/METADATA,sha256=rE525f6iQecRG5ZSqr2rBXqY2cCdMXHb5I1mPBNvXX4,3636
5
+ mps_conv3d-0.1.5.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
6
+ mps_conv3d-0.1.5.dist-info/top_level.txt,sha256=TjxUUmzGDVNDsIKTolOKbln2HXVSrh9M3DU84ug60II,11
7
+ mps_conv3d-0.1.5.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: false
4
+ Tag: cp314-cp314-macosx_15_0_arm64
5
+
@@ -0,0 +1 @@
1
+ mps_conv3d