mps-correlation 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mps_correlation-0.1.2/PKG-INFO +163 -0
- mps_correlation-0.1.2/README.md +139 -0
- mps_correlation-0.1.2/mps_correlation/__init__.py +246 -0
- mps_correlation-0.1.2/mps_correlation/csrc/correlation_mps.mm +543 -0
- mps_correlation-0.1.2/mps_correlation.egg-info/PKG-INFO +163 -0
- mps_correlation-0.1.2/mps_correlation.egg-info/SOURCES.txt +11 -0
- mps_correlation-0.1.2/mps_correlation.egg-info/dependency_links.txt +1 -0
- mps_correlation-0.1.2/mps_correlation.egg-info/requires.txt +1 -0
- mps_correlation-0.1.2/mps_correlation.egg-info/top_level.txt +1 -0
- mps_correlation-0.1.2/pyproject.toml +38 -0
- mps_correlation-0.1.2/setup.cfg +4 -0
- mps_correlation-0.1.2/setup.py +64 -0
- mps_correlation-0.1.2/tests/test_correlation.py +273 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mps-correlation
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Correlation layer for optical flow on Apple Silicon (MPS)
|
|
5
|
+
Author: mpsops
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/mpsops/mps-correlation
|
|
8
|
+
Project-URL: Repository, https://github.com/mpsops/mps-correlation
|
|
9
|
+
Project-URL: Issues, https://github.com/mpsops/mps-correlation/issues
|
|
10
|
+
Keywords: correlation,optical-flow,raft,pwc-net,flownet,apple-silicon,pytorch,mps,metal
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
Requires-Dist: torch>=2.0.0
|
|
23
|
+
Dynamic: requires-python
|
|
24
|
+
|
|
25
|
+
# MPS Correlation
|
|
26
|
+
|
|
27
|
+
Correlation layer for optical flow on Apple Silicon (M1/M2/M3/M4).
|
|
28
|
+
|
|
29
|
+
**Drop-in replacement** for `spatial-correlation-sampler` and mmcv's correlation op.
|
|
30
|
+
|
|
31
|
+
## Why?
|
|
32
|
+
|
|
33
|
+
Correlation layers are essential for optical flow estimation:
|
|
34
|
+
- **RAFT**: State-of-the-art optical flow
|
|
35
|
+
- **PWC-Net**: Efficient optical flow
|
|
36
|
+
- **FlowNet/FlowNet2**: Classic deep optical flow
|
|
37
|
+
|
|
38
|
+
But existing implementations are **CUDA-only**. On Mac you get:
|
|
39
|
+
```
|
|
40
|
+
NotImplementedError: correlation not implemented for MPS
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
This package provides a native Metal implementation.
|
|
44
|
+
|
|
45
|
+
## Installation
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install mps-correlation
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Or from source:
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
git clone https://github.com/mpsops/mps-correlation
|
|
55
|
+
cd mps-correlation
|
|
56
|
+
pip install -e .
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Quick Start
|
|
60
|
+
|
|
61
|
+
### Basic Usage
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
import torch
|
|
65
|
+
from mps_correlation import correlation
|
|
66
|
+
|
|
67
|
+
# Two feature maps from consecutive frames
|
|
68
|
+
fmap1 = torch.randn(1, 256, 64, 64, device='mps')
|
|
69
|
+
fmap2 = torch.randn(1, 256, 64, 64, device='mps')
|
|
70
|
+
|
|
71
|
+
# Compute correlation volume
|
|
72
|
+
corr = correlation(
|
|
73
|
+
fmap1, fmap2,
|
|
74
|
+
kernel_size=1,
|
|
75
|
+
max_displacement=4,
|
|
76
|
+
stride1=1,
|
|
77
|
+
stride2=1,
|
|
78
|
+
pad_size=4
|
|
79
|
+
)
|
|
80
|
+
# Output: (1, 81, 64, 64) - 81 = (2*4+1)^2 displacement channels
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
### Correlation Module
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
from mps_correlation import Correlation
|
|
87
|
+
|
|
88
|
+
corr_layer = Correlation(
|
|
89
|
+
kernel_size=1,
|
|
90
|
+
max_displacement=4,
|
|
91
|
+
stride1=1,
|
|
92
|
+
stride2=1,
|
|
93
|
+
pad_size=4
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
corr = corr_layer(fmap1, fmap2)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
### RAFT-style All-Pairs Correlation
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from mps_correlation import CorrBlock
|
|
103
|
+
|
|
104
|
+
# Build correlation pyramid
|
|
105
|
+
corr_block = CorrBlock(fmap1, fmap2, num_levels=4, radius=4)
|
|
106
|
+
|
|
107
|
+
# Lookup at specific coordinates
|
|
108
|
+
coords = torch.zeros(1, 2, 64, 64, device='mps') # (x, y) coordinates
|
|
109
|
+
corr_features = corr_block(coords)
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
## API Reference
|
|
113
|
+
|
|
114
|
+
### `correlation(input1, input2, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)`
|
|
115
|
+
|
|
116
|
+
| Parameter | Type | Description |
|
|
117
|
+
|-----------|------|-------------|
|
|
118
|
+
| `input1` | Tensor | First feature map (N, C, H, W) |
|
|
119
|
+
| `input2` | Tensor | Second feature map (N, C, H, W) |
|
|
120
|
+
| `kernel_size` | int | Size of correlation kernel (default: 1) |
|
|
121
|
+
| `max_displacement` | int | Maximum displacement to search (default: 4) |
|
|
122
|
+
| `stride1` | int | Stride for input1 (default: 1) |
|
|
123
|
+
| `stride2` | int | Stride for displacement (default: 1) |
|
|
124
|
+
| `pad_size` | int | Padding size (default: 4) |
|
|
125
|
+
| `is_multiply` | bool | Use multiplication (True) or subtraction (False) |
|
|
126
|
+
|
|
127
|
+
### `CorrBlock`
|
|
128
|
+
|
|
129
|
+
RAFT-style correlation block with pyramid and lookup.
|
|
130
|
+
|
|
131
|
+
## How It Works
|
|
132
|
+
|
|
133
|
+
Correlation computes similarity between patches at different displacements:
|
|
134
|
+
|
|
135
|
+
```
|
|
136
|
+
For each position (x, y) in output:
|
|
137
|
+
For each displacement (dx, dy) in [-max_disp, max_disp]:
|
|
138
|
+
corr[x, y, dx, dy] = sum(fmap1[x, y, :] * fmap2[x+dx, y+dy, :])
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
This creates a 4D cost volume that optical flow networks use to estimate motion.
|
|
142
|
+
|
|
143
|
+
## Compatibility
|
|
144
|
+
|
|
145
|
+
- **PyTorch**: 2.0+
|
|
146
|
+
- **macOS**: 12.0+ (Monterey)
|
|
147
|
+
- **Hardware**: Apple Silicon (M1/M2/M3/M4)
|
|
148
|
+
|
|
149
|
+
## Features
|
|
150
|
+
|
|
151
|
+
- Full forward and backward pass (training supported)
|
|
152
|
+
- fp32 and fp16 supported
|
|
153
|
+
- Compatible with RAFT, PWC-Net, FlowNet architectures
|
|
154
|
+
|
|
155
|
+
## Credits
|
|
156
|
+
|
|
157
|
+
- [spatial-correlation-sampler](https://github.com/ClementPinard/Pytorch-Correlation-extension) - Reference implementation
|
|
158
|
+
- [RAFT](https://github.com/princeton-vl/RAFT) - State-of-the-art optical flow
|
|
159
|
+
- [PWC-Net](https://github.com/NVlabs/PWC-Net) - Efficient optical flow
|
|
160
|
+
|
|
161
|
+
## License
|
|
162
|
+
|
|
163
|
+
MIT
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# MPS Correlation
|
|
2
|
+
|
|
3
|
+
Correlation layer for optical flow on Apple Silicon (M1/M2/M3/M4).
|
|
4
|
+
|
|
5
|
+
**Drop-in replacement** for `spatial-correlation-sampler` and mmcv's correlation op.
|
|
6
|
+
|
|
7
|
+
## Why?
|
|
8
|
+
|
|
9
|
+
Correlation layers are essential for optical flow estimation:
|
|
10
|
+
- **RAFT**: State-of-the-art optical flow
|
|
11
|
+
- **PWC-Net**: Efficient optical flow
|
|
12
|
+
- **FlowNet/FlowNet2**: Classic deep optical flow
|
|
13
|
+
|
|
14
|
+
But existing implementations are **CUDA-only**. On Mac you get:
|
|
15
|
+
```
|
|
16
|
+
NotImplementedError: correlation not implemented for MPS
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
This package provides a native Metal implementation.
|
|
20
|
+
|
|
21
|
+
## Installation
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
pip install mps-correlation
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Or from source:
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
git clone https://github.com/mpsops/mps-correlation
|
|
31
|
+
cd mps-correlation
|
|
32
|
+
pip install -e .
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
## Quick Start
|
|
36
|
+
|
|
37
|
+
### Basic Usage
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
import torch
|
|
41
|
+
from mps_correlation import correlation
|
|
42
|
+
|
|
43
|
+
# Two feature maps from consecutive frames
|
|
44
|
+
fmap1 = torch.randn(1, 256, 64, 64, device='mps')
|
|
45
|
+
fmap2 = torch.randn(1, 256, 64, 64, device='mps')
|
|
46
|
+
|
|
47
|
+
# Compute correlation volume
|
|
48
|
+
corr = correlation(
|
|
49
|
+
fmap1, fmap2,
|
|
50
|
+
kernel_size=1,
|
|
51
|
+
max_displacement=4,
|
|
52
|
+
stride1=1,
|
|
53
|
+
stride2=1,
|
|
54
|
+
pad_size=4
|
|
55
|
+
)
|
|
56
|
+
# Output: (1, 81, 64, 64) - 81 = (2*4+1)^2 displacement channels
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
### Correlation Module
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
from mps_correlation import Correlation
|
|
63
|
+
|
|
64
|
+
corr_layer = Correlation(
|
|
65
|
+
kernel_size=1,
|
|
66
|
+
max_displacement=4,
|
|
67
|
+
stride1=1,
|
|
68
|
+
stride2=1,
|
|
69
|
+
pad_size=4
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
corr = corr_layer(fmap1, fmap2)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
### RAFT-style All-Pairs Correlation
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
from mps_correlation import CorrBlock
|
|
79
|
+
|
|
80
|
+
# Build correlation pyramid
|
|
81
|
+
corr_block = CorrBlock(fmap1, fmap2, num_levels=4, radius=4)
|
|
82
|
+
|
|
83
|
+
# Lookup at specific coordinates
|
|
84
|
+
coords = torch.zeros(1, 2, 64, 64, device='mps') # (x, y) coordinates
|
|
85
|
+
corr_features = corr_block(coords)
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
## API Reference
|
|
89
|
+
|
|
90
|
+
### `correlation(input1, input2, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)`
|
|
91
|
+
|
|
92
|
+
| Parameter | Type | Description |
|
|
93
|
+
|-----------|------|-------------|
|
|
94
|
+
| `input1` | Tensor | First feature map (N, C, H, W) |
|
|
95
|
+
| `input2` | Tensor | Second feature map (N, C, H, W) |
|
|
96
|
+
| `kernel_size` | int | Size of correlation kernel (default: 1) |
|
|
97
|
+
| `max_displacement` | int | Maximum displacement to search (default: 4) |
|
|
98
|
+
| `stride1` | int | Stride for input1 (default: 1) |
|
|
99
|
+
| `stride2` | int | Stride for displacement (default: 1) |
|
|
100
|
+
| `pad_size` | int | Padding size (default: 4) |
|
|
101
|
+
| `is_multiply` | bool | Use multiplication (True) or subtraction (False) |
|
|
102
|
+
|
|
103
|
+
### `CorrBlock`
|
|
104
|
+
|
|
105
|
+
RAFT-style correlation block with pyramid and lookup.
|
|
106
|
+
|
|
107
|
+
## How It Works
|
|
108
|
+
|
|
109
|
+
Correlation computes similarity between patches at different displacements:
|
|
110
|
+
|
|
111
|
+
```
|
|
112
|
+
For each position (x, y) in output:
|
|
113
|
+
For each displacement (dx, dy) in [-max_disp, max_disp]:
|
|
114
|
+
corr[x, y, dx, dy] = sum(fmap1[x, y, :] * fmap2[x+dx, y+dy, :])
|
|
115
|
+
```
|
|
116
|
+
|
|
117
|
+
This creates a 4D cost volume that optical flow networks use to estimate motion.
|
|
118
|
+
|
|
119
|
+
## Compatibility
|
|
120
|
+
|
|
121
|
+
- **PyTorch**: 2.0+
|
|
122
|
+
- **macOS**: 12.0+ (Monterey)
|
|
123
|
+
- **Hardware**: Apple Silicon (M1/M2/M3/M4)
|
|
124
|
+
|
|
125
|
+
## Features
|
|
126
|
+
|
|
127
|
+
- Full forward and backward pass (training supported)
|
|
128
|
+
- fp32 and fp16 supported
|
|
129
|
+
- Compatible with RAFT, PWC-Net, FlowNet architectures
|
|
130
|
+
|
|
131
|
+
## Credits
|
|
132
|
+
|
|
133
|
+
- [spatial-correlation-sampler](https://github.com/ClementPinard/Pytorch-Correlation-extension) - Reference implementation
|
|
134
|
+
- [RAFT](https://github.com/princeton-vl/RAFT) - State-of-the-art optical flow
|
|
135
|
+
- [PWC-Net](https://github.com/NVlabs/PWC-Net) - Efficient optical flow
|
|
136
|
+
|
|
137
|
+
## License
|
|
138
|
+
|
|
139
|
+
MIT
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MPS Correlation - Optical flow correlation layer for Apple Silicon
|
|
3
|
+
|
|
4
|
+
Drop-in replacement for spatial-correlation-sampler and mmcv's correlation op.
|
|
5
|
+
Used in RAFT, PWC-Net, FlowNet, etc.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn, Tensor
|
|
10
|
+
from torch.autograd import Function
|
|
11
|
+
from typing import Optional, Tuple
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
__version__ = "0.1.2"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _load_library():
|
|
18
|
+
"""Load the native extension."""
|
|
19
|
+
try:
|
|
20
|
+
from mps_correlation import _C as _lib
|
|
21
|
+
return _lib
|
|
22
|
+
except ImportError:
|
|
23
|
+
import os
|
|
24
|
+
from torch.utils.cpp_extension import load
|
|
25
|
+
|
|
26
|
+
src_dir = os.path.join(os.path.dirname(__file__), "csrc")
|
|
27
|
+
_lib = load(
|
|
28
|
+
name="mps_correlation",
|
|
29
|
+
sources=[os.path.join(src_dir, "correlation_mps.mm")],
|
|
30
|
+
extra_cflags=["-std=c++17"],
|
|
31
|
+
extra_ldflags=["-framework", "Metal", "-framework", "Foundation"],
|
|
32
|
+
verbose=False,
|
|
33
|
+
)
|
|
34
|
+
return _lib
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_lib_cache = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_lib():
|
|
41
|
+
global _lib_cache
|
|
42
|
+
if _lib_cache is None:
|
|
43
|
+
_lib_cache = _load_library()
|
|
44
|
+
return _lib_cache
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class _CorrelationFunction(Function):
|
|
48
|
+
"""Autograd function for correlation."""
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def forward(
|
|
52
|
+
ctx,
|
|
53
|
+
input1: Tensor,
|
|
54
|
+
input2: Tensor,
|
|
55
|
+
kernel_size: int,
|
|
56
|
+
max_displacement: int,
|
|
57
|
+
stride1: int,
|
|
58
|
+
stride2: int,
|
|
59
|
+
pad_size: int,
|
|
60
|
+
is_multiply: bool,
|
|
61
|
+
) -> Tensor:
|
|
62
|
+
ctx.save_for_backward(input1, input2)
|
|
63
|
+
ctx.kernel_size = kernel_size
|
|
64
|
+
ctx.max_displacement = max_displacement
|
|
65
|
+
ctx.stride1 = stride1
|
|
66
|
+
ctx.stride2 = stride2
|
|
67
|
+
ctx.pad_size = pad_size
|
|
68
|
+
ctx.is_multiply = is_multiply
|
|
69
|
+
|
|
70
|
+
lib = _get_lib()
|
|
71
|
+
output = lib.correlation_forward(
|
|
72
|
+
input1, input2,
|
|
73
|
+
kernel_size, max_displacement,
|
|
74
|
+
stride1, stride2, pad_size,
|
|
75
|
+
is_multiply
|
|
76
|
+
)
|
|
77
|
+
return output
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def backward(ctx, grad_output: Tensor):
|
|
81
|
+
input1, input2 = ctx.saved_tensors
|
|
82
|
+
lib = _get_lib()
|
|
83
|
+
|
|
84
|
+
grad_input1, grad_input2 = lib.correlation_backward(
|
|
85
|
+
grad_output.contiguous(), input1.contiguous(), input2.contiguous(),
|
|
86
|
+
ctx.kernel_size, ctx.max_displacement,
|
|
87
|
+
ctx.stride1, ctx.stride2, ctx.pad_size,
|
|
88
|
+
ctx.is_multiply
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return grad_input1, grad_input2, None, None, None, None, None, None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def correlation(
|
|
95
|
+
input1: Tensor,
|
|
96
|
+
input2: Tensor,
|
|
97
|
+
kernel_size: int = 1,
|
|
98
|
+
max_displacement: int = 4,
|
|
99
|
+
stride1: int = 1,
|
|
100
|
+
stride2: int = 1,
|
|
101
|
+
pad_size: int = 4,
|
|
102
|
+
is_multiply: bool = True,
|
|
103
|
+
) -> Tensor:
|
|
104
|
+
"""
|
|
105
|
+
Compute correlation between two feature maps.
|
|
106
|
+
|
|
107
|
+
Used in optical flow estimation (RAFT, PWC-Net, FlowNet).
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
input1: First feature map (N, C, H, W)
|
|
111
|
+
input2: Second feature map (N, C, H, W)
|
|
112
|
+
kernel_size: Size of the correlation kernel
|
|
113
|
+
max_displacement: Maximum displacement for correlation search
|
|
114
|
+
stride1: Stride for input1
|
|
115
|
+
stride2: Stride for input2 (displacement stride)
|
|
116
|
+
pad_size: Padding size
|
|
117
|
+
is_multiply: If True, use multiplication. If False, use subtraction.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Correlation volume (N, D*D, H, W) where D = 2*max_displacement/stride2 + 1
|
|
121
|
+
"""
|
|
122
|
+
if input1.device.type != "mps":
|
|
123
|
+
raise ValueError(f"Input must be on MPS device, got {input1.device}")
|
|
124
|
+
|
|
125
|
+
return _CorrelationFunction.apply(
|
|
126
|
+
input1, input2,
|
|
127
|
+
kernel_size, max_displacement,
|
|
128
|
+
stride1, stride2, pad_size,
|
|
129
|
+
is_multiply
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Correlation(nn.Module):
|
|
134
|
+
"""
|
|
135
|
+
Correlation layer for optical flow.
|
|
136
|
+
|
|
137
|
+
Computes a cost volume by correlating features from two images.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
kernel_size: Size of correlation kernel
|
|
141
|
+
max_displacement: Maximum displacement to search
|
|
142
|
+
stride1: Stride for first input
|
|
143
|
+
stride2: Stride for displacement
|
|
144
|
+
pad_size: Padding size
|
|
145
|
+
is_multiply: Use multiplication (True) or subtraction (False)
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
kernel_size: int = 1,
|
|
151
|
+
max_displacement: int = 4,
|
|
152
|
+
stride1: int = 1,
|
|
153
|
+
stride2: int = 1,
|
|
154
|
+
pad_size: int = 4,
|
|
155
|
+
is_multiply: bool = True,
|
|
156
|
+
):
|
|
157
|
+
super().__init__()
|
|
158
|
+
self.kernel_size = kernel_size
|
|
159
|
+
self.max_displacement = max_displacement
|
|
160
|
+
self.stride1 = stride1
|
|
161
|
+
self.stride2 = stride2
|
|
162
|
+
self.pad_size = pad_size
|
|
163
|
+
self.is_multiply = is_multiply
|
|
164
|
+
|
|
165
|
+
def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
|
|
166
|
+
return correlation(
|
|
167
|
+
input1, input2,
|
|
168
|
+
self.kernel_size, self.max_displacement,
|
|
169
|
+
self.stride1, self.stride2, self.pad_size,
|
|
170
|
+
self.is_multiply
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# RAFT-style correlation (all-pairs)
|
|
175
|
+
class CorrBlock:
|
|
176
|
+
"""
|
|
177
|
+
RAFT-style correlation block.
|
|
178
|
+
|
|
179
|
+
Computes all-pairs correlation and provides lookup functionality.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __init__(self, fmap1: Tensor, fmap2: Tensor, num_levels: int = 4, radius: int = 4):
|
|
183
|
+
self.num_levels = num_levels
|
|
184
|
+
self.radius = radius
|
|
185
|
+
self.corr_pyramid = []
|
|
186
|
+
|
|
187
|
+
# All-pairs correlation
|
|
188
|
+
batch, dim, ht, wd = fmap1.shape
|
|
189
|
+
fmap1 = fmap1.view(batch, dim, ht * wd)
|
|
190
|
+
fmap2 = fmap2.view(batch, dim, ht * wd)
|
|
191
|
+
|
|
192
|
+
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
|
193
|
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
|
194
|
+
corr = corr / torch.sqrt(torch.tensor(dim, dtype=torch.float32, device=fmap1.device))
|
|
195
|
+
|
|
196
|
+
self.corr_pyramid.append(corr)
|
|
197
|
+
for _ in range(num_levels - 1):
|
|
198
|
+
# Get current spatial dims from last two dimensions
|
|
199
|
+
curr_h, curr_w = corr.shape[-2], corr.shape[-1]
|
|
200
|
+
corr = torch.nn.functional.avg_pool2d(
|
|
201
|
+
corr.view(batch * ht * wd, 1, curr_h, curr_w),
|
|
202
|
+
kernel_size=2, stride=2
|
|
203
|
+
)
|
|
204
|
+
_, _, h, w = corr.shape
|
|
205
|
+
corr = corr.view(batch, ht, wd, 1, h, w)
|
|
206
|
+
self.corr_pyramid.append(corr)
|
|
207
|
+
|
|
208
|
+
def __call__(self, coords: Tensor) -> Tensor:
|
|
209
|
+
"""Lookup correlation values at given coordinates."""
|
|
210
|
+
r = self.radius
|
|
211
|
+
batch, _, ht, wd = coords.shape
|
|
212
|
+
coords = coords.permute(0, 2, 3, 1)
|
|
213
|
+
|
|
214
|
+
out_pyramid = []
|
|
215
|
+
for i, corr in enumerate(self.corr_pyramid):
|
|
216
|
+
# Build lookup grid
|
|
217
|
+
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
|
218
|
+
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
|
219
|
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), dim=-1)
|
|
220
|
+
|
|
221
|
+
centroid_lvl = coords.reshape(batch * ht * wd, 1, 1, 2) / 2**i
|
|
222
|
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
|
223
|
+
coords_lvl = centroid_lvl + delta_lvl
|
|
224
|
+
|
|
225
|
+
# Sample from correlation volume
|
|
226
|
+
corr_lvl = corr.view(batch * ht * wd, 1, *corr.shape[-2:])
|
|
227
|
+
|
|
228
|
+
# Normalize coords to [-1, 1]
|
|
229
|
+
_, _, h, w = corr_lvl.shape
|
|
230
|
+
coords_lvl[..., 0] = 2 * coords_lvl[..., 0] / (w - 1) - 1
|
|
231
|
+
coords_lvl[..., 1] = 2 * coords_lvl[..., 1] / (h - 1) - 1
|
|
232
|
+
|
|
233
|
+
corr_lvl = torch.nn.functional.grid_sample(
|
|
234
|
+
corr_lvl, coords_lvl,
|
|
235
|
+
align_corners=True, mode='bilinear', padding_mode='zeros'
|
|
236
|
+
)
|
|
237
|
+
corr_lvl = corr_lvl.view(batch, ht, wd, -1)
|
|
238
|
+
out_pyramid.append(corr_lvl)
|
|
239
|
+
|
|
240
|
+
out = torch.cat(out_pyramid, dim=-1)
|
|
241
|
+
return out.permute(0, 3, 1, 2).contiguous()
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def is_available() -> bool:
|
|
245
|
+
"""Check if MPS is available."""
|
|
246
|
+
return torch.backends.mps.is_available()
|