wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1430 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/METADATA +93 -53
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/resnet.py
CHANGED
|
@@ -27,14 +27,10 @@ from typing import Any
|
|
|
27
27
|
import torch
|
|
28
28
|
import torch.nn as nn
|
|
29
29
|
|
|
30
|
-
from wavedl.models.base import BaseModel
|
|
30
|
+
from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
|
|
31
31
|
from wavedl.models.registry import register_model
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
# Type alias for spatial shapes
|
|
35
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
36
|
-
|
|
37
|
-
|
|
38
34
|
def _get_conv_layers(
|
|
39
35
|
dim: int,
|
|
40
36
|
) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
|
|
@@ -49,36 +45,6 @@ def _get_conv_layers(
|
|
|
49
45
|
raise ValueError(f"Unsupported dimensionality: {dim}D. Supported: 1D, 2D, 3D.")
|
|
50
46
|
|
|
51
47
|
|
|
52
|
-
def _get_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
|
|
53
|
-
"""
|
|
54
|
-
Get valid num_groups for GroupNorm that divides num_channels evenly.
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
num_channels: Number of channels to normalize
|
|
58
|
-
preferred_groups: Preferred number of groups (default: 32)
|
|
59
|
-
|
|
60
|
-
Returns:
|
|
61
|
-
Valid num_groups that divides num_channels
|
|
62
|
-
|
|
63
|
-
Raises:
|
|
64
|
-
ValueError: If no valid divisor found (shouldn't happen with power-of-2 channels)
|
|
65
|
-
"""
|
|
66
|
-
# Try preferred groups first, then decrease
|
|
67
|
-
for groups in [preferred_groups, 16, 8, 4, 2, 1]:
|
|
68
|
-
if groups <= num_channels and num_channels % groups == 0:
|
|
69
|
-
return groups
|
|
70
|
-
|
|
71
|
-
# Fallback: find any valid divisor
|
|
72
|
-
for groups in range(min(32, num_channels), 0, -1):
|
|
73
|
-
if num_channels % groups == 0:
|
|
74
|
-
return groups
|
|
75
|
-
|
|
76
|
-
raise ValueError(
|
|
77
|
-
f"Cannot find valid num_groups for {num_channels} channels. "
|
|
78
|
-
f"Consider using base_width that is a power of 2 (e.g., 32, 64, 128)."
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
|
|
82
48
|
class BasicBlock(nn.Module):
|
|
83
49
|
"""
|
|
84
50
|
Basic residual block for ResNet-18/34.
|
|
@@ -107,12 +73,12 @@ class BasicBlock(nn.Module):
|
|
|
107
73
|
padding=1,
|
|
108
74
|
bias=False,
|
|
109
75
|
)
|
|
110
|
-
self.gn1 = nn.GroupNorm(
|
|
76
|
+
self.gn1 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
|
|
111
77
|
self.relu = nn.ReLU(inplace=True)
|
|
112
78
|
self.conv2 = Conv(
|
|
113
79
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
|
114
80
|
)
|
|
115
|
-
self.gn2 = nn.GroupNorm(
|
|
81
|
+
self.gn2 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
|
|
116
82
|
self.downsample = downsample
|
|
117
83
|
|
|
118
84
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -155,7 +121,7 @@ class Bottleneck(nn.Module):
|
|
|
155
121
|
|
|
156
122
|
# 1x1 reduce
|
|
157
123
|
self.conv1 = Conv(in_channels, out_channels, kernel_size=1, bias=False)
|
|
158
|
-
self.gn1 = nn.GroupNorm(
|
|
124
|
+
self.gn1 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
|
|
159
125
|
|
|
160
126
|
# 3x3 conv
|
|
161
127
|
self.conv2 = Conv(
|
|
@@ -166,14 +132,16 @@ class Bottleneck(nn.Module):
|
|
|
166
132
|
padding=1,
|
|
167
133
|
bias=False,
|
|
168
134
|
)
|
|
169
|
-
self.gn2 = nn.GroupNorm(
|
|
135
|
+
self.gn2 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
|
|
170
136
|
|
|
171
137
|
# 1x1 expand
|
|
172
138
|
self.conv3 = Conv(
|
|
173
139
|
out_channels, out_channels * self.expansion, kernel_size=1, bias=False
|
|
174
140
|
)
|
|
175
141
|
expanded_channels = out_channels * self.expansion
|
|
176
|
-
self.gn3 = nn.GroupNorm(
|
|
142
|
+
self.gn3 = nn.GroupNorm(
|
|
143
|
+
compute_num_groups(expanded_channels), expanded_channels
|
|
144
|
+
)
|
|
177
145
|
|
|
178
146
|
self.relu = nn.ReLU(inplace=True)
|
|
179
147
|
self.downsample = downsample
|
|
@@ -229,7 +197,7 @@ class ResNetBase(BaseModel):
|
|
|
229
197
|
|
|
230
198
|
# Stem: 7x7 conv (or equivalent for 1D/3D)
|
|
231
199
|
self.conv1 = Conv(1, base_width, kernel_size=7, stride=2, padding=3, bias=False)
|
|
232
|
-
self.gn1 = nn.GroupNorm(
|
|
200
|
+
self.gn1 = nn.GroupNorm(compute_num_groups(base_width), base_width)
|
|
233
201
|
self.relu = nn.ReLU(inplace=True)
|
|
234
202
|
self.maxpool = MaxPool(kernel_size=3, stride=2, padding=1)
|
|
235
203
|
|
|
@@ -275,7 +243,7 @@ class ResNetBase(BaseModel):
|
|
|
275
243
|
bias=False,
|
|
276
244
|
),
|
|
277
245
|
nn.GroupNorm(
|
|
278
|
-
|
|
246
|
+
compute_num_groups(out_channels * block.expansion),
|
|
279
247
|
out_channels * block.expansion,
|
|
280
248
|
),
|
|
281
249
|
)
|
|
@@ -495,21 +463,11 @@ class PretrainedResNetBase(BaseModel):
|
|
|
495
463
|
|
|
496
464
|
# Modify first conv for single-channel input
|
|
497
465
|
# Original: Conv2d(3, 64, ...) → New: Conv2d(1, 64, ...)
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
kernel_size=old_conv.kernel_size,
|
|
503
|
-
stride=old_conv.stride,
|
|
504
|
-
padding=old_conv.padding,
|
|
505
|
-
bias=False,
|
|
466
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
467
|
+
|
|
468
|
+
adapt_first_conv_for_single_channel(
|
|
469
|
+
self.backbone, "conv1", pretrained=pretrained
|
|
506
470
|
)
|
|
507
|
-
# Initialize new conv with mean of pretrained weights
|
|
508
|
-
if pretrained:
|
|
509
|
-
with torch.no_grad():
|
|
510
|
-
self.backbone.conv1.weight = nn.Parameter(
|
|
511
|
-
old_conv.weight.mean(dim=1, keepdim=True)
|
|
512
|
-
)
|
|
513
471
|
|
|
514
472
|
# Optionally freeze backbone
|
|
515
473
|
if freeze_backbone:
|
wavedl/models/resnet3d.py
CHANGED
|
@@ -1,258 +1,258 @@
|
|
|
1
|
-
"""
|
|
2
|
-
ResNet3D: 3D Residual Networks for Volumetric Data
|
|
3
|
-
===================================================
|
|
4
|
-
|
|
5
|
-
3D extension of ResNet for processing volumetric data such as C-scans,
|
|
6
|
-
3D wavefield imaging, and spatiotemporal cubes. Wraps torchvision's
|
|
7
|
-
video models adapted for regression tasks.
|
|
8
|
-
|
|
9
|
-
**Key Features**:
|
|
10
|
-
- Native 3D convolutions for volumetric processing
|
|
11
|
-
- Pretrained weights from Kinetics-400 (video action recognition)
|
|
12
|
-
- Adapted for single-channel input (grayscale volumes)
|
|
13
|
-
- Custom regression head for parameter estimation
|
|
14
|
-
|
|
15
|
-
**Variants**:
|
|
16
|
-
- resnet3d_18: Lightweight (33M params)
|
|
17
|
-
- resnet3d_34: Medium depth
|
|
18
|
-
- resnet3d_50: Higher capacity with bottleneck blocks
|
|
19
|
-
|
|
20
|
-
**Use Cases**:
|
|
21
|
-
- C-scan volume analysis (ultrasonic NDT)
|
|
22
|
-
- 3D wavefield imaging and inversion
|
|
23
|
-
- Spatiotemporal data cubes (time × space × space)
|
|
24
|
-
- Medical imaging (CT/MRI volumes)
|
|
25
|
-
|
|
26
|
-
**Note**: ResNet3D is 3D-only. For 1D/2D data, use TCN or standard ResNet.
|
|
27
|
-
|
|
28
|
-
References:
|
|
29
|
-
Hara, K., Kataoka, H., & Satoh, Y. (2018). Can Spatiotemporal 3D CNNs
|
|
30
|
-
Retrace the History of 2D CNNs and ImageNet? CVPR 2018.
|
|
31
|
-
https://arxiv.org/abs/1711.09577
|
|
32
|
-
|
|
33
|
-
He, K., et al. (2016). Deep Residual Learning for Image Recognition.
|
|
34
|
-
CVPR 2016. https://arxiv.org/abs/1512.03385
|
|
35
|
-
|
|
36
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
from typing import Any
|
|
40
|
-
|
|
41
|
-
import torch
|
|
42
|
-
import torch.nn as nn
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
try:
|
|
46
|
-
from torchvision.models.video import (
|
|
47
|
-
MC3_18_Weights,
|
|
48
|
-
R3D_18_Weights,
|
|
49
|
-
mc3_18,
|
|
50
|
-
r3d_18,
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
RESNET3D_AVAILABLE = True
|
|
54
|
-
except ImportError:
|
|
55
|
-
RESNET3D_AVAILABLE = False
|
|
56
|
-
|
|
57
|
-
from wavedl.models.base import BaseModel
|
|
58
|
-
from wavedl.models.registry import register_model
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class ResNet3DBase(BaseModel):
|
|
62
|
-
"""
|
|
63
|
-
Base ResNet3D class for volumetric regression tasks.
|
|
64
|
-
|
|
65
|
-
Wraps torchvision 3D ResNet with:
|
|
66
|
-
- Optional pretrained weights (Kinetics-400)
|
|
67
|
-
- Automatic input channel adaptation (grayscale → 3ch)
|
|
68
|
-
- Custom regression head
|
|
69
|
-
|
|
70
|
-
Note: This is 3D-only. Input shape must be (D, H, W).
|
|
71
|
-
"""
|
|
72
|
-
|
|
73
|
-
def __init__(
|
|
74
|
-
self,
|
|
75
|
-
in_shape: tuple[int, int, int],
|
|
76
|
-
out_size: int,
|
|
77
|
-
model_fn,
|
|
78
|
-
weights_class,
|
|
79
|
-
pretrained: bool = True,
|
|
80
|
-
dropout_rate: float = 0.3,
|
|
81
|
-
freeze_backbone: bool = False,
|
|
82
|
-
regression_hidden: int = 512,
|
|
83
|
-
**kwargs,
|
|
84
|
-
):
|
|
85
|
-
"""
|
|
86
|
-
Initialize ResNet3D for regression.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
in_shape: (D, H, W) input volume dimensions
|
|
90
|
-
out_size: Number of regression output targets
|
|
91
|
-
model_fn: torchvision model constructor
|
|
92
|
-
weights_class: Pretrained weights enum class
|
|
93
|
-
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
94
|
-
dropout_rate: Dropout rate in regression head (default: 0.3)
|
|
95
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
96
|
-
regression_hidden: Hidden units in regression head (default: 512)
|
|
97
|
-
"""
|
|
98
|
-
super().__init__(in_shape, out_size)
|
|
99
|
-
|
|
100
|
-
if not RESNET3D_AVAILABLE:
|
|
101
|
-
raise ImportError(
|
|
102
|
-
"torchvision >= 0.12 is required for ResNet3D. "
|
|
103
|
-
"Install with: pip install torchvision>=0.12"
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if len(in_shape) != 3:
|
|
107
|
-
raise ValueError(
|
|
108
|
-
f"ResNet3D requires 3D input (D, H, W), got {len(in_shape)}D. "
|
|
109
|
-
"For 1D data, use TCN. For 2D data, use standard ResNet."
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
self.pretrained = pretrained
|
|
113
|
-
self.dropout_rate = dropout_rate
|
|
114
|
-
self.freeze_backbone = freeze_backbone
|
|
115
|
-
self.regression_hidden = regression_hidden
|
|
116
|
-
|
|
117
|
-
# Load pretrained backbone
|
|
118
|
-
weights = weights_class.DEFAULT if pretrained else None
|
|
119
|
-
self.backbone = model_fn(weights=weights)
|
|
120
|
-
|
|
121
|
-
# Get the fc input features
|
|
122
|
-
in_features = self.backbone.fc.in_features
|
|
123
|
-
|
|
124
|
-
# Replace fc with regression head
|
|
125
|
-
self.backbone.fc = nn.Sequential(
|
|
126
|
-
nn.Dropout(dropout_rate),
|
|
127
|
-
nn.Linear(in_features, regression_hidden),
|
|
128
|
-
nn.ReLU(inplace=True),
|
|
129
|
-
nn.Dropout(dropout_rate * 0.5),
|
|
130
|
-
nn.Linear(regression_hidden, regression_hidden // 2),
|
|
131
|
-
nn.ReLU(inplace=True),
|
|
132
|
-
nn.Linear(regression_hidden // 2, out_size),
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
# Optionally freeze backbone for fine-tuning
|
|
136
|
-
if freeze_backbone:
|
|
137
|
-
self._freeze_backbone()
|
|
138
|
-
|
|
139
|
-
def _freeze_backbone(self):
|
|
140
|
-
"""Freeze all backbone parameters except the fc head."""
|
|
141
|
-
for name, param in self.backbone.named_parameters():
|
|
142
|
-
if "fc" not in name:
|
|
143
|
-
param.requires_grad = False
|
|
144
|
-
|
|
145
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
146
|
-
"""
|
|
147
|
-
Forward pass.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
x: Input tensor of shape (B, C, D, H, W) where C is 1 or 3
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
Output tensor of shape (B, out_size)
|
|
154
|
-
"""
|
|
155
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
156
|
-
if x.size(1) == 1:
|
|
157
|
-
x = x.expand(-1, 3, -1, -1, -1)
|
|
158
|
-
|
|
159
|
-
return self.backbone(x)
|
|
160
|
-
|
|
161
|
-
@classmethod
|
|
162
|
-
def get_default_config(cls) -> dict[str, Any]:
|
|
163
|
-
"""Return default configuration for ResNet3D."""
|
|
164
|
-
return {
|
|
165
|
-
"pretrained": True,
|
|
166
|
-
"dropout_rate": 0.3,
|
|
167
|
-
"freeze_backbone": False,
|
|
168
|
-
"regression_hidden": 512,
|
|
169
|
-
}
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
# =============================================================================
|
|
173
|
-
# REGISTERED MODEL VARIANTS
|
|
174
|
-
# =============================================================================
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
@register_model("resnet3d_18")
|
|
178
|
-
class ResNet3D18(ResNet3DBase):
|
|
179
|
-
"""
|
|
180
|
-
ResNet3D-18: Lightweight 3D ResNet for volumetric data.
|
|
181
|
-
|
|
182
|
-
~33.2M backbone parameters. Uses 3D convolutions throughout for true volumetric processing.
|
|
183
|
-
Pretrained on Kinetics-400 (video action recognition).
|
|
184
|
-
|
|
185
|
-
Recommended for:
|
|
186
|
-
- C-scan ultrasonic inspection volumes
|
|
187
|
-
- 3D wavefield data cubes
|
|
188
|
-
- Medical imaging (CT/MRI)
|
|
189
|
-
- Moderate compute budgets
|
|
190
|
-
|
|
191
|
-
Args:
|
|
192
|
-
in_shape: (D, H, W) volume dimensions
|
|
193
|
-
out_size: Number of regression targets
|
|
194
|
-
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
195
|
-
dropout_rate: Dropout rate in head (default: 0.3)
|
|
196
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
197
|
-
regression_hidden: Hidden units in regression head (default: 512)
|
|
198
|
-
|
|
199
|
-
Example:
|
|
200
|
-
>>> model = ResNet3D18(in_shape=(16, 112, 112), out_size=3)
|
|
201
|
-
>>> x = torch.randn(2, 1, 16, 112, 112)
|
|
202
|
-
>>> out = model(x) # (2, 3)
|
|
203
|
-
"""
|
|
204
|
-
|
|
205
|
-
def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
|
|
206
|
-
super().__init__(
|
|
207
|
-
in_shape=in_shape,
|
|
208
|
-
out_size=out_size,
|
|
209
|
-
model_fn=r3d_18,
|
|
210
|
-
weights_class=R3D_18_Weights,
|
|
211
|
-
**kwargs,
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
def __repr__(self) -> str:
|
|
215
|
-
pt = "pretrained" if self.pretrained else "scratch"
|
|
216
|
-
return f"ResNet3D_18({pt}, in={self.in_shape}, out={self.out_size})"
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
@register_model("mc3_18")
|
|
220
|
-
class MC3_18(ResNet3DBase):
|
|
221
|
-
"""
|
|
222
|
-
MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
|
|
223
|
-
|
|
224
|
-
~11.5M backbone parameters. More efficient than pure 3D ResNet while maintaining
|
|
225
|
-
good spatiotemporal modeling. Uses 3D convolutions in early layers
|
|
226
|
-
and 2D convolutions in later layers.
|
|
227
|
-
|
|
228
|
-
Recommended for:
|
|
229
|
-
- When pure 3D is too expensive
|
|
230
|
-
- Volumes with limited temporal/depth extent
|
|
231
|
-
- Faster training with reasonable accuracy
|
|
232
|
-
|
|
233
|
-
Args:
|
|
234
|
-
in_shape: (D, H, W) volume dimensions
|
|
235
|
-
out_size: Number of regression targets
|
|
236
|
-
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
237
|
-
dropout_rate: Dropout rate in head (default: 0.3)
|
|
238
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
239
|
-
regression_hidden: Hidden units in regression head (default: 512)
|
|
240
|
-
|
|
241
|
-
Example:
|
|
242
|
-
>>> model = MC3_18(in_shape=(16, 112, 112), out_size=3)
|
|
243
|
-
>>> x = torch.randn(2, 1, 16, 112, 112)
|
|
244
|
-
>>> out = model(x) # (2, 3)
|
|
245
|
-
"""
|
|
246
|
-
|
|
247
|
-
def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
|
|
248
|
-
super().__init__(
|
|
249
|
-
in_shape=in_shape,
|
|
250
|
-
out_size=out_size,
|
|
251
|
-
model_fn=mc3_18,
|
|
252
|
-
weights_class=MC3_18_Weights,
|
|
253
|
-
**kwargs,
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
def __repr__(self) -> str:
|
|
257
|
-
pt = "pretrained" if self.pretrained else "scratch"
|
|
258
|
-
return f"MC3_18({pt}, in={self.in_shape}, out={self.out_size})"
|
|
1
|
+
"""
|
|
2
|
+
ResNet3D: 3D Residual Networks for Volumetric Data
|
|
3
|
+
===================================================
|
|
4
|
+
|
|
5
|
+
3D extension of ResNet for processing volumetric data such as C-scans,
|
|
6
|
+
3D wavefield imaging, and spatiotemporal cubes. Wraps torchvision's
|
|
7
|
+
video models adapted for regression tasks.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Native 3D convolutions for volumetric processing
|
|
11
|
+
- Pretrained weights from Kinetics-400 (video action recognition)
|
|
12
|
+
- Adapted for single-channel input (grayscale volumes)
|
|
13
|
+
- Custom regression head for parameter estimation
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- resnet3d_18: Lightweight (33M params)
|
|
17
|
+
- resnet3d_34: Medium depth
|
|
18
|
+
- resnet3d_50: Higher capacity with bottleneck blocks
|
|
19
|
+
|
|
20
|
+
**Use Cases**:
|
|
21
|
+
- C-scan volume analysis (ultrasonic NDT)
|
|
22
|
+
- 3D wavefield imaging and inversion
|
|
23
|
+
- Spatiotemporal data cubes (time × space × space)
|
|
24
|
+
- Medical imaging (CT/MRI volumes)
|
|
25
|
+
|
|
26
|
+
**Note**: ResNet3D is 3D-only. For 1D/2D data, use TCN or standard ResNet.
|
|
27
|
+
|
|
28
|
+
References:
|
|
29
|
+
Hara, K., Kataoka, H., & Satoh, Y. (2018). Can Spatiotemporal 3D CNNs
|
|
30
|
+
Retrace the History of 2D CNNs and ImageNet? CVPR 2018.
|
|
31
|
+
https://arxiv.org/abs/1711.09577
|
|
32
|
+
|
|
33
|
+
He, K., et al. (2016). Deep Residual Learning for Image Recognition.
|
|
34
|
+
CVPR 2016. https://arxiv.org/abs/1512.03385
|
|
35
|
+
|
|
36
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
from typing import Any
|
|
40
|
+
|
|
41
|
+
import torch
|
|
42
|
+
import torch.nn as nn
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
from torchvision.models.video import (
|
|
47
|
+
MC3_18_Weights,
|
|
48
|
+
R3D_18_Weights,
|
|
49
|
+
mc3_18,
|
|
50
|
+
r3d_18,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
RESNET3D_AVAILABLE = True
|
|
54
|
+
except ImportError:
|
|
55
|
+
RESNET3D_AVAILABLE = False
|
|
56
|
+
|
|
57
|
+
from wavedl.models.base import BaseModel
|
|
58
|
+
from wavedl.models.registry import register_model
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ResNet3DBase(BaseModel):
|
|
62
|
+
"""
|
|
63
|
+
Base ResNet3D class for volumetric regression tasks.
|
|
64
|
+
|
|
65
|
+
Wraps torchvision 3D ResNet with:
|
|
66
|
+
- Optional pretrained weights (Kinetics-400)
|
|
67
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
68
|
+
- Custom regression head
|
|
69
|
+
|
|
70
|
+
Note: This is 3D-only. Input shape must be (D, H, W).
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
in_shape: tuple[int, int, int],
|
|
76
|
+
out_size: int,
|
|
77
|
+
model_fn,
|
|
78
|
+
weights_class,
|
|
79
|
+
pretrained: bool = True,
|
|
80
|
+
dropout_rate: float = 0.3,
|
|
81
|
+
freeze_backbone: bool = False,
|
|
82
|
+
regression_hidden: int = 512,
|
|
83
|
+
**kwargs,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Initialize ResNet3D for regression.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
in_shape: (D, H, W) input volume dimensions
|
|
90
|
+
out_size: Number of regression output targets
|
|
91
|
+
model_fn: torchvision model constructor
|
|
92
|
+
weights_class: Pretrained weights enum class
|
|
93
|
+
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
94
|
+
dropout_rate: Dropout rate in regression head (default: 0.3)
|
|
95
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
96
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
97
|
+
"""
|
|
98
|
+
super().__init__(in_shape, out_size)
|
|
99
|
+
|
|
100
|
+
if not RESNET3D_AVAILABLE:
|
|
101
|
+
raise ImportError(
|
|
102
|
+
"torchvision >= 0.12 is required for ResNet3D. "
|
|
103
|
+
"Install with: pip install torchvision>=0.12"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if len(in_shape) != 3:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"ResNet3D requires 3D input (D, H, W), got {len(in_shape)}D. "
|
|
109
|
+
"For 1D data, use TCN. For 2D data, use standard ResNet."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self.pretrained = pretrained
|
|
113
|
+
self.dropout_rate = dropout_rate
|
|
114
|
+
self.freeze_backbone = freeze_backbone
|
|
115
|
+
self.regression_hidden = regression_hidden
|
|
116
|
+
|
|
117
|
+
# Load pretrained backbone
|
|
118
|
+
weights = weights_class.DEFAULT if pretrained else None
|
|
119
|
+
self.backbone = model_fn(weights=weights)
|
|
120
|
+
|
|
121
|
+
# Get the fc input features
|
|
122
|
+
in_features = self.backbone.fc.in_features
|
|
123
|
+
|
|
124
|
+
# Replace fc with regression head
|
|
125
|
+
self.backbone.fc = nn.Sequential(
|
|
126
|
+
nn.Dropout(dropout_rate),
|
|
127
|
+
nn.Linear(in_features, regression_hidden),
|
|
128
|
+
nn.ReLU(inplace=True),
|
|
129
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
130
|
+
nn.Linear(regression_hidden, regression_hidden // 2),
|
|
131
|
+
nn.ReLU(inplace=True),
|
|
132
|
+
nn.Linear(regression_hidden // 2, out_size),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Optionally freeze backbone for fine-tuning
|
|
136
|
+
if freeze_backbone:
|
|
137
|
+
self._freeze_backbone()
|
|
138
|
+
|
|
139
|
+
def _freeze_backbone(self):
|
|
140
|
+
"""Freeze all backbone parameters except the fc head."""
|
|
141
|
+
for name, param in self.backbone.named_parameters():
|
|
142
|
+
if "fc" not in name:
|
|
143
|
+
param.requires_grad = False
|
|
144
|
+
|
|
145
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
"""
|
|
147
|
+
Forward pass.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
x: Input tensor of shape (B, C, D, H, W) where C is 1 or 3
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Output tensor of shape (B, out_size)
|
|
154
|
+
"""
|
|
155
|
+
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
156
|
+
if x.size(1) == 1:
|
|
157
|
+
x = x.expand(-1, 3, -1, -1, -1)
|
|
158
|
+
|
|
159
|
+
return self.backbone(x)
|
|
160
|
+
|
|
161
|
+
@classmethod
|
|
162
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
163
|
+
"""Return default configuration for ResNet3D."""
|
|
164
|
+
return {
|
|
165
|
+
"pretrained": True,
|
|
166
|
+
"dropout_rate": 0.3,
|
|
167
|
+
"freeze_backbone": False,
|
|
168
|
+
"regression_hidden": 512,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# =============================================================================
|
|
173
|
+
# REGISTERED MODEL VARIANTS
|
|
174
|
+
# =============================================================================
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@register_model("resnet3d_18")
|
|
178
|
+
class ResNet3D18(ResNet3DBase):
|
|
179
|
+
"""
|
|
180
|
+
ResNet3D-18: Lightweight 3D ResNet for volumetric data.
|
|
181
|
+
|
|
182
|
+
~33.2M backbone parameters. Uses 3D convolutions throughout for true volumetric processing.
|
|
183
|
+
Pretrained on Kinetics-400 (video action recognition).
|
|
184
|
+
|
|
185
|
+
Recommended for:
|
|
186
|
+
- C-scan ultrasonic inspection volumes
|
|
187
|
+
- 3D wavefield data cubes
|
|
188
|
+
- Medical imaging (CT/MRI)
|
|
189
|
+
- Moderate compute budgets
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
in_shape: (D, H, W) volume dimensions
|
|
193
|
+
out_size: Number of regression targets
|
|
194
|
+
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
195
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
196
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
197
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
198
|
+
|
|
199
|
+
Example:
|
|
200
|
+
>>> model = ResNet3D18(in_shape=(16, 112, 112), out_size=3)
|
|
201
|
+
>>> x = torch.randn(2, 1, 16, 112, 112)
|
|
202
|
+
>>> out = model(x) # (2, 3)
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
|
|
206
|
+
super().__init__(
|
|
207
|
+
in_shape=in_shape,
|
|
208
|
+
out_size=out_size,
|
|
209
|
+
model_fn=r3d_18,
|
|
210
|
+
weights_class=R3D_18_Weights,
|
|
211
|
+
**kwargs,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def __repr__(self) -> str:
|
|
215
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
216
|
+
return f"ResNet3D_18({pt}, in={self.in_shape}, out={self.out_size})"
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@register_model("mc3_18")
|
|
220
|
+
class MC3_18(ResNet3DBase):
|
|
221
|
+
"""
|
|
222
|
+
MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
|
|
223
|
+
|
|
224
|
+
~11.5M backbone parameters. More efficient than pure 3D ResNet while maintaining
|
|
225
|
+
good spatiotemporal modeling. Uses 3D convolutions in early layers
|
|
226
|
+
and 2D convolutions in later layers.
|
|
227
|
+
|
|
228
|
+
Recommended for:
|
|
229
|
+
- When pure 3D is too expensive
|
|
230
|
+
- Volumes with limited temporal/depth extent
|
|
231
|
+
- Faster training with reasonable accuracy
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
in_shape: (D, H, W) volume dimensions
|
|
235
|
+
out_size: Number of regression targets
|
|
236
|
+
pretrained: Use Kinetics-400 pretrained weights (default: True)
|
|
237
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
238
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
239
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
240
|
+
|
|
241
|
+
Example:
|
|
242
|
+
>>> model = MC3_18(in_shape=(16, 112, 112), out_size=3)
|
|
243
|
+
>>> x = torch.randn(2, 1, 16, 112, 112)
|
|
244
|
+
>>> out = model(x) # (2, 3)
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
|
|
248
|
+
super().__init__(
|
|
249
|
+
in_shape=in_shape,
|
|
250
|
+
out_size=out_size,
|
|
251
|
+
model_fn=mc3_18,
|
|
252
|
+
weights_class=MC3_18_Weights,
|
|
253
|
+
**kwargs,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def __repr__(self) -> str:
|
|
257
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
258
|
+
return f"MC3_18({pt}, in={self.in_shape}, out={self.out_size})"
|