wavedl 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,138 @@
1
+ """
2
+ Distributed Training Utilities
3
+ ==============================
4
+
5
+ Provides DDP-safe utilities for multi-GPU training including:
6
+ - Early stopping synchronization across ranks
7
+ - Value broadcasting from rank 0
8
+ - Tensor synchronization
9
+
10
+ Author: Ductho Le (ductho.le@outlook.com)
11
+ Version: 1.0.0
12
+ """
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+ from accelerate import Accelerator
17
+
18
+
19
+ def broadcast_early_stop(should_stop: bool, accelerator: Accelerator) -> bool:
20
+ """
21
+ Broadcast early stopping decision from rank 0 to all processes.
22
+
23
+ In DDP training, early stopping state (patience counter, best loss) is typically
24
+ tracked only on rank 0. This function ensures all ranks receive the same stop signal
25
+ to prevent deadlocks from inconsistent termination.
26
+
27
+ Args:
28
+ should_stop: Whether to stop training (only meaningful on rank 0)
29
+ accelerator: Accelerator instance for device and process info
30
+
31
+ Returns:
32
+ True if training should stop (synchronized across all ranks)
33
+
34
+ Example:
35
+ # On main process: patience_ctr >= patience
36
+ # On other processes: unknown/stale value
37
+ should_stop = patience_ctr >= args.patience if accelerator.is_main_process else False
38
+ if broadcast_early_stop(should_stop, accelerator):
39
+ break # All ranks exit loop together
40
+ """
41
+ stop_tensor = torch.tensor(
42
+ 1 if should_stop else 0, device=accelerator.device, dtype=torch.int32
43
+ )
44
+
45
+ if accelerator.num_processes > 1:
46
+ dist.broadcast(stop_tensor, src=0)
47
+
48
+ return stop_tensor.item() == 1
49
+
50
+
51
+ def broadcast_value(value: int | float, accelerator: Accelerator) -> int | float:
52
+ """
53
+ Broadcast a scalar value from rank 0 to all processes.
54
+
55
+ Useful for synchronizing hyperparameters or computed values across ranks.
56
+
57
+ Args:
58
+ value: Scalar value to broadcast (only rank 0's value is used)
59
+ accelerator: Accelerator instance for device and process info
60
+
61
+ Returns:
62
+ Value from rank 0 (synchronized across all ranks)
63
+ """
64
+ is_int = isinstance(value, int)
65
+ dtype = torch.int64 if is_int else torch.float32
66
+
67
+ tensor = torch.tensor(value, device=accelerator.device, dtype=dtype)
68
+
69
+ if accelerator.num_processes > 1:
70
+ dist.broadcast(tensor, src=0)
71
+
72
+ result = tensor.item()
73
+ return int(result) if is_int else result
74
+
75
+
76
+ def sync_tensor(
77
+ tensor: torch.Tensor, accelerator: Accelerator, reduction: str = "sum"
78
+ ) -> torch.Tensor:
79
+ """
80
+ Synchronize a tensor across all processes with specified reduction.
81
+
82
+ Wrapper around accelerator.reduce with additional validation.
83
+
84
+ Args:
85
+ tensor: Tensor to synchronize
86
+ accelerator: Accelerator instance
87
+ reduction: Reduction operation ("sum", "mean", "max", "min")
88
+
89
+ Returns:
90
+ Reduced tensor (synchronized across all ranks)
91
+
92
+ Raises:
93
+ ValueError: If reduction type is not recognized
94
+ """
95
+ valid_reductions = {"sum", "mean", "max", "min"}
96
+ if reduction not in valid_reductions:
97
+ raise ValueError(
98
+ f"Invalid reduction '{reduction}'. Must be one of {valid_reductions}"
99
+ )
100
+
101
+ return accelerator.reduce(tensor, reduction=reduction)
102
+
103
+
104
+ def get_world_info(accelerator: Accelerator) -> dict:
105
+ """
106
+ Get distributed training world information.
107
+
108
+ Args:
109
+ accelerator: Accelerator instance
110
+
111
+ Returns:
112
+ Dictionary with world_size, rank, local_rank, is_main, device
113
+ """
114
+ return {
115
+ "world_size": accelerator.num_processes,
116
+ "rank": accelerator.process_index,
117
+ "local_rank": accelerator.local_process_index,
118
+ "is_main": accelerator.is_main_process,
119
+ "device": str(accelerator.device),
120
+ }
121
+
122
+
123
+ def print_rank0(message: str, accelerator: Accelerator, logger=None):
124
+ """
125
+ Print message only on rank 0.
126
+
127
+ Convenience function for logging in distributed setting.
128
+
129
+ Args:
130
+ message: Message to print
131
+ accelerator: Accelerator instance
132
+ logger: Optional logger (uses print if None)
133
+ """
134
+ if accelerator.is_main_process:
135
+ if logger:
136
+ logger.info(message)
137
+ else:
138
+ print(message)
wavedl/utils/losses.py ADDED
@@ -0,0 +1,216 @@
1
+ """
2
+ Loss Functions for Regression Tasks
3
+ ====================================
4
+
5
+ Provides a comprehensive set of loss functions for regression problems,
6
+ with a factory function for easy selection via CLI arguments.
7
+
8
+ Supported Losses:
9
+ - mse: Mean Squared Error (default)
10
+ - mae: Mean Absolute Error (L1)
11
+ - huber: Huber Loss (smooth blend of MSE and MAE)
12
+ - smooth_l1: Smooth L1 Loss (PyTorch native Huber variant)
13
+ - log_cosh: Log-Cosh Loss (smooth approximation to MAE)
14
+ - weighted_mse: Per-target weighted MSE
15
+
16
+ Author: Ductho Le (ductho.le@outlook.com)
17
+ Version: 1.0.0
18
+ """
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+
24
+ # ==============================================================================
25
+ # CUSTOM LOSS FUNCTIONS
26
+ # ==============================================================================
27
+ class LogCoshLoss(nn.Module):
28
+ """
29
+ Log-Cosh Loss: A smooth approximation to Mean Absolute Error.
30
+
31
+ The loss is defined as: loss = log(cosh(pred - target))
32
+
33
+ Properties:
34
+ - Smooth everywhere (differentiable)
35
+ - Behaves like L2 for small errors, L1 for large errors
36
+ - More robust to outliers than MSE
37
+
38
+ Example:
39
+ >>> criterion = LogCoshLoss()
40
+ >>> loss = criterion(predictions, targets)
41
+ """
42
+
43
+ def __init__(self, reduction: str = "mean"):
44
+ """
45
+ Args:
46
+ reduction: Specifies the reduction: 'none' | 'mean' | 'sum'
47
+ """
48
+ super().__init__()
49
+ if reduction not in ("none", "mean", "sum"):
50
+ raise ValueError(f"Invalid reduction mode: {reduction}")
51
+ self.reduction = reduction
52
+
53
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
54
+ """
55
+ Compute Log-Cosh loss.
56
+
57
+ Args:
58
+ pred: Predicted values of shape (N, *)
59
+ target: Target values of shape (N, *)
60
+
61
+ Returns:
62
+ Loss value (scalar if reduction is 'mean' or 'sum')
63
+ """
64
+ diff = pred - target
65
+ # log(cosh(x)) = x + softplus(-2x) - log(2)
66
+ # This formulation is numerically stable
67
+ loss = diff + torch.nn.functional.softplus(-2.0 * diff) - 0.693147 # log(2)
68
+
69
+ if self.reduction == "none":
70
+ return loss
71
+ elif self.reduction == "sum":
72
+ return loss.sum()
73
+ else: # mean
74
+ return loss.mean()
75
+
76
+
77
+ class WeightedMSELoss(nn.Module):
78
+ """
79
+ Weighted Mean Squared Error Loss.
80
+
81
+ Applies different weights to each target dimension, allowing
82
+ prioritization of specific outputs (e.g., prioritize thickness
83
+ over velocity in NDE applications).
84
+
85
+ Example:
86
+ >>> # 3 targets, prioritize first target
87
+ >>> criterion = WeightedMSELoss(weights=[2.0, 1.0, 1.0])
88
+ >>> loss = criterion(predictions, targets)
89
+ """
90
+
91
+ def __init__(
92
+ self, weights: list[float] | torch.Tensor | None = None, reduction: str = "mean"
93
+ ):
94
+ """
95
+ Args:
96
+ weights: Per-target weights. If None, equal weights (standard MSE).
97
+ Length must match number of output targets.
98
+ reduction: Specifies the reduction: 'none' | 'mean' | 'sum'
99
+ """
100
+ super().__init__()
101
+ if reduction not in ("none", "mean", "sum"):
102
+ raise ValueError(f"Invalid reduction mode: {reduction}")
103
+ self.reduction = reduction
104
+
105
+ if weights is not None:
106
+ if isinstance(weights, list):
107
+ weights = torch.tensor(weights, dtype=torch.float32)
108
+ self.register_buffer("weights", weights)
109
+ else:
110
+ self.weights = None
111
+
112
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Compute weighted MSE loss.
115
+
116
+ Args:
117
+ pred: Predicted values of shape (N, T) where T is number of targets
118
+ target: Target values of shape (N, T)
119
+
120
+ Returns:
121
+ Loss value (scalar if reduction is 'mean' or 'sum')
122
+
123
+ Raises:
124
+ ValueError: If weight dimension doesn't match target dimension
125
+ """
126
+ mse = (pred - target) ** 2
127
+
128
+ if self.weights is not None:
129
+ # Validate weight dimension matches target dimension
130
+ if self.weights.shape[0] != pred.shape[-1]:
131
+ raise ValueError(
132
+ f"Weight dimension ({self.weights.shape[0]}) must match "
133
+ f"output dimension ({pred.shape[-1]}). "
134
+ f"Check your --loss_weights argument."
135
+ )
136
+ # Use local variable to avoid mutating registered buffer
137
+ # (mutating self.weights breaks state_dict semantics)
138
+ weights = self.weights.to(mse.device)
139
+ # Apply per-target weights with correct broadcasting: (N, T) * (T,) -> (N, T)
140
+ mse = mse * weights
141
+
142
+ if self.reduction == "none":
143
+ return mse
144
+ elif self.reduction == "sum":
145
+ return mse.sum()
146
+ else: # mean
147
+ return mse.mean()
148
+
149
+
150
+ # ==============================================================================
151
+ # LOSS REGISTRY
152
+ # ==============================================================================
153
+ _LOSS_REGISTRY = {
154
+ "mse": nn.MSELoss,
155
+ "mae": nn.L1Loss,
156
+ "l1": nn.L1Loss, # Alias for mae
157
+ "huber": nn.HuberLoss,
158
+ "smooth_l1": nn.SmoothL1Loss,
159
+ "log_cosh": LogCoshLoss,
160
+ "logcosh": LogCoshLoss, # Alias
161
+ "weighted_mse": WeightedMSELoss,
162
+ }
163
+
164
+
165
+ def list_losses() -> list[str]:
166
+ """
167
+ Return list of available loss function names.
168
+
169
+ Returns:
170
+ List of registered loss function names (excluding aliases)
171
+ """
172
+ # Return unique loss names (exclude aliases)
173
+ primary_names = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
174
+ return primary_names
175
+
176
+
177
+ def get_loss(
178
+ name: str, weights: list[float] | None = None, delta: float = 1.0, **kwargs
179
+ ) -> nn.Module:
180
+ """
181
+ Factory function to create loss function by name.
182
+
183
+ Args:
184
+ name: Loss function name (see list_losses())
185
+ weights: Per-target weights for weighted_mse
186
+ delta: Delta parameter for Huber loss (default: 1.0)
187
+ **kwargs: Additional arguments passed to loss constructor
188
+
189
+ Returns:
190
+ Instantiated loss function (nn.Module)
191
+
192
+ Raises:
193
+ ValueError: If loss name is not recognized
194
+
195
+ Example:
196
+ >>> criterion = get_loss("mse")
197
+ >>> criterion = get_loss("huber", delta=0.5)
198
+ >>> criterion = get_loss("weighted_mse", weights=[2.0, 1.0, 1.0])
199
+ """
200
+ name_lower = name.lower().replace("-", "_")
201
+
202
+ if name_lower not in _LOSS_REGISTRY:
203
+ available = ", ".join(list_losses())
204
+ raise ValueError(
205
+ f"Unknown loss function: '{name}'. Available options: {available}"
206
+ )
207
+
208
+ loss_cls = _LOSS_REGISTRY[name_lower]
209
+
210
+ # Special handling for specific loss types
211
+ if name_lower == "huber":
212
+ return loss_cls(delta=delta, **kwargs)
213
+ elif name_lower == "weighted_mse":
214
+ return loss_cls(weights=weights, **kwargs)
215
+ else:
216
+ return loss_cls(**kwargs)