mps-carafe 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mps_carafe-0.2.1/PKG-INFO +153 -0
- mps_carafe-0.2.1/README.md +129 -0
- mps_carafe-0.2.1/mps_carafe/__init__.py +262 -0
- mps_carafe-0.2.1/mps_carafe/csrc/carafe_mps.mm +577 -0
- mps_carafe-0.2.1/mps_carafe.egg-info/PKG-INFO +153 -0
- mps_carafe-0.2.1/mps_carafe.egg-info/SOURCES.txt +11 -0
- mps_carafe-0.2.1/mps_carafe.egg-info/dependency_links.txt +1 -0
- mps_carafe-0.2.1/mps_carafe.egg-info/requires.txt +1 -0
- mps_carafe-0.2.1/mps_carafe.egg-info/top_level.txt +1 -0
- mps_carafe-0.2.1/pyproject.toml +30 -0
- mps_carafe-0.2.1/setup.cfg +4 -0
- mps_carafe-0.2.1/setup.py +64 -0
- mps_carafe-0.2.1/tests/test_carafe.py +369 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mps-carafe
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: CARAFE content-aware upsampling for Apple Silicon (MPS)
|
|
5
|
+
Author: mpsops
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/mpsops/mps-carafe
|
|
8
|
+
Project-URL: Repository, https://github.com/mpsops/mps-carafe
|
|
9
|
+
Project-URL: Issues, https://github.com/mpsops/mps-carafe/issues
|
|
10
|
+
Keywords: pytorch,mps,apple-silicon,carafe,upsampling,segmentation
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
Requires-Dist: torch>=2.0.0
|
|
23
|
+
Dynamic: requires-python
|
|
24
|
+
|
|
25
|
+
# MPS CARAFE
|
|
26
|
+
|
|
27
|
+
CARAFE (Content-Aware ReAssembly of FEatures) for Apple Silicon (M1/M2/M3/M4).
|
|
28
|
+
|
|
29
|
+
**Drop-in replacement** for mmcv's CARAFE op.
|
|
30
|
+
|
|
31
|
+
## Why?
|
|
32
|
+
|
|
33
|
+
CARAFE is a learnable upsampling operator used in:
|
|
34
|
+
- **Mask R-CNN**: Instance segmentation
|
|
35
|
+
- **FPN**: Feature Pyramid Networks
|
|
36
|
+
- **YOLACT**: Real-time instance segmentation
|
|
37
|
+
|
|
38
|
+
But mmcv's implementation is **CUDA-only**. On Mac you get:
|
|
39
|
+
```
|
|
40
|
+
NotImplementedError: carafe not implemented for MPS
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
This package provides a native Metal implementation.
|
|
44
|
+
|
|
45
|
+
## Installation
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install mps-carafe
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Or from source:
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
git clone https://github.com/mpsops/mps-carafe
|
|
55
|
+
cd mps-carafe
|
|
56
|
+
pip install -e .
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Quick Start
|
|
60
|
+
|
|
61
|
+
### Basic CARAFE Operation
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
import torch
|
|
65
|
+
from mps_carafe import carafe
|
|
66
|
+
|
|
67
|
+
# Input features (N, C, H, W)
|
|
68
|
+
features = torch.randn(1, 64, 32, 32, device='mps')
|
|
69
|
+
|
|
70
|
+
# Reassembly masks (N, group_size * k^2, H*scale, W*scale)
|
|
71
|
+
kernel_size = 5
|
|
72
|
+
group_size = 1
|
|
73
|
+
scale_factor = 2
|
|
74
|
+
masks = torch.softmax(
|
|
75
|
+
torch.randn(1, group_size * kernel_size**2, 64, 64, device='mps'),
|
|
76
|
+
dim=1
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Upsample with CARAFE
|
|
80
|
+
output = carafe(features, masks, kernel_size, group_size, scale_factor)
|
|
81
|
+
# Output: (1, 64, 64, 64)
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
### CARAFE Module
|
|
85
|
+
|
|
86
|
+
```python
|
|
87
|
+
from mps_carafe import CARAFE
|
|
88
|
+
|
|
89
|
+
carafe_layer = CARAFE(kernel_size=5, group_size=1, scale_factor=2)
|
|
90
|
+
output = carafe_layer(features, masks)
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### CARAFEPack (with mask predictor)
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
from mps_carafe import CARAFEPack
|
|
97
|
+
|
|
98
|
+
# Complete upsampling block with built-in mask predictor
|
|
99
|
+
upsample = CARAFEPack(
|
|
100
|
+
channels=64,
|
|
101
|
+
kernel_size=5,
|
|
102
|
+
group_size=1,
|
|
103
|
+
scale_factor=2
|
|
104
|
+
).to('mps')
|
|
105
|
+
|
|
106
|
+
output = upsample(features) # No need to provide masks
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
## API Reference
|
|
110
|
+
|
|
111
|
+
### `carafe(features, masks, kernel_size, group_size, scale_factor)`
|
|
112
|
+
|
|
113
|
+
| Parameter | Type | Description |
|
|
114
|
+
|-----------|------|-------------|
|
|
115
|
+
| `features` | Tensor | Input features (N, C, H, W) |
|
|
116
|
+
| `masks` | Tensor | Reassembly kernels (N, group_size * k^2, H*scale, W*scale) |
|
|
117
|
+
| `kernel_size` | int | Size of reassembly kernel (typically 5) |
|
|
118
|
+
| `group_size` | int | Number of channel groups (typically 1) |
|
|
119
|
+
| `scale_factor` | int | Upsampling factor (typically 2) |
|
|
120
|
+
|
|
121
|
+
### `CARAFEPack`
|
|
122
|
+
|
|
123
|
+
Complete CARAFE block with mask prediction convolutions.
|
|
124
|
+
|
|
125
|
+
## How It Works
|
|
126
|
+
|
|
127
|
+
CARAFE upsamples by:
|
|
128
|
+
1. For each output pixel, identify the corresponding input neighborhood
|
|
129
|
+
2. Use learned reassembly kernels (masks) to weight input pixels
|
|
130
|
+
3. Sum the weighted inputs to produce the output
|
|
131
|
+
|
|
132
|
+
Unlike bilinear upsampling which uses fixed weights, CARAFE learns content-aware weights that adapt to the image content.
|
|
133
|
+
|
|
134
|
+
## Compatibility
|
|
135
|
+
|
|
136
|
+
- **PyTorch**: 2.0+
|
|
137
|
+
- **macOS**: 12.0+ (Monterey)
|
|
138
|
+
- **Hardware**: Apple Silicon (M1/M2/M3/M4)
|
|
139
|
+
|
|
140
|
+
## Features
|
|
141
|
+
|
|
142
|
+
- Full forward and backward pass (training supported)
|
|
143
|
+
- fp32 and fp16 supported
|
|
144
|
+
- Compatible with mmcv CARAFE API
|
|
145
|
+
|
|
146
|
+
## Credits
|
|
147
|
+
|
|
148
|
+
- [CARAFE Paper](https://arxiv.org/abs/1905.02188) - Original research
|
|
149
|
+
- [mmcv](https://github.com/open-mmlab/mmcv) - Reference implementation
|
|
150
|
+
|
|
151
|
+
## License
|
|
152
|
+
|
|
153
|
+
MIT
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# MPS CARAFE
|
|
2
|
+
|
|
3
|
+
CARAFE (Content-Aware ReAssembly of FEatures) for Apple Silicon (M1/M2/M3/M4).
|
|
4
|
+
|
|
5
|
+
**Drop-in replacement** for mmcv's CARAFE op.
|
|
6
|
+
|
|
7
|
+
## Why?
|
|
8
|
+
|
|
9
|
+
CARAFE is a learnable upsampling operator used in:
|
|
10
|
+
- **Mask R-CNN**: Instance segmentation
|
|
11
|
+
- **FPN**: Feature Pyramid Networks
|
|
12
|
+
- **YOLACT**: Real-time instance segmentation
|
|
13
|
+
|
|
14
|
+
But mmcv's implementation is **CUDA-only**. On Mac you get:
|
|
15
|
+
```
|
|
16
|
+
NotImplementedError: carafe not implemented for MPS
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
This package provides a native Metal implementation.
|
|
20
|
+
|
|
21
|
+
## Installation
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
pip install mps-carafe
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Or from source:
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
git clone https://github.com/mpsops/mps-carafe
|
|
31
|
+
cd mps-carafe
|
|
32
|
+
pip install -e .
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
## Quick Start
|
|
36
|
+
|
|
37
|
+
### Basic CARAFE Operation
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
import torch
|
|
41
|
+
from mps_carafe import carafe
|
|
42
|
+
|
|
43
|
+
# Input features (N, C, H, W)
|
|
44
|
+
features = torch.randn(1, 64, 32, 32, device='mps')
|
|
45
|
+
|
|
46
|
+
# Reassembly masks (N, group_size * k^2, H*scale, W*scale)
|
|
47
|
+
kernel_size = 5
|
|
48
|
+
group_size = 1
|
|
49
|
+
scale_factor = 2
|
|
50
|
+
masks = torch.softmax(
|
|
51
|
+
torch.randn(1, group_size * kernel_size**2, 64, 64, device='mps'),
|
|
52
|
+
dim=1
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Upsample with CARAFE
|
|
56
|
+
output = carafe(features, masks, kernel_size, group_size, scale_factor)
|
|
57
|
+
# Output: (1, 64, 64, 64)
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
### CARAFE Module
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
from mps_carafe import CARAFE
|
|
64
|
+
|
|
65
|
+
carafe_layer = CARAFE(kernel_size=5, group_size=1, scale_factor=2)
|
|
66
|
+
output = carafe_layer(features, masks)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
### CARAFEPack (with mask predictor)
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
from mps_carafe import CARAFEPack
|
|
73
|
+
|
|
74
|
+
# Complete upsampling block with built-in mask predictor
|
|
75
|
+
upsample = CARAFEPack(
|
|
76
|
+
channels=64,
|
|
77
|
+
kernel_size=5,
|
|
78
|
+
group_size=1,
|
|
79
|
+
scale_factor=2
|
|
80
|
+
).to('mps')
|
|
81
|
+
|
|
82
|
+
output = upsample(features) # No need to provide masks
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
## API Reference
|
|
86
|
+
|
|
87
|
+
### `carafe(features, masks, kernel_size, group_size, scale_factor)`
|
|
88
|
+
|
|
89
|
+
| Parameter | Type | Description |
|
|
90
|
+
|-----------|------|-------------|
|
|
91
|
+
| `features` | Tensor | Input features (N, C, H, W) |
|
|
92
|
+
| `masks` | Tensor | Reassembly kernels (N, group_size * k^2, H*scale, W*scale) |
|
|
93
|
+
| `kernel_size` | int | Size of reassembly kernel (typically 5) |
|
|
94
|
+
| `group_size` | int | Number of channel groups (typically 1) |
|
|
95
|
+
| `scale_factor` | int | Upsampling factor (typically 2) |
|
|
96
|
+
|
|
97
|
+
### `CARAFEPack`
|
|
98
|
+
|
|
99
|
+
Complete CARAFE block with mask prediction convolutions.
|
|
100
|
+
|
|
101
|
+
## How It Works
|
|
102
|
+
|
|
103
|
+
CARAFE upsamples by:
|
|
104
|
+
1. For each output pixel, identify the corresponding input neighborhood
|
|
105
|
+
2. Use learned reassembly kernels (masks) to weight input pixels
|
|
106
|
+
3. Sum the weighted inputs to produce the output
|
|
107
|
+
|
|
108
|
+
Unlike bilinear upsampling which uses fixed weights, CARAFE learns content-aware weights that adapt to the image content.
|
|
109
|
+
|
|
110
|
+
## Compatibility
|
|
111
|
+
|
|
112
|
+
- **PyTorch**: 2.0+
|
|
113
|
+
- **macOS**: 12.0+ (Monterey)
|
|
114
|
+
- **Hardware**: Apple Silicon (M1/M2/M3/M4)
|
|
115
|
+
|
|
116
|
+
## Features
|
|
117
|
+
|
|
118
|
+
- Full forward and backward pass (training supported)
|
|
119
|
+
- fp32 and fp16 supported
|
|
120
|
+
- Compatible with mmcv CARAFE API
|
|
121
|
+
|
|
122
|
+
## Credits
|
|
123
|
+
|
|
124
|
+
- [CARAFE Paper](https://arxiv.org/abs/1905.02188) - Original research
|
|
125
|
+
- [mmcv](https://github.com/open-mmlab/mmcv) - Reference implementation
|
|
126
|
+
|
|
127
|
+
## License
|
|
128
|
+
|
|
129
|
+
MIT
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MPS CARAFE - Content-Aware ReAssembly of FEatures for Apple Silicon
|
|
3
|
+
|
|
4
|
+
Drop-in replacement for mmcv's CARAFE op.
|
|
5
|
+
Used in Mask R-CNN, FPN, and other detection/segmentation networks.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn, Tensor
|
|
10
|
+
from torch.autograd import Function
|
|
11
|
+
from typing import Tuple
|
|
12
|
+
import threading
|
|
13
|
+
|
|
14
|
+
__version__ = "0.2.0"
|
|
15
|
+
|
|
16
|
+
# Thread-safe library loading
|
|
17
|
+
_lib_lock = threading.Lock()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _load_library():
|
|
21
|
+
"""Load the native extension."""
|
|
22
|
+
try:
|
|
23
|
+
from mps_carafe import _C as _lib
|
|
24
|
+
return _lib
|
|
25
|
+
except ImportError:
|
|
26
|
+
import os
|
|
27
|
+
from torch.utils.cpp_extension import load
|
|
28
|
+
|
|
29
|
+
src_dir = os.path.join(os.path.dirname(__file__), "csrc")
|
|
30
|
+
_lib = load(
|
|
31
|
+
name="mps_carafe",
|
|
32
|
+
sources=[os.path.join(src_dir, "carafe_mps.mm")],
|
|
33
|
+
extra_cflags=["-std=c++17"],
|
|
34
|
+
extra_ldflags=["-framework", "Metal", "-framework", "Foundation"],
|
|
35
|
+
verbose=False,
|
|
36
|
+
)
|
|
37
|
+
return _lib
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
_lib_cache = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _get_lib():
|
|
44
|
+
"""Thread-safe library loading."""
|
|
45
|
+
global _lib_cache
|
|
46
|
+
if _lib_cache is None:
|
|
47
|
+
with _lib_lock:
|
|
48
|
+
# Double-check after acquiring lock
|
|
49
|
+
if _lib_cache is None:
|
|
50
|
+
_lib_cache = _load_library()
|
|
51
|
+
return _lib_cache
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _validate_params(kernel_size: int, group_size: int, scale_factor: int, channels: int) -> None:
|
|
55
|
+
"""Validate CARAFE parameters."""
|
|
56
|
+
if kernel_size <= 0:
|
|
57
|
+
raise ValueError(f"kernel_size must be positive, got {kernel_size}")
|
|
58
|
+
if group_size <= 0:
|
|
59
|
+
raise ValueError(f"group_size must be positive, got {group_size}")
|
|
60
|
+
if scale_factor <= 0:
|
|
61
|
+
raise ValueError(f"scale_factor must be positive, got {scale_factor}")
|
|
62
|
+
if channels % group_size != 0:
|
|
63
|
+
raise ValueError(f"channels ({channels}) must be divisible by group_size ({group_size})")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _check_tensor_device(tensor: Tensor, name: str, expected_device: torch.device) -> None:
|
|
67
|
+
"""Check that tensor is on the expected device."""
|
|
68
|
+
if tensor.device != expected_device:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"{name} must be on same device as features ({expected_device}), got {tensor.device}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class _CARAFEFunction(Function):
|
|
75
|
+
"""Autograd function for CARAFE."""
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def forward(
|
|
79
|
+
ctx,
|
|
80
|
+
features: Tensor,
|
|
81
|
+
masks: Tensor,
|
|
82
|
+
kernel_size: int,
|
|
83
|
+
group_size: int,
|
|
84
|
+
scale_factor: int,
|
|
85
|
+
) -> Tensor:
|
|
86
|
+
# Device validation
|
|
87
|
+
_check_tensor_device(masks, "masks", features.device)
|
|
88
|
+
|
|
89
|
+
ctx.save_for_backward(features, masks)
|
|
90
|
+
ctx.kernel_size = kernel_size
|
|
91
|
+
ctx.group_size = group_size
|
|
92
|
+
ctx.scale_factor = scale_factor
|
|
93
|
+
|
|
94
|
+
lib = _get_lib()
|
|
95
|
+
output = lib.carafe_forward(
|
|
96
|
+
features, masks,
|
|
97
|
+
kernel_size, group_size, scale_factor
|
|
98
|
+
)
|
|
99
|
+
return output
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def backward(ctx, grad_output: Tensor):
|
|
103
|
+
features, masks = ctx.saved_tensors
|
|
104
|
+
|
|
105
|
+
# Device validation for backward
|
|
106
|
+
if not grad_output.device.type == "mps":
|
|
107
|
+
raise ValueError(f"grad_output must be on MPS device, got {grad_output.device}")
|
|
108
|
+
|
|
109
|
+
lib = _get_lib()
|
|
110
|
+
|
|
111
|
+
grad_features, grad_masks = lib.carafe_backward(
|
|
112
|
+
grad_output.contiguous(),
|
|
113
|
+
features.contiguous(),
|
|
114
|
+
masks.contiguous(),
|
|
115
|
+
ctx.kernel_size,
|
|
116
|
+
ctx.group_size,
|
|
117
|
+
ctx.scale_factor
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return grad_features, grad_masks, None, None, None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def carafe(
|
|
124
|
+
features: Tensor,
|
|
125
|
+
masks: Tensor,
|
|
126
|
+
kernel_size: int,
|
|
127
|
+
group_size: int,
|
|
128
|
+
scale_factor: int,
|
|
129
|
+
) -> Tensor:
|
|
130
|
+
"""
|
|
131
|
+
CARAFE: Content-Aware ReAssembly of FEatures.
|
|
132
|
+
|
|
133
|
+
Performs content-aware upsampling using learned reassembly kernels.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
features: Input feature map (N, C, H, W)
|
|
137
|
+
masks: Reassembly kernels (N, group_size * kernel_size^2, H*scale, W*scale)
|
|
138
|
+
kernel_size: Size of the reassembly kernel (typically 5, must be positive)
|
|
139
|
+
group_size: Number of groups for channel-wise reassembly (must be positive)
|
|
140
|
+
scale_factor: Upsampling factor (typically 2, must be positive)
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Upsampled feature map (N, C, H*scale_factor, W*scale_factor)
|
|
144
|
+
|
|
145
|
+
Raises:
|
|
146
|
+
ValueError: If tensors not on MPS or parameters invalid.
|
|
147
|
+
"""
|
|
148
|
+
if features.device.type != "mps":
|
|
149
|
+
raise ValueError(f"Input must be on MPS device, got {features.device}")
|
|
150
|
+
|
|
151
|
+
channels = features.size(1)
|
|
152
|
+
_validate_params(kernel_size, group_size, scale_factor, channels)
|
|
153
|
+
|
|
154
|
+
return _CARAFEFunction.apply(
|
|
155
|
+
features, masks,
|
|
156
|
+
kernel_size, group_size, scale_factor
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class CARAFE(nn.Module):
|
|
161
|
+
"""
|
|
162
|
+
CARAFE: Content-Aware ReAssembly of FEatures.
|
|
163
|
+
|
|
164
|
+
A learnable upsampling operator that uses content-aware kernels
|
|
165
|
+
for feature reassembly. Better than bilinear/nearest for dense prediction.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
kernel_size: Size of reassembly kernel (default: 5)
|
|
169
|
+
group_size: Number of channel groups (default: 1)
|
|
170
|
+
scale_factor: Upsampling factor (default: 2)
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
kernel_size: int = 5,
|
|
176
|
+
group_size: int = 1,
|
|
177
|
+
scale_factor: int = 2,
|
|
178
|
+
):
|
|
179
|
+
super().__init__()
|
|
180
|
+
self.kernel_size = kernel_size
|
|
181
|
+
self.group_size = group_size
|
|
182
|
+
self.scale_factor = scale_factor
|
|
183
|
+
|
|
184
|
+
def forward(self, features: Tensor, masks: Tensor) -> Tensor:
|
|
185
|
+
return carafe(
|
|
186
|
+
features, masks,
|
|
187
|
+
self.kernel_size, self.group_size, self.scale_factor
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class CARAFEPack(nn.Module):
|
|
192
|
+
"""
|
|
193
|
+
CARAFE with built-in mask predictor.
|
|
194
|
+
|
|
195
|
+
This module includes the convolutions to predict reassembly masks
|
|
196
|
+
from input features, making it a complete upsampling block.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
channels: Number of input/output channels
|
|
200
|
+
kernel_size: Size of reassembly kernel (default: 5)
|
|
201
|
+
group_size: Number of channel groups (default: 1)
|
|
202
|
+
scale_factor: Upsampling factor (default: 2)
|
|
203
|
+
compressed_channels: Channels after compression (default: 64)
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
channels: int,
|
|
209
|
+
kernel_size: int = 5,
|
|
210
|
+
group_size: int = 1,
|
|
211
|
+
scale_factor: int = 2,
|
|
212
|
+
compressed_channels: int = 64,
|
|
213
|
+
):
|
|
214
|
+
super().__init__()
|
|
215
|
+
self.channels = channels
|
|
216
|
+
self.kernel_size = kernel_size
|
|
217
|
+
self.group_size = group_size
|
|
218
|
+
self.scale_factor = scale_factor
|
|
219
|
+
|
|
220
|
+
# Channel compressor
|
|
221
|
+
self.channel_compressor = nn.Conv2d(
|
|
222
|
+
channels, compressed_channels, kernel_size=1
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Content encoder - predicts reassembly masks
|
|
226
|
+
self.content_encoder = nn.Conv2d(
|
|
227
|
+
compressed_channels,
|
|
228
|
+
group_size * kernel_size * kernel_size * scale_factor * scale_factor,
|
|
229
|
+
kernel_size=3,
|
|
230
|
+
padding=1,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# CARAFE operator
|
|
234
|
+
self.carafe = CARAFE(kernel_size, group_size, scale_factor)
|
|
235
|
+
|
|
236
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
237
|
+
# Compress channels
|
|
238
|
+
compressed = self.channel_compressor(x)
|
|
239
|
+
|
|
240
|
+
# Predict masks
|
|
241
|
+
masks = self.content_encoder(compressed)
|
|
242
|
+
|
|
243
|
+
# Upsample masks to output resolution
|
|
244
|
+
n, _, h, w = masks.shape
|
|
245
|
+
masks = masks.view(n, self.group_size, self.kernel_size * self.kernel_size,
|
|
246
|
+
self.scale_factor, self.scale_factor, h, w)
|
|
247
|
+
masks = masks.permute(0, 1, 2, 5, 3, 6, 4).contiguous()
|
|
248
|
+
masks = masks.view(n, self.group_size * self.kernel_size * self.kernel_size,
|
|
249
|
+
h * self.scale_factor, w * self.scale_factor)
|
|
250
|
+
|
|
251
|
+
# Softmax over kernel positions
|
|
252
|
+
masks = torch.softmax(masks.view(n, self.group_size, -1,
|
|
253
|
+
h * self.scale_factor, w * self.scale_factor), dim=2)
|
|
254
|
+
masks = masks.view(n, -1, h * self.scale_factor, w * self.scale_factor)
|
|
255
|
+
|
|
256
|
+
# Apply CARAFE
|
|
257
|
+
return self.carafe(x, masks)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def is_available() -> bool:
|
|
261
|
+
"""Check if MPS is available."""
|
|
262
|
+
return torch.backends.mps.is_available()
|