wavedl 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +48 -28
- wavedl/models/__init__.py +33 -7
- wavedl/models/_template.py +28 -41
- wavedl/models/base.py +49 -2
- wavedl/models/cnn.py +0 -1
- wavedl/models/convnext.py +4 -1
- wavedl/models/densenet.py +4 -1
- wavedl/models/efficientnet.py +9 -5
- wavedl/models/efficientnetv2.py +292 -0
- wavedl/models/mobilenetv3.py +272 -0
- wavedl/models/registry.py +0 -1
- wavedl/models/regnet.py +383 -0
- wavedl/models/resnet.py +7 -4
- wavedl/models/resnet3d.py +258 -0
- wavedl/models/swin.py +390 -0
- wavedl/models/tcn.py +389 -0
- wavedl/models/unet.py +44 -110
- wavedl/models/vit.py +8 -4
- wavedl/train.py +1144 -1116
- wavedl/utils/config.py +88 -2
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/METADATA +136 -98
- wavedl-1.4.1.dist-info/RECORD +37 -0
- wavedl-1.3.1.dist-info/RECORD +0 -31
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/LICENSE +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/WHEEL +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/top_level.txt +0 -0
wavedl/models/regnet.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RegNet: Designing Network Design Spaces
|
|
3
|
+
========================================
|
|
4
|
+
|
|
5
|
+
RegNet provides a family of models with predictable scaling behavior,
|
|
6
|
+
designed through systematic exploration of network design spaces.
|
|
7
|
+
Models scale smoothly from mobile to server deployments.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Predictable scaling: accuracy increases linearly with compute
|
|
11
|
+
- Simple, uniform architecture (no complex compound scaling)
|
|
12
|
+
- Group convolutions for efficiency
|
|
13
|
+
- Optional Squeeze-and-Excitation (SE) attention
|
|
14
|
+
|
|
15
|
+
**Variants** (RegNetY includes SE attention):
|
|
16
|
+
- regnet_y_400mf: Ultra-light (~4.0M params, 0.4 GFLOPs)
|
|
17
|
+
- regnet_y_800mf: Light (~5.8M params, 0.8 GFLOPs)
|
|
18
|
+
- regnet_y_1_6gf: Medium (~10.5M params, 1.6 GFLOPs) - Recommended
|
|
19
|
+
- regnet_y_3_2gf: Large (~18.3M params, 3.2 GFLOPs)
|
|
20
|
+
- regnet_y_8gf: Very large (~37.9M params, 8.0 GFLOPs)
|
|
21
|
+
|
|
22
|
+
**When to Use RegNet**:
|
|
23
|
+
- When you need predictable performance at a given compute budget
|
|
24
|
+
- For systematic model selection experiments
|
|
25
|
+
- When interpretability of design choices matters
|
|
26
|
+
- As an efficient alternative to ResNet
|
|
27
|
+
|
|
28
|
+
**Note**: RegNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
29
|
+
|
|
30
|
+
References:
|
|
31
|
+
Radosavovic, I., et al. (2020). Designing Network Design Spaces.
|
|
32
|
+
CVPR 2020. https://arxiv.org/abs/2003.13678
|
|
33
|
+
|
|
34
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from typing import Any
|
|
38
|
+
|
|
39
|
+
import torch
|
|
40
|
+
import torch.nn as nn
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
from torchvision.models import (
|
|
45
|
+
RegNet_Y_1_6GF_Weights,
|
|
46
|
+
RegNet_Y_3_2GF_Weights,
|
|
47
|
+
RegNet_Y_8GF_Weights,
|
|
48
|
+
RegNet_Y_400MF_Weights,
|
|
49
|
+
RegNet_Y_800MF_Weights,
|
|
50
|
+
regnet_y_1_6gf,
|
|
51
|
+
regnet_y_3_2gf,
|
|
52
|
+
regnet_y_8gf,
|
|
53
|
+
regnet_y_400mf,
|
|
54
|
+
regnet_y_800mf,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
REGNET_AVAILABLE = True
|
|
58
|
+
except ImportError:
|
|
59
|
+
REGNET_AVAILABLE = False
|
|
60
|
+
|
|
61
|
+
from wavedl.models.base import BaseModel
|
|
62
|
+
from wavedl.models.registry import register_model
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class RegNetBase(BaseModel):
|
|
66
|
+
"""
|
|
67
|
+
Base RegNet class for regression tasks.
|
|
68
|
+
|
|
69
|
+
Wraps torchvision RegNetY (with SE attention) with:
|
|
70
|
+
- Optional pretrained weights (ImageNet-1K)
|
|
71
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
72
|
+
- Custom regression head
|
|
73
|
+
|
|
74
|
+
RegNet advantages:
|
|
75
|
+
- Simple, uniform design (easy to understand and modify)
|
|
76
|
+
- Predictable accuracy/compute trade-off
|
|
77
|
+
- Efficient group convolutions
|
|
78
|
+
- SE attention for channel weighting (RegNetY variants)
|
|
79
|
+
|
|
80
|
+
Note: This is 2D-only. Input shape must be (H, W).
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
in_shape: tuple[int, int],
|
|
86
|
+
out_size: int,
|
|
87
|
+
model_fn,
|
|
88
|
+
weights_class,
|
|
89
|
+
pretrained: bool = True,
|
|
90
|
+
dropout_rate: float = 0.2,
|
|
91
|
+
freeze_backbone: bool = False,
|
|
92
|
+
regression_hidden: int = 256,
|
|
93
|
+
**kwargs,
|
|
94
|
+
):
|
|
95
|
+
"""
|
|
96
|
+
Initialize RegNet for regression.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
in_shape: (H, W) input image dimensions
|
|
100
|
+
out_size: Number of regression output targets
|
|
101
|
+
model_fn: torchvision model constructor
|
|
102
|
+
weights_class: Pretrained weights enum class
|
|
103
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
104
|
+
dropout_rate: Dropout rate in regression head (default: 0.2)
|
|
105
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
106
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
107
|
+
"""
|
|
108
|
+
super().__init__(in_shape, out_size)
|
|
109
|
+
|
|
110
|
+
if not REGNET_AVAILABLE:
|
|
111
|
+
raise ImportError(
|
|
112
|
+
"torchvision is required for RegNet. "
|
|
113
|
+
"Install with: pip install torchvision"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if len(in_shape) != 2:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"RegNet requires 2D input (H, W), got {len(in_shape)}D. "
|
|
119
|
+
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
self.pretrained = pretrained
|
|
123
|
+
self.dropout_rate = dropout_rate
|
|
124
|
+
self.freeze_backbone = freeze_backbone
|
|
125
|
+
self.regression_hidden = regression_hidden
|
|
126
|
+
|
|
127
|
+
# Load pretrained backbone
|
|
128
|
+
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
129
|
+
self.backbone = model_fn(weights=weights)
|
|
130
|
+
|
|
131
|
+
# RegNet uses .fc as the classification head
|
|
132
|
+
in_features = self.backbone.fc.in_features
|
|
133
|
+
|
|
134
|
+
# Replace fc with regression head
|
|
135
|
+
self.backbone.fc = nn.Sequential(
|
|
136
|
+
nn.Dropout(dropout_rate),
|
|
137
|
+
nn.Linear(in_features, regression_hidden),
|
|
138
|
+
nn.ReLU(inplace=True),
|
|
139
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
140
|
+
nn.Linear(regression_hidden, out_size),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Optionally freeze backbone for fine-tuning
|
|
144
|
+
if freeze_backbone:
|
|
145
|
+
self._freeze_backbone()
|
|
146
|
+
|
|
147
|
+
def _freeze_backbone(self):
|
|
148
|
+
"""Freeze all backbone parameters except the fc layer."""
|
|
149
|
+
for name, param in self.backbone.named_parameters():
|
|
150
|
+
if "fc" not in name:
|
|
151
|
+
param.requires_grad = False
|
|
152
|
+
|
|
153
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
154
|
+
"""
|
|
155
|
+
Forward pass.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
x: Input tensor of shape (B, C, H, W) where C is 1 or 3
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Output tensor of shape (B, out_size)
|
|
162
|
+
"""
|
|
163
|
+
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
164
|
+
if x.size(1) == 1:
|
|
165
|
+
x = x.expand(-1, 3, -1, -1)
|
|
166
|
+
|
|
167
|
+
return self.backbone(x)
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
171
|
+
"""Return default configuration for RegNet."""
|
|
172
|
+
return {
|
|
173
|
+
"pretrained": True,
|
|
174
|
+
"dropout_rate": 0.2,
|
|
175
|
+
"freeze_backbone": False,
|
|
176
|
+
"regression_hidden": 256,
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# =============================================================================
|
|
181
|
+
# REGISTERED MODEL VARIANTS
|
|
182
|
+
# =============================================================================
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@register_model("regnet_y_400mf")
|
|
186
|
+
class RegNetY400MF(RegNetBase):
|
|
187
|
+
"""
|
|
188
|
+
RegNetY-400MF: Ultra-lightweight for constrained environments.
|
|
189
|
+
|
|
190
|
+
~4.0M parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
|
|
191
|
+
|
|
192
|
+
Recommended for:
|
|
193
|
+
- Edge deployment with moderate accuracy needs
|
|
194
|
+
- Quick training experiments
|
|
195
|
+
- Baseline comparisons
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
in_shape: (H, W) image dimensions
|
|
199
|
+
out_size: Number of regression targets
|
|
200
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
201
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
202
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
203
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
204
|
+
|
|
205
|
+
Example:
|
|
206
|
+
>>> model = RegNetY400MF(in_shape=(224, 224), out_size=3)
|
|
207
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
208
|
+
>>> out = model(x) # (4, 3)
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
212
|
+
super().__init__(
|
|
213
|
+
in_shape=in_shape,
|
|
214
|
+
out_size=out_size,
|
|
215
|
+
model_fn=regnet_y_400mf,
|
|
216
|
+
weights_class=RegNet_Y_400MF_Weights,
|
|
217
|
+
**kwargs,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def __repr__(self) -> str:
|
|
221
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
222
|
+
return f"RegNetY_400MF({pt}, in={self.in_shape}, out={self.out_size})"
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@register_model("regnet_y_800mf")
|
|
226
|
+
class RegNetY800MF(RegNetBase):
|
|
227
|
+
"""
|
|
228
|
+
RegNetY-800MF: Light variant with good accuracy.
|
|
229
|
+
|
|
230
|
+
~6.4M parameters, 0.8 GFLOPs. Good balance for mobile deployment.
|
|
231
|
+
|
|
232
|
+
Recommended for:
|
|
233
|
+
- Mobile/portable devices
|
|
234
|
+
- When MobileNet isn't accurate enough
|
|
235
|
+
- Moderate compute budgets
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
in_shape: (H, W) image dimensions
|
|
239
|
+
out_size: Number of regression targets
|
|
240
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
241
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
242
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
243
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
244
|
+
|
|
245
|
+
Example:
|
|
246
|
+
>>> model = RegNetY800MF(in_shape=(224, 224), out_size=3)
|
|
247
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
248
|
+
>>> out = model(x) # (4, 3)
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
252
|
+
super().__init__(
|
|
253
|
+
in_shape=in_shape,
|
|
254
|
+
out_size=out_size,
|
|
255
|
+
model_fn=regnet_y_800mf,
|
|
256
|
+
weights_class=RegNet_Y_800MF_Weights,
|
|
257
|
+
**kwargs,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def __repr__(self) -> str:
|
|
261
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
262
|
+
return f"RegNetY_800MF({pt}, in={self.in_shape}, out={self.out_size})"
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@register_model("regnet_y_1_6gf")
|
|
266
|
+
class RegNetY1_6GF(RegNetBase):
|
|
267
|
+
"""
|
|
268
|
+
RegNetY-1.6GF: Recommended default for balanced performance.
|
|
269
|
+
|
|
270
|
+
~11.2M parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
|
|
271
|
+
Comparable to ResNet50 but more efficient.
|
|
272
|
+
|
|
273
|
+
Recommended for:
|
|
274
|
+
- Default choice for general wave-based tasks
|
|
275
|
+
- When you want predictable scaling
|
|
276
|
+
- Server deployment with efficiency needs
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
in_shape: (H, W) image dimensions
|
|
280
|
+
out_size: Number of regression targets
|
|
281
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
282
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
283
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
284
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
285
|
+
|
|
286
|
+
Example:
|
|
287
|
+
>>> model = RegNetY1_6GF(in_shape=(224, 224), out_size=3)
|
|
288
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
289
|
+
>>> out = model(x) # (4, 3)
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
293
|
+
super().__init__(
|
|
294
|
+
in_shape=in_shape,
|
|
295
|
+
out_size=out_size,
|
|
296
|
+
model_fn=regnet_y_1_6gf,
|
|
297
|
+
weights_class=RegNet_Y_1_6GF_Weights,
|
|
298
|
+
**kwargs,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
def __repr__(self) -> str:
|
|
302
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
303
|
+
return f"RegNetY_1.6GF({pt}, in={self.in_shape}, out={self.out_size})"
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@register_model("regnet_y_3_2gf")
|
|
307
|
+
class RegNetY3_2GF(RegNetBase):
|
|
308
|
+
"""
|
|
309
|
+
RegNetY-3.2GF: Higher accuracy for demanding tasks.
|
|
310
|
+
|
|
311
|
+
~19.4M parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
|
|
312
|
+
|
|
313
|
+
Recommended for:
|
|
314
|
+
- Larger datasets requiring more capacity
|
|
315
|
+
- When accuracy is more important than efficiency
|
|
316
|
+
- Research experiments with multiple model sizes
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
in_shape: (H, W) image dimensions
|
|
320
|
+
out_size: Number of regression targets
|
|
321
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
322
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
323
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
324
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
325
|
+
|
|
326
|
+
Example:
|
|
327
|
+
>>> model = RegNetY3_2GF(in_shape=(224, 224), out_size=3)
|
|
328
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
329
|
+
>>> out = model(x) # (4, 3)
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
333
|
+
super().__init__(
|
|
334
|
+
in_shape=in_shape,
|
|
335
|
+
out_size=out_size,
|
|
336
|
+
model_fn=regnet_y_3_2gf,
|
|
337
|
+
weights_class=RegNet_Y_3_2GF_Weights,
|
|
338
|
+
**kwargs,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def __repr__(self) -> str:
|
|
342
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
343
|
+
return f"RegNetY_3.2GF({pt}, in={self.in_shape}, out={self.out_size})"
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@register_model("regnet_y_8gf")
|
|
347
|
+
class RegNetY8GF(RegNetBase):
|
|
348
|
+
"""
|
|
349
|
+
RegNetY-8GF: High capacity for large-scale tasks.
|
|
350
|
+
|
|
351
|
+
~39.2M parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
|
|
352
|
+
|
|
353
|
+
Recommended for:
|
|
354
|
+
- Very large datasets (>50k samples)
|
|
355
|
+
- Complex wave patterns
|
|
356
|
+
- HPC environments with ample GPU memory
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
in_shape: (H, W) image dimensions
|
|
360
|
+
out_size: Number of regression targets
|
|
361
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
362
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
363
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
364
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
365
|
+
|
|
366
|
+
Example:
|
|
367
|
+
>>> model = RegNetY8GF(in_shape=(224, 224), out_size=3)
|
|
368
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
369
|
+
>>> out = model(x) # (4, 3)
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
373
|
+
super().__init__(
|
|
374
|
+
in_shape=in_shape,
|
|
375
|
+
out_size=out_size,
|
|
376
|
+
model_fn=regnet_y_8gf,
|
|
377
|
+
weights_class=RegNet_Y_8GF_Weights,
|
|
378
|
+
**kwargs,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def __repr__(self) -> str:
|
|
382
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
383
|
+
return f"RegNetY_8GF({pt}, in={self.in_shape}, out={self.out_size})"
|
wavedl/models/resnet.py
CHANGED
|
@@ -11,12 +11,15 @@ Provides multiple depth variants (18, 34, 50) with optional pretrained weights f
|
|
|
11
11
|
- 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
|
|
12
12
|
|
|
13
13
|
**Variants**:
|
|
14
|
-
- resnet18: Lightweight, fast training
|
|
15
|
-
- resnet34: Balanced capacity
|
|
16
|
-
- resnet50: Higher capacity with bottleneck blocks
|
|
14
|
+
- resnet18: Lightweight, fast training (~11M params)
|
|
15
|
+
- resnet34: Balanced capacity (~21M params)
|
|
16
|
+
- resnet50: Higher capacity with bottleneck blocks (~25M params)
|
|
17
|
+
|
|
18
|
+
References:
|
|
19
|
+
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning
|
|
20
|
+
for Image Recognition. CVPR 2016. https://arxiv.org/abs/1512.03385
|
|
17
21
|
|
|
18
22
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
19
|
-
Version: 1.0.0
|
|
20
23
|
"""
|
|
21
24
|
|
|
22
25
|
from typing import Any
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ResNet3D: 3D Residual Networks for Volumetric Data
|
|
3
|
+
===================================================
|
|
4
|
+
|
|
5
|
+
3D extension of ResNet for processing volumetric data such as C-scans,
|
|
6
|
+
3D wavefield imaging, and spatiotemporal cubes. Wraps torchvision's
|
|
7
|
+
video models adapted for regression tasks.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Native 3D convolutions for volumetric processing
|
|
11
|
+
- Pretrained weights from Kinetics-400 (video action recognition)
|
|
12
|
+
- Adapted for single-channel input (grayscale volumes)
|
|
13
|
+
- Custom regression head for parameter estimation
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- resnet3d_18: Lightweight (33M params)
|
|
17
|
+
- resnet3d_34: Medium depth
|
|
18
|
+
- resnet3d_50: Higher capacity with bottleneck blocks
|
|
19
|
+
|
|
20
|
+
**Use Cases**:
|
|
21
|
+
- C-scan volume analysis (ultrasonic NDT)
|
|
22
|
+
- 3D wavefield imaging and inversion
|
|
23
|
+
- Spatiotemporal data cubes (time × space × space)
|
|
24
|
+
- Medical imaging (CT/MRI volumes)
|
|
25
|
+
|
|
26
|
+
**Note**: ResNet3D is 3D-only. For 1D/2D data, use TCN or standard ResNet.
|
|
27
|
+
|
|
28
|
+
References:
|
|
29
|
+
Hara, K., Kataoka, H., & Satoh, Y. (2018). Can Spatiotemporal 3D CNNs
|
|
30
|
+
Retrace the History of 2D CNNs and ImageNet? CVPR 2018.
|
|
31
|
+
https://arxiv.org/abs/1711.09577
|
|
32
|
+
|
|
33
|
+
He, K., et al. (2016). Deep Residual Learning for Image Recognition.
|
|
34
|
+
CVPR 2016. https://arxiv.org/abs/1512.03385
|
|
35
|
+
|
|
36
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
from typing import Any
|
|
40
|
+
|
|
41
|
+
import torch
|
|
42
|
+
import torch.nn as nn
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
from torchvision.models.video import (
|
|
47
|
+
MC3_18_Weights,
|
|
48
|
+
R3D_18_Weights,
|
|
49
|
+
mc3_18,
|
|
50
|
+
r3d_18,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
RESNET3D_AVAILABLE = True
|
|
54
|
+
except ImportError:
|
|
55
|
+
RESNET3D_AVAILABLE = False
|
|
56
|
+
|
|
57
|
+
from wavedl.models.base import BaseModel
|
|
58
|
+
from wavedl.models.registry import register_model
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ResNet3DBase(BaseModel):
|
|
62
|
+
"""
|
|
63
|
+
Base ResNet3D class for volumetric regression tasks.
|
|
64
|
+
|
|
65
|
+
Wraps torchvision 3D ResNet with:
|
|
66
|
+
- Optional pretrained weights (Kinetics-400)
|
|
67
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
68
|
+
- Custom regression head
|
|
69
|
+
|
|
70
|
+
Note: This is 3D-only. Input shape must be (D, H, W).
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
in_shape: tuple[int, int, int],
|
|
76
|
+
out_size: int,
|
|
77
|
+
model_fn,
|
|
78
|
+
weights_class,
|
|
79
|
+
pretrained: bool = True,
|
|
80
|
+
dropout_rate: float = 0.3,
|
|
81
|
+
freeze_backbone: bool = False,
|
|
82
|
+
regression_hidden: int = 512,
|
|
83
|
+
**kwargs,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Initialize ResNet3D for regression.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
in_shape: (D, H, W) input volume dimensions
|
|
90
|
+
out_size: Number of regression output targets
|
|
91
|
+
model_fn: torchvision model constructor
|
|
92
|
+
weights_class: Pretrained weights enum class
|
|
93
|
+
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
94
|
+
dropout_rate: Dropout rate in regression head (default: 0.3)
|
|
95
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
96
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
97
|
+
"""
|
|
98
|
+
super().__init__(in_shape, out_size)
|
|
99
|
+
|
|
100
|
+
if not RESNET3D_AVAILABLE:
|
|
101
|
+
raise ImportError(
|
|
102
|
+
"torchvision >= 0.12 is required for ResNet3D. "
|
|
103
|
+
"Install with: pip install torchvision>=0.12"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if len(in_shape) != 3:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"ResNet3D requires 3D input (D, H, W), got {len(in_shape)}D. "
|
|
109
|
+
"For 1D data, use TCN. For 2D data, use standard ResNet."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self.pretrained = pretrained
|
|
113
|
+
self.dropout_rate = dropout_rate
|
|
114
|
+
self.freeze_backbone = freeze_backbone
|
|
115
|
+
self.regression_hidden = regression_hidden
|
|
116
|
+
|
|
117
|
+
# Load pretrained backbone
|
|
118
|
+
weights = weights_class.DEFAULT if pretrained else None
|
|
119
|
+
self.backbone = model_fn(weights=weights)
|
|
120
|
+
|
|
121
|
+
# Get the fc input features
|
|
122
|
+
in_features = self.backbone.fc.in_features
|
|
123
|
+
|
|
124
|
+
# Replace fc with regression head
|
|
125
|
+
self.backbone.fc = nn.Sequential(
|
|
126
|
+
nn.Dropout(dropout_rate),
|
|
127
|
+
nn.Linear(in_features, regression_hidden),
|
|
128
|
+
nn.ReLU(inplace=True),
|
|
129
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
130
|
+
nn.Linear(regression_hidden, regression_hidden // 2),
|
|
131
|
+
nn.ReLU(inplace=True),
|
|
132
|
+
nn.Linear(regression_hidden // 2, out_size),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Optionally freeze backbone for fine-tuning
|
|
136
|
+
if freeze_backbone:
|
|
137
|
+
self._freeze_backbone()
|
|
138
|
+
|
|
139
|
+
def _freeze_backbone(self):
|
|
140
|
+
"""Freeze all backbone parameters except the fc head."""
|
|
141
|
+
for name, param in self.backbone.named_parameters():
|
|
142
|
+
if "fc" not in name:
|
|
143
|
+
param.requires_grad = False
|
|
144
|
+
|
|
145
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
"""
|
|
147
|
+
Forward pass.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
x: Input tensor of shape (B, C, D, H, W) where C is 1 or 3
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Output tensor of shape (B, out_size)
|
|
154
|
+
"""
|
|
155
|
+
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
156
|
+
if x.size(1) == 1:
|
|
157
|
+
x = x.expand(-1, 3, -1, -1, -1)
|
|
158
|
+
|
|
159
|
+
return self.backbone(x)
|
|
160
|
+
|
|
161
|
+
@classmethod
|
|
162
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
163
|
+
"""Return default configuration for ResNet3D."""
|
|
164
|
+
return {
|
|
165
|
+
"pretrained": True,
|
|
166
|
+
"dropout_rate": 0.3,
|
|
167
|
+
"freeze_backbone": False,
|
|
168
|
+
"regression_hidden": 512,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# =============================================================================
|
|
173
|
+
# REGISTERED MODEL VARIANTS
|
|
174
|
+
# =============================================================================
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@register_model("resnet3d_18")
|
|
178
|
+
class ResNet3D18(ResNet3DBase):
|
|
179
|
+
"""
|
|
180
|
+
ResNet3D-18: Lightweight 3D ResNet for volumetric data.
|
|
181
|
+
|
|
182
|
+
~33M parameters. Uses 3D convolutions throughout for true volumetric processing.
|
|
183
|
+
Pretrained on Kinetics-400 (video action recognition).
|
|
184
|
+
|
|
185
|
+
Recommended for:
|
|
186
|
+
- C-scan ultrasonic inspection volumes
|
|
187
|
+
- 3D wavefield data cubes
|
|
188
|
+
- Medical imaging (CT/MRI)
|
|
189
|
+
- Moderate compute budgets
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
in_shape: (D, H, W) volume dimensions
|
|
193
|
+
out_size: Number of regression targets
|
|
194
|
+
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
195
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
196
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
197
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
198
|
+
|
|
199
|
+
Example:
|
|
200
|
+
>>> model = ResNet3D18(in_shape=(16, 112, 112), out_size=3)
|
|
201
|
+
>>> x = torch.randn(2, 1, 16, 112, 112)
|
|
202
|
+
>>> out = model(x) # (2, 3)
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
|
|
206
|
+
super().__init__(
|
|
207
|
+
in_shape=in_shape,
|
|
208
|
+
out_size=out_size,
|
|
209
|
+
model_fn=r3d_18,
|
|
210
|
+
weights_class=R3D_18_Weights,
|
|
211
|
+
**kwargs,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def __repr__(self) -> str:
|
|
215
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
216
|
+
return f"ResNet3D_18({pt}, in={self.in_shape}, out={self.out_size})"
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@register_model("mc3_18")
|
|
220
|
+
class MC3_18(ResNet3DBase):
|
|
221
|
+
"""
|
|
222
|
+
MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
|
|
223
|
+
|
|
224
|
+
~11M parameters. More efficient than pure 3D ResNet while maintaining
|
|
225
|
+
good spatiotemporal modeling. Uses 3D convolutions in early layers
|
|
226
|
+
and 2D convolutions in later layers.
|
|
227
|
+
|
|
228
|
+
Recommended for:
|
|
229
|
+
- When pure 3D is too expensive
|
|
230
|
+
- Volumes with limited temporal/depth extent
|
|
231
|
+
- Faster training with reasonable accuracy
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
in_shape: (D, H, W) volume dimensions
|
|
235
|
+
out_size: Number of regression targets
|
|
236
|
+
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
237
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
238
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
239
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
240
|
+
|
|
241
|
+
Example:
|
|
242
|
+
>>> model = MC3_18(in_shape=(16, 112, 112), out_size=3)
|
|
243
|
+
>>> x = torch.randn(2, 1, 16, 112, 112)
|
|
244
|
+
>>> out = model(x) # (2, 3)
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
|
|
248
|
+
super().__init__(
|
|
249
|
+
in_shape=in_shape,
|
|
250
|
+
out_size=out_size,
|
|
251
|
+
model_fn=mc3_18,
|
|
252
|
+
weights_class=MC3_18_Weights,
|
|
253
|
+
**kwargs,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def __repr__(self) -> str:
|
|
257
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
258
|
+
return f"MC3_18({pt}, in={self.in_shape}, out={self.out_size})"
|