mps-correlation 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,163 @@
1
+ Metadata-Version: 2.4
2
+ Name: mps-correlation
3
+ Version: 0.1.2
4
+ Summary: Correlation layer for optical flow on Apple Silicon (MPS)
5
+ Author: mpsops
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/mpsops/mps-correlation
8
+ Project-URL: Repository, https://github.com/mpsops/mps-correlation
9
+ Project-URL: Issues, https://github.com/mpsops/mps-correlation/issues
10
+ Keywords: correlation,optical-flow,raft,pwc-net,flownet,apple-silicon,pytorch,mps,metal
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Operating System :: MacOS :: MacOS X
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ Requires-Dist: torch>=2.0.0
23
+ Dynamic: requires-python
24
+
25
+ # MPS Correlation
26
+
27
+ Correlation layer for optical flow on Apple Silicon (M1/M2/M3/M4).
28
+
29
+ **Drop-in replacement** for `spatial-correlation-sampler` and mmcv's correlation op.
30
+
31
+ ## Why?
32
+
33
+ Correlation layers are essential for optical flow estimation:
34
+ - **RAFT**: State-of-the-art optical flow
35
+ - **PWC-Net**: Efficient optical flow
36
+ - **FlowNet/FlowNet2**: Classic deep optical flow
37
+
38
+ But existing implementations are **CUDA-only**. On Mac you get:
39
+ ```
40
+ NotImplementedError: correlation not implemented for MPS
41
+ ```
42
+
43
+ This package provides a native Metal implementation.
44
+
45
+ ## Installation
46
+
47
+ ```bash
48
+ pip install mps-correlation
49
+ ```
50
+
51
+ Or from source:
52
+
53
+ ```bash
54
+ git clone https://github.com/mpsops/mps-correlation
55
+ cd mps-correlation
56
+ pip install -e .
57
+ ```
58
+
59
+ ## Quick Start
60
+
61
+ ### Basic Usage
62
+
63
+ ```python
64
+ import torch
65
+ from mps_correlation import correlation
66
+
67
+ # Two feature maps from consecutive frames
68
+ fmap1 = torch.randn(1, 256, 64, 64, device='mps')
69
+ fmap2 = torch.randn(1, 256, 64, 64, device='mps')
70
+
71
+ # Compute correlation volume
72
+ corr = correlation(
73
+ fmap1, fmap2,
74
+ kernel_size=1,
75
+ max_displacement=4,
76
+ stride1=1,
77
+ stride2=1,
78
+ pad_size=4
79
+ )
80
+ # Output: (1, 81, 64, 64) - 81 = (2*4+1)^2 displacement channels
81
+ ```
82
+
83
+ ### Correlation Module
84
+
85
+ ```python
86
+ from mps_correlation import Correlation
87
+
88
+ corr_layer = Correlation(
89
+ kernel_size=1,
90
+ max_displacement=4,
91
+ stride1=1,
92
+ stride2=1,
93
+ pad_size=4
94
+ )
95
+
96
+ corr = corr_layer(fmap1, fmap2)
97
+ ```
98
+
99
+ ### RAFT-style All-Pairs Correlation
100
+
101
+ ```python
102
+ from mps_correlation import CorrBlock
103
+
104
+ # Build correlation pyramid
105
+ corr_block = CorrBlock(fmap1, fmap2, num_levels=4, radius=4)
106
+
107
+ # Lookup at specific coordinates
108
+ coords = torch.zeros(1, 2, 64, 64, device='mps') # (x, y) coordinates
109
+ corr_features = corr_block(coords)
110
+ ```
111
+
112
+ ## API Reference
113
+
114
+ ### `correlation(input1, input2, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)`
115
+
116
+ | Parameter | Type | Description |
117
+ |-----------|------|-------------|
118
+ | `input1` | Tensor | First feature map (N, C, H, W) |
119
+ | `input2` | Tensor | Second feature map (N, C, H, W) |
120
+ | `kernel_size` | int | Size of correlation kernel (default: 1) |
121
+ | `max_displacement` | int | Maximum displacement to search (default: 4) |
122
+ | `stride1` | int | Stride for input1 (default: 1) |
123
+ | `stride2` | int | Stride for displacement (default: 1) |
124
+ | `pad_size` | int | Padding size (default: 4) |
125
+ | `is_multiply` | bool | Use multiplication (True) or subtraction (False) |
126
+
127
+ ### `CorrBlock`
128
+
129
+ RAFT-style correlation block with pyramid and lookup.
130
+
131
+ ## How It Works
132
+
133
+ Correlation computes similarity between patches at different displacements:
134
+
135
+ ```
136
+ For each position (x, y) in output:
137
+ For each displacement (dx, dy) in [-max_disp, max_disp]:
138
+ corr[x, y, dx, dy] = sum(fmap1[x, y, :] * fmap2[x+dx, y+dy, :])
139
+ ```
140
+
141
+ This creates a 4D cost volume that optical flow networks use to estimate motion.
142
+
143
+ ## Compatibility
144
+
145
+ - **PyTorch**: 2.0+
146
+ - **macOS**: 12.0+ (Monterey)
147
+ - **Hardware**: Apple Silicon (M1/M2/M3/M4)
148
+
149
+ ## Features
150
+
151
+ - Full forward and backward pass (training supported)
152
+ - fp32 and fp16 supported
153
+ - Compatible with RAFT, PWC-Net, FlowNet architectures
154
+
155
+ ## Credits
156
+
157
+ - [spatial-correlation-sampler](https://github.com/ClementPinard/Pytorch-Correlation-extension) - Reference implementation
158
+ - [RAFT](https://github.com/princeton-vl/RAFT) - State-of-the-art optical flow
159
+ - [PWC-Net](https://github.com/NVlabs/PWC-Net) - Efficient optical flow
160
+
161
+ ## License
162
+
163
+ MIT
@@ -0,0 +1,139 @@
1
+ # MPS Correlation
2
+
3
+ Correlation layer for optical flow on Apple Silicon (M1/M2/M3/M4).
4
+
5
+ **Drop-in replacement** for `spatial-correlation-sampler` and mmcv's correlation op.
6
+
7
+ ## Why?
8
+
9
+ Correlation layers are essential for optical flow estimation:
10
+ - **RAFT**: State-of-the-art optical flow
11
+ - **PWC-Net**: Efficient optical flow
12
+ - **FlowNet/FlowNet2**: Classic deep optical flow
13
+
14
+ But existing implementations are **CUDA-only**. On Mac you get:
15
+ ```
16
+ NotImplementedError: correlation not implemented for MPS
17
+ ```
18
+
19
+ This package provides a native Metal implementation.
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ pip install mps-correlation
25
+ ```
26
+
27
+ Or from source:
28
+
29
+ ```bash
30
+ git clone https://github.com/mpsops/mps-correlation
31
+ cd mps-correlation
32
+ pip install -e .
33
+ ```
34
+
35
+ ## Quick Start
36
+
37
+ ### Basic Usage
38
+
39
+ ```python
40
+ import torch
41
+ from mps_correlation import correlation
42
+
43
+ # Two feature maps from consecutive frames
44
+ fmap1 = torch.randn(1, 256, 64, 64, device='mps')
45
+ fmap2 = torch.randn(1, 256, 64, 64, device='mps')
46
+
47
+ # Compute correlation volume
48
+ corr = correlation(
49
+ fmap1, fmap2,
50
+ kernel_size=1,
51
+ max_displacement=4,
52
+ stride1=1,
53
+ stride2=1,
54
+ pad_size=4
55
+ )
56
+ # Output: (1, 81, 64, 64) - 81 = (2*4+1)^2 displacement channels
57
+ ```
58
+
59
+ ### Correlation Module
60
+
61
+ ```python
62
+ from mps_correlation import Correlation
63
+
64
+ corr_layer = Correlation(
65
+ kernel_size=1,
66
+ max_displacement=4,
67
+ stride1=1,
68
+ stride2=1,
69
+ pad_size=4
70
+ )
71
+
72
+ corr = corr_layer(fmap1, fmap2)
73
+ ```
74
+
75
+ ### RAFT-style All-Pairs Correlation
76
+
77
+ ```python
78
+ from mps_correlation import CorrBlock
79
+
80
+ # Build correlation pyramid
81
+ corr_block = CorrBlock(fmap1, fmap2, num_levels=4, radius=4)
82
+
83
+ # Lookup at specific coordinates
84
+ coords = torch.zeros(1, 2, 64, 64, device='mps') # (x, y) coordinates
85
+ corr_features = corr_block(coords)
86
+ ```
87
+
88
+ ## API Reference
89
+
90
+ ### `correlation(input1, input2, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)`
91
+
92
+ | Parameter | Type | Description |
93
+ |-----------|------|-------------|
94
+ | `input1` | Tensor | First feature map (N, C, H, W) |
95
+ | `input2` | Tensor | Second feature map (N, C, H, W) |
96
+ | `kernel_size` | int | Size of correlation kernel (default: 1) |
97
+ | `max_displacement` | int | Maximum displacement to search (default: 4) |
98
+ | `stride1` | int | Stride for input1 (default: 1) |
99
+ | `stride2` | int | Stride for displacement (default: 1) |
100
+ | `pad_size` | int | Padding size (default: 4) |
101
+ | `is_multiply` | bool | Use multiplication (True) or subtraction (False) |
102
+
103
+ ### `CorrBlock`
104
+
105
+ RAFT-style correlation block with pyramid and lookup.
106
+
107
+ ## How It Works
108
+
109
+ Correlation computes similarity between patches at different displacements:
110
+
111
+ ```
112
+ For each position (x, y) in output:
113
+ For each displacement (dx, dy) in [-max_disp, max_disp]:
114
+ corr[x, y, dx, dy] = sum(fmap1[x, y, :] * fmap2[x+dx, y+dy, :])
115
+ ```
116
+
117
+ This creates a 4D cost volume that optical flow networks use to estimate motion.
118
+
119
+ ## Compatibility
120
+
121
+ - **PyTorch**: 2.0+
122
+ - **macOS**: 12.0+ (Monterey)
123
+ - **Hardware**: Apple Silicon (M1/M2/M3/M4)
124
+
125
+ ## Features
126
+
127
+ - Full forward and backward pass (training supported)
128
+ - fp32 and fp16 supported
129
+ - Compatible with RAFT, PWC-Net, FlowNet architectures
130
+
131
+ ## Credits
132
+
133
+ - [spatial-correlation-sampler](https://github.com/ClementPinard/Pytorch-Correlation-extension) - Reference implementation
134
+ - [RAFT](https://github.com/princeton-vl/RAFT) - State-of-the-art optical flow
135
+ - [PWC-Net](https://github.com/NVlabs/PWC-Net) - Efficient optical flow
136
+
137
+ ## License
138
+
139
+ MIT
@@ -0,0 +1,246 @@
1
+ """
2
+ MPS Correlation - Optical flow correlation layer for Apple Silicon
3
+
4
+ Drop-in replacement for spatial-correlation-sampler and mmcv's correlation op.
5
+ Used in RAFT, PWC-Net, FlowNet, etc.
6
+ """
7
+
8
+ import torch
9
+ from torch import nn, Tensor
10
+ from torch.autograd import Function
11
+ from typing import Optional, Tuple
12
+ import math
13
+
14
+ __version__ = "0.1.2"
15
+
16
+
17
+ def _load_library():
18
+ """Load the native extension."""
19
+ try:
20
+ from mps_correlation import _C as _lib
21
+ return _lib
22
+ except ImportError:
23
+ import os
24
+ from torch.utils.cpp_extension import load
25
+
26
+ src_dir = os.path.join(os.path.dirname(__file__), "csrc")
27
+ _lib = load(
28
+ name="mps_correlation",
29
+ sources=[os.path.join(src_dir, "correlation_mps.mm")],
30
+ extra_cflags=["-std=c++17"],
31
+ extra_ldflags=["-framework", "Metal", "-framework", "Foundation"],
32
+ verbose=False,
33
+ )
34
+ return _lib
35
+
36
+
37
+ _lib_cache = None
38
+
39
+
40
+ def _get_lib():
41
+ global _lib_cache
42
+ if _lib_cache is None:
43
+ _lib_cache = _load_library()
44
+ return _lib_cache
45
+
46
+
47
+ class _CorrelationFunction(Function):
48
+ """Autograd function for correlation."""
49
+
50
+ @staticmethod
51
+ def forward(
52
+ ctx,
53
+ input1: Tensor,
54
+ input2: Tensor,
55
+ kernel_size: int,
56
+ max_displacement: int,
57
+ stride1: int,
58
+ stride2: int,
59
+ pad_size: int,
60
+ is_multiply: bool,
61
+ ) -> Tensor:
62
+ ctx.save_for_backward(input1, input2)
63
+ ctx.kernel_size = kernel_size
64
+ ctx.max_displacement = max_displacement
65
+ ctx.stride1 = stride1
66
+ ctx.stride2 = stride2
67
+ ctx.pad_size = pad_size
68
+ ctx.is_multiply = is_multiply
69
+
70
+ lib = _get_lib()
71
+ output = lib.correlation_forward(
72
+ input1, input2,
73
+ kernel_size, max_displacement,
74
+ stride1, stride2, pad_size,
75
+ is_multiply
76
+ )
77
+ return output
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output: Tensor):
81
+ input1, input2 = ctx.saved_tensors
82
+ lib = _get_lib()
83
+
84
+ grad_input1, grad_input2 = lib.correlation_backward(
85
+ grad_output.contiguous(), input1.contiguous(), input2.contiguous(),
86
+ ctx.kernel_size, ctx.max_displacement,
87
+ ctx.stride1, ctx.stride2, ctx.pad_size,
88
+ ctx.is_multiply
89
+ )
90
+
91
+ return grad_input1, grad_input2, None, None, None, None, None, None
92
+
93
+
94
+ def correlation(
95
+ input1: Tensor,
96
+ input2: Tensor,
97
+ kernel_size: int = 1,
98
+ max_displacement: int = 4,
99
+ stride1: int = 1,
100
+ stride2: int = 1,
101
+ pad_size: int = 4,
102
+ is_multiply: bool = True,
103
+ ) -> Tensor:
104
+ """
105
+ Compute correlation between two feature maps.
106
+
107
+ Used in optical flow estimation (RAFT, PWC-Net, FlowNet).
108
+
109
+ Args:
110
+ input1: First feature map (N, C, H, W)
111
+ input2: Second feature map (N, C, H, W)
112
+ kernel_size: Size of the correlation kernel
113
+ max_displacement: Maximum displacement for correlation search
114
+ stride1: Stride for input1
115
+ stride2: Stride for input2 (displacement stride)
116
+ pad_size: Padding size
117
+ is_multiply: If True, use multiplication. If False, use subtraction.
118
+
119
+ Returns:
120
+ Correlation volume (N, D*D, H, W) where D = 2*max_displacement/stride2 + 1
121
+ """
122
+ if input1.device.type != "mps":
123
+ raise ValueError(f"Input must be on MPS device, got {input1.device}")
124
+
125
+ return _CorrelationFunction.apply(
126
+ input1, input2,
127
+ kernel_size, max_displacement,
128
+ stride1, stride2, pad_size,
129
+ is_multiply
130
+ )
131
+
132
+
133
+ class Correlation(nn.Module):
134
+ """
135
+ Correlation layer for optical flow.
136
+
137
+ Computes a cost volume by correlating features from two images.
138
+
139
+ Args:
140
+ kernel_size: Size of correlation kernel
141
+ max_displacement: Maximum displacement to search
142
+ stride1: Stride for first input
143
+ stride2: Stride for displacement
144
+ pad_size: Padding size
145
+ is_multiply: Use multiplication (True) or subtraction (False)
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ kernel_size: int = 1,
151
+ max_displacement: int = 4,
152
+ stride1: int = 1,
153
+ stride2: int = 1,
154
+ pad_size: int = 4,
155
+ is_multiply: bool = True,
156
+ ):
157
+ super().__init__()
158
+ self.kernel_size = kernel_size
159
+ self.max_displacement = max_displacement
160
+ self.stride1 = stride1
161
+ self.stride2 = stride2
162
+ self.pad_size = pad_size
163
+ self.is_multiply = is_multiply
164
+
165
+ def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
166
+ return correlation(
167
+ input1, input2,
168
+ self.kernel_size, self.max_displacement,
169
+ self.stride1, self.stride2, self.pad_size,
170
+ self.is_multiply
171
+ )
172
+
173
+
174
+ # RAFT-style correlation (all-pairs)
175
+ class CorrBlock:
176
+ """
177
+ RAFT-style correlation block.
178
+
179
+ Computes all-pairs correlation and provides lookup functionality.
180
+ """
181
+
182
+ def __init__(self, fmap1: Tensor, fmap2: Tensor, num_levels: int = 4, radius: int = 4):
183
+ self.num_levels = num_levels
184
+ self.radius = radius
185
+ self.corr_pyramid = []
186
+
187
+ # All-pairs correlation
188
+ batch, dim, ht, wd = fmap1.shape
189
+ fmap1 = fmap1.view(batch, dim, ht * wd)
190
+ fmap2 = fmap2.view(batch, dim, ht * wd)
191
+
192
+ corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
193
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
194
+ corr = corr / torch.sqrt(torch.tensor(dim, dtype=torch.float32, device=fmap1.device))
195
+
196
+ self.corr_pyramid.append(corr)
197
+ for _ in range(num_levels - 1):
198
+ # Get current spatial dims from last two dimensions
199
+ curr_h, curr_w = corr.shape[-2], corr.shape[-1]
200
+ corr = torch.nn.functional.avg_pool2d(
201
+ corr.view(batch * ht * wd, 1, curr_h, curr_w),
202
+ kernel_size=2, stride=2
203
+ )
204
+ _, _, h, w = corr.shape
205
+ corr = corr.view(batch, ht, wd, 1, h, w)
206
+ self.corr_pyramid.append(corr)
207
+
208
+ def __call__(self, coords: Tensor) -> Tensor:
209
+ """Lookup correlation values at given coordinates."""
210
+ r = self.radius
211
+ batch, _, ht, wd = coords.shape
212
+ coords = coords.permute(0, 2, 3, 1)
213
+
214
+ out_pyramid = []
215
+ for i, corr in enumerate(self.corr_pyramid):
216
+ # Build lookup grid
217
+ dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
218
+ dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
219
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), dim=-1)
220
+
221
+ centroid_lvl = coords.reshape(batch * ht * wd, 1, 1, 2) / 2**i
222
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
223
+ coords_lvl = centroid_lvl + delta_lvl
224
+
225
+ # Sample from correlation volume
226
+ corr_lvl = corr.view(batch * ht * wd, 1, *corr.shape[-2:])
227
+
228
+ # Normalize coords to [-1, 1]
229
+ _, _, h, w = corr_lvl.shape
230
+ coords_lvl[..., 0] = 2 * coords_lvl[..., 0] / (w - 1) - 1
231
+ coords_lvl[..., 1] = 2 * coords_lvl[..., 1] / (h - 1) - 1
232
+
233
+ corr_lvl = torch.nn.functional.grid_sample(
234
+ corr_lvl, coords_lvl,
235
+ align_corners=True, mode='bilinear', padding_mode='zeros'
236
+ )
237
+ corr_lvl = corr_lvl.view(batch, ht, wd, -1)
238
+ out_pyramid.append(corr_lvl)
239
+
240
+ out = torch.cat(out_pyramid, dim=-1)
241
+ return out.permute(0, 3, 1, 2).contiguous()
242
+
243
+
244
+ def is_available() -> bool:
245
+ """Check if MPS is available."""
246
+ return torch.backends.mps.is_available()