wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1430 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/METADATA +93 -53
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/__init__.py
CHANGED
|
@@ -80,8 +80,24 @@ from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
|
80
80
|
# Optional timm-based models (imported conditionally)
|
|
81
81
|
try:
|
|
82
82
|
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
83
|
+
from .efficientvit import (
|
|
84
|
+
EfficientViTB0,
|
|
85
|
+
EfficientViTB1,
|
|
86
|
+
EfficientViTB2,
|
|
87
|
+
EfficientViTB3,
|
|
88
|
+
EfficientViTL1,
|
|
89
|
+
EfficientViTL2,
|
|
90
|
+
EfficientViTM0,
|
|
91
|
+
EfficientViTM1,
|
|
92
|
+
EfficientViTM2,
|
|
93
|
+
)
|
|
83
94
|
from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
|
|
84
95
|
from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
|
|
96
|
+
from .unireplknet import (
|
|
97
|
+
UniRepLKNetBaseLarge,
|
|
98
|
+
UniRepLKNetSmall,
|
|
99
|
+
UniRepLKNetTiny,
|
|
100
|
+
)
|
|
85
101
|
|
|
86
102
|
_HAS_TIMM_MODELS = True
|
|
87
103
|
except ImportError:
|
|
@@ -148,6 +164,15 @@ if _HAS_TIMM_MODELS:
|
|
|
148
164
|
[
|
|
149
165
|
"CaFormerS18",
|
|
150
166
|
"CaFormerS36",
|
|
167
|
+
"EfficientViTB0",
|
|
168
|
+
"EfficientViTB1",
|
|
169
|
+
"EfficientViTB2",
|
|
170
|
+
"EfficientViTB3",
|
|
171
|
+
"EfficientViTL1",
|
|
172
|
+
"EfficientViTL2",
|
|
173
|
+
"EfficientViTM0",
|
|
174
|
+
"EfficientViTM1",
|
|
175
|
+
"EfficientViTM2",
|
|
151
176
|
"FastViTS12",
|
|
152
177
|
"FastViTSA12",
|
|
153
178
|
"FastViTT8",
|
|
@@ -156,5 +181,8 @@ if _HAS_TIMM_MODELS:
|
|
|
156
181
|
"MaxViTSmall",
|
|
157
182
|
"MaxViTTiny",
|
|
158
183
|
"PoolFormerS12",
|
|
184
|
+
"UniRepLKNetBaseLarge",
|
|
185
|
+
"UniRepLKNetSmall",
|
|
186
|
+
"UniRepLKNetTiny",
|
|
159
187
|
]
|
|
160
188
|
)
|
|
@@ -236,3 +236,131 @@ def adapt_input_channels(
|
|
|
236
236
|
return new_conv
|
|
237
237
|
else:
|
|
238
238
|
raise NotImplementedError(f"Unsupported layer type: {type(conv_layer)}")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def adapt_first_conv_for_single_channel(
|
|
242
|
+
module: nn.Module,
|
|
243
|
+
conv_path: str,
|
|
244
|
+
pretrained: bool = True,
|
|
245
|
+
) -> None:
|
|
246
|
+
"""
|
|
247
|
+
Adapt the first convolutional layer of a pretrained model for single-channel input.
|
|
248
|
+
|
|
249
|
+
This is a convenience function for torchvision-style models where the path
|
|
250
|
+
to the first conv layer is known. It modifies the model in-place.
|
|
251
|
+
|
|
252
|
+
For pretrained models, the RGB weights are averaged to create grayscale weights,
|
|
253
|
+
which provides a reasonable initialization for single-channel inputs.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
module: The model or submodule containing the conv layer
|
|
257
|
+
conv_path: Dot-separated path to the conv layer (e.g., "conv1", "features.0.0")
|
|
258
|
+
pretrained: Whether to adapt pretrained weights by averaging RGB channels
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
>>> # For torchvision ResNet
|
|
262
|
+
>>> adapt_first_conv_for_single_channel(
|
|
263
|
+
... model.backbone, "conv1", pretrained=True
|
|
264
|
+
... )
|
|
265
|
+
>>> # For torchvision ConvNeXt
|
|
266
|
+
>>> adapt_first_conv_for_single_channel(
|
|
267
|
+
... model.backbone, "features.0.0", pretrained=True
|
|
268
|
+
... )
|
|
269
|
+
>>> # For torchvision DenseNet
|
|
270
|
+
>>> adapt_first_conv_for_single_channel(
|
|
271
|
+
... model.backbone, "features.conv0", pretrained=True
|
|
272
|
+
... )
|
|
273
|
+
"""
|
|
274
|
+
# Navigate to parent and get the conv layer
|
|
275
|
+
parts = conv_path.split(".")
|
|
276
|
+
parent = module
|
|
277
|
+
for part in parts[:-1]:
|
|
278
|
+
if part.isdigit():
|
|
279
|
+
parent = parent[int(part)]
|
|
280
|
+
else:
|
|
281
|
+
parent = getattr(parent, part)
|
|
282
|
+
|
|
283
|
+
# Get the final attribute name and the old conv
|
|
284
|
+
final_attr = parts[-1]
|
|
285
|
+
if final_attr.isdigit():
|
|
286
|
+
old_conv = parent[int(final_attr)]
|
|
287
|
+
else:
|
|
288
|
+
old_conv = getattr(parent, final_attr)
|
|
289
|
+
|
|
290
|
+
# Create and set the new conv
|
|
291
|
+
new_conv = adapt_input_channels(old_conv, new_in_channels=1, pretrained=pretrained)
|
|
292
|
+
|
|
293
|
+
if final_attr.isdigit():
|
|
294
|
+
parent[int(final_attr)] = new_conv
|
|
295
|
+
else:
|
|
296
|
+
setattr(parent, final_attr, new_conv)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def find_and_adapt_input_convs(
|
|
300
|
+
backbone: nn.Module,
|
|
301
|
+
pretrained: bool = True,
|
|
302
|
+
adapt_all: bool = False,
|
|
303
|
+
) -> int:
|
|
304
|
+
"""
|
|
305
|
+
Find and adapt Conv2d layers with 3 input channels for single-channel input.
|
|
306
|
+
|
|
307
|
+
This is useful for timm-style models where the exact path to the first
|
|
308
|
+
conv layer may vary or where multiple layers need adaptation.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
backbone: The backbone model to adapt
|
|
312
|
+
pretrained: Whether to adapt pretrained weights by averaging RGB channels
|
|
313
|
+
adapt_all: If True, adapt all Conv2d layers with 3 input channels.
|
|
314
|
+
If False (default), only adapt the first one found.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
Number of layers adapted
|
|
318
|
+
|
|
319
|
+
Example:
|
|
320
|
+
>>> # For timm models (adapt first conv only)
|
|
321
|
+
>>> count = find_and_adapt_input_convs(model.backbone, pretrained=True)
|
|
322
|
+
>>> # For models with multiple input convs (e.g., FastViT)
|
|
323
|
+
>>> count = find_and_adapt_input_convs(
|
|
324
|
+
... model.backbone, pretrained=True, adapt_all=True
|
|
325
|
+
... )
|
|
326
|
+
"""
|
|
327
|
+
adapted_count = 0
|
|
328
|
+
|
|
329
|
+
for name, module in backbone.named_modules():
|
|
330
|
+
if not hasattr(module, "in_channels") or module.in_channels != 3:
|
|
331
|
+
continue
|
|
332
|
+
|
|
333
|
+
# Check if this is a wrapper with inner .conv attribute
|
|
334
|
+
if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
|
|
335
|
+
old_conv = module.conv
|
|
336
|
+
module.conv = adapt_input_channels(
|
|
337
|
+
old_conv, new_in_channels=1, pretrained=pretrained
|
|
338
|
+
)
|
|
339
|
+
adapted_count += 1
|
|
340
|
+
|
|
341
|
+
elif isinstance(module, nn.Conv2d):
|
|
342
|
+
# Direct Conv2d - need to replace it in parent
|
|
343
|
+
parts = name.split(".")
|
|
344
|
+
parent = backbone
|
|
345
|
+
for part in parts[:-1]:
|
|
346
|
+
if part.isdigit():
|
|
347
|
+
parent = parent[int(part)]
|
|
348
|
+
else:
|
|
349
|
+
parent = getattr(parent, part)
|
|
350
|
+
|
|
351
|
+
child_name = parts[-1]
|
|
352
|
+
new_conv = adapt_input_channels(
|
|
353
|
+
module, new_in_channels=1, pretrained=pretrained
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
if child_name.isdigit():
|
|
357
|
+
parent[int(child_name)] = new_conv
|
|
358
|
+
else:
|
|
359
|
+
setattr(parent, child_name, new_conv)
|
|
360
|
+
|
|
361
|
+
adapted_count += 1
|
|
362
|
+
|
|
363
|
+
if not adapt_all and adapted_count > 0:
|
|
364
|
+
break
|
|
365
|
+
|
|
366
|
+
return adapted_count
|
wavedl/models/base.py
CHANGED
|
@@ -15,6 +15,54 @@ import torch
|
|
|
15
15
|
import torch.nn as nn
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# TYPE ALIASES
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
# Spatial shape type aliases for model input dimensions
|
|
23
|
+
SpatialShape1D = tuple[int]
|
|
24
|
+
SpatialShape2D = tuple[int, int]
|
|
25
|
+
SpatialShape3D = tuple[int, int, int]
|
|
26
|
+
SpatialShape = SpatialShape1D | SpatialShape2D | SpatialShape3D
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# UTILITY FUNCTIONS
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def compute_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
|
|
35
|
+
"""
|
|
36
|
+
Compute valid num_groups for GroupNorm that divides num_channels evenly.
|
|
37
|
+
|
|
38
|
+
GroupNorm requires num_channels to be divisible by num_groups. This function
|
|
39
|
+
finds the largest valid divisor up to preferred_groups.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
num_channels: Number of channels to normalize (must be positive)
|
|
43
|
+
preferred_groups: Preferred number of groups (default: 32)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Valid num_groups that satisfies num_channels % num_groups == 0
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> compute_num_groups(64) # Returns 32
|
|
50
|
+
>>> compute_num_groups(48) # Returns 16 (48 % 32 != 0)
|
|
51
|
+
>>> compute_num_groups(7) # Returns 1 (prime number)
|
|
52
|
+
"""
|
|
53
|
+
# Try preferred groups first, then common divisors
|
|
54
|
+
for groups in [preferred_groups, 16, 8, 4, 2, 1]:
|
|
55
|
+
if groups <= num_channels and num_channels % groups == 0:
|
|
56
|
+
return groups
|
|
57
|
+
|
|
58
|
+
# Fallback: find any valid divisor (always returns at least 1)
|
|
59
|
+
for groups in range(min(32, num_channels), 0, -1):
|
|
60
|
+
if num_channels % groups == 0:
|
|
61
|
+
return groups
|
|
62
|
+
|
|
63
|
+
return 1 # Always valid
|
|
64
|
+
|
|
65
|
+
|
|
18
66
|
class BaseModel(nn.Module, ABC):
|
|
19
67
|
"""
|
|
20
68
|
Abstract base class for all regression models.
|
wavedl/models/caformer.py
CHANGED
|
@@ -33,7 +33,7 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
33
33
|
import torch
|
|
34
34
|
import torch.nn as nn
|
|
35
35
|
|
|
36
|
-
from wavedl.models.
|
|
36
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
37
37
|
from wavedl.models.base import BaseModel
|
|
38
38
|
from wavedl.models.registry import register_model
|
|
39
39
|
|
wavedl/models/cnn.py
CHANGED
|
@@ -24,14 +24,10 @@ from typing import Any
|
|
|
24
24
|
import torch
|
|
25
25
|
import torch.nn as nn
|
|
26
26
|
|
|
27
|
-
from wavedl.models.base import BaseModel
|
|
27
|
+
from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
|
|
28
28
|
from wavedl.models.registry import register_model
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
# Type alias for spatial shapes
|
|
32
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
33
|
-
|
|
34
|
-
|
|
35
31
|
def _get_conv_layers(
|
|
36
32
|
dim: int,
|
|
37
33
|
) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
|
|
@@ -163,27 +159,6 @@ class CNN(BaseModel):
|
|
|
163
159
|
nn.Linear(64, out_size),
|
|
164
160
|
)
|
|
165
161
|
|
|
166
|
-
@staticmethod
|
|
167
|
-
def _compute_num_groups(num_channels: int, target_groups: int = 4) -> int:
|
|
168
|
-
"""
|
|
169
|
-
Compute valid num_groups for GroupNorm that divides num_channels.
|
|
170
|
-
|
|
171
|
-
Finds the largest divisor of num_channels that is <= target_groups,
|
|
172
|
-
or falls back to 1 if no suitable divisor exists.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
num_channels: Number of channels (must be positive)
|
|
176
|
-
target_groups: Desired number of groups (default: 4)
|
|
177
|
-
|
|
178
|
-
Returns:
|
|
179
|
-
Valid num_groups that satisfies num_channels % num_groups == 0
|
|
180
|
-
"""
|
|
181
|
-
# Try target_groups down to 1, return first valid divisor
|
|
182
|
-
for g in range(min(target_groups, num_channels), 0, -1):
|
|
183
|
-
if num_channels % g == 0:
|
|
184
|
-
return g
|
|
185
|
-
return 1 # Fallback (always valid)
|
|
186
|
-
|
|
187
162
|
def _make_conv_block(
|
|
188
163
|
self, in_channels: int, out_channels: int, dropout: float = 0.0
|
|
189
164
|
) -> nn.Sequential:
|
|
@@ -198,7 +173,7 @@ class CNN(BaseModel):
|
|
|
198
173
|
Returns:
|
|
199
174
|
Sequential block: Conv → GroupNorm → LeakyReLU → MaxPool [→ Dropout]
|
|
200
175
|
"""
|
|
201
|
-
num_groups =
|
|
176
|
+
num_groups = compute_num_groups(out_channels, preferred_groups=4)
|
|
202
177
|
|
|
203
178
|
layers = [
|
|
204
179
|
self._Conv(in_channels, out_channels, kernel_size=3, padding=1),
|
wavedl/models/convnext.py
CHANGED
|
@@ -28,14 +28,10 @@ import torch
|
|
|
28
28
|
import torch.nn as nn
|
|
29
29
|
import torch.nn.functional as F
|
|
30
30
|
|
|
31
|
-
from wavedl.models.base import BaseModel
|
|
31
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
32
32
|
from wavedl.models.registry import register_model
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
# Type alias for spatial shapes
|
|
36
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
37
|
-
|
|
38
|
-
|
|
39
35
|
def _get_conv_layer(dim: int) -> type[nn.Module]:
|
|
40
36
|
"""Get dimension-appropriate Conv class."""
|
|
41
37
|
if dim == 1:
|
|
@@ -468,20 +464,11 @@ class ConvNeXtTinyPretrained(BaseModel):
|
|
|
468
464
|
)
|
|
469
465
|
|
|
470
466
|
# Modify first conv for single-channel input
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
kernel_size=old_conv.kernel_size,
|
|
476
|
-
stride=old_conv.stride,
|
|
477
|
-
padding=old_conv.padding,
|
|
478
|
-
bias=old_conv.bias is not None,
|
|
467
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
468
|
+
|
|
469
|
+
adapt_first_conv_for_single_channel(
|
|
470
|
+
self.backbone, "features.0.0", pretrained=pretrained
|
|
479
471
|
)
|
|
480
|
-
if pretrained:
|
|
481
|
-
with torch.no_grad():
|
|
482
|
-
self.backbone.features[0][0].weight = nn.Parameter(
|
|
483
|
-
old_conv.weight.mean(dim=1, keepdim=True)
|
|
484
|
-
)
|
|
485
472
|
|
|
486
473
|
if freeze_backbone:
|
|
487
474
|
self._freeze_backbone()
|
wavedl/models/convnext_v2.py
CHANGED
|
@@ -31,20 +31,17 @@ from typing import Any
|
|
|
31
31
|
import torch
|
|
32
32
|
import torch.nn as nn
|
|
33
33
|
|
|
34
|
-
from wavedl.models.
|
|
34
|
+
from wavedl.models._pretrained_utils import (
|
|
35
35
|
LayerNormNd,
|
|
36
36
|
build_regression_head,
|
|
37
37
|
get_conv_layer,
|
|
38
38
|
get_grn_layer,
|
|
39
39
|
get_pool_layer,
|
|
40
40
|
)
|
|
41
|
-
from wavedl.models.base import BaseModel
|
|
41
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
42
42
|
from wavedl.models.registry import register_model
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
# Type alias for spatial shapes
|
|
46
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
47
|
-
|
|
48
45
|
__all__ = [
|
|
49
46
|
"ConvNeXtV2Base",
|
|
50
47
|
"ConvNeXtV2BaseLarge",
|
|
@@ -469,24 +466,11 @@ class ConvNeXtV2TinyPretrained(BaseModel):
|
|
|
469
466
|
|
|
470
467
|
def _adapt_input_channels(self):
|
|
471
468
|
"""Adapt first conv layer for single-channel input."""
|
|
472
|
-
|
|
473
|
-
new_conv = nn.Conv2d(
|
|
474
|
-
1,
|
|
475
|
-
old_conv.out_channels,
|
|
476
|
-
kernel_size=old_conv.kernel_size,
|
|
477
|
-
stride=old_conv.stride,
|
|
478
|
-
padding=old_conv.padding,
|
|
479
|
-
bias=old_conv.bias is not None,
|
|
480
|
-
)
|
|
469
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
481
470
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
486
|
-
if old_conv.bias is not None:
|
|
487
|
-
new_conv.bias.copy_(old_conv.bias)
|
|
488
|
-
|
|
489
|
-
self.backbone.features[0][0] = new_conv
|
|
471
|
+
adapt_first_conv_for_single_channel(
|
|
472
|
+
self.backbone, "features.0.0", pretrained=self.pretrained
|
|
473
|
+
)
|
|
490
474
|
|
|
491
475
|
def _freeze_backbone(self):
|
|
492
476
|
"""Freeze all backbone parameters except classifier."""
|
wavedl/models/densenet.py
CHANGED
|
@@ -26,14 +26,10 @@ from typing import Any
|
|
|
26
26
|
import torch
|
|
27
27
|
import torch.nn as nn
|
|
28
28
|
|
|
29
|
-
from wavedl.models.base import BaseModel
|
|
29
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
30
30
|
from wavedl.models.registry import register_model
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
# Type alias for spatial shapes
|
|
34
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
35
|
-
|
|
36
|
-
|
|
37
33
|
def _get_layers(dim: int):
|
|
38
34
|
"""Get dimension-appropriate layer classes."""
|
|
39
35
|
if dim == 1:
|
|
@@ -374,20 +370,11 @@ class DenseNet121Pretrained(BaseModel):
|
|
|
374
370
|
)
|
|
375
371
|
|
|
376
372
|
# Modify first conv for single-channel input
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
kernel_size=old_conv.kernel_size,
|
|
382
|
-
stride=old_conv.stride,
|
|
383
|
-
padding=old_conv.padding,
|
|
384
|
-
bias=False,
|
|
373
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
374
|
+
|
|
375
|
+
adapt_first_conv_for_single_channel(
|
|
376
|
+
self.backbone, "features.conv0", pretrained=pretrained
|
|
385
377
|
)
|
|
386
|
-
if pretrained:
|
|
387
|
-
with torch.no_grad():
|
|
388
|
-
self.backbone.features.conv0.weight = nn.Parameter(
|
|
389
|
-
old_conv.weight.mean(dim=1, keepdim=True)
|
|
390
|
-
)
|
|
391
378
|
|
|
392
379
|
if freeze_backbone:
|
|
393
380
|
self._freeze_backbone()
|