mps-conv3d 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,148 @@
1
+ Metadata-Version: 2.4
2
+ Name: mps-conv3d
3
+ Version: 0.1.4
4
+ Summary: 3D Convolution for Apple Silicon (MPS)
5
+ Author: mpsops
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/mpsops/mps-conv3d
8
+ Project-URL: Repository, https://github.com/mpsops/mps-conv3d
9
+ Project-URL: Issues, https://github.com/mpsops/mps-conv3d/issues
10
+ Keywords: pytorch,mps,apple-silicon,conv3d,video,3d-convolution
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Operating System :: MacOS :: MacOS X
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ Requires-Dist: torch>=2.0.0
23
+ Dynamic: requires-python
24
+
25
+ # MPS Conv3D
26
+
27
+ 3D Convolution for Apple Silicon (M1/M2/M3/M4).
28
+
29
+ **Drop-in replacement** for `torch.nn.functional.conv3d` on MPS.
30
+
31
+ ## Why?
32
+
33
+ 3D convolutions are essential for video models:
34
+ - **Synchformer**: Audio-visual synchronization
35
+ - **I3D**: Video classification
36
+ - **SlowFast**: Action recognition
37
+ - **C3D**: Video feature extraction
38
+ - **MMAudio**: Audio generation from video
39
+
40
+ But PyTorch's MPS backend doesn't support 3D convolutions:
41
+ ```
42
+ NotImplementedError: aten::slow_conv3d_forward is not implemented for MPS
43
+ ```
44
+
45
+ This package provides a native Metal implementation.
46
+
47
+ ## Installation
48
+
49
+ ```bash
50
+ pip install mps-conv3d
51
+ ```
52
+
53
+ Or from source:
54
+
55
+ ```bash
56
+ git clone https://github.com/mpsops/mps-conv3d
57
+ cd mps-conv3d
58
+ pip install -e .
59
+ ```
60
+
61
+ ## Quick Start
62
+
63
+ ### Patch All Conv3D Operations (Recommended)
64
+
65
+ ```python
66
+ from mps_conv3d import patch_conv3d
67
+
68
+ # Patch at the start of your script
69
+ patch_conv3d()
70
+
71
+ # Now all conv3d operations use MPS!
72
+ import torch
73
+ import torch.nn.functional as F
74
+
75
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
76
+ w = torch.randn(64, 3, 3, 7, 7, device='mps')
77
+ out = F.conv3d(x, w, padding=(1, 3, 3)) # Uses MPS!
78
+ ```
79
+
80
+ ### Direct Usage
81
+
82
+ ```python
83
+ import torch
84
+ from mps_conv3d import conv3d
85
+
86
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
87
+ w = torch.randn(64, 3, 3, 7, 7, device='mps')
88
+
89
+ out = conv3d(x, w, stride=1, padding=(1, 3, 3))
90
+ ```
91
+
92
+ ### Conv3d Module
93
+
94
+ ```python
95
+ from mps_conv3d import Conv3d
96
+
97
+ conv = Conv3d(
98
+ in_channels=3,
99
+ out_channels=64,
100
+ kernel_size=(3, 7, 7),
101
+ stride=(1, 2, 2),
102
+ padding=(1, 3, 3)
103
+ ).to('mps')
104
+
105
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
106
+ out = conv(x)
107
+ ```
108
+
109
+ ## API Reference
110
+
111
+ ### `conv3d(input, weight, bias, stride, padding, dilation, groups)`
112
+
113
+ Same signature as `torch.nn.functional.conv3d`.
114
+
115
+ | Parameter | Type | Description |
116
+ |-----------|------|-------------|
117
+ | `input` | Tensor | Input tensor (N, C_in, D, H, W) |
118
+ | `weight` | Tensor | Weight tensor (C_out, C_in/groups, kD, kH, kW) |
119
+ | `bias` | Tensor | Optional bias (C_out,) |
120
+ | `stride` | int/tuple | Stride of convolution |
121
+ | `padding` | int/tuple | Padding added to input |
122
+ | `dilation` | int/tuple | Dilation of kernel |
123
+ | `groups` | int | Number of groups |
124
+
125
+ ### `patch_conv3d()`
126
+
127
+ Monkey-patches `torch.nn.functional.conv3d` to use MPS implementation for MPS tensors.
128
+
129
+ ### `unpatch_conv3d()`
130
+
131
+ Restores original `torch.nn.functional.conv3d`.
132
+
133
+ ## Compatibility
134
+
135
+ - **PyTorch**: 2.0+
136
+ - **macOS**: 12.0+ (Monterey)
137
+ - **Hardware**: Apple Silicon (M1/M2/M3/M4)
138
+
139
+ ## Features
140
+
141
+ - Full forward and backward pass (training supported)
142
+ - fp32 and fp16 supported
143
+ - Groups and dilation supported
144
+ - Drop-in compatible with PyTorch API
145
+
146
+ ## License
147
+
148
+ MIT
@@ -0,0 +1,124 @@
1
+ # MPS Conv3D
2
+
3
+ 3D Convolution for Apple Silicon (M1/M2/M3/M4).
4
+
5
+ **Drop-in replacement** for `torch.nn.functional.conv3d` on MPS.
6
+
7
+ ## Why?
8
+
9
+ 3D convolutions are essential for video models:
10
+ - **Synchformer**: Audio-visual synchronization
11
+ - **I3D**: Video classification
12
+ - **SlowFast**: Action recognition
13
+ - **C3D**: Video feature extraction
14
+ - **MMAudio**: Audio generation from video
15
+
16
+ But PyTorch's MPS backend doesn't support 3D convolutions:
17
+ ```
18
+ NotImplementedError: aten::slow_conv3d_forward is not implemented for MPS
19
+ ```
20
+
21
+ This package provides a native Metal implementation.
22
+
23
+ ## Installation
24
+
25
+ ```bash
26
+ pip install mps-conv3d
27
+ ```
28
+
29
+ Or from source:
30
+
31
+ ```bash
32
+ git clone https://github.com/mpsops/mps-conv3d
33
+ cd mps-conv3d
34
+ pip install -e .
35
+ ```
36
+
37
+ ## Quick Start
38
+
39
+ ### Patch All Conv3D Operations (Recommended)
40
+
41
+ ```python
42
+ from mps_conv3d import patch_conv3d
43
+
44
+ # Patch at the start of your script
45
+ patch_conv3d()
46
+
47
+ # Now all conv3d operations use MPS!
48
+ import torch
49
+ import torch.nn.functional as F
50
+
51
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
52
+ w = torch.randn(64, 3, 3, 7, 7, device='mps')
53
+ out = F.conv3d(x, w, padding=(1, 3, 3)) # Uses MPS!
54
+ ```
55
+
56
+ ### Direct Usage
57
+
58
+ ```python
59
+ import torch
60
+ from mps_conv3d import conv3d
61
+
62
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
63
+ w = torch.randn(64, 3, 3, 7, 7, device='mps')
64
+
65
+ out = conv3d(x, w, stride=1, padding=(1, 3, 3))
66
+ ```
67
+
68
+ ### Conv3d Module
69
+
70
+ ```python
71
+ from mps_conv3d import Conv3d
72
+
73
+ conv = Conv3d(
74
+ in_channels=3,
75
+ out_channels=64,
76
+ kernel_size=(3, 7, 7),
77
+ stride=(1, 2, 2),
78
+ padding=(1, 3, 3)
79
+ ).to('mps')
80
+
81
+ x = torch.randn(1, 3, 16, 112, 112, device='mps')
82
+ out = conv(x)
83
+ ```
84
+
85
+ ## API Reference
86
+
87
+ ### `conv3d(input, weight, bias, stride, padding, dilation, groups)`
88
+
89
+ Same signature as `torch.nn.functional.conv3d`.
90
+
91
+ | Parameter | Type | Description |
92
+ |-----------|------|-------------|
93
+ | `input` | Tensor | Input tensor (N, C_in, D, H, W) |
94
+ | `weight` | Tensor | Weight tensor (C_out, C_in/groups, kD, kH, kW) |
95
+ | `bias` | Tensor | Optional bias (C_out,) |
96
+ | `stride` | int/tuple | Stride of convolution |
97
+ | `padding` | int/tuple | Padding added to input |
98
+ | `dilation` | int/tuple | Dilation of kernel |
99
+ | `groups` | int | Number of groups |
100
+
101
+ ### `patch_conv3d()`
102
+
103
+ Monkey-patches `torch.nn.functional.conv3d` to use MPS implementation for MPS tensors.
104
+
105
+ ### `unpatch_conv3d()`
106
+
107
+ Restores original `torch.nn.functional.conv3d`.
108
+
109
+ ## Compatibility
110
+
111
+ - **PyTorch**: 2.0+
112
+ - **macOS**: 12.0+ (Monterey)
113
+ - **Hardware**: Apple Silicon (M1/M2/M3/M4)
114
+
115
+ ## Features
116
+
117
+ - Full forward and backward pass (training supported)
118
+ - fp32 and fp16 supported
119
+ - Groups and dilation supported
120
+ - Drop-in compatible with PyTorch API
121
+
122
+ ## License
123
+
124
+ MIT
@@ -0,0 +1,259 @@
1
+ """
2
+ MPS Conv3D - 3D Convolution for Apple Silicon
3
+
4
+ Drop-in replacement for torch.nn.functional.conv3d on MPS.
5
+ Used in video models: Synchformer, I3D, SlowFast, C3D, etc.
6
+ """
7
+
8
+ import torch
9
+ from torch import nn, Tensor
10
+ from torch.autograd import Function
11
+ from typing import Optional, Tuple, Union
12
+ import torch.nn.functional as F
13
+
14
+ __version__ = "0.1.2"
15
+
16
+
17
+ def _load_library():
18
+ """Load the native extension."""
19
+ try:
20
+ from mps_conv3d import _C as _lib
21
+ return _lib
22
+ except ImportError:
23
+ import os
24
+ from torch.utils.cpp_extension import load
25
+
26
+ src_dir = os.path.join(os.path.dirname(__file__), "csrc")
27
+ _lib = load(
28
+ name="mps_conv3d",
29
+ sources=[os.path.join(src_dir, "conv3d_mps.mm")],
30
+ extra_cflags=["-std=c++17"],
31
+ extra_ldflags=["-framework", "Metal", "-framework", "Foundation"],
32
+ verbose=False,
33
+ )
34
+ return _lib
35
+
36
+
37
+ _lib_cache = None
38
+
39
+
40
+ def _get_lib():
41
+ global _lib_cache
42
+ if _lib_cache is None:
43
+ _lib_cache = _load_library()
44
+ return _lib_cache
45
+
46
+
47
+ class _Conv3DFunction(Function):
48
+ """Autograd function for Conv3D."""
49
+
50
+ @staticmethod
51
+ def forward(
52
+ ctx,
53
+ input: Tensor,
54
+ weight: Tensor,
55
+ bias: Optional[Tensor],
56
+ stride: Tuple[int, int, int],
57
+ padding: Tuple[int, int, int],
58
+ dilation: Tuple[int, int, int],
59
+ groups: int,
60
+ ) -> Tensor:
61
+ ctx.save_for_backward(input, weight, bias)
62
+ ctx.stride = stride
63
+ ctx.padding = padding
64
+ ctx.dilation = dilation
65
+ ctx.groups = groups
66
+
67
+ lib = _get_lib()
68
+ output = lib.conv3d_forward(
69
+ input, weight,
70
+ stride[0], stride[1], stride[2],
71
+ padding[0], padding[1], padding[2],
72
+ dilation[0], dilation[1], dilation[2],
73
+ groups
74
+ )
75
+
76
+ if bias is not None:
77
+ output = output + bias.view(1, -1, 1, 1, 1)
78
+
79
+ return output
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output: Tensor):
83
+ input, weight, bias = ctx.saved_tensors
84
+ lib = _get_lib()
85
+
86
+ grad_input = grad_weight = grad_bias = None
87
+
88
+ if ctx.needs_input_grad[0]:
89
+ grad_input = lib.conv3d_backward_input(
90
+ grad_output, weight, input.shape,
91
+ ctx.stride[0], ctx.stride[1], ctx.stride[2],
92
+ ctx.padding[0], ctx.padding[1], ctx.padding[2],
93
+ ctx.dilation[0], ctx.dilation[1], ctx.dilation[2],
94
+ ctx.groups
95
+ )
96
+
97
+ if ctx.needs_input_grad[1]:
98
+ grad_weight = lib.conv3d_backward_weight(
99
+ grad_output, input, weight.shape,
100
+ ctx.stride[0], ctx.stride[1], ctx.stride[2],
101
+ ctx.padding[0], ctx.padding[1], ctx.padding[2],
102
+ ctx.dilation[0], ctx.dilation[1], ctx.dilation[2],
103
+ ctx.groups
104
+ )
105
+
106
+ if bias is not None and ctx.needs_input_grad[2]:
107
+ grad_bias = grad_output.sum(dim=[0, 2, 3, 4])
108
+
109
+ return grad_input, grad_weight, grad_bias, None, None, None, None
110
+
111
+
112
+ def _normalize_tuple(value, n, name):
113
+ """Convert int or tuple to n-tuple."""
114
+ if isinstance(value, int):
115
+ return (value,) * n
116
+ if isinstance(value, (list, tuple)):
117
+ if len(value) == n:
118
+ return tuple(value)
119
+ elif len(value) == 1:
120
+ return (value[0],) * n
121
+ raise ValueError(f"{name} must be an int or {n}-tuple")
122
+
123
+
124
+ def conv3d(
125
+ input: Tensor,
126
+ weight: Tensor,
127
+ bias: Optional[Tensor] = None,
128
+ stride: Union[int, Tuple[int, int, int]] = 1,
129
+ padding: Union[int, Tuple[int, int, int]] = 0,
130
+ dilation: Union[int, Tuple[int, int, int]] = 1,
131
+ groups: int = 1,
132
+ ) -> Tensor:
133
+ """
134
+ 3D convolution on MPS.
135
+
136
+ Drop-in replacement for torch.nn.functional.conv3d.
137
+
138
+ Args:
139
+ input: Input tensor (N, C_in, D, H, W)
140
+ weight: Weight tensor (C_out, C_in/groups, kD, kH, kW)
141
+ bias: Optional bias tensor (C_out,)
142
+ stride: Stride of convolution
143
+ padding: Padding added to input
144
+ dilation: Dilation of kernel
145
+ groups: Number of groups
146
+
147
+ Returns:
148
+ Output tensor (N, C_out, D_out, H_out, W_out)
149
+ """
150
+ if input.device.type != "mps":
151
+ # Fallback to PyTorch for non-MPS
152
+ return F.conv3d(input, weight, bias, stride, padding, dilation, groups)
153
+
154
+ stride = _normalize_tuple(stride, 3, "stride")
155
+ padding = _normalize_tuple(padding, 3, "padding")
156
+ dilation = _normalize_tuple(dilation, 3, "dilation")
157
+
158
+ return _Conv3DFunction.apply(
159
+ input.contiguous(), weight.contiguous(), bias,
160
+ stride, padding, dilation, groups
161
+ )
162
+
163
+
164
+ class Conv3d(nn.Module):
165
+ """
166
+ 3D Convolution layer for MPS.
167
+
168
+ Drop-in replacement for torch.nn.Conv3d.
169
+
170
+ Args:
171
+ in_channels: Number of input channels
172
+ out_channels: Number of output channels
173
+ kernel_size: Size of the convolving kernel
174
+ stride: Stride of the convolution
175
+ padding: Padding added to input
176
+ dilation: Dilation of kernel elements
177
+ groups: Number of blocked connections
178
+ bias: If True, adds a learnable bias
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ in_channels: int,
184
+ out_channels: int,
185
+ kernel_size: Union[int, Tuple[int, int, int]],
186
+ stride: Union[int, Tuple[int, int, int]] = 1,
187
+ padding: Union[int, Tuple[int, int, int]] = 0,
188
+ dilation: Union[int, Tuple[int, int, int]] = 1,
189
+ groups: int = 1,
190
+ bias: bool = True,
191
+ ):
192
+ super().__init__()
193
+ kernel_size = _normalize_tuple(kernel_size, 3, "kernel_size")
194
+ self.stride = _normalize_tuple(stride, 3, "stride")
195
+ self.padding = _normalize_tuple(padding, 3, "padding")
196
+ self.dilation = _normalize_tuple(dilation, 3, "dilation")
197
+ self.groups = groups
198
+
199
+ self.weight = nn.Parameter(
200
+ torch.empty(out_channels, in_channels // groups, *kernel_size)
201
+ )
202
+ if bias:
203
+ self.bias = nn.Parameter(torch.empty(out_channels))
204
+ else:
205
+ self.register_parameter('bias', None)
206
+
207
+ self.reset_parameters()
208
+
209
+ def reset_parameters(self):
210
+ nn.init.kaiming_uniform_(self.weight, a=5**0.5)
211
+ if self.bias is not None:
212
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
213
+ bound = 1 / (fan_in ** 0.5)
214
+ nn.init.uniform_(self.bias, -bound, bound)
215
+
216
+ def forward(self, input: Tensor) -> Tensor:
217
+ return conv3d(
218
+ input, self.weight, self.bias,
219
+ self.stride, self.padding, self.dilation, self.groups
220
+ )
221
+
222
+
223
+ _original_conv3d = None
224
+
225
+
226
+ def patch_conv3d():
227
+ """
228
+ Monkey-patch torch.nn.functional.conv3d to use MPS implementation.
229
+
230
+ Call this at the start of your script to automatically use MPS conv3d
231
+ for all 3D convolution operations.
232
+ """
233
+ global _original_conv3d
234
+
235
+ if _original_conv3d is not None:
236
+ return # Already patched
237
+
238
+ _original_conv3d = F.conv3d
239
+
240
+ def patched_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
241
+ if input.device.type == 'mps':
242
+ return conv3d(input, weight, bias, stride, padding, dilation, groups)
243
+ return _original_conv3d(input, weight, bias, stride, padding, dilation, groups)
244
+
245
+ F.conv3d = patched_conv3d
246
+ print("MPS Conv3D: Patched F.conv3d")
247
+
248
+
249
+ def unpatch_conv3d():
250
+ """Restore original torch.nn.functional.conv3d."""
251
+ global _original_conv3d
252
+ if _original_conv3d is not None:
253
+ F.conv3d = _original_conv3d
254
+ _original_conv3d = None
255
+
256
+
257
+ def is_available() -> bool:
258
+ """Check if MPS is available."""
259
+ return torch.backends.mps.is_available()