wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/utils/data.py
CHANGED
|
@@ -474,9 +474,18 @@ class _TransposedH5Dataset:
|
|
|
474
474
|
self.shape = tuple(reversed(h5_dataset.shape))
|
|
475
475
|
self.dtype = h5_dataset.dtype
|
|
476
476
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
477
|
+
@property
|
|
478
|
+
def ndim(self) -> int:
|
|
479
|
+
"""Number of dimensions (derived from shape for numpy compatibility)."""
|
|
480
|
+
return len(self.shape)
|
|
481
|
+
|
|
482
|
+
@property
|
|
483
|
+
def _transpose_axes(self) -> tuple[int, ...]:
|
|
484
|
+
"""Transpose axis order for reversing dimensions.
|
|
485
|
+
|
|
486
|
+
For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0).
|
|
487
|
+
"""
|
|
488
|
+
return tuple(range(len(self._dataset.shape) - 1, -1, -1))
|
|
480
489
|
|
|
481
490
|
def __len__(self) -> int:
|
|
482
491
|
return self.shape[0]
|
|
@@ -965,8 +974,17 @@ def load_test_data(
|
|
|
965
974
|
else:
|
|
966
975
|
# Fallback to default source.load() for unknown formats
|
|
967
976
|
inp, outp = source.load(path)
|
|
968
|
-
except KeyError:
|
|
969
|
-
#
|
|
977
|
+
except KeyError as e:
|
|
978
|
+
# IMPORTANT: Only fall back to inference-only mode if outputs are
|
|
979
|
+
# genuinely missing (auto-detection failed). If user explicitly
|
|
980
|
+
# provided --output_key, they expect it to exist - don't silently drop.
|
|
981
|
+
if output_key is not None:
|
|
982
|
+
raise KeyError(
|
|
983
|
+
f"Explicit --output_key '{output_key}' not found in file. "
|
|
984
|
+
f"Available keys depend on file format. Original error: {e}"
|
|
985
|
+
) from e
|
|
986
|
+
|
|
987
|
+
# Legitimate fallback: no explicit output_key, outputs just not present
|
|
970
988
|
if format == "npz":
|
|
971
989
|
# First pass to find keys
|
|
972
990
|
with np.load(path, allow_pickle=False) as probe:
|
|
@@ -1083,11 +1101,26 @@ def load_test_data(
|
|
|
1083
1101
|
raise ValueError(
|
|
1084
1102
|
f"Input appears to be channels-last format: {tuple(X.shape)}. "
|
|
1085
1103
|
"WaveDL expects channels-first (N, C, H, W). "
|
|
1086
|
-
"Convert your data using: X = X.permute(0, 3, 1, 2)"
|
|
1104
|
+
"Convert your data using: X = X.permute(0, 3, 1, 2). "
|
|
1105
|
+
"If this is actually a 3D volume with small depth, "
|
|
1106
|
+
"use --input_channels 1 to add a channel dimension."
|
|
1087
1107
|
)
|
|
1088
1108
|
elif X.shape[1] > 16:
|
|
1089
1109
|
# Heuristic fallback: large dim 1 suggests 3D volume needing channel
|
|
1090
1110
|
X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
|
|
1111
|
+
else:
|
|
1112
|
+
# Ambiguous case: shallow 3D volume (D <= 16) or multi-channel 2D
|
|
1113
|
+
# Default to treating as multi-channel 2D (no modification needed)
|
|
1114
|
+
# Log a warning so users know about the --input_channels option
|
|
1115
|
+
import warnings
|
|
1116
|
+
|
|
1117
|
+
warnings.warn(
|
|
1118
|
+
f"Ambiguous 4D input shape: {tuple(X.shape)}. "
|
|
1119
|
+
f"Assuming {X.shape[1]} channels (multi-channel 2D). "
|
|
1120
|
+
f"For 3D volumes with depth={X.shape[1]}, use --input_channels 1.",
|
|
1121
|
+
UserWarning,
|
|
1122
|
+
stacklevel=2,
|
|
1123
|
+
)
|
|
1091
1124
|
# X.ndim >= 5: assume channel dimension already exists
|
|
1092
1125
|
|
|
1093
1126
|
return X, y
|
wavedl/utils/losses.py
CHANGED
|
@@ -1,216 +1,216 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Loss Functions for Regression Tasks
|
|
3
|
-
====================================
|
|
4
|
-
|
|
5
|
-
Provides a comprehensive set of loss functions for regression problems,
|
|
6
|
-
with a factory function for easy selection via CLI arguments.
|
|
7
|
-
|
|
8
|
-
Supported Losses:
|
|
9
|
-
- mse: Mean Squared Error (default)
|
|
10
|
-
- mae: Mean Absolute Error (L1)
|
|
11
|
-
- huber: Huber Loss (smooth blend of MSE and MAE)
|
|
12
|
-
- smooth_l1: Smooth L1 Loss (PyTorch native Huber variant)
|
|
13
|
-
- log_cosh: Log-Cosh Loss (smooth approximation to MAE)
|
|
14
|
-
- weighted_mse: Per-target weighted MSE
|
|
15
|
-
|
|
16
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
17
|
-
Version: 1.0.0
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
import torch
|
|
21
|
-
import torch.nn as nn
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
# ==============================================================================
|
|
25
|
-
# CUSTOM LOSS FUNCTIONS
|
|
26
|
-
# ==============================================================================
|
|
27
|
-
class LogCoshLoss(nn.Module):
|
|
28
|
-
"""
|
|
29
|
-
Log-Cosh Loss: A smooth approximation to Mean Absolute Error.
|
|
30
|
-
|
|
31
|
-
The loss is defined as: loss = log(cosh(pred - target))
|
|
32
|
-
|
|
33
|
-
Properties:
|
|
34
|
-
- Smooth everywhere (differentiable)
|
|
35
|
-
- Behaves like L2 for small errors, L1 for large errors
|
|
36
|
-
- More robust to outliers than MSE
|
|
37
|
-
|
|
38
|
-
Example:
|
|
39
|
-
>>> criterion = LogCoshLoss()
|
|
40
|
-
>>> loss = criterion(predictions, targets)
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
def __init__(self, reduction: str = "mean"):
|
|
44
|
-
"""
|
|
45
|
-
Args:
|
|
46
|
-
reduction: Specifies the reduction: 'none' | 'mean' | 'sum'
|
|
47
|
-
"""
|
|
48
|
-
super().__init__()
|
|
49
|
-
if reduction not in ("none", "mean", "sum"):
|
|
50
|
-
raise ValueError(f"Invalid reduction mode: {reduction}")
|
|
51
|
-
self.reduction = reduction
|
|
52
|
-
|
|
53
|
-
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
54
|
-
"""
|
|
55
|
-
Compute Log-Cosh loss.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
pred: Predicted values of shape (N, *)
|
|
59
|
-
target: Target values of shape (N, *)
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
Loss value (scalar if reduction is 'mean' or 'sum')
|
|
63
|
-
"""
|
|
64
|
-
diff = pred - target
|
|
65
|
-
# log(cosh(x)) = x + softplus(-2x) - log(2)
|
|
66
|
-
# This formulation is numerically stable
|
|
67
|
-
loss = diff + torch.nn.functional.softplus(-2.0 * diff) - 0.693147 # log(2)
|
|
68
|
-
|
|
69
|
-
if self.reduction == "none":
|
|
70
|
-
return loss
|
|
71
|
-
elif self.reduction == "sum":
|
|
72
|
-
return loss.sum()
|
|
73
|
-
else: # mean
|
|
74
|
-
return loss.mean()
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
class WeightedMSELoss(nn.Module):
|
|
78
|
-
"""
|
|
79
|
-
Weighted Mean Squared Error Loss.
|
|
80
|
-
|
|
81
|
-
Applies different weights to each target dimension, allowing
|
|
82
|
-
prioritization of specific outputs (e.g., prioritize thickness
|
|
83
|
-
over velocity in NDE applications).
|
|
84
|
-
|
|
85
|
-
Example:
|
|
86
|
-
>>> # 3 targets, prioritize first target
|
|
87
|
-
>>> criterion = WeightedMSELoss(weights=[2.0, 1.0, 1.0])
|
|
88
|
-
>>> loss = criterion(predictions, targets)
|
|
89
|
-
"""
|
|
90
|
-
|
|
91
|
-
def __init__(
|
|
92
|
-
self, weights: list[float] | torch.Tensor | None = None, reduction: str = "mean"
|
|
93
|
-
):
|
|
94
|
-
"""
|
|
95
|
-
Args:
|
|
96
|
-
weights: Per-target weights. If None, equal weights (standard MSE).
|
|
97
|
-
Length must match number of output targets.
|
|
98
|
-
reduction: Specifies the reduction: 'none' | 'mean' | 'sum'
|
|
99
|
-
"""
|
|
100
|
-
super().__init__()
|
|
101
|
-
if reduction not in ("none", "mean", "sum"):
|
|
102
|
-
raise ValueError(f"Invalid reduction mode: {reduction}")
|
|
103
|
-
self.reduction = reduction
|
|
104
|
-
|
|
105
|
-
if weights is not None:
|
|
106
|
-
if isinstance(weights, list):
|
|
107
|
-
weights = torch.tensor(weights, dtype=torch.float32)
|
|
108
|
-
self.register_buffer("weights", weights)
|
|
109
|
-
else:
|
|
110
|
-
self.weights = None
|
|
111
|
-
|
|
112
|
-
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
113
|
-
"""
|
|
114
|
-
Compute weighted MSE loss.
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
pred: Predicted values of shape (N, T) where T is number of targets
|
|
118
|
-
target: Target values of shape (N, T)
|
|
119
|
-
|
|
120
|
-
Returns:
|
|
121
|
-
Loss value (scalar if reduction is 'mean' or 'sum')
|
|
122
|
-
|
|
123
|
-
Raises:
|
|
124
|
-
ValueError: If weight dimension doesn't match target dimension
|
|
125
|
-
"""
|
|
126
|
-
mse = (pred - target) ** 2
|
|
127
|
-
|
|
128
|
-
if self.weights is not None:
|
|
129
|
-
# Validate weight dimension matches target dimension
|
|
130
|
-
if self.weights.shape[0] != pred.shape[-1]:
|
|
131
|
-
raise ValueError(
|
|
132
|
-
f"Weight dimension ({self.weights.shape[0]}) must match "
|
|
133
|
-
f"output dimension ({pred.shape[-1]}). "
|
|
134
|
-
f"Check your --loss_weights argument."
|
|
135
|
-
)
|
|
136
|
-
# Use local variable to avoid mutating registered buffer
|
|
137
|
-
# (mutating self.weights breaks state_dict semantics)
|
|
138
|
-
weights = self.weights.to(mse.device)
|
|
139
|
-
# Apply per-target weights with correct broadcasting: (N, T) * (T,) -> (N, T)
|
|
140
|
-
mse = mse * weights
|
|
141
|
-
|
|
142
|
-
if self.reduction == "none":
|
|
143
|
-
return mse
|
|
144
|
-
elif self.reduction == "sum":
|
|
145
|
-
return mse.sum()
|
|
146
|
-
else: # mean
|
|
147
|
-
return mse.mean()
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
# ==============================================================================
|
|
151
|
-
# LOSS REGISTRY
|
|
152
|
-
# ==============================================================================
|
|
153
|
-
_LOSS_REGISTRY = {
|
|
154
|
-
"mse": nn.MSELoss,
|
|
155
|
-
"mae": nn.L1Loss,
|
|
156
|
-
"l1": nn.L1Loss, # Alias for mae
|
|
157
|
-
"huber": nn.HuberLoss,
|
|
158
|
-
"smooth_l1": nn.SmoothL1Loss,
|
|
159
|
-
"log_cosh": LogCoshLoss,
|
|
160
|
-
"logcosh": LogCoshLoss, # Alias
|
|
161
|
-
"weighted_mse": WeightedMSELoss,
|
|
162
|
-
}
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def list_losses() -> list[str]:
|
|
166
|
-
"""
|
|
167
|
-
Return list of available loss function names.
|
|
168
|
-
|
|
169
|
-
Returns:
|
|
170
|
-
List of registered loss function names (excluding aliases)
|
|
171
|
-
"""
|
|
172
|
-
# Return unique loss names (exclude aliases)
|
|
173
|
-
primary_names = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
|
|
174
|
-
return primary_names
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
def get_loss(
|
|
178
|
-
name: str, weights: list[float] | None = None, delta: float = 1.0, **kwargs
|
|
179
|
-
) -> nn.Module:
|
|
180
|
-
"""
|
|
181
|
-
Factory function to create loss function by name.
|
|
182
|
-
|
|
183
|
-
Args:
|
|
184
|
-
name: Loss function name (see list_losses())
|
|
185
|
-
weights: Per-target weights for weighted_mse
|
|
186
|
-
delta: Delta parameter for Huber loss (default: 1.0)
|
|
187
|
-
**kwargs: Additional arguments passed to loss constructor
|
|
188
|
-
|
|
189
|
-
Returns:
|
|
190
|
-
Instantiated loss function (nn.Module)
|
|
191
|
-
|
|
192
|
-
Raises:
|
|
193
|
-
ValueError: If loss name is not recognized
|
|
194
|
-
|
|
195
|
-
Example:
|
|
196
|
-
>>> criterion = get_loss("mse")
|
|
197
|
-
>>> criterion = get_loss("huber", delta=0.5)
|
|
198
|
-
>>> criterion = get_loss("weighted_mse", weights=[2.0, 1.0, 1.0])
|
|
199
|
-
"""
|
|
200
|
-
name_lower = name.lower().replace("-", "_")
|
|
201
|
-
|
|
202
|
-
if name_lower not in _LOSS_REGISTRY:
|
|
203
|
-
available = ", ".join(list_losses())
|
|
204
|
-
raise ValueError(
|
|
205
|
-
f"Unknown loss function: '{name}'. Available options: {available}"
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
loss_cls = _LOSS_REGISTRY[name_lower]
|
|
209
|
-
|
|
210
|
-
# Special handling for specific loss types
|
|
211
|
-
if name_lower == "huber":
|
|
212
|
-
return loss_cls(delta=delta, **kwargs)
|
|
213
|
-
elif name_lower == "weighted_mse":
|
|
214
|
-
return loss_cls(weights=weights, **kwargs)
|
|
215
|
-
else:
|
|
216
|
-
return loss_cls(**kwargs)
|
|
1
|
+
"""
|
|
2
|
+
Loss Functions for Regression Tasks
|
|
3
|
+
====================================
|
|
4
|
+
|
|
5
|
+
Provides a comprehensive set of loss functions for regression problems,
|
|
6
|
+
with a factory function for easy selection via CLI arguments.
|
|
7
|
+
|
|
8
|
+
Supported Losses:
|
|
9
|
+
- mse: Mean Squared Error (default)
|
|
10
|
+
- mae: Mean Absolute Error (L1)
|
|
11
|
+
- huber: Huber Loss (smooth blend of MSE and MAE)
|
|
12
|
+
- smooth_l1: Smooth L1 Loss (PyTorch native Huber variant)
|
|
13
|
+
- log_cosh: Log-Cosh Loss (smooth approximation to MAE)
|
|
14
|
+
- weighted_mse: Per-target weighted MSE
|
|
15
|
+
|
|
16
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
17
|
+
Version: 1.0.0
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ==============================================================================
|
|
25
|
+
# CUSTOM LOSS FUNCTIONS
|
|
26
|
+
# ==============================================================================
|
|
27
|
+
class LogCoshLoss(nn.Module):
|
|
28
|
+
"""
|
|
29
|
+
Log-Cosh Loss: A smooth approximation to Mean Absolute Error.
|
|
30
|
+
|
|
31
|
+
The loss is defined as: loss = log(cosh(pred - target))
|
|
32
|
+
|
|
33
|
+
Properties:
|
|
34
|
+
- Smooth everywhere (differentiable)
|
|
35
|
+
- Behaves like L2 for small errors, L1 for large errors
|
|
36
|
+
- More robust to outliers than MSE
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
>>> criterion = LogCoshLoss()
|
|
40
|
+
>>> loss = criterion(predictions, targets)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, reduction: str = "mean"):
|
|
44
|
+
"""
|
|
45
|
+
Args:
|
|
46
|
+
reduction: Specifies the reduction: 'none' | 'mean' | 'sum'
|
|
47
|
+
"""
|
|
48
|
+
super().__init__()
|
|
49
|
+
if reduction not in ("none", "mean", "sum"):
|
|
50
|
+
raise ValueError(f"Invalid reduction mode: {reduction}")
|
|
51
|
+
self.reduction = reduction
|
|
52
|
+
|
|
53
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
"""
|
|
55
|
+
Compute Log-Cosh loss.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
pred: Predicted values of shape (N, *)
|
|
59
|
+
target: Target values of shape (N, *)
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Loss value (scalar if reduction is 'mean' or 'sum')
|
|
63
|
+
"""
|
|
64
|
+
diff = pred - target
|
|
65
|
+
# log(cosh(x)) = x + softplus(-2x) - log(2)
|
|
66
|
+
# This formulation is numerically stable
|
|
67
|
+
loss = diff + torch.nn.functional.softplus(-2.0 * diff) - 0.693147 # log(2)
|
|
68
|
+
|
|
69
|
+
if self.reduction == "none":
|
|
70
|
+
return loss
|
|
71
|
+
elif self.reduction == "sum":
|
|
72
|
+
return loss.sum()
|
|
73
|
+
else: # mean
|
|
74
|
+
return loss.mean()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class WeightedMSELoss(nn.Module):
|
|
78
|
+
"""
|
|
79
|
+
Weighted Mean Squared Error Loss.
|
|
80
|
+
|
|
81
|
+
Applies different weights to each target dimension, allowing
|
|
82
|
+
prioritization of specific outputs (e.g., prioritize thickness
|
|
83
|
+
over velocity in NDE applications).
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
>>> # 3 targets, prioritize first target
|
|
87
|
+
>>> criterion = WeightedMSELoss(weights=[2.0, 1.0, 1.0])
|
|
88
|
+
>>> loss = criterion(predictions, targets)
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self, weights: list[float] | torch.Tensor | None = None, reduction: str = "mean"
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
Args:
|
|
96
|
+
weights: Per-target weights. If None, equal weights (standard MSE).
|
|
97
|
+
Length must match number of output targets.
|
|
98
|
+
reduction: Specifies the reduction: 'none' | 'mean' | 'sum'
|
|
99
|
+
"""
|
|
100
|
+
super().__init__()
|
|
101
|
+
if reduction not in ("none", "mean", "sum"):
|
|
102
|
+
raise ValueError(f"Invalid reduction mode: {reduction}")
|
|
103
|
+
self.reduction = reduction
|
|
104
|
+
|
|
105
|
+
if weights is not None:
|
|
106
|
+
if isinstance(weights, list):
|
|
107
|
+
weights = torch.tensor(weights, dtype=torch.float32)
|
|
108
|
+
self.register_buffer("weights", weights)
|
|
109
|
+
else:
|
|
110
|
+
self.weights = None
|
|
111
|
+
|
|
112
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
113
|
+
"""
|
|
114
|
+
Compute weighted MSE loss.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
pred: Predicted values of shape (N, T) where T is number of targets
|
|
118
|
+
target: Target values of shape (N, T)
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Loss value (scalar if reduction is 'mean' or 'sum')
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: If weight dimension doesn't match target dimension
|
|
125
|
+
"""
|
|
126
|
+
mse = (pred - target) ** 2
|
|
127
|
+
|
|
128
|
+
if self.weights is not None:
|
|
129
|
+
# Validate weight dimension matches target dimension
|
|
130
|
+
if self.weights.shape[0] != pred.shape[-1]:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Weight dimension ({self.weights.shape[0]}) must match "
|
|
133
|
+
f"output dimension ({pred.shape[-1]}). "
|
|
134
|
+
f"Check your --loss_weights argument."
|
|
135
|
+
)
|
|
136
|
+
# Use local variable to avoid mutating registered buffer
|
|
137
|
+
# (mutating self.weights breaks state_dict semantics)
|
|
138
|
+
weights = self.weights.to(mse.device)
|
|
139
|
+
# Apply per-target weights with correct broadcasting: (N, T) * (T,) -> (N, T)
|
|
140
|
+
mse = mse * weights
|
|
141
|
+
|
|
142
|
+
if self.reduction == "none":
|
|
143
|
+
return mse
|
|
144
|
+
elif self.reduction == "sum":
|
|
145
|
+
return mse.sum()
|
|
146
|
+
else: # mean
|
|
147
|
+
return mse.mean()
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# ==============================================================================
|
|
151
|
+
# LOSS REGISTRY
|
|
152
|
+
# ==============================================================================
|
|
153
|
+
_LOSS_REGISTRY = {
|
|
154
|
+
"mse": nn.MSELoss,
|
|
155
|
+
"mae": nn.L1Loss,
|
|
156
|
+
"l1": nn.L1Loss, # Alias for mae
|
|
157
|
+
"huber": nn.HuberLoss,
|
|
158
|
+
"smooth_l1": nn.SmoothL1Loss,
|
|
159
|
+
"log_cosh": LogCoshLoss,
|
|
160
|
+
"logcosh": LogCoshLoss, # Alias
|
|
161
|
+
"weighted_mse": WeightedMSELoss,
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def list_losses() -> list[str]:
|
|
166
|
+
"""
|
|
167
|
+
Return list of available loss function names.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
List of registered loss function names (excluding aliases)
|
|
171
|
+
"""
|
|
172
|
+
# Return unique loss names (exclude aliases)
|
|
173
|
+
primary_names = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
|
|
174
|
+
return primary_names
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def get_loss(
|
|
178
|
+
name: str, weights: list[float] | None = None, delta: float = 1.0, **kwargs
|
|
179
|
+
) -> nn.Module:
|
|
180
|
+
"""
|
|
181
|
+
Factory function to create loss function by name.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
name: Loss function name (see list_losses())
|
|
185
|
+
weights: Per-target weights for weighted_mse
|
|
186
|
+
delta: Delta parameter for Huber loss (default: 1.0)
|
|
187
|
+
**kwargs: Additional arguments passed to loss constructor
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Instantiated loss function (nn.Module)
|
|
191
|
+
|
|
192
|
+
Raises:
|
|
193
|
+
ValueError: If loss name is not recognized
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
>>> criterion = get_loss("mse")
|
|
197
|
+
>>> criterion = get_loss("huber", delta=0.5)
|
|
198
|
+
>>> criterion = get_loss("weighted_mse", weights=[2.0, 1.0, 1.0])
|
|
199
|
+
"""
|
|
200
|
+
name_lower = name.lower().replace("-", "_")
|
|
201
|
+
|
|
202
|
+
if name_lower not in _LOSS_REGISTRY:
|
|
203
|
+
available = ", ".join(list_losses())
|
|
204
|
+
raise ValueError(
|
|
205
|
+
f"Unknown loss function: '{name}'. Available options: {available}"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
loss_cls = _LOSS_REGISTRY[name_lower]
|
|
209
|
+
|
|
210
|
+
# Special handling for specific loss types
|
|
211
|
+
if name_lower == "huber":
|
|
212
|
+
return loss_cls(delta=delta, **kwargs)
|
|
213
|
+
elif name_lower == "weighted_mse":
|
|
214
|
+
return loss_cls(weights=weights, **kwargs)
|
|
215
|
+
else:
|
|
216
|
+
return loss_cls(**kwargs)
|