mps-conv3d 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mps_conv3d-0.1.4/PKG-INFO +148 -0
- mps_conv3d-0.1.4/README.md +124 -0
- mps_conv3d-0.1.4/mps_conv3d/__init__.py +259 -0
- mps_conv3d-0.1.4/mps_conv3d/csrc/conv3d_mps.mm +726 -0
- mps_conv3d-0.1.4/mps_conv3d.egg-info/PKG-INFO +148 -0
- mps_conv3d-0.1.4/mps_conv3d.egg-info/SOURCES.txt +11 -0
- mps_conv3d-0.1.4/mps_conv3d.egg-info/dependency_links.txt +1 -0
- mps_conv3d-0.1.4/mps_conv3d.egg-info/requires.txt +1 -0
- mps_conv3d-0.1.4/mps_conv3d.egg-info/top_level.txt +1 -0
- mps_conv3d-0.1.4/pyproject.toml +30 -0
- mps_conv3d-0.1.4/setup.cfg +4 -0
- mps_conv3d-0.1.4/setup.py +64 -0
- mps_conv3d-0.1.4/tests/test_conv3d.py +270 -0
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mps-conv3d
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Summary: 3D Convolution for Apple Silicon (MPS)
|
|
5
|
+
Author: mpsops
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/mpsops/mps-conv3d
|
|
8
|
+
Project-URL: Repository, https://github.com/mpsops/mps-conv3d
|
|
9
|
+
Project-URL: Issues, https://github.com/mpsops/mps-conv3d/issues
|
|
10
|
+
Keywords: pytorch,mps,apple-silicon,conv3d,video,3d-convolution
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
Requires-Dist: torch>=2.0.0
|
|
23
|
+
Dynamic: requires-python
|
|
24
|
+
|
|
25
|
+
# MPS Conv3D
|
|
26
|
+
|
|
27
|
+
3D Convolution for Apple Silicon (M1/M2/M3/M4).
|
|
28
|
+
|
|
29
|
+
**Drop-in replacement** for `torch.nn.functional.conv3d` on MPS.
|
|
30
|
+
|
|
31
|
+
## Why?
|
|
32
|
+
|
|
33
|
+
3D convolutions are essential for video models:
|
|
34
|
+
- **Synchformer**: Audio-visual synchronization
|
|
35
|
+
- **I3D**: Video classification
|
|
36
|
+
- **SlowFast**: Action recognition
|
|
37
|
+
- **C3D**: Video feature extraction
|
|
38
|
+
- **MMAudio**: Audio generation from video
|
|
39
|
+
|
|
40
|
+
But PyTorch's MPS backend doesn't support 3D convolutions:
|
|
41
|
+
```
|
|
42
|
+
NotImplementedError: aten::slow_conv3d_forward is not implemented for MPS
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
This package provides a native Metal implementation.
|
|
46
|
+
|
|
47
|
+
## Installation
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pip install mps-conv3d
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
Or from source:
|
|
54
|
+
|
|
55
|
+
```bash
|
|
56
|
+
git clone https://github.com/mpsops/mps-conv3d
|
|
57
|
+
cd mps-conv3d
|
|
58
|
+
pip install -e .
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Quick Start
|
|
62
|
+
|
|
63
|
+
### Patch All Conv3D Operations (Recommended)
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
from mps_conv3d import patch_conv3d
|
|
67
|
+
|
|
68
|
+
# Patch at the start of your script
|
|
69
|
+
patch_conv3d()
|
|
70
|
+
|
|
71
|
+
# Now all conv3d operations use MPS!
|
|
72
|
+
import torch
|
|
73
|
+
import torch.nn.functional as F
|
|
74
|
+
|
|
75
|
+
x = torch.randn(1, 3, 16, 112, 112, device='mps')
|
|
76
|
+
w = torch.randn(64, 3, 3, 7, 7, device='mps')
|
|
77
|
+
out = F.conv3d(x, w, padding=(1, 3, 3)) # Uses MPS!
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
### Direct Usage
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
import torch
|
|
84
|
+
from mps_conv3d import conv3d
|
|
85
|
+
|
|
86
|
+
x = torch.randn(1, 3, 16, 112, 112, device='mps')
|
|
87
|
+
w = torch.randn(64, 3, 3, 7, 7, device='mps')
|
|
88
|
+
|
|
89
|
+
out = conv3d(x, w, stride=1, padding=(1, 3, 3))
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### Conv3d Module
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
from mps_conv3d import Conv3d
|
|
96
|
+
|
|
97
|
+
conv = Conv3d(
|
|
98
|
+
in_channels=3,
|
|
99
|
+
out_channels=64,
|
|
100
|
+
kernel_size=(3, 7, 7),
|
|
101
|
+
stride=(1, 2, 2),
|
|
102
|
+
padding=(1, 3, 3)
|
|
103
|
+
).to('mps')
|
|
104
|
+
|
|
105
|
+
x = torch.randn(1, 3, 16, 112, 112, device='mps')
|
|
106
|
+
out = conv(x)
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
## API Reference
|
|
110
|
+
|
|
111
|
+
### `conv3d(input, weight, bias, stride, padding, dilation, groups)`
|
|
112
|
+
|
|
113
|
+
Same signature as `torch.nn.functional.conv3d`.
|
|
114
|
+
|
|
115
|
+
| Parameter | Type | Description |
|
|
116
|
+
|-----------|------|-------------|
|
|
117
|
+
| `input` | Tensor | Input tensor (N, C_in, D, H, W) |
|
|
118
|
+
| `weight` | Tensor | Weight tensor (C_out, C_in/groups, kD, kH, kW) |
|
|
119
|
+
| `bias` | Tensor | Optional bias (C_out,) |
|
|
120
|
+
| `stride` | int/tuple | Stride of convolution |
|
|
121
|
+
| `padding` | int/tuple | Padding added to input |
|
|
122
|
+
| `dilation` | int/tuple | Dilation of kernel |
|
|
123
|
+
| `groups` | int | Number of groups |
|
|
124
|
+
|
|
125
|
+
### `patch_conv3d()`
|
|
126
|
+
|
|
127
|
+
Monkey-patches `torch.nn.functional.conv3d` to use MPS implementation for MPS tensors.
|
|
128
|
+
|
|
129
|
+
### `unpatch_conv3d()`
|
|
130
|
+
|
|
131
|
+
Restores original `torch.nn.functional.conv3d`.
|
|
132
|
+
|
|
133
|
+
## Compatibility
|
|
134
|
+
|
|
135
|
+
- **PyTorch**: 2.0+
|
|
136
|
+
- **macOS**: 12.0+ (Monterey)
|
|
137
|
+
- **Hardware**: Apple Silicon (M1/M2/M3/M4)
|
|
138
|
+
|
|
139
|
+
## Features
|
|
140
|
+
|
|
141
|
+
- Full forward and backward pass (training supported)
|
|
142
|
+
- fp32 and fp16 supported
|
|
143
|
+
- Groups and dilation supported
|
|
144
|
+
- Drop-in compatible with PyTorch API
|
|
145
|
+
|
|
146
|
+
## License
|
|
147
|
+
|
|
148
|
+
MIT
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# MPS Conv3D
|
|
2
|
+
|
|
3
|
+
3D Convolution for Apple Silicon (M1/M2/M3/M4).
|
|
4
|
+
|
|
5
|
+
**Drop-in replacement** for `torch.nn.functional.conv3d` on MPS.
|
|
6
|
+
|
|
7
|
+
## Why?
|
|
8
|
+
|
|
9
|
+
3D convolutions are essential for video models:
|
|
10
|
+
- **Synchformer**: Audio-visual synchronization
|
|
11
|
+
- **I3D**: Video classification
|
|
12
|
+
- **SlowFast**: Action recognition
|
|
13
|
+
- **C3D**: Video feature extraction
|
|
14
|
+
- **MMAudio**: Audio generation from video
|
|
15
|
+
|
|
16
|
+
But PyTorch's MPS backend doesn't support 3D convolutions:
|
|
17
|
+
```
|
|
18
|
+
NotImplementedError: aten::slow_conv3d_forward is not implemented for MPS
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
This package provides a native Metal implementation.
|
|
22
|
+
|
|
23
|
+
## Installation
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
pip install mps-conv3d
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
Or from source:
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
git clone https://github.com/mpsops/mps-conv3d
|
|
33
|
+
cd mps-conv3d
|
|
34
|
+
pip install -e .
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Quick Start
|
|
38
|
+
|
|
39
|
+
### Patch All Conv3D Operations (Recommended)
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
from mps_conv3d import patch_conv3d
|
|
43
|
+
|
|
44
|
+
# Patch at the start of your script
|
|
45
|
+
patch_conv3d()
|
|
46
|
+
|
|
47
|
+
# Now all conv3d operations use MPS!
|
|
48
|
+
import torch
|
|
49
|
+
import torch.nn.functional as F
|
|
50
|
+
|
|
51
|
+
x = torch.randn(1, 3, 16, 112, 112, device='mps')
|
|
52
|
+
w = torch.randn(64, 3, 3, 7, 7, device='mps')
|
|
53
|
+
out = F.conv3d(x, w, padding=(1, 3, 3)) # Uses MPS!
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
### Direct Usage
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
import torch
|
|
60
|
+
from mps_conv3d import conv3d
|
|
61
|
+
|
|
62
|
+
x = torch.randn(1, 3, 16, 112, 112, device='mps')
|
|
63
|
+
w = torch.randn(64, 3, 3, 7, 7, device='mps')
|
|
64
|
+
|
|
65
|
+
out = conv3d(x, w, stride=1, padding=(1, 3, 3))
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
### Conv3d Module
|
|
69
|
+
|
|
70
|
+
```python
|
|
71
|
+
from mps_conv3d import Conv3d
|
|
72
|
+
|
|
73
|
+
conv = Conv3d(
|
|
74
|
+
in_channels=3,
|
|
75
|
+
out_channels=64,
|
|
76
|
+
kernel_size=(3, 7, 7),
|
|
77
|
+
stride=(1, 2, 2),
|
|
78
|
+
padding=(1, 3, 3)
|
|
79
|
+
).to('mps')
|
|
80
|
+
|
|
81
|
+
x = torch.randn(1, 3, 16, 112, 112, device='mps')
|
|
82
|
+
out = conv(x)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
## API Reference
|
|
86
|
+
|
|
87
|
+
### `conv3d(input, weight, bias, stride, padding, dilation, groups)`
|
|
88
|
+
|
|
89
|
+
Same signature as `torch.nn.functional.conv3d`.
|
|
90
|
+
|
|
91
|
+
| Parameter | Type | Description |
|
|
92
|
+
|-----------|------|-------------|
|
|
93
|
+
| `input` | Tensor | Input tensor (N, C_in, D, H, W) |
|
|
94
|
+
| `weight` | Tensor | Weight tensor (C_out, C_in/groups, kD, kH, kW) |
|
|
95
|
+
| `bias` | Tensor | Optional bias (C_out,) |
|
|
96
|
+
| `stride` | int/tuple | Stride of convolution |
|
|
97
|
+
| `padding` | int/tuple | Padding added to input |
|
|
98
|
+
| `dilation` | int/tuple | Dilation of kernel |
|
|
99
|
+
| `groups` | int | Number of groups |
|
|
100
|
+
|
|
101
|
+
### `patch_conv3d()`
|
|
102
|
+
|
|
103
|
+
Monkey-patches `torch.nn.functional.conv3d` to use MPS implementation for MPS tensors.
|
|
104
|
+
|
|
105
|
+
### `unpatch_conv3d()`
|
|
106
|
+
|
|
107
|
+
Restores original `torch.nn.functional.conv3d`.
|
|
108
|
+
|
|
109
|
+
## Compatibility
|
|
110
|
+
|
|
111
|
+
- **PyTorch**: 2.0+
|
|
112
|
+
- **macOS**: 12.0+ (Monterey)
|
|
113
|
+
- **Hardware**: Apple Silicon (M1/M2/M3/M4)
|
|
114
|
+
|
|
115
|
+
## Features
|
|
116
|
+
|
|
117
|
+
- Full forward and backward pass (training supported)
|
|
118
|
+
- fp32 and fp16 supported
|
|
119
|
+
- Groups and dilation supported
|
|
120
|
+
- Drop-in compatible with PyTorch API
|
|
121
|
+
|
|
122
|
+
## License
|
|
123
|
+
|
|
124
|
+
MIT
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MPS Conv3D - 3D Convolution for Apple Silicon
|
|
3
|
+
|
|
4
|
+
Drop-in replacement for torch.nn.functional.conv3d on MPS.
|
|
5
|
+
Used in video models: Synchformer, I3D, SlowFast, C3D, etc.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn, Tensor
|
|
10
|
+
from torch.autograd import Function
|
|
11
|
+
from typing import Optional, Tuple, Union
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
__version__ = "0.1.2"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _load_library():
|
|
18
|
+
"""Load the native extension."""
|
|
19
|
+
try:
|
|
20
|
+
from mps_conv3d import _C as _lib
|
|
21
|
+
return _lib
|
|
22
|
+
except ImportError:
|
|
23
|
+
import os
|
|
24
|
+
from torch.utils.cpp_extension import load
|
|
25
|
+
|
|
26
|
+
src_dir = os.path.join(os.path.dirname(__file__), "csrc")
|
|
27
|
+
_lib = load(
|
|
28
|
+
name="mps_conv3d",
|
|
29
|
+
sources=[os.path.join(src_dir, "conv3d_mps.mm")],
|
|
30
|
+
extra_cflags=["-std=c++17"],
|
|
31
|
+
extra_ldflags=["-framework", "Metal", "-framework", "Foundation"],
|
|
32
|
+
verbose=False,
|
|
33
|
+
)
|
|
34
|
+
return _lib
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_lib_cache = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_lib():
|
|
41
|
+
global _lib_cache
|
|
42
|
+
if _lib_cache is None:
|
|
43
|
+
_lib_cache = _load_library()
|
|
44
|
+
return _lib_cache
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class _Conv3DFunction(Function):
|
|
48
|
+
"""Autograd function for Conv3D."""
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def forward(
|
|
52
|
+
ctx,
|
|
53
|
+
input: Tensor,
|
|
54
|
+
weight: Tensor,
|
|
55
|
+
bias: Optional[Tensor],
|
|
56
|
+
stride: Tuple[int, int, int],
|
|
57
|
+
padding: Tuple[int, int, int],
|
|
58
|
+
dilation: Tuple[int, int, int],
|
|
59
|
+
groups: int,
|
|
60
|
+
) -> Tensor:
|
|
61
|
+
ctx.save_for_backward(input, weight, bias)
|
|
62
|
+
ctx.stride = stride
|
|
63
|
+
ctx.padding = padding
|
|
64
|
+
ctx.dilation = dilation
|
|
65
|
+
ctx.groups = groups
|
|
66
|
+
|
|
67
|
+
lib = _get_lib()
|
|
68
|
+
output = lib.conv3d_forward(
|
|
69
|
+
input, weight,
|
|
70
|
+
stride[0], stride[1], stride[2],
|
|
71
|
+
padding[0], padding[1], padding[2],
|
|
72
|
+
dilation[0], dilation[1], dilation[2],
|
|
73
|
+
groups
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if bias is not None:
|
|
77
|
+
output = output + bias.view(1, -1, 1, 1, 1)
|
|
78
|
+
|
|
79
|
+
return output
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def backward(ctx, grad_output: Tensor):
|
|
83
|
+
input, weight, bias = ctx.saved_tensors
|
|
84
|
+
lib = _get_lib()
|
|
85
|
+
|
|
86
|
+
grad_input = grad_weight = grad_bias = None
|
|
87
|
+
|
|
88
|
+
if ctx.needs_input_grad[0]:
|
|
89
|
+
grad_input = lib.conv3d_backward_input(
|
|
90
|
+
grad_output, weight, input.shape,
|
|
91
|
+
ctx.stride[0], ctx.stride[1], ctx.stride[2],
|
|
92
|
+
ctx.padding[0], ctx.padding[1], ctx.padding[2],
|
|
93
|
+
ctx.dilation[0], ctx.dilation[1], ctx.dilation[2],
|
|
94
|
+
ctx.groups
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if ctx.needs_input_grad[1]:
|
|
98
|
+
grad_weight = lib.conv3d_backward_weight(
|
|
99
|
+
grad_output, input, weight.shape,
|
|
100
|
+
ctx.stride[0], ctx.stride[1], ctx.stride[2],
|
|
101
|
+
ctx.padding[0], ctx.padding[1], ctx.padding[2],
|
|
102
|
+
ctx.dilation[0], ctx.dilation[1], ctx.dilation[2],
|
|
103
|
+
ctx.groups
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if bias is not None and ctx.needs_input_grad[2]:
|
|
107
|
+
grad_bias = grad_output.sum(dim=[0, 2, 3, 4])
|
|
108
|
+
|
|
109
|
+
return grad_input, grad_weight, grad_bias, None, None, None, None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _normalize_tuple(value, n, name):
|
|
113
|
+
"""Convert int or tuple to n-tuple."""
|
|
114
|
+
if isinstance(value, int):
|
|
115
|
+
return (value,) * n
|
|
116
|
+
if isinstance(value, (list, tuple)):
|
|
117
|
+
if len(value) == n:
|
|
118
|
+
return tuple(value)
|
|
119
|
+
elif len(value) == 1:
|
|
120
|
+
return (value[0],) * n
|
|
121
|
+
raise ValueError(f"{name} must be an int or {n}-tuple")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def conv3d(
|
|
125
|
+
input: Tensor,
|
|
126
|
+
weight: Tensor,
|
|
127
|
+
bias: Optional[Tensor] = None,
|
|
128
|
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
|
129
|
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
|
130
|
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
|
131
|
+
groups: int = 1,
|
|
132
|
+
) -> Tensor:
|
|
133
|
+
"""
|
|
134
|
+
3D convolution on MPS.
|
|
135
|
+
|
|
136
|
+
Drop-in replacement for torch.nn.functional.conv3d.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
input: Input tensor (N, C_in, D, H, W)
|
|
140
|
+
weight: Weight tensor (C_out, C_in/groups, kD, kH, kW)
|
|
141
|
+
bias: Optional bias tensor (C_out,)
|
|
142
|
+
stride: Stride of convolution
|
|
143
|
+
padding: Padding added to input
|
|
144
|
+
dilation: Dilation of kernel
|
|
145
|
+
groups: Number of groups
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Output tensor (N, C_out, D_out, H_out, W_out)
|
|
149
|
+
"""
|
|
150
|
+
if input.device.type != "mps":
|
|
151
|
+
# Fallback to PyTorch for non-MPS
|
|
152
|
+
return F.conv3d(input, weight, bias, stride, padding, dilation, groups)
|
|
153
|
+
|
|
154
|
+
stride = _normalize_tuple(stride, 3, "stride")
|
|
155
|
+
padding = _normalize_tuple(padding, 3, "padding")
|
|
156
|
+
dilation = _normalize_tuple(dilation, 3, "dilation")
|
|
157
|
+
|
|
158
|
+
return _Conv3DFunction.apply(
|
|
159
|
+
input.contiguous(), weight.contiguous(), bias,
|
|
160
|
+
stride, padding, dilation, groups
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Conv3d(nn.Module):
|
|
165
|
+
"""
|
|
166
|
+
3D Convolution layer for MPS.
|
|
167
|
+
|
|
168
|
+
Drop-in replacement for torch.nn.Conv3d.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
in_channels: Number of input channels
|
|
172
|
+
out_channels: Number of output channels
|
|
173
|
+
kernel_size: Size of the convolving kernel
|
|
174
|
+
stride: Stride of the convolution
|
|
175
|
+
padding: Padding added to input
|
|
176
|
+
dilation: Dilation of kernel elements
|
|
177
|
+
groups: Number of blocked connections
|
|
178
|
+
bias: If True, adds a learnable bias
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
in_channels: int,
|
|
184
|
+
out_channels: int,
|
|
185
|
+
kernel_size: Union[int, Tuple[int, int, int]],
|
|
186
|
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
|
187
|
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
|
188
|
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
|
189
|
+
groups: int = 1,
|
|
190
|
+
bias: bool = True,
|
|
191
|
+
):
|
|
192
|
+
super().__init__()
|
|
193
|
+
kernel_size = _normalize_tuple(kernel_size, 3, "kernel_size")
|
|
194
|
+
self.stride = _normalize_tuple(stride, 3, "stride")
|
|
195
|
+
self.padding = _normalize_tuple(padding, 3, "padding")
|
|
196
|
+
self.dilation = _normalize_tuple(dilation, 3, "dilation")
|
|
197
|
+
self.groups = groups
|
|
198
|
+
|
|
199
|
+
self.weight = nn.Parameter(
|
|
200
|
+
torch.empty(out_channels, in_channels // groups, *kernel_size)
|
|
201
|
+
)
|
|
202
|
+
if bias:
|
|
203
|
+
self.bias = nn.Parameter(torch.empty(out_channels))
|
|
204
|
+
else:
|
|
205
|
+
self.register_parameter('bias', None)
|
|
206
|
+
|
|
207
|
+
self.reset_parameters()
|
|
208
|
+
|
|
209
|
+
def reset_parameters(self):
|
|
210
|
+
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
|
211
|
+
if self.bias is not None:
|
|
212
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
|
213
|
+
bound = 1 / (fan_in ** 0.5)
|
|
214
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
215
|
+
|
|
216
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
217
|
+
return conv3d(
|
|
218
|
+
input, self.weight, self.bias,
|
|
219
|
+
self.stride, self.padding, self.dilation, self.groups
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
_original_conv3d = None
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def patch_conv3d():
|
|
227
|
+
"""
|
|
228
|
+
Monkey-patch torch.nn.functional.conv3d to use MPS implementation.
|
|
229
|
+
|
|
230
|
+
Call this at the start of your script to automatically use MPS conv3d
|
|
231
|
+
for all 3D convolution operations.
|
|
232
|
+
"""
|
|
233
|
+
global _original_conv3d
|
|
234
|
+
|
|
235
|
+
if _original_conv3d is not None:
|
|
236
|
+
return # Already patched
|
|
237
|
+
|
|
238
|
+
_original_conv3d = F.conv3d
|
|
239
|
+
|
|
240
|
+
def patched_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|
241
|
+
if input.device.type == 'mps':
|
|
242
|
+
return conv3d(input, weight, bias, stride, padding, dilation, groups)
|
|
243
|
+
return _original_conv3d(input, weight, bias, stride, padding, dilation, groups)
|
|
244
|
+
|
|
245
|
+
F.conv3d = patched_conv3d
|
|
246
|
+
print("MPS Conv3D: Patched F.conv3d")
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def unpatch_conv3d():
|
|
250
|
+
"""Restore original torch.nn.functional.conv3d."""
|
|
251
|
+
global _original_conv3d
|
|
252
|
+
if _original_conv3d is not None:
|
|
253
|
+
F.conv3d = _original_conv3d
|
|
254
|
+
_original_conv3d = None
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def is_available() -> bool:
|
|
258
|
+
"""Check if MPS is available."""
|
|
259
|
+
return torch.backends.mps.is_available()
|