wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/resnet.py CHANGED
@@ -27,14 +27,10 @@ from typing import Any
27
27
  import torch
28
28
  import torch.nn as nn
29
29
 
30
- from wavedl.models.base import BaseModel
30
+ from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
31
31
  from wavedl.models.registry import register_model
32
32
 
33
33
 
34
- # Type alias for spatial shapes
35
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
36
-
37
-
38
34
  def _get_conv_layers(
39
35
  dim: int,
40
36
  ) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
@@ -49,36 +45,6 @@ def _get_conv_layers(
49
45
  raise ValueError(f"Unsupported dimensionality: {dim}D. Supported: 1D, 2D, 3D.")
50
46
 
51
47
 
52
- def _get_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
53
- """
54
- Get valid num_groups for GroupNorm that divides num_channels evenly.
55
-
56
- Args:
57
- num_channels: Number of channels to normalize
58
- preferred_groups: Preferred number of groups (default: 32)
59
-
60
- Returns:
61
- Valid num_groups that divides num_channels
62
-
63
- Raises:
64
- ValueError: If no valid divisor found (shouldn't happen with power-of-2 channels)
65
- """
66
- # Try preferred groups first, then decrease
67
- for groups in [preferred_groups, 16, 8, 4, 2, 1]:
68
- if groups <= num_channels and num_channels % groups == 0:
69
- return groups
70
-
71
- # Fallback: find any valid divisor
72
- for groups in range(min(32, num_channels), 0, -1):
73
- if num_channels % groups == 0:
74
- return groups
75
-
76
- raise ValueError(
77
- f"Cannot find valid num_groups for {num_channels} channels. "
78
- f"Consider using base_width that is a power of 2 (e.g., 32, 64, 128)."
79
- )
80
-
81
-
82
48
  class BasicBlock(nn.Module):
83
49
  """
84
50
  Basic residual block for ResNet-18/34.
@@ -107,12 +73,12 @@ class BasicBlock(nn.Module):
107
73
  padding=1,
108
74
  bias=False,
109
75
  )
110
- self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
76
+ self.gn1 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
111
77
  self.relu = nn.ReLU(inplace=True)
112
78
  self.conv2 = Conv(
113
79
  out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
114
80
  )
115
- self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
81
+ self.gn2 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
116
82
  self.downsample = downsample
117
83
 
118
84
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -155,7 +121,7 @@ class Bottleneck(nn.Module):
155
121
 
156
122
  # 1x1 reduce
157
123
  self.conv1 = Conv(in_channels, out_channels, kernel_size=1, bias=False)
158
- self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
124
+ self.gn1 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
159
125
 
160
126
  # 3x3 conv
161
127
  self.conv2 = Conv(
@@ -166,14 +132,16 @@ class Bottleneck(nn.Module):
166
132
  padding=1,
167
133
  bias=False,
168
134
  )
169
- self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
135
+ self.gn2 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
170
136
 
171
137
  # 1x1 expand
172
138
  self.conv3 = Conv(
173
139
  out_channels, out_channels * self.expansion, kernel_size=1, bias=False
174
140
  )
175
141
  expanded_channels = out_channels * self.expansion
176
- self.gn3 = nn.GroupNorm(_get_num_groups(expanded_channels), expanded_channels)
142
+ self.gn3 = nn.GroupNorm(
143
+ compute_num_groups(expanded_channels), expanded_channels
144
+ )
177
145
 
178
146
  self.relu = nn.ReLU(inplace=True)
179
147
  self.downsample = downsample
@@ -229,7 +197,7 @@ class ResNetBase(BaseModel):
229
197
 
230
198
  # Stem: 7x7 conv (or equivalent for 1D/3D)
231
199
  self.conv1 = Conv(1, base_width, kernel_size=7, stride=2, padding=3, bias=False)
232
- self.gn1 = nn.GroupNorm(_get_num_groups(base_width), base_width)
200
+ self.gn1 = nn.GroupNorm(compute_num_groups(base_width), base_width)
233
201
  self.relu = nn.ReLU(inplace=True)
234
202
  self.maxpool = MaxPool(kernel_size=3, stride=2, padding=1)
235
203
 
@@ -275,7 +243,7 @@ class ResNetBase(BaseModel):
275
243
  bias=False,
276
244
  ),
277
245
  nn.GroupNorm(
278
- _get_num_groups(out_channels * block.expansion),
246
+ compute_num_groups(out_channels * block.expansion),
279
247
  out_channels * block.expansion,
280
248
  ),
281
249
  )
@@ -495,21 +463,11 @@ class PretrainedResNetBase(BaseModel):
495
463
 
496
464
  # Modify first conv for single-channel input
497
465
  # Original: Conv2d(3, 64, ...) → New: Conv2d(1, 64, ...)
498
- old_conv = self.backbone.conv1
499
- self.backbone.conv1 = nn.Conv2d(
500
- 1,
501
- old_conv.out_channels,
502
- kernel_size=old_conv.kernel_size,
503
- stride=old_conv.stride,
504
- padding=old_conv.padding,
505
- bias=False,
466
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
467
+
468
+ adapt_first_conv_for_single_channel(
469
+ self.backbone, "conv1", pretrained=pretrained
506
470
  )
507
- # Initialize new conv with mean of pretrained weights
508
- if pretrained:
509
- with torch.no_grad():
510
- self.backbone.conv1.weight = nn.Parameter(
511
- old_conv.weight.mean(dim=1, keepdim=True)
512
- )
513
471
 
514
472
  # Optionally freeze backbone
515
473
  if freeze_backbone:
wavedl/models/resnet3d.py CHANGED
@@ -1,258 +1,258 @@
1
- """
2
- ResNet3D: 3D Residual Networks for Volumetric Data
3
- ===================================================
4
-
5
- 3D extension of ResNet for processing volumetric data such as C-scans,
6
- 3D wavefield imaging, and spatiotemporal cubes. Wraps torchvision's
7
- video models adapted for regression tasks.
8
-
9
- **Key Features**:
10
- - Native 3D convolutions for volumetric processing
11
- - Pretrained weights from Kinetics-400 (video action recognition)
12
- - Adapted for single-channel input (grayscale volumes)
13
- - Custom regression head for parameter estimation
14
-
15
- **Variants**:
16
- - resnet3d_18: Lightweight (33M params)
17
- - resnet3d_34: Medium depth
18
- - resnet3d_50: Higher capacity with bottleneck blocks
19
-
20
- **Use Cases**:
21
- - C-scan volume analysis (ultrasonic NDT)
22
- - 3D wavefield imaging and inversion
23
- - Spatiotemporal data cubes (time × space × space)
24
- - Medical imaging (CT/MRI volumes)
25
-
26
- **Note**: ResNet3D is 3D-only. For 1D/2D data, use TCN or standard ResNet.
27
-
28
- References:
29
- Hara, K., Kataoka, H., & Satoh, Y. (2018). Can Spatiotemporal 3D CNNs
30
- Retrace the History of 2D CNNs and ImageNet? CVPR 2018.
31
- https://arxiv.org/abs/1711.09577
32
-
33
- He, K., et al. (2016). Deep Residual Learning for Image Recognition.
34
- CVPR 2016. https://arxiv.org/abs/1512.03385
35
-
36
- Author: Ductho Le (ductho.le@outlook.com)
37
- """
38
-
39
- from typing import Any
40
-
41
- import torch
42
- import torch.nn as nn
43
-
44
-
45
- try:
46
- from torchvision.models.video import (
47
- MC3_18_Weights,
48
- R3D_18_Weights,
49
- mc3_18,
50
- r3d_18,
51
- )
52
-
53
- RESNET3D_AVAILABLE = True
54
- except ImportError:
55
- RESNET3D_AVAILABLE = False
56
-
57
- from wavedl.models.base import BaseModel
58
- from wavedl.models.registry import register_model
59
-
60
-
61
- class ResNet3DBase(BaseModel):
62
- """
63
- Base ResNet3D class for volumetric regression tasks.
64
-
65
- Wraps torchvision 3D ResNet with:
66
- - Optional pretrained weights (Kinetics-400)
67
- - Automatic input channel adaptation (grayscale → 3ch)
68
- - Custom regression head
69
-
70
- Note: This is 3D-only. Input shape must be (D, H, W).
71
- """
72
-
73
- def __init__(
74
- self,
75
- in_shape: tuple[int, int, int],
76
- out_size: int,
77
- model_fn,
78
- weights_class,
79
- pretrained: bool = True,
80
- dropout_rate: float = 0.3,
81
- freeze_backbone: bool = False,
82
- regression_hidden: int = 512,
83
- **kwargs,
84
- ):
85
- """
86
- Initialize ResNet3D for regression.
87
-
88
- Args:
89
- in_shape: (D, H, W) input volume dimensions
90
- out_size: Number of regression output targets
91
- model_fn: torchvision model constructor
92
- weights_class: Pretrained weights enum class
93
- pretrained: Use Kinetics-400 pretrained weights (default: True)
94
- dropout_rate: Dropout rate in regression head (default: 0.3)
95
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
96
- regression_hidden: Hidden units in regression head (default: 512)
97
- """
98
- super().__init__(in_shape, out_size)
99
-
100
- if not RESNET3D_AVAILABLE:
101
- raise ImportError(
102
- "torchvision >= 0.12 is required for ResNet3D. "
103
- "Install with: pip install torchvision>=0.12"
104
- )
105
-
106
- if len(in_shape) != 3:
107
- raise ValueError(
108
- f"ResNet3D requires 3D input (D, H, W), got {len(in_shape)}D. "
109
- "For 1D data, use TCN. For 2D data, use standard ResNet."
110
- )
111
-
112
- self.pretrained = pretrained
113
- self.dropout_rate = dropout_rate
114
- self.freeze_backbone = freeze_backbone
115
- self.regression_hidden = regression_hidden
116
-
117
- # Load pretrained backbone
118
- weights = weights_class.DEFAULT if pretrained else None
119
- self.backbone = model_fn(weights=weights)
120
-
121
- # Get the fc input features
122
- in_features = self.backbone.fc.in_features
123
-
124
- # Replace fc with regression head
125
- self.backbone.fc = nn.Sequential(
126
- nn.Dropout(dropout_rate),
127
- nn.Linear(in_features, regression_hidden),
128
- nn.ReLU(inplace=True),
129
- nn.Dropout(dropout_rate * 0.5),
130
- nn.Linear(regression_hidden, regression_hidden // 2),
131
- nn.ReLU(inplace=True),
132
- nn.Linear(regression_hidden // 2, out_size),
133
- )
134
-
135
- # Optionally freeze backbone for fine-tuning
136
- if freeze_backbone:
137
- self._freeze_backbone()
138
-
139
- def _freeze_backbone(self):
140
- """Freeze all backbone parameters except the fc head."""
141
- for name, param in self.backbone.named_parameters():
142
- if "fc" not in name:
143
- param.requires_grad = False
144
-
145
- def forward(self, x: torch.Tensor) -> torch.Tensor:
146
- """
147
- Forward pass.
148
-
149
- Args:
150
- x: Input tensor of shape (B, C, D, H, W) where C is 1 or 3
151
-
152
- Returns:
153
- Output tensor of shape (B, out_size)
154
- """
155
- # Expand single channel to 3 channels for pretrained weights compatibility
156
- if x.size(1) == 1:
157
- x = x.expand(-1, 3, -1, -1, -1)
158
-
159
- return self.backbone(x)
160
-
161
- @classmethod
162
- def get_default_config(cls) -> dict[str, Any]:
163
- """Return default configuration for ResNet3D."""
164
- return {
165
- "pretrained": True,
166
- "dropout_rate": 0.3,
167
- "freeze_backbone": False,
168
- "regression_hidden": 512,
169
- }
170
-
171
-
172
- # =============================================================================
173
- # REGISTERED MODEL VARIANTS
174
- # =============================================================================
175
-
176
-
177
- @register_model("resnet3d_18")
178
- class ResNet3D18(ResNet3DBase):
179
- """
180
- ResNet3D-18: Lightweight 3D ResNet for volumetric data.
181
-
182
- ~33.2M backbone parameters. Uses 3D convolutions throughout for true volumetric processing.
183
- Pretrained on Kinetics-400 (video action recognition).
184
-
185
- Recommended for:
186
- - C-scan ultrasonic inspection volumes
187
- - 3D wavefield data cubes
188
- - Medical imaging (CT/MRI)
189
- - Moderate compute budgets
190
-
191
- Args:
192
- in_shape: (D, H, W) volume dimensions
193
- out_size: Number of regression targets
194
- pretrained: Use Kinetics-400 pretrained weights (default: True)
195
- dropout_rate: Dropout rate in head (default: 0.3)
196
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
197
- regression_hidden: Hidden units in regression head (default: 512)
198
-
199
- Example:
200
- >>> model = ResNet3D18(in_shape=(16, 112, 112), out_size=3)
201
- >>> x = torch.randn(2, 1, 16, 112, 112)
202
- >>> out = model(x) # (2, 3)
203
- """
204
-
205
- def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
206
- super().__init__(
207
- in_shape=in_shape,
208
- out_size=out_size,
209
- model_fn=r3d_18,
210
- weights_class=R3D_18_Weights,
211
- **kwargs,
212
- )
213
-
214
- def __repr__(self) -> str:
215
- pt = "pretrained" if self.pretrained else "scratch"
216
- return f"ResNet3D_18({pt}, in={self.in_shape}, out={self.out_size})"
217
-
218
-
219
- @register_model("mc3_18")
220
- class MC3_18(ResNet3DBase):
221
- """
222
- MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
223
-
224
- ~11.5M backbone parameters. More efficient than pure 3D ResNet while maintaining
225
- good spatiotemporal modeling. Uses 3D convolutions in early layers
226
- and 2D convolutions in later layers.
227
-
228
- Recommended for:
229
- - When pure 3D is too expensive
230
- - Volumes with limited temporal/depth extent
231
- - Faster training with reasonable accuracy
232
-
233
- Args:
234
- in_shape: (D, H, W) volume dimensions
235
- out_size: Number of regression targets
236
- pretrained: Use Kinetics-400 pretrained weights (default: True)
237
- dropout_rate: Dropout rate in head (default: 0.3)
238
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
239
- regression_hidden: Hidden units in regression head (default: 512)
240
-
241
- Example:
242
- >>> model = MC3_18(in_shape=(16, 112, 112), out_size=3)
243
- >>> x = torch.randn(2, 1, 16, 112, 112)
244
- >>> out = model(x) # (2, 3)
245
- """
246
-
247
- def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
248
- super().__init__(
249
- in_shape=in_shape,
250
- out_size=out_size,
251
- model_fn=mc3_18,
252
- weights_class=MC3_18_Weights,
253
- **kwargs,
254
- )
255
-
256
- def __repr__(self) -> str:
257
- pt = "pretrained" if self.pretrained else "scratch"
258
- return f"MC3_18({pt}, in={self.in_shape}, out={self.out_size})"
1
+ """
2
+ ResNet3D: 3D Residual Networks for Volumetric Data
3
+ ===================================================
4
+
5
+ 3D extension of ResNet for processing volumetric data such as C-scans,
6
+ 3D wavefield imaging, and spatiotemporal cubes. Wraps torchvision's
7
+ video models adapted for regression tasks.
8
+
9
+ **Key Features**:
10
+ - Native 3D convolutions for volumetric processing
11
+ - Pretrained weights from Kinetics-400 (video action recognition)
12
+ - Adapted for single-channel input (grayscale volumes)
13
+ - Custom regression head for parameter estimation
14
+
15
+ **Variants**:
16
+ - resnet3d_18: Lightweight (33M params)
17
+ - resnet3d_34: Medium depth
18
+ - resnet3d_50: Higher capacity with bottleneck blocks
19
+
20
+ **Use Cases**:
21
+ - C-scan volume analysis (ultrasonic NDT)
22
+ - 3D wavefield imaging and inversion
23
+ - Spatiotemporal data cubes (time × space × space)
24
+ - Medical imaging (CT/MRI volumes)
25
+
26
+ **Note**: ResNet3D is 3D-only. For 1D/2D data, use TCN or standard ResNet.
27
+
28
+ References:
29
+ Hara, K., Kataoka, H., & Satoh, Y. (2018). Can Spatiotemporal 3D CNNs
30
+ Retrace the History of 2D CNNs and ImageNet? CVPR 2018.
31
+ https://arxiv.org/abs/1711.09577
32
+
33
+ He, K., et al. (2016). Deep Residual Learning for Image Recognition.
34
+ CVPR 2016. https://arxiv.org/abs/1512.03385
35
+
36
+ Author: Ductho Le (ductho.le@outlook.com)
37
+ """
38
+
39
+ from typing import Any
40
+
41
+ import torch
42
+ import torch.nn as nn
43
+
44
+
45
+ try:
46
+ from torchvision.models.video import (
47
+ MC3_18_Weights,
48
+ R3D_18_Weights,
49
+ mc3_18,
50
+ r3d_18,
51
+ )
52
+
53
+ RESNET3D_AVAILABLE = True
54
+ except ImportError:
55
+ RESNET3D_AVAILABLE = False
56
+
57
+ from wavedl.models.base import BaseModel
58
+ from wavedl.models.registry import register_model
59
+
60
+
61
+ class ResNet3DBase(BaseModel):
62
+ """
63
+ Base ResNet3D class for volumetric regression tasks.
64
+
65
+ Wraps torchvision 3D ResNet with:
66
+ - Optional pretrained weights (Kinetics-400)
67
+ - Automatic input channel adaptation (grayscale → 3ch)
68
+ - Custom regression head
69
+
70
+ Note: This is 3D-only. Input shape must be (D, H, W).
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ in_shape: tuple[int, int, int],
76
+ out_size: int,
77
+ model_fn,
78
+ weights_class,
79
+ pretrained: bool = True,
80
+ dropout_rate: float = 0.3,
81
+ freeze_backbone: bool = False,
82
+ regression_hidden: int = 512,
83
+ **kwargs,
84
+ ):
85
+ """
86
+ Initialize ResNet3D for regression.
87
+
88
+ Args:
89
+ in_shape: (D, H, W) input volume dimensions
90
+ out_size: Number of regression output targets
91
+ model_fn: torchvision model constructor
92
+ weights_class: Pretrained weights enum class
93
+ pretrained: Use Kinetics-400 pretrained weights (default: True)
94
+ dropout_rate: Dropout rate in regression head (default: 0.3)
95
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
96
+ regression_hidden: Hidden units in regression head (default: 512)
97
+ """
98
+ super().__init__(in_shape, out_size)
99
+
100
+ if not RESNET3D_AVAILABLE:
101
+ raise ImportError(
102
+ "torchvision >= 0.12 is required for ResNet3D. "
103
+ "Install with: pip install torchvision>=0.12"
104
+ )
105
+
106
+ if len(in_shape) != 3:
107
+ raise ValueError(
108
+ f"ResNet3D requires 3D input (D, H, W), got {len(in_shape)}D. "
109
+ "For 1D data, use TCN. For 2D data, use standard ResNet."
110
+ )
111
+
112
+ self.pretrained = pretrained
113
+ self.dropout_rate = dropout_rate
114
+ self.freeze_backbone = freeze_backbone
115
+ self.regression_hidden = regression_hidden
116
+
117
+ # Load pretrained backbone
118
+ weights = weights_class.DEFAULT if pretrained else None
119
+ self.backbone = model_fn(weights=weights)
120
+
121
+ # Get the fc input features
122
+ in_features = self.backbone.fc.in_features
123
+
124
+ # Replace fc with regression head
125
+ self.backbone.fc = nn.Sequential(
126
+ nn.Dropout(dropout_rate),
127
+ nn.Linear(in_features, regression_hidden),
128
+ nn.ReLU(inplace=True),
129
+ nn.Dropout(dropout_rate * 0.5),
130
+ nn.Linear(regression_hidden, regression_hidden // 2),
131
+ nn.ReLU(inplace=True),
132
+ nn.Linear(regression_hidden // 2, out_size),
133
+ )
134
+
135
+ # Optionally freeze backbone for fine-tuning
136
+ if freeze_backbone:
137
+ self._freeze_backbone()
138
+
139
+ def _freeze_backbone(self):
140
+ """Freeze all backbone parameters except the fc head."""
141
+ for name, param in self.backbone.named_parameters():
142
+ if "fc" not in name:
143
+ param.requires_grad = False
144
+
145
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Forward pass.
148
+
149
+ Args:
150
+ x: Input tensor of shape (B, C, D, H, W) where C is 1 or 3
151
+
152
+ Returns:
153
+ Output tensor of shape (B, out_size)
154
+ """
155
+ # Expand single channel to 3 channels for pretrained weights compatibility
156
+ if x.size(1) == 1:
157
+ x = x.expand(-1, 3, -1, -1, -1)
158
+
159
+ return self.backbone(x)
160
+
161
+ @classmethod
162
+ def get_default_config(cls) -> dict[str, Any]:
163
+ """Return default configuration for ResNet3D."""
164
+ return {
165
+ "pretrained": True,
166
+ "dropout_rate": 0.3,
167
+ "freeze_backbone": False,
168
+ "regression_hidden": 512,
169
+ }
170
+
171
+
172
+ # =============================================================================
173
+ # REGISTERED MODEL VARIANTS
174
+ # =============================================================================
175
+
176
+
177
+ @register_model("resnet3d_18")
178
+ class ResNet3D18(ResNet3DBase):
179
+ """
180
+ ResNet3D-18: Lightweight 3D ResNet for volumetric data.
181
+
182
+ ~33.2M backbone parameters. Uses 3D convolutions throughout for true volumetric processing.
183
+ Pretrained on Kinetics-400 (video action recognition).
184
+
185
+ Recommended for:
186
+ - C-scan ultrasonic inspection volumes
187
+ - 3D wavefield data cubes
188
+ - Medical imaging (CT/MRI)
189
+ - Moderate compute budgets
190
+
191
+ Args:
192
+ in_shape: (D, H, W) volume dimensions
193
+ out_size: Number of regression targets
194
+ pretrained: Use Kinetics-400 pretrained weights (default: True)
195
+ dropout_rate: Dropout rate in head (default: 0.3)
196
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
197
+ regression_hidden: Hidden units in regression head (default: 512)
198
+
199
+ Example:
200
+ >>> model = ResNet3D18(in_shape=(16, 112, 112), out_size=3)
201
+ >>> x = torch.randn(2, 1, 16, 112, 112)
202
+ >>> out = model(x) # (2, 3)
203
+ """
204
+
205
+ def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
206
+ super().__init__(
207
+ in_shape=in_shape,
208
+ out_size=out_size,
209
+ model_fn=r3d_18,
210
+ weights_class=R3D_18_Weights,
211
+ **kwargs,
212
+ )
213
+
214
+ def __repr__(self) -> str:
215
+ pt = "pretrained" if self.pretrained else "scratch"
216
+ return f"ResNet3D_18({pt}, in={self.in_shape}, out={self.out_size})"
217
+
218
+
219
+ @register_model("mc3_18")
220
+ class MC3_18(ResNet3DBase):
221
+ """
222
+ MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
223
+
224
+ ~11.5M backbone parameters. More efficient than pure 3D ResNet while maintaining
225
+ good spatiotemporal modeling. Uses 3D convolutions in early layers
226
+ and 2D convolutions in later layers.
227
+
228
+ Recommended for:
229
+ - When pure 3D is too expensive
230
+ - Volumes with limited temporal/depth extent
231
+ - Faster training with reasonable accuracy
232
+
233
+ Args:
234
+ in_shape: (D, H, W) volume dimensions
235
+ out_size: Number of regression targets
236
+ pretrained: Use Kinetics-400 pretrained weights (default: True)
237
+ dropout_rate: Dropout rate in head (default: 0.3)
238
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
239
+ regression_hidden: Hidden units in regression head (default: 512)
240
+
241
+ Example:
242
+ >>> model = MC3_18(in_shape=(16, 112, 112), out_size=3)
243
+ >>> x = torch.randn(2, 1, 16, 112, 112)
244
+ >>> out = model(x) # (2, 3)
245
+ """
246
+
247
+ def __init__(self, in_shape: tuple[int, int, int], out_size: int, **kwargs):
248
+ super().__init__(
249
+ in_shape=in_shape,
250
+ out_size=out_size,
251
+ model_fn=mc3_18,
252
+ weights_class=MC3_18_Weights,
253
+ **kwargs,
254
+ )
255
+
256
+ def __repr__(self) -> str:
257
+ pt = "pretrained" if self.pretrained else "scratch"
258
+ return f"MC3_18({pt}, in={self.in_shape}, out={self.out_size})"