wavedl 1.6.2__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +115 -9
- wavedl/models/__init__.py +22 -0
- wavedl/models/_pretrained_utils.py +72 -0
- wavedl/models/_template.py +7 -6
- wavedl/models/cnn.py +20 -0
- wavedl/models/convnext.py +3 -70
- wavedl/models/convnext_v2.py +1 -18
- wavedl/models/mamba.py +126 -38
- wavedl/models/resnet3d.py +23 -5
- wavedl/models/unireplknet.py +1 -18
- wavedl/models/vit.py +18 -8
- wavedl/test.py +13 -23
- wavedl/train.py +494 -28
- wavedl/utils/__init__.py +49 -9
- wavedl/utils/config.py +6 -8
- wavedl/utils/cross_validation.py +17 -4
- wavedl/utils/data.py +176 -180
- wavedl/utils/metrics.py +26 -5
- wavedl/utils/schedulers.py +2 -2
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/METADATA +37 -18
- wavedl-1.7.0.dist-info/RECORD +46 -0
- wavedl-1.6.2.dist-info/RECORD +0 -46
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/LICENSE +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/WHEEL +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.7.0.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpo.py
CHANGED
|
@@ -10,12 +10,28 @@ Usage:
|
|
|
10
10
|
# Quick search (fewer parameters)
|
|
11
11
|
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
12
12
|
|
|
13
|
+
# Medium search (balanced)
|
|
14
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --medium
|
|
15
|
+
|
|
13
16
|
# Full search with specific models
|
|
14
17
|
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
|
|
15
18
|
|
|
16
19
|
# Parallel trials on multiple GPUs
|
|
17
20
|
wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
|
|
18
21
|
|
|
22
|
+
# In-process mode (enables pruning, faster, single-GPU)
|
|
23
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
|
|
24
|
+
|
|
25
|
+
Execution Modes:
|
|
26
|
+
--inprocess: Runs trials in the same Python process. Enables pruning
|
|
27
|
+
(MedianPruner) for early stopping of unpromising trials.
|
|
28
|
+
Faster due to no subprocess overhead, but trials share
|
|
29
|
+
GPU memory (no isolation between trials).
|
|
30
|
+
|
|
31
|
+
Default (subprocess): Launches each trial as a separate process.
|
|
32
|
+
Provides GPU memory isolation but prevents pruning
|
|
33
|
+
(subprocess can't report intermediate results).
|
|
34
|
+
|
|
19
35
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
20
36
|
"""
|
|
21
37
|
|
|
@@ -41,10 +57,12 @@ except ImportError:
|
|
|
41
57
|
|
|
42
58
|
DEFAULT_MODELS = ["cnn", "resnet18", "resnet34"]
|
|
43
59
|
QUICK_MODELS = ["cnn"]
|
|
60
|
+
MEDIUM_MODELS = ["cnn", "resnet18"]
|
|
44
61
|
|
|
45
62
|
# All 6 optimizers
|
|
46
63
|
DEFAULT_OPTIMIZERS = ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
|
|
47
64
|
QUICK_OPTIMIZERS = ["adamw"]
|
|
65
|
+
MEDIUM_OPTIMIZERS = ["adamw", "adam", "sgd"]
|
|
48
66
|
|
|
49
67
|
# All 8 schedulers
|
|
50
68
|
DEFAULT_SCHEDULERS = [
|
|
@@ -58,10 +76,12 @@ DEFAULT_SCHEDULERS = [
|
|
|
58
76
|
"linear_warmup",
|
|
59
77
|
]
|
|
60
78
|
QUICK_SCHEDULERS = ["plateau"]
|
|
79
|
+
MEDIUM_SCHEDULERS = ["plateau", "cosine", "onecycle"]
|
|
61
80
|
|
|
62
81
|
# All 6 losses
|
|
63
82
|
DEFAULT_LOSSES = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
|
|
64
83
|
QUICK_LOSSES = ["mse"]
|
|
84
|
+
MEDIUM_LOSSES = ["mse", "mae", "huber"]
|
|
65
85
|
|
|
66
86
|
|
|
67
87
|
# =============================================================================
|
|
@@ -70,16 +90,28 @@ QUICK_LOSSES = ["mse"]
|
|
|
70
90
|
|
|
71
91
|
|
|
72
92
|
def create_objective(args):
|
|
73
|
-
"""Create Optuna objective function with configurable search space.
|
|
93
|
+
"""Create Optuna objective function with configurable search space.
|
|
94
|
+
|
|
95
|
+
Supports two execution modes:
|
|
96
|
+
- Subprocess (default): Launches wavedl.train via subprocess. Provides GPU
|
|
97
|
+
memory isolation but prevents pruning (MedianPruner has no effect).
|
|
98
|
+
- In-process (--inprocess): Calls train_single_trial() directly. Enables
|
|
99
|
+
pruning and reduces overhead, but trials share GPU memory.
|
|
100
|
+
"""
|
|
74
101
|
|
|
75
102
|
def objective(trial):
|
|
76
|
-
# Select search space based on mode
|
|
103
|
+
# Select search space based on mode (quick < medium < full)
|
|
77
104
|
# CLI arguments always take precedence over defaults
|
|
78
105
|
if args.quick:
|
|
79
106
|
models = args.models or QUICK_MODELS
|
|
80
107
|
optimizers = args.optimizers or QUICK_OPTIMIZERS
|
|
81
108
|
schedulers = args.schedulers or QUICK_SCHEDULERS
|
|
82
109
|
losses = args.losses or QUICK_LOSSES
|
|
110
|
+
elif args.medium:
|
|
111
|
+
models = args.models or MEDIUM_MODELS
|
|
112
|
+
optimizers = args.optimizers or MEDIUM_OPTIMIZERS
|
|
113
|
+
schedulers = args.schedulers or MEDIUM_SCHEDULERS
|
|
114
|
+
losses = args.losses or MEDIUM_LOSSES
|
|
83
115
|
else:
|
|
84
116
|
models = args.models or DEFAULT_MODELS
|
|
85
117
|
optimizers = args.optimizers or DEFAULT_OPTIMIZERS
|
|
@@ -101,13 +133,59 @@ def create_objective(args):
|
|
|
101
133
|
if loss == "huber":
|
|
102
134
|
huber_delta = trial.suggest_float("huber_delta", 0.1, 2.0)
|
|
103
135
|
else:
|
|
104
|
-
huber_delta =
|
|
136
|
+
huber_delta = 1.0 # default
|
|
105
137
|
|
|
106
138
|
if optimizer == "sgd":
|
|
107
139
|
momentum = trial.suggest_float("momentum", 0.8, 0.99)
|
|
108
140
|
else:
|
|
109
|
-
momentum =
|
|
141
|
+
momentum = 0.9 # default
|
|
142
|
+
|
|
143
|
+
# ==================================================================
|
|
144
|
+
# IN-PROCESS MODE: Direct function call with pruning support
|
|
145
|
+
# ==================================================================
|
|
146
|
+
if args.inprocess:
|
|
147
|
+
from wavedl.train import train_single_trial
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
result = train_single_trial(
|
|
151
|
+
data_path=args.data_path,
|
|
152
|
+
model_name=model,
|
|
153
|
+
lr=lr,
|
|
154
|
+
batch_size=batch_size,
|
|
155
|
+
epochs=args.max_epochs,
|
|
156
|
+
patience=patience,
|
|
157
|
+
optimizer_name=optimizer,
|
|
158
|
+
scheduler_name=scheduler,
|
|
159
|
+
loss_name=loss,
|
|
160
|
+
weight_decay=weight_decay,
|
|
161
|
+
seed=args.seed,
|
|
162
|
+
huber_delta=huber_delta,
|
|
163
|
+
momentum=momentum,
|
|
164
|
+
trial=trial, # Enable pruning via trial.report/should_prune
|
|
165
|
+
verbose=False,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if result["pruned"]:
|
|
169
|
+
print(
|
|
170
|
+
f"Trial {trial.number}: Pruned at epoch {result['epochs_trained']}"
|
|
171
|
+
)
|
|
172
|
+
raise optuna.TrialPruned()
|
|
173
|
+
|
|
174
|
+
val_loss = result["best_val_loss"]
|
|
175
|
+
print(
|
|
176
|
+
f"Trial {trial.number}: val_loss={val_loss:.6f} ({result['epochs_trained']} epochs)"
|
|
177
|
+
)
|
|
178
|
+
return val_loss
|
|
110
179
|
|
|
180
|
+
except optuna.TrialPruned:
|
|
181
|
+
raise # Re-raise for Optuna to handle
|
|
182
|
+
except Exception as e:
|
|
183
|
+
print(f"Trial {trial.number}: Error - {e}")
|
|
184
|
+
return float("inf")
|
|
185
|
+
|
|
186
|
+
# ==================================================================
|
|
187
|
+
# SUBPROCESS MODE (default): GPU memory isolation, no pruning
|
|
188
|
+
# ==================================================================
|
|
111
189
|
# Build command
|
|
112
190
|
cmd = [
|
|
113
191
|
sys.executable,
|
|
@@ -138,9 +216,9 @@ def create_objective(args):
|
|
|
138
216
|
]
|
|
139
217
|
|
|
140
218
|
# Add conditional args
|
|
141
|
-
if
|
|
219
|
+
if loss == "huber":
|
|
142
220
|
cmd.extend(["--huber_delta", str(huber_delta)])
|
|
143
|
-
if
|
|
221
|
+
if optimizer == "sgd":
|
|
144
222
|
cmd.extend(["--momentum", str(momentum)])
|
|
145
223
|
|
|
146
224
|
# Use temporary directory for trial output
|
|
@@ -285,7 +363,17 @@ Examples:
|
|
|
285
363
|
parser.add_argument(
|
|
286
364
|
"--quick",
|
|
287
365
|
action="store_true",
|
|
288
|
-
help="Quick mode: search fewer parameters",
|
|
366
|
+
help="Quick mode: search fewer parameters (fastest, least thorough)",
|
|
367
|
+
)
|
|
368
|
+
parser.add_argument(
|
|
369
|
+
"--medium",
|
|
370
|
+
action="store_true",
|
|
371
|
+
help="Medium mode: balanced parameter search (between --quick and full)",
|
|
372
|
+
)
|
|
373
|
+
parser.add_argument(
|
|
374
|
+
"--inprocess",
|
|
375
|
+
action="store_true",
|
|
376
|
+
help="Run trials in-process (enables pruning, faster, but no GPU memory isolation)",
|
|
289
377
|
)
|
|
290
378
|
parser.add_argument(
|
|
291
379
|
"--timeout",
|
|
@@ -384,14 +472,32 @@ Examples:
|
|
|
384
472
|
print("=" * 60)
|
|
385
473
|
print(f"Data: {args.data_path}")
|
|
386
474
|
print(f"Trials: {args.n_trials}")
|
|
387
|
-
|
|
475
|
+
# Determine mode name for display
|
|
476
|
+
if args.quick:
|
|
477
|
+
mode_name = "Quick"
|
|
478
|
+
elif args.medium:
|
|
479
|
+
mode_name = "Medium"
|
|
480
|
+
else:
|
|
481
|
+
mode_name = "Full"
|
|
482
|
+
|
|
483
|
+
print(
|
|
484
|
+
f"Mode: {mode_name}"
|
|
485
|
+
+ (" (in-process, pruning enabled)" if args.inprocess else " (subprocess)")
|
|
486
|
+
)
|
|
388
487
|
print(f"Parallel jobs: {args.n_jobs}")
|
|
389
488
|
print("=" * 60)
|
|
390
489
|
|
|
490
|
+
# Use MedianPruner only for in-process mode (subprocess trials can't report)
|
|
491
|
+
if args.inprocess:
|
|
492
|
+
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10)
|
|
493
|
+
else:
|
|
494
|
+
# NopPruner for subprocess mode - pruning has no effect there
|
|
495
|
+
pruner = optuna.pruners.NopPruner()
|
|
496
|
+
|
|
391
497
|
study = optuna.create_study(
|
|
392
498
|
study_name=args.study_name,
|
|
393
499
|
direction="minimize",
|
|
394
|
-
pruner=
|
|
500
|
+
pruner=pruner,
|
|
395
501
|
)
|
|
396
502
|
|
|
397
503
|
# Run optimization
|
wavedl/models/__init__.py
CHANGED
|
@@ -77,6 +77,15 @@ from .unet import UNetRegression
|
|
|
77
77
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
78
78
|
|
|
79
79
|
|
|
80
|
+
# Optional RATENet (unpublished, may be gitignored)
|
|
81
|
+
try:
|
|
82
|
+
from .ratenet import RATENet, RATENetLite, RATENetTiny, RATENetV2
|
|
83
|
+
|
|
84
|
+
_HAS_RATENET = True
|
|
85
|
+
except ImportError:
|
|
86
|
+
_HAS_RATENET = False
|
|
87
|
+
|
|
88
|
+
|
|
80
89
|
# Optional timm-based models (imported conditionally)
|
|
81
90
|
try:
|
|
82
91
|
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
@@ -111,6 +120,7 @@ __all__ = [
|
|
|
111
120
|
"MC3_18",
|
|
112
121
|
"MODEL_REGISTRY",
|
|
113
122
|
"TCN",
|
|
123
|
+
# Classes (uppercase first, alphabetically)
|
|
114
124
|
"BaseModel",
|
|
115
125
|
"ConvNeXtBase_",
|
|
116
126
|
"ConvNeXtSmall",
|
|
@@ -152,6 +162,7 @@ __all__ = [
|
|
|
152
162
|
"VimBase",
|
|
153
163
|
"VimSmall",
|
|
154
164
|
"VimTiny",
|
|
165
|
+
# Functions (lowercase, alphabetically)
|
|
155
166
|
"build_model",
|
|
156
167
|
"get_model",
|
|
157
168
|
"list_models",
|
|
@@ -186,3 +197,14 @@ if _HAS_TIMM_MODELS:
|
|
|
186
197
|
"UniRepLKNetTiny",
|
|
187
198
|
]
|
|
188
199
|
)
|
|
200
|
+
|
|
201
|
+
# Add RATENet models to __all__ if available (unpublished)
|
|
202
|
+
if _HAS_RATENET:
|
|
203
|
+
__all__.extend(
|
|
204
|
+
[
|
|
205
|
+
"RATENet",
|
|
206
|
+
"RATENetLite",
|
|
207
|
+
"RATENetTiny",
|
|
208
|
+
"RATENetV2",
|
|
209
|
+
]
|
|
210
|
+
)
|
|
@@ -166,6 +166,78 @@ class LayerNormNd(nn.Module):
|
|
|
166
166
|
return x
|
|
167
167
|
|
|
168
168
|
|
|
169
|
+
# =============================================================================
|
|
170
|
+
# STOCHASTIC DEPTH (DropPath)
|
|
171
|
+
# =============================================================================
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class DropPath(nn.Module):
|
|
175
|
+
"""
|
|
176
|
+
Stochastic Depth (drop path) regularization for residual networks.
|
|
177
|
+
|
|
178
|
+
Randomly drops entire residual branches during training. Used in modern
|
|
179
|
+
architectures like ConvNeXt, Swin Transformer, UniRepLKNet.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
drop_prob: Probability of dropping the path (default: 0.0)
|
|
183
|
+
|
|
184
|
+
Reference:
|
|
185
|
+
Huang, G., et al. (2016). Deep Networks with Stochastic Depth.
|
|
186
|
+
https://arxiv.org/abs/1603.09382
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(self, drop_prob: float = 0.0):
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.drop_prob = drop_prob
|
|
192
|
+
|
|
193
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
194
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
195
|
+
return x
|
|
196
|
+
|
|
197
|
+
keep_prob = 1 - self.drop_prob
|
|
198
|
+
# Shape: (batch_size, 1, 1, ...) for broadcasting
|
|
199
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
200
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
201
|
+
random_tensor.floor_() # Binarize
|
|
202
|
+
return x.div(keep_prob) * random_tensor
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
# =============================================================================
|
|
206
|
+
# BACKBONE FREEZING UTILITIES
|
|
207
|
+
# =============================================================================
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def freeze_backbone(
|
|
211
|
+
model: nn.Module,
|
|
212
|
+
exclude_patterns: list[str] | None = None,
|
|
213
|
+
) -> int:
|
|
214
|
+
"""
|
|
215
|
+
Freeze backbone parameters, keeping specified layers trainable.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
model: The model whose parameters to freeze
|
|
219
|
+
exclude_patterns: List of patterns to exclude from freezing.
|
|
220
|
+
Parameters with names containing any of these patterns stay trainable.
|
|
221
|
+
Default: ["classifier", "head", "fc"]
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Number of parameters frozen
|
|
225
|
+
|
|
226
|
+
Example:
|
|
227
|
+
>>> freeze_backbone(model.backbone, exclude_patterns=["fc", "classifier"])
|
|
228
|
+
"""
|
|
229
|
+
if exclude_patterns is None:
|
|
230
|
+
exclude_patterns = ["classifier", "head", "fc"]
|
|
231
|
+
|
|
232
|
+
frozen_count = 0
|
|
233
|
+
for name, param in model.named_parameters():
|
|
234
|
+
if not any(pattern in name for pattern in exclude_patterns):
|
|
235
|
+
param.requires_grad = False
|
|
236
|
+
frozen_count += param.numel()
|
|
237
|
+
|
|
238
|
+
return frozen_count
|
|
239
|
+
|
|
240
|
+
|
|
169
241
|
# =============================================================================
|
|
170
242
|
# REGRESSION HEAD BUILDERS
|
|
171
243
|
# =============================================================================
|
wavedl/models/_template.py
CHANGED
|
@@ -31,22 +31,23 @@ from wavedl.models.base import BaseModel
|
|
|
31
31
|
# @register_model("my_model")
|
|
32
32
|
class TemplateModel(BaseModel):
|
|
33
33
|
"""
|
|
34
|
-
Template Model Architecture.
|
|
34
|
+
Template Model Architecture (2D only).
|
|
35
35
|
|
|
36
36
|
Replace this docstring with your model description.
|
|
37
37
|
The first line will appear in --list_models output.
|
|
38
38
|
|
|
39
|
+
NOTE: This template is hardcoded for 2D inputs using Conv2d/MaxPool2d.
|
|
40
|
+
For 1D/3D support, use dimension-agnostic layer factories from
|
|
41
|
+
_pretrained_utils.py (get_conv_layer, get_pool_layer, get_norm_layer).
|
|
42
|
+
|
|
39
43
|
Args:
|
|
40
|
-
in_shape: Input spatial dimensions (
|
|
41
|
-
- 1D: (L,) for signals
|
|
42
|
-
- 2D: (H, W) for images
|
|
43
|
-
- 3D: (D, H, W) for volumes
|
|
44
|
+
in_shape: Input spatial dimensions as (H, W) for 2D images
|
|
44
45
|
out_size: Number of regression targets (auto-detected from data)
|
|
45
46
|
hidden_dim: Size of hidden layers (default: 256)
|
|
46
47
|
dropout: Dropout rate (default: 0.1)
|
|
47
48
|
|
|
48
49
|
Input Shape:
|
|
49
|
-
(B, 1,
|
|
50
|
+
(B, 1, H, W) - 2D grayscale images
|
|
50
51
|
|
|
51
52
|
Output Shape:
|
|
52
53
|
(B, out_size) - Regression predictions
|
wavedl/models/cnn.py
CHANGED
|
@@ -159,6 +159,26 @@ class CNN(BaseModel):
|
|
|
159
159
|
nn.Linear(64, out_size),
|
|
160
160
|
)
|
|
161
161
|
|
|
162
|
+
# Initialize weights
|
|
163
|
+
self._init_weights()
|
|
164
|
+
|
|
165
|
+
def _init_weights(self):
|
|
166
|
+
"""Initialize weights with Kaiming for conv, Xavier for linear."""
|
|
167
|
+
for m in self.modules():
|
|
168
|
+
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
|
169
|
+
nn.init.kaiming_normal_(
|
|
170
|
+
m.weight, mode="fan_out", nonlinearity="leaky_relu"
|
|
171
|
+
)
|
|
172
|
+
if m.bias is not None:
|
|
173
|
+
nn.init.zeros_(m.bias)
|
|
174
|
+
elif isinstance(m, nn.Linear):
|
|
175
|
+
nn.init.xavier_uniform_(m.weight)
|
|
176
|
+
if m.bias is not None:
|
|
177
|
+
nn.init.zeros_(m.bias)
|
|
178
|
+
elif isinstance(m, (nn.GroupNorm, nn.LayerNorm)):
|
|
179
|
+
nn.init.ones_(m.weight)
|
|
180
|
+
nn.init.zeros_(m.bias)
|
|
181
|
+
|
|
162
182
|
def _make_conv_block(
|
|
163
183
|
self, in_channels: int, out_channels: int, dropout: float = 0.0
|
|
164
184
|
) -> nn.Sequential:
|
wavedl/models/convnext.py
CHANGED
|
@@ -26,79 +26,12 @@ from typing import Any
|
|
|
26
26
|
|
|
27
27
|
import torch
|
|
28
28
|
import torch.nn as nn
|
|
29
|
-
import torch.nn.functional as F
|
|
30
29
|
|
|
30
|
+
from wavedl.models._pretrained_utils import LayerNormNd, get_conv_layer
|
|
31
31
|
from wavedl.models.base import BaseModel, SpatialShape
|
|
32
32
|
from wavedl.models.registry import register_model
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def _get_conv_layer(dim: int) -> type[nn.Module]:
|
|
36
|
-
"""Get dimension-appropriate Conv class."""
|
|
37
|
-
if dim == 1:
|
|
38
|
-
return nn.Conv1d
|
|
39
|
-
elif dim == 2:
|
|
40
|
-
return nn.Conv2d
|
|
41
|
-
elif dim == 3:
|
|
42
|
-
return nn.Conv3d
|
|
43
|
-
else:
|
|
44
|
-
raise ValueError(f"Unsupported dimensionality: {dim}D")
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class LayerNormNd(nn.Module):
|
|
48
|
-
"""
|
|
49
|
-
LayerNorm for N-dimensional tensors (channels-first format).
|
|
50
|
-
|
|
51
|
-
Implements channels-last LayerNorm as used in the original ConvNeXt paper.
|
|
52
|
-
Permutes data to channels-last, applies LayerNorm per-channel over spatial
|
|
53
|
-
dimensions, and permutes back to channels-first format.
|
|
54
|
-
|
|
55
|
-
This matches PyTorch's nn.LayerNorm behavior when applied to the channel
|
|
56
|
-
dimension, providing stable gradients for deep ConvNeXt networks.
|
|
57
|
-
|
|
58
|
-
References:
|
|
59
|
-
Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
|
|
60
|
-
https://github.com/facebookresearch/ConvNeXt
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
def __init__(self, num_channels: int, dim: int, eps: float = 1e-6):
|
|
64
|
-
super().__init__()
|
|
65
|
-
self.dim = dim
|
|
66
|
-
self.num_channels = num_channels
|
|
67
|
-
self.weight = nn.Parameter(torch.ones(num_channels))
|
|
68
|
-
self.bias = nn.Parameter(torch.zeros(num_channels))
|
|
69
|
-
self.eps = eps
|
|
70
|
-
|
|
71
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
72
|
-
"""
|
|
73
|
-
Apply LayerNorm in channels-last format.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
x: Input tensor in channels-first format
|
|
77
|
-
- 1D: (B, C, L)
|
|
78
|
-
- 2D: (B, C, H, W)
|
|
79
|
-
- 3D: (B, C, D, H, W)
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
Normalized tensor in same format as input
|
|
83
|
-
"""
|
|
84
|
-
if self.dim == 1:
|
|
85
|
-
# (B, C, L) -> (B, L, C) -> LayerNorm -> (B, C, L)
|
|
86
|
-
x = x.permute(0, 2, 1)
|
|
87
|
-
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
88
|
-
x = x.permute(0, 2, 1)
|
|
89
|
-
elif self.dim == 2:
|
|
90
|
-
# (B, C, H, W) -> (B, H, W, C) -> LayerNorm -> (B, C, H, W)
|
|
91
|
-
x = x.permute(0, 2, 3, 1)
|
|
92
|
-
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
93
|
-
x = x.permute(0, 3, 1, 2)
|
|
94
|
-
else:
|
|
95
|
-
# (B, C, D, H, W) -> (B, D, H, W, C) -> LayerNorm -> (B, C, D, H, W)
|
|
96
|
-
x = x.permute(0, 2, 3, 4, 1)
|
|
97
|
-
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
98
|
-
x = x.permute(0, 4, 1, 2, 3)
|
|
99
|
-
return x
|
|
100
|
-
|
|
101
|
-
|
|
102
35
|
class ConvNeXtBlock(nn.Module):
|
|
103
36
|
"""
|
|
104
37
|
ConvNeXt block matching the official Facebook implementation.
|
|
@@ -129,7 +62,7 @@ class ConvNeXtBlock(nn.Module):
|
|
|
129
62
|
):
|
|
130
63
|
super().__init__()
|
|
131
64
|
self.dim = dim
|
|
132
|
-
Conv =
|
|
65
|
+
Conv = get_conv_layer(dim)
|
|
133
66
|
hidden_dim = int(channels * expansion_ratio)
|
|
134
67
|
|
|
135
68
|
# Depthwise conv (7x7) - operates in channels-first
|
|
@@ -223,7 +156,7 @@ class ConvNeXtBase(BaseModel):
|
|
|
223
156
|
self.dims = dims
|
|
224
157
|
self.dropout_rate = dropout_rate
|
|
225
158
|
|
|
226
|
-
Conv =
|
|
159
|
+
Conv = get_conv_layer(self.dim)
|
|
227
160
|
|
|
228
161
|
# Stem: Patchify with stride-4 conv (like ViT patch embedding)
|
|
229
162
|
self.stem = nn.Sequential(
|
wavedl/models/convnext_v2.py
CHANGED
|
@@ -32,6 +32,7 @@ import torch
|
|
|
32
32
|
import torch.nn as nn
|
|
33
33
|
|
|
34
34
|
from wavedl.models._pretrained_utils import (
|
|
35
|
+
DropPath,
|
|
35
36
|
LayerNormNd,
|
|
36
37
|
build_regression_head,
|
|
37
38
|
get_conv_layer,
|
|
@@ -151,24 +152,6 @@ class ConvNeXtV2Block(nn.Module):
|
|
|
151
152
|
return x
|
|
152
153
|
|
|
153
154
|
|
|
154
|
-
class DropPath(nn.Module):
|
|
155
|
-
"""Stochastic Depth (drop path) regularization."""
|
|
156
|
-
|
|
157
|
-
def __init__(self, drop_prob: float = 0.0):
|
|
158
|
-
super().__init__()
|
|
159
|
-
self.drop_prob = drop_prob
|
|
160
|
-
|
|
161
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
162
|
-
if self.drop_prob == 0.0 or not self.training:
|
|
163
|
-
return x
|
|
164
|
-
|
|
165
|
-
keep_prob = 1 - self.drop_prob
|
|
166
|
-
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
167
|
-
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
168
|
-
random_tensor.floor_()
|
|
169
|
-
return x.div(keep_prob) * random_tensor
|
|
170
|
-
|
|
171
|
-
|
|
172
155
|
# =============================================================================
|
|
173
156
|
# CONVNEXT V2 BASE CLASS
|
|
174
157
|
# =============================================================================
|