torch-l1-snr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,201 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-l1-snr
3
+ Version: 0.0.1
4
+ Summary: L1-SNR loss functions for audio source separation in PyTorch
5
+ Home-page: https://github.com/crlandsc/torch-l1-snr
6
+ Author: Christopher Landscaping
7
+ Author-email: crlandschoot@gmail.com
8
+ License: MIT
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Operating System :: OS Independent
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
21
+ Requires-Python: >=3.8
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch
25
+ Requires-Dist: torchaudio
26
+ Requires-Dist: numpy>=1.21.0
27
+ Dynamic: license-file
28
+
29
+ ![torch-l1-snr-logo](https://raw.githubusercontent.com/crlandsc/torch-l1-snr/main/images/logo.png) -->
30
+
31
+ # NOTE: Repo is currently a work-in-progress and not ready for installation & use.
32
+
33
+ [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1snr)](https://github.com/crlandsc/torch-l1snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1snr)](https://github.com/crlandsc/torch-l1snr/stargazers)
34
+
35
+ # torch-l1-snr
36
+
37
+ A PyTorch implementation of L1-based Signal-to-Noise Ratio (SNR) loss functions for audio source separation. This package provides implementations and novel extensions based on concepts from recent academic papers, offering flexible and robust loss functions that can be easily integrated into any PyTorch-based audio separation pipeline.
38
+
39
+ The core `L1SNRLoss` is based on the loss function described in [1], while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [2].
40
+
41
+ ## Features
42
+
43
+ - **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [1].
44
+ - **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [2], plus an optional L1 loss component.
45
+ - **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [2], calculated over multiple STFT resolutions.
46
+ - **Modular Stem-based Loss**: A wrapper that combines time and spectrogram domain losses and can be configured to run on specific stems.
47
+ - **Efficient & Robust**: Includes optimizations for pure L1 loss calculation and robust handling of `NaN`/`inf` values and short audio segments.
48
+
49
+ ## Installation
50
+
51
+ <!-- Add PyPI badges once the package is published -->
52
+ <!-- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
53
+ <!-- [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
54
+ <!-- [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
55
+
56
+ You can install the package directly from GitHub:
57
+
58
+ ```bash
59
+ pip install git+https://github.com/crlandsc/torch-l1snr.git
60
+ ```
61
+
62
+ Or, you can clone the repository and install it in editable mode for development:
63
+
64
+ ```bash
65
+ git clone https://github.com/crlandsc/torch-l1snr.git
66
+ cd torch-l1snr
67
+ pip install -e .
68
+ ```
69
+
70
+ ## Dependencies
71
+
72
+ - [PyTorch](https://pytorch.org/)
73
+ - [torchaudio](https://pytorch.org/audio/stable/index.html)
74
+
75
+ ## Supported Tensor Shapes
76
+
77
+ All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)` or `(batch, channels, samples)`. For 3D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
78
+
79
+ ## Usage
80
+
81
+ The loss functions can be imported directly from the `torch_l1snr` package.
82
+
83
+ ### Example: `L1SNRDBLoss` (Time Domain)
84
+
85
+ ```python
86
+ import torch
87
+ from torch_l1snr import L1SNRDBLoss
88
+
89
+ # Create dummy audio signals
90
+ estimates = torch.randn(4, 32000) # Batch of 4, 32000 samples
91
+ actuals = torch.randn(4, 32000)
92
+
93
+ # Initialize the loss function
94
+ # l1_weight=0.1 blends L1SNR with 10% L1 loss
95
+ loss_fn = L1SNRDBLoss(l1_weight=0.1)
96
+
97
+ # Calculate loss
98
+ loss = loss_fn(estimates, actuals)
99
+ loss.backward()
100
+
101
+ print(f"L1SNRDBLoss: {loss.item()}")
102
+ ```
103
+
104
+ ### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
105
+
106
+ ```python
107
+ import torch
108
+ from torch_l1snr import STFTL1SNRDBLoss
109
+
110
+ # Create dummy audio signals
111
+ estimates = torch.randn(4, 32000)
112
+ actuals = torch.randn(4, 32000)
113
+
114
+ # Initialize the loss function
115
+ # Uses multiple STFT resolutions by default
116
+ loss_fn = STFTL1SNRDBLoss(l1_weight=0.0) # Pure L1SNR + Regularization
117
+
118
+ # Calculate loss
119
+ loss = loss_fn(estimates, actuals)
120
+ loss.backward()
121
+
122
+ print(f"STFTL1SNRDBLoss: {loss.item()}")
123
+ ```
124
+
125
+ ### Example: `MultiL1SNRDBLoss` for a Combined Time+Spectrogram Loss
126
+
127
+ This loss combines the time-domain and spectrogram-domain losses into a single, weighted objective function.
128
+
129
+ ```python
130
+ import torch
131
+ from torch_l1snr import MultiL1SNRDBLoss
132
+
133
+ # Create dummy audio signals
134
+ # Shape: (batch, channels, samples)
135
+ estimates = torch.randn(2, 2, 44100) # Batch of 2, stereo
136
+ actuals = torch.randn(2, 2, 44100)
137
+
138
+ # --- Configuration ---
139
+ loss_fn = MultiL1SNRDBLoss(
140
+ weight=1.0, # Overall weight for this loss
141
+ spec_weight=0.7, # 70% spectrogram loss, 30% time-domain loss
142
+ l1_weight=0.1, # Use 10% L1, 90% L1SNR+Reg
143
+ )
144
+ loss = loss_fn(estimates, actuals)
145
+ print(f"Multi-domain Loss: {loss.item()}")
146
+ ```
147
+
148
+ ## Motivation
149
+
150
+ The goal of these loss functions is to provide a perceptually-informed and robust alternative to common audio losses like L1, L2 (MSE), and SI-SDR for training audio source separation models.
151
+
152
+ - **Robustness**: The L1 norm is less sensitive to large outliers than the L2 norm, making it more suitable for audio signals which can have sharp transients.
153
+ - **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
154
+ - **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
155
+
156
+ #### Level-Matching Regularization
157
+
158
+ A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [2]. This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
159
+
160
+ #### Multi-Resolution Spectrogram Analysis
161
+
162
+ The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts—from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training.
163
+
164
+ #### "All-or-Nothing" Behavior and `l1_weight`
165
+
166
+ A characteristic of SNR-style losses is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
167
+
168
+ While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
169
+
170
+ - `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
171
+ - `l1_weight=1.0`: Pure L1 loss.
172
+ - `0.0 < l1_weight < 1.0`: A weighted combination of the two.
173
+
174
+ The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
175
+
176
+ **Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), you may need to tune `l1_scale_time` and `l1_scale_spec`. This is to ensure the gradients of the L1 and L1SNR components are balanced, which is crucial for stable training. The default values provide a reasonable starting point, but monitoring the loss components is recommended to ensure they are scaled appropriately.
177
+
178
+ ## Limitations
179
+
180
+ - The L1SNR loss is not scale-invariant. Unlike SI-SNR, it requires the model's output to be correctly scaled relative to the target.
181
+ - While the dB scaling and regularization are psychoacoustically motivated, the loss does not model more complex perceptual phenomena like auditory masking.
182
+
183
+ ## Contributing
184
+
185
+ Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
186
+
187
+ ## License
188
+
189
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
190
+
191
+ ## Acknowledgments
192
+
193
+ The loss functions implemented here are based on the work of the authors of the referenced papers.
194
+
195
+ ## References
196
+
197
+ [1] K. N. Watcharasupat, C.-W. Wu, Y. Ding, I. Orife, A. J. Hipple, P. A. Williams, S. Kramer, A. Lerch, and W. Wolcott, "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation," IEEE Open Journal of Signal Processing, 2023. (arXiv:2309.02539)
198
+
199
+ [2] K. N. Watcharasupat and A. Lerch, "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries," arXiv:2501.16171.
200
+
201
+ [3] K. N. Watcharasupat and A. Lerch, "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems," Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024. (arXiv:2406.18747)
@@ -0,0 +1,7 @@
1
+ torch_l1_snr-0.0.1.dist-info/licenses/LICENSE,sha256=gBRgAD6TvJXVLS9LVkO0V1GzNHKX2poZLFGtyc6hwq0,1079
2
+ torch_l1snr/__init__.py,sha256=pR9jg3fjTKt_suZoVDC67tqB7EWRkbfaXaPP7pYQrlQ,220
3
+ torch_l1snr/l1snr.py,sha256=aqmtNfT_8A0IRI9jiVGwNse3igBvelQGKnjfe23Xh7w,35304
4
+ torch_l1_snr-0.0.1.dist-info/METADATA,sha256=0EuezNJH0APVDEWIsmqW_6Wc9B1J-Zk3ksg8VknU7kY,10374
5
+ torch_l1_snr-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ torch_l1_snr-0.0.1.dist-info/top_level.txt,sha256=NfaRND6pcjZ7-035d4XAg8xJuz31EEU210Y9xWeFOxc,12
7
+ torch_l1_snr-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Christopher Landschoot
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ torch_l1snr
@@ -0,0 +1,15 @@
1
+ from .l1snr import (
2
+ dbrms,
3
+ L1SNRLoss,
4
+ L1SNRDBLoss,
5
+ STFTL1SNRDBLoss,
6
+ MultiL1SNRDBLoss,
7
+ )
8
+
9
+ __all__ = [
10
+ "dbrms",
11
+ "L1SNRLoss",
12
+ "L1SNRDBLoss",
13
+ "STFTL1SNRDBLoss",
14
+ "MultiL1SNRDBLoss",
15
+ ]
torch_l1snr/l1snr.py ADDED
@@ -0,0 +1,786 @@
1
+ # PyTorch implementation of L1SNR loss functions for audio source separation
2
+ # https://github.com/crlandsc/torch-l1-snr
3
+ # Copyright (c) 2026 crlandsc
4
+ # MIT License
5
+ #
6
+ # This implementation is based on and extends the loss functions described in:
7
+ # [1] "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries"
8
+ # Karn N. Watcharasupat, Alexander Lerch
9
+ # arXiv:2501.16171
10
+ # [2] "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation"
11
+ # Karn N. Watcharasupat, Chih-Wei Wu, Yiwei Ding, Iroro Orife, Aaron J. Hipple, Phillip A. Williams, Scott Kramer, Alexander Lerch, William Wolcott
12
+ # IEEE Open Journal of Signal Processing, 2023
13
+ # arXiv:2309.02539
14
+ # [3] "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems"
15
+ # Karn N. Watcharasupat, Alexander Lerch
16
+ # Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024
17
+ # arXiv:2406.18747
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torchaudio.transforms import Spectrogram
22
+ import math
23
+ from typing import Union, Dict, Tuple, Optional, List
24
+
25
+ def dbrms(x, eps=1e-8):
26
+ """
27
+ Compute RMS level in decibels for a batch of signals.
28
+ Args:
29
+ x: (batch, time) or (batch, ...) tensor
30
+ eps: stability constant
31
+ Returns:
32
+ (batch,) tensor of dB RMS
33
+ """
34
+ x = x.reshape(x.shape[0], -1)
35
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1) + eps)
36
+ return 20.0 * torch.log10(rms + eps)
37
+
38
+
39
+ class L1SNRLoss(torch.nn.Module):
40
+ """
41
+ Implements the L1 Signal-to-Noise Ratio (SNR) loss with optional weighted L1 loss
42
+ component to balance "all-or-nothing" behavior.
43
+
44
+ Paper-aligned D1(ŷ; y) form:
45
+ D1 = 10 * log10( (||ŷ - y||_1 + eps) / (||y||_1 + eps) )
46
+ L1SNR_loss = mean(D1)
47
+
48
+ When l1_weight > 0, the loss combines L1SNR with scaled L1:
49
+ loss = (1 - l1_weight) * L1SNR_loss + l1_weight * L1_auto_scaled
50
+
51
+ Input Shape:
52
+ Accepts waveform tensors (time-domain audio) of any shape as long as they are batch-first.
53
+ Recommended shapes:
54
+ - [batch, time] for single-source audio
55
+ - [batch, num_sources, time] for multi-source audio
56
+ - [batch, num_sources, channels, time] for multi-channel multi-source audio
57
+
58
+ Attributes:
59
+ name (str): Name identifier for the loss.
60
+ weight (float): Global weight multiplier for the loss.
61
+ eps (float): Small epsilon for numerical stability in D1 (default 1e-3 per the papers).
62
+ l1_weight (float): Weight for the L1 term mixed into L1SNR.
63
+ """
64
+ def __init__(
65
+ self,
66
+ name,
67
+ weight: float = 1.0,
68
+ eps: float = 1e-3,
69
+ l1_weight: float = 0.0,
70
+ ):
71
+ super().__init__()
72
+ self.name = name
73
+ self.weight = weight
74
+ self.eps = eps
75
+ self.l1_weight = l1_weight
76
+
77
+ def forward(self, estimates, actuals, *args, **kwargs):
78
+ batch_size = estimates.shape[0]
79
+
80
+ est_source = estimates.reshape(batch_size, -1)
81
+ act_source = actuals.reshape(batch_size, -1)
82
+
83
+ # L1 errors and reference
84
+ l1_error = torch.mean(torch.abs(est_source - act_source), dim=-1)
85
+ l1_true = torch.mean(torch.abs(act_source), dim=-1)
86
+
87
+ # Auto-balanced L1/SNR mixing
88
+ w = float(self.l1_weight)
89
+
90
+ # Pure-L1 shortcut: avoid D1 computation
91
+ if w >= 1.0:
92
+ return torch.mean(l1_error) * self.weight
93
+
94
+ # If pure SNR (w == 0) we can skip L1 scaling math
95
+ if w <= 0.0:
96
+ d1 = 10.0 * torch.log10((l1_error + self.eps) / (l1_true + self.eps))
97
+ l1snr_loss = torch.mean(d1)
98
+ return l1snr_loss * self.weight
99
+
100
+ # Mixed path
101
+ d1 = 10.0 * torch.log10((l1_error + self.eps) / (l1_true + self.eps))
102
+ l1snr_loss = torch.mean(d1)
103
+
104
+ c = 10.0 / math.log(10.0)
105
+ inv_mean = torch.mean(1.0 / (l1_error.detach() + self.eps))
106
+ # w-independent scaling to match typical gradient magnitudes
107
+ scale_time = c * inv_mean
108
+ l1_term = torch.mean(l1_error) * scale_time
109
+
110
+ if getattr(self, "balance_per_sample", False):
111
+ # per-sample w-independent scaling
112
+ bal = c / (l1_error.detach() + self.eps)
113
+ l1_term = torch.mean(l1_error * bal)
114
+
115
+ if getattr(self, "debug_balance", False):
116
+ g_d1 = (1.0 - w) * c * inv_mean
117
+ if getattr(self, "balance_per_sample", False):
118
+ g_l1 = w * torch.mean(c / (l1_error.detach() + self.eps))
119
+ else:
120
+ g_l1 = w * c * inv_mean
121
+ ratio = (g_l1 / (g_d1 + 1e-12)).item()
122
+ setattr(self, "last_balance_ratio", ratio)
123
+
124
+ loss = (1.0 - w) * l1snr_loss + w * l1_term
125
+ return loss * self.weight
126
+
127
+
128
+ class L1SNRDBLoss(torch.nn.Module):
129
+ """
130
+ Implements L1SNR plus adaptive level-matching regularization in the time domain
131
+ as described in arXiv:2501.16171, with optional L1 loss component to balance
132
+ "all-or-nothing" behavior.
133
+
134
+ The loss combines three components:
135
+ 1. L1SNR loss: mean(10*log10((l1_error + eps) / (l1_true + eps)))
136
+ 2. Level-matching regularization: λ*|L_pred - L_true|
137
+ Where λ is adaptively computed based on the signal levels
138
+ 3. Optional L1 loss: mean(l1_error)
139
+
140
+ The complete loss is structured as:
141
+ When l1_weight < 1.0: total_loss = l1snr_loss + (1-l1_weight) * mean(reg_loss)
142
+ When l1_weight = 1.0: total_loss = l1_loss (pure L1, bypassing all other computations)
143
+
144
+ The adaptive weighting λ for regularization increases when loud parts of a stem aren't
145
+ reconstructed properly, helping balance between quality and level preservation.
146
+
147
+ When l1_weight=1.0, this loss efficiently switches to a pure L1 loss calculation,
148
+ bypassing all SNR and regularization computations for standard L1 behavior.
149
+ This is useful when you want to avoid the "all-or-nothing" behavior of the SNR-style loss.
150
+
151
+ Input Shape:
152
+ Accepts waveform tensors (time-domain audio) of any shape as long as they are batch-first.
153
+ Recommended shapes:
154
+ - [batch, time] for single-source audio
155
+ - [batch, num_sources, time] for multi-source audio
156
+ - [batch, num_sources, channels, time] for multi-channel multi-source audio
157
+
158
+ Attributes:
159
+ name (str): The name identifier for the loss.
160
+ weight (float): The overall weight multiplier for the loss.
161
+ lambda0 (float): Minimum regularization weight (λ_min).
162
+ delta_lambda (float): Range of extra weight for regularization (Δλ).
163
+ l1snr_eps (float): Epsilon value for the L1SNR component to avoid log(0).
164
+ dbrms_eps (float): Epsilon value for dBRMS calculation to avoid log(0).
165
+ lmin (float): Minimum dBRMS considered non-silent for adaptive weighting.
166
+ use_regularization (bool): Whether to use level-matching regularization.
167
+ If False, only the L1SNR (and optional L1) components are used.
168
+ l1_weight (float): Weight for the L1 loss component. Default 0 (disabled).
169
+ As this increases, the regularization term is also scaled down proportionally.
170
+ When set to 1.0, efficiently computes only L1 loss.
171
+ """
172
+ def __init__(
173
+ self,
174
+ name,
175
+ weight: float = 1.0,
176
+ lambda0: float = 0.1,
177
+ delta_lambda: float = 0.9,
178
+ l1snr_eps: float = 1e-3,
179
+ dbrms_eps: float = 1e-8,
180
+ lmin: float = -60.0,
181
+ use_regularization: bool = True,
182
+ l1_weight: float = 0.0,
183
+ ):
184
+ super().__init__()
185
+ self.name = name
186
+ self.weight = weight
187
+ self.lambda0 = lambda0 # minimum regularization weight
188
+ self.delta_lambda = delta_lambda # range of extra weight
189
+ self.l1snr_eps = l1snr_eps
190
+ self.dbrms_eps = dbrms_eps
191
+ self.lmin = lmin
192
+ self.use_regularization = use_regularization
193
+
194
+ # Validate l1_weight is between 0.0 and 1.0 inclusive
195
+ assert 0.0 <= l1_weight <= 1.0, "l1_weight must be between 0.0 and 1.0 inclusive"
196
+ self.l1_weight = l1_weight
197
+
198
+ # Initialize component losses based on l1_weight
199
+ if self.l1_weight == 1.0:
200
+ # Pure L1 mode - only need L1 loss
201
+ self.l1snr_loss = None
202
+ self.l1_loss = torch.nn.L1Loss()
203
+ else:
204
+ # Standard mode with L1SNR (and optional weighted L1 if l1_weight > 0)
205
+ self.l1snr_loss = L1SNRLoss(
206
+ name="l1_snr",
207
+ weight=1.0, # We'll apply the weight at the end
208
+ eps=l1snr_eps,
209
+ l1_weight=l1_weight,
210
+ )
211
+ self.l1_loss = None
212
+
213
+ @staticmethod
214
+ def compute_adaptive_weight(L_pred, L_true, L_min, lambda0, delta_lambda, R):
215
+ """
216
+ Implements the adaptive weighting of the level-matching regularization term, per arXiv:2501.16171.
217
+ Args:
218
+ L_pred: predicted dBRMS, shape (batch,)
219
+ L_true: reference dBRMS, shape (batch,)
220
+ L_min: minimum dBRMS considered non-silent (float)
221
+ lambda0: minimum weight for regularization
222
+ delta_lambda: range of extra weight for regularization
223
+ R: |L_pred - L_true|, shape (batch,)
224
+ Returns:
225
+ lambda_weight: shape (batch,)
226
+ """
227
+ # Compute eta: 1 if L_true > max(L_pred, L_min), else 0
228
+ max_val = torch.max(L_pred, torch.full_like(L_true, L_min))
229
+ eta = (L_true > max_val).float()
230
+ denom = (L_true - L_min).clamp(min=1e-6)
231
+ clamp_arg = (R / denom).clamp(0.0, 1.0)
232
+ lam = lambda0 + eta * delta_lambda * clamp_arg
233
+ return lam.detach() # Stop-gradient
234
+
235
+ def forward(self, estimates, actuals, *args, **kwargs):
236
+ batch_size = estimates.shape[0]
237
+
238
+ est_source = estimates.reshape(batch_size, -1)
239
+ act_source = actuals.reshape(batch_size, -1)
240
+
241
+ # Pure L1 mode - efficient path that bypasses SNR and regularization
242
+ if self.l1_loss is not None:
243
+ l1_loss = self.l1_loss(est_source, act_source)
244
+ return l1_loss * self.weight
245
+
246
+ # Standard mode with L1SNR, regularization, and optional weighted L1
247
+ # 1. L1SNR reconstruction loss (with L1 component if l1_weight > 0)
248
+ l1snr_loss = self.l1snr_loss(estimates, actuals, *args, **kwargs)
249
+
250
+ # Only compute and apply regularization if enabled
251
+ if self.use_regularization:
252
+ # 2. Level-matching regularization
253
+ L_true = dbrms(act_source, self.dbrms_eps) # (batch,)
254
+ L_pred = dbrms(est_source, self.dbrms_eps) # (batch,)
255
+ R = torch.abs(L_pred - L_true) # (batch,)
256
+
257
+ lambda_weight = self.compute_adaptive_weight(L_pred, L_true, self.lmin, self.lambda0, self.delta_lambda, R) # (batch,)
258
+
259
+ reg_loss = lambda_weight * R
260
+
261
+ # Scale regularization by the same factor as L1SNR loss
262
+ l1snr_weight = 1.0 - self.l1_weight
263
+ total_loss = l1snr_loss + (l1snr_weight * torch.mean(reg_loss))
264
+ else:
265
+ # Skip regularization calculation entirely
266
+ total_loss = l1snr_loss
267
+
268
+ return total_loss * self.weight
269
+
270
+
271
+ class STFTL1SNRDBLoss(torch.nn.Module):
272
+ """
273
+ Implements L1SNR plus adaptive level-matching regularization in the spectrogram domain
274
+ as described in arXiv:2501.16171, with multi-resolution STFT and optional L1 loss component
275
+ to balance "all-or-nothing" behavior.
276
+
277
+ This loss applies the same principles as L1SNRDBLoss but operates in the complex
278
+ spectrogram domain across multiple time-frequency resolutions. For each resolution:
279
+
280
+ 1. L1SNR loss: Computed on the complex STFT representation (real/imaginary parts)
281
+ 2. Level-matching regularization: Applied to STFT magnitudes with adaptive weighting
282
+ 3. Optional L1 loss: Direct L1 penalty on STFT differences
283
+
284
+ Multi-resolution processing helps capture both fine temporal details and frequency
285
+ characteristics. The loss averages results across all valid STFT resolutions.
286
+
287
+ The complete loss structure is similar to L1SNRDBLoss:
288
+ When l1_weight < 1.0: total_loss = l1snr_loss + (1-l1_weight) * spec_reg_coef * mean(reg_loss)
289
+ When l1_weight = 1.0: total_loss = l1_loss (pure L1 in spectrogram domain, bypassing all other computations)
290
+
291
+ When l1_weight=1.0, this loss efficiently switches to a pure L1 loss calculation in the
292
+ spectrogram domain, bypassing all SNR and regularization computations for standard L1 behavior.
293
+ This is useful when you want to avoid the "all-or-nothing" behavior of the SNR-style loss.
294
+
295
+ Input Shape:
296
+ Accepts waveform tensors (time-domain audio) of any shape as long as they are batch-first
297
+ and time-last. Recommended shapes:
298
+ - [batch, time] for single-source audio
299
+ - [batch, num_sources, time] for multi-source audio
300
+ - [batch, num_sources, channels, time] for multi-channel multi-source audio
301
+
302
+ Attributes:
303
+ name (str): The name identifier for the loss.
304
+ weight (float): The overall weight multiplier for the loss.
305
+ lambda0 (float): Minimum regularization weight (λ_min).
306
+ delta_lambda (float): Range of extra weight for regularization (Δλ).
307
+ l1snr_eps (float): Epsilon value for the L1SNR component to avoid log(0).
308
+ dbrms_eps (float): Epsilon value for dBRMS calculation to avoid log(0).
309
+ lmin (float): Minimum dBRMS considered non-silent for adaptive weighting.
310
+ n_ffts (List[int]): List of FFT sizes for multi-resolution STFT analysis.
311
+ hop_lengths (List[int]): List of hop lengths (STFT time steps) for each resolution.
312
+ win_lengths (List[int]): List of window lengths for each resolution.
313
+ window_fn (str): Window function for the STFT ('hann', 'hamming', etc.)
314
+ min_audio_length (int): Minimum audio length required for processing.
315
+ If audio is shorter, returns zero loss to avoid errors.
316
+ use_regularization (bool): Whether to use level-matching regularization.
317
+ If False, only the L1SNR (and optional L1) components are used.
318
+ l1_weight (float): Weight for the L1 loss component. Default 0 (disabled).
319
+ As this increases, the regularization term is also scaled down proportionally.
320
+ When set to 1.0, efficiently computes only L1 loss.
321
+ """
322
+ def __init__(
323
+ self,
324
+ name,
325
+ weight: float = 1.0,
326
+ lambda0: float = 0.1,
327
+ delta_lambda: float = 0.9,
328
+ l1snr_eps: float = 1e-3,
329
+ dbrms_eps: float = 1e-8,
330
+ lmin: float = -60.0,
331
+ n_ffts: List[int] = [512, 1024, 2048],
332
+ hop_lengths: List[int] = [128, 256, 512],
333
+ win_lengths: List[int] = [512, 1024, 2048],
334
+ window_fn: str = 'hann',
335
+ min_audio_length: int = 512,
336
+ use_regularization: bool = False,
337
+ spec_reg_coef: float = 0.1,
338
+ l1_weight: float = 0.0,
339
+ ):
340
+ super().__init__()
341
+ self.name = name
342
+ self.weight = weight
343
+ self.min_audio_length = min_audio_length
344
+
345
+ # Validate STFT parameters
346
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths), "All STFT parameter lists must have the same length"
347
+
348
+ # Store STFT parameters for validation during forward pass
349
+ self.n_ffts = n_ffts
350
+ self.hop_lengths = hop_lengths
351
+ self.win_lengths = win_lengths
352
+ self.window_fn_name = window_fn
353
+
354
+ # Validate window sizes
355
+ for n_fft, win_length in zip(n_ffts, win_lengths):
356
+ assert n_fft >= win_length, f"FFT size ({n_fft}) must be greater than or equal to window length ({win_length})"
357
+
358
+ # Pre-initialize Spectrogram transforms for maximum efficiency
359
+ self.spectrogram_transforms = nn.ModuleList()
360
+ window_fn_callable = getattr(torch, f"{window_fn}_window")
361
+
362
+ for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
363
+ # Create a spectrogram transform for each resolution
364
+ transform = Spectrogram(
365
+ n_fft=n_fft,
366
+ win_length=win_length,
367
+ hop_length=hop_length,
368
+ pad_mode="reflect",
369
+ center=True,
370
+ window_fn=window_fn_callable,
371
+ normalized=True,
372
+ power=None, # This ensures the output is complex
373
+ )
374
+ self.spectrogram_transforms.append(transform)
375
+
376
+ # Parameters for spectrogram domain level-matching
377
+ self.lambda0 = lambda0
378
+ self.delta_lambda = delta_lambda
379
+ self.lmin = lmin
380
+ self.dbrms_eps = dbrms_eps
381
+ self.l1snr_eps = l1snr_eps
382
+
383
+ # Add L1 loss parameters and validate
384
+ assert 0.0 <= l1_weight <= 1.0, "l1_weight must be between 0.0 and 1.0 inclusive"
385
+ self.l1_weight = l1_weight
386
+
387
+ # Flag for pure L1 mode
388
+ self.pure_l1_mode = (self.l1_weight == 1.0)
389
+ # Create L1 loss function for pure L1 mode
390
+ if self.pure_l1_mode:
391
+ self.l1_loss = torch.nn.L1Loss()
392
+ else:
393
+ self.l1_loss = None
394
+
395
+
396
+ # Add this parameter to control regularization
397
+ self.use_regularization = use_regularization
398
+ # Coefficient to scale spectral regularization (disabled by default)
399
+ self.spec_reg_coef = spec_reg_coef
400
+
401
+ # Fallback time-domain loss (used when audio is too short for TF processing)
402
+ self.fallback_time_loss = L1SNRDBLoss(
403
+ name=f"{name}_fallback_time",
404
+ weight=1.0,
405
+ lambda0=self.lambda0,
406
+ delta_lambda=self.delta_lambda,
407
+ l1snr_eps=self.l1snr_eps,
408
+ dbrms_eps=self.dbrms_eps,
409
+ lmin=self.lmin,
410
+ use_regularization=False, # regularizer belongs to TF for this class
411
+ l1_weight=self.l1_weight,
412
+ )
413
+
414
+ # Simplified tracking
415
+ self.nan_inf_counts = {"inputs": 0, "spec_loss": 0}
416
+
417
+ def _compute_complex_spec_l1snr_loss(self, est_spec, act_spec):
418
+ """
419
+ Compute TF-domain loss as per the papers:
420
+ - D1 on real part + D1 on imaginary part, summed.
421
+ - Optional L1 mixing applied symmetrically to Re/Im.
422
+ est_spec, act_spec: complex tensors with shape (B, C, F, T)
423
+ """
424
+ # Ensure same shape (assert to avoid silent mismatches)
425
+ assert est_spec.shape == act_spec.shape, f"Spec shapes must match: {est_spec.shape} vs {act_spec.shape}"
426
+
427
+ # Split real/imag
428
+ est_re, est_im = est_spec.real, est_spec.imag
429
+ act_re, act_im = act_spec.real, act_spec.imag
430
+
431
+ B = est_spec.shape[0]
432
+
433
+ # Flatten to (B, -1)
434
+ est_re = est_re.reshape(B, -1)
435
+ act_re = act_re.reshape(B, -1)
436
+ est_im = est_im.reshape(B, -1)
437
+ act_im = act_im.reshape(B, -1)
438
+
439
+ # L1 errors and refs
440
+ err_re = torch.mean(torch.abs(est_re - act_re), dim=1)
441
+ ref_re = torch.mean(torch.abs(act_re), dim=1)
442
+ err_im = torch.mean(torch.abs(est_im - act_im), dim=1)
443
+ ref_im = torch.mean(torch.abs(act_im), dim=1)
444
+
445
+ # Paper-aligned D1 = 10*log10((||e||_1 + eps)/(||y||_1 + eps))
446
+ d1_re = 10.0 * torch.log10((err_re + self.l1snr_eps) / (ref_re + self.l1snr_eps))
447
+ d1_im = 10.0 * torch.log10((err_im + self.l1snr_eps) / (ref_im + self.l1snr_eps))
448
+ d1_sum = torch.mean(d1_re + d1_im) # mean over batch
449
+
450
+ # Pure L1 mode
451
+ if self.pure_l1_mode:
452
+ l1_re = torch.mean(err_re)
453
+ l1_im = torch.mean(err_im)
454
+ l1_term = 0.5 * (l1_re + l1_im)
455
+ return l1_term
456
+
457
+ # Mixed mode (auto-balanced L1/SNR) with per-batch scaling
458
+ w = float(self.l1_weight)
459
+ if 0.0 < w < 1.0:
460
+ c = 10.0 / math.log(10.0)
461
+ inv_mean_comp = torch.mean(0.5 * (1.0 / (err_re.detach() + self.l1snr_eps) +
462
+ 1.0 / (err_im.detach() + self.l1snr_eps)))
463
+ # w-independent scaling to match typical gradient magnitudes (factor 2.0 for Re/Im symmetry)
464
+ scale_spec = 2.0 * c * inv_mean_comp
465
+ l1_term = 0.5 * (torch.mean(err_re) + torch.mean(err_im)) * scale_spec
466
+
467
+ if getattr(self, "balance_per_sample", False):
468
+ bal_re = c / (err_re.detach() + self.l1snr_eps)
469
+ bal_im = c / (err_im.detach() + self.l1snr_eps)
470
+ l1_term = 0.5 * (torch.mean(err_re * bal_re) + torch.mean(err_im * bal_im))
471
+
472
+ loss = (1.0 - w) * d1_sum + w * l1_term
473
+ return loss
474
+ elif w >= 1.0:
475
+ # Pure L1
476
+ l1_term = 0.5 * (torch.mean(err_re) + torch.mean(err_im))
477
+ return l1_term
478
+ else:
479
+ # Pure SNR (D1)
480
+ return d1_sum
481
+
482
+ def _compute_spec_level_matching(self, est_spec, act_spec):
483
+ """
484
+ Compute the level matching regularization term for a spectrogram.
485
+ """
486
+ batch_size = est_spec.shape[0]
487
+
488
+ # Make sure dimensions match before operations
489
+ if est_spec.shape != act_spec.shape:
490
+ # Resize to match the smaller of the two
491
+ min_freq = min(est_spec.shape[-2], act_spec.shape[-2])
492
+ min_time = min(est_spec.shape[-1], act_spec.shape[-1])
493
+ est_spec = est_spec[..., :min_freq, :min_time]
494
+ act_spec = act_spec[..., :min_freq, :min_time]
495
+
496
+ # For level-matching regularization, we use magnitude information
497
+ est_mag = torch.abs(est_spec)
498
+ act_mag = torch.abs(act_spec)
499
+
500
+ # Reshape once for efficiency
501
+ est_mag_flat = est_mag.reshape(batch_size, -1)
502
+ act_mag_flat = act_mag.reshape(batch_size, -1)
503
+
504
+ # Calculate dB levels
505
+ L_true = dbrms(act_mag_flat, self.dbrms_eps)
506
+ L_pred = dbrms(est_mag_flat, self.dbrms_eps)
507
+
508
+ R = torch.abs(L_pred - L_true)
509
+
510
+ # Use the adaptive weighting function
511
+ lambda_weight = L1SNRDBLoss.compute_adaptive_weight(
512
+ L_pred, L_true, self.lmin, self.lambda0, self.delta_lambda, R
513
+ )
514
+
515
+ return torch.mean(lambda_weight * R)
516
+
517
+ def _validate_audio_length(self, audio_length):
518
+ """
519
+ Validates that the audio is long enough for the STFT parameters.
520
+ """
521
+ if audio_length < self.min_audio_length:
522
+ return False
523
+
524
+ for n_fft, hop_length in zip(self.n_ffts, self.hop_lengths):
525
+ n_frames = (audio_length // hop_length) + 1
526
+ if n_frames < 2:
527
+ return False
528
+
529
+ return True
530
+
531
+ def forward(self, estimates, actuals, *args, **kwargs):
532
+ device = estimates.device
533
+ batch_size = estimates.shape[0]
534
+
535
+ # Basic NaN/Inf handling (simplified)
536
+ if torch.isnan(estimates).any() or torch.isinf(estimates).any() or torch.isnan(actuals).any() or torch.isinf(actuals).any():
537
+ self.nan_inf_counts["inputs"] += 1
538
+ estimates = torch.nan_to_num(estimates, nan=0.0, posinf=1.0, neginf=-1.0)
539
+ actuals = torch.nan_to_num(actuals, nan=0.0, posinf=1.0, neginf=-1.0)
540
+
541
+ est_source = estimates.reshape(batch_size, -1, estimates.shape[-1])
542
+ act_source = actuals.reshape(batch_size, -1, actuals.shape[-1])
543
+
544
+ # Validate audio length
545
+ audio_length = est_source.shape[-1]
546
+ if not self._validate_audio_length(audio_length):
547
+ # Fallback to time-domain L1SNR-style loss instead of zero
548
+ return self.fallback_time_loss(estimates, actuals, *args, **kwargs) * self.weight
549
+
550
+ # Track losses (initialize as tensors on the correct device for stability)
551
+ total_spec_loss = torch.tensor(0.0, device=device)
552
+ total_spec_reg_loss = torch.tensor(0.0, device=device)
553
+ valid_transforms = 0
554
+
555
+ # Ensure transforms are on the correct device
556
+ self.spectrogram_transforms.to(device)
557
+
558
+ # Process each resolution
559
+ for i, transform in enumerate(self.spectrogram_transforms):
560
+ try:
561
+ # Compute spectrograms using pre-initialized transforms
562
+ try:
563
+ est_spec = transform(est_source)
564
+ act_spec = transform(act_source)
565
+ except RuntimeError as e:
566
+ print(f"Error computing spectrogram for resolution {i}: {e}")
567
+ print(f"Parameters: n_fft={self.n_ffts[i]}, hop_length={self.hop_lengths[i]}, win_length={self.win_lengths[i]}")
568
+ continue
569
+
570
+ # Ensure same (B, C, F, T); crop only (F, T) if needed
571
+ if est_spec.shape != act_spec.shape:
572
+ min_f = min(est_spec.shape[-2], act_spec.shape[-2])
573
+ min_t = min(est_spec.shape[-1], act_spec.shape[-1])
574
+ est_spec = est_spec[..., :min_f, :min_t]
575
+ act_spec = act_spec[..., :min_f, :min_t]
576
+
577
+ # Compute complex spectral loss (either L1 or L1SNR based on self.pure_l1_mode)
578
+ try:
579
+ spec_loss = self._compute_complex_spec_l1snr_loss(est_spec, act_spec)
580
+ except RuntimeError as e:
581
+ print(f"Error computing complex spectral loss for resolution {i}: {e}")
582
+ continue
583
+
584
+ # Check for numerical issues
585
+ if torch.isnan(spec_loss) or torch.isinf(spec_loss):
586
+ self.nan_inf_counts["spec_loss"] += 1
587
+ continue
588
+
589
+ # Only compute regularization if not in pure L1 mode and regularization is enabled
590
+ if not self.pure_l1_mode and self.use_regularization:
591
+ try:
592
+ spec_reg_loss = self._compute_spec_level_matching(est_spec, act_spec)
593
+
594
+ # Check for numerical issues
595
+ if torch.isnan(spec_reg_loss) or torch.isinf(spec_reg_loss):
596
+ self.nan_inf_counts["spec_loss"] += 1
597
+ spec_reg_loss = 0.0 # Use zero reg_loss if there are issues
598
+
599
+ # Accumulate regularization loss
600
+ total_spec_reg_loss += spec_reg_loss
601
+ except RuntimeError as e:
602
+ print(f"Error computing spectral level-matching for resolution {i}: {e}")
603
+
604
+ # Accumulate loss
605
+ total_spec_loss += spec_loss
606
+ valid_transforms += 1
607
+
608
+ except RuntimeError as e:
609
+ print(f"Runtime error in spectrogram transform {i}: {e}")
610
+ continue
611
+
612
+ # If all transforms failed, return zero loss
613
+ if valid_transforms == 0:
614
+ print("Warning: All spectrogram transforms failed. Returning zero loss.")
615
+ return torch.tensor(0.0, device=device)
616
+
617
+ # Average losses across valid transforms
618
+ avg_spec_loss = total_spec_loss / valid_transforms
619
+
620
+ # For standard mode, apply regularization if enabled
621
+ if not self.pure_l1_mode and self.use_regularization:
622
+ avg_spec_reg_loss = total_spec_reg_loss / valid_transforms
623
+ # Scale spectral regularization by both (1 - l1_weight) and spec_reg_coef
624
+ l1snr_weight = 1.0 - self.l1_weight
625
+ final_loss = avg_spec_loss + l1snr_weight * (self.spec_reg_coef * avg_spec_reg_loss)
626
+ else:
627
+ final_loss = avg_spec_loss
628
+
629
+ return final_loss * self.weight
630
+
631
+
632
+ class MultiL1SNRDBLoss(torch.nn.Module):
633
+ """
634
+ A modular loss function that combines time-domain and spectrogram-domain L1SNR and
635
+ adaptive level-matching losses, as described in arXiv:2501.16171, with optional
636
+ L1 loss component to balance "all-or-nothing" behavior.
637
+
638
+ This implementation uses separate specialized components:
639
+ - L1SNRDBLoss for time domain processing
640
+ - STFTL1SNRDBLoss for spectrogram domain processing
641
+
642
+ The loss combines time-domain and spectrogram-domain losses:
643
+ Loss = weight * [(1-spec_weight) * time_loss + spec_weight * spec_loss]
644
+
645
+ Where time_loss and spec_loss are computed by L1SNRDBLoss and STFTL1SNRDBLoss respectively,
646
+ each handling their own L1SNR, regularization, and optional L1 components as described
647
+ in their individual docstrings.
648
+
649
+ When l1_weight=1.0, this loss efficiently switches to a pure L1 loss calculation in both
650
+ domains, bypassing all SNR and regularization computations for standard L1 behavior.
651
+ This is useful when you want to avoid the "all-or-nothing" behavior of the SNR-style loss.
652
+
653
+ The regularization components use adaptive weighting based on level differences
654
+ between estimated and target signals, with weighting controlled by lambda0 and delta_lambda.
655
+
656
+ Input Shape:
657
+ Accepts waveform tensors (time-domain audio) of any shape as long as they are batch-first
658
+ and time-last. Recommended shapes:
659
+ - [batch, time] for single-source audio
660
+ - [batch, num_sources, time] for multi-source audio
661
+ - [batch, num_sources, channels, time] for multi-channel multi-source audio
662
+
663
+ Attributes:
664
+ name (str): The name identifier for the loss.
665
+ weight (float): The overall weight multiplier for the loss.
666
+ spec_weight (float): The weight for spectrogram domain loss relative to time domain.
667
+ Default 0.5 (equal weighting). Set higher to emphasize spectral accuracy.
668
+ use_time_regularization (bool): Whether to use level-matching regularization in time domain.
669
+ use_spec_regularization (bool): Whether to use level-matching regularization in spectogram domain.
670
+ l1_weight (float): Weight for the L1 loss component vs the L1SNR+reg components.
671
+ Default 0 (disabled). As this increases, the regularization term is also scaled down.
672
+ When set to 1.0, efficiently computes only L1 loss in both domains.
673
+ lambda0 (float): Minimum regularization weight for both domains.
674
+ delta_lambda (float): Range of extra weight for regularization in both domains.
675
+ time_loss_params (dict): Optional additional parameters to pass to time domain loss.
676
+ spec_loss_params (dict): Optional additional parameters to pass to spectrogram domain loss.
677
+ """
678
+ def __init__(
679
+ self,
680
+ name,
681
+ weight: float = 1.0,
682
+ spec_weight: float = 0.5, # Balance between time and frequency domain
683
+ # L1 component parameters
684
+ l1_weight: float = 0.0, # Weight for the L1 loss component vs (L1SNR + Regularization).
685
+ # Note: Regularization term is also scaled by (1.0 - l1_weight).
686
+ # When set to 1.0, efficiently computes only L1 loss in both domains.
687
+ # auto-balanced mixing used
688
+ # Regularization on/off flags
689
+ use_time_regularization: bool = True,
690
+ use_spec_regularization: bool = False, # likely redundant if already using in time domain
691
+ # Default parameters for both loss components
692
+ lambda0: float = 0.1,
693
+ delta_lambda: float = 0.9,
694
+ l1snr_eps: float = 1e-3,
695
+ dbrms_eps: float = 1e-8,
696
+ lmin: float = -60.0,
697
+ # STFT parameters
698
+ n_ffts: List[int] = [512, 1024, 2048],
699
+ hop_lengths: List[int] = [128, 256, 512],
700
+ win_lengths: List[int] = [512, 1024, 2048],
701
+ window_fn: str = 'hann',
702
+ min_audio_length: int = 512,
703
+ # Allow for separate parameter overrides (e.g. different delta_lambda for time and spec)
704
+ time_loss_params: dict = None,
705
+ spec_loss_params: dict = None,
706
+ ):
707
+ super().__init__()
708
+ self.name = name
709
+ self.weight = weight
710
+ self.spec_weight = spec_weight
711
+
712
+ # Validate l1_weight is in valid range
713
+ assert 0.0 <= l1_weight <= 1.0, "l1_weight must be between 0.0 and 1.0 inclusive"
714
+ self.l1_weight = l1_weight
715
+ self.use_time_regularization = use_time_regularization
716
+ self.use_spec_regularization = use_spec_regularization
717
+
718
+ # Set up default parameters
719
+ default_time_params = {
720
+ "name": f"{name}_time",
721
+ "weight": 1.0, # Will be scaled by the combined loss
722
+ "lambda0": lambda0,
723
+ "delta_lambda": delta_lambda,
724
+ "l1snr_eps": l1snr_eps,
725
+ "dbrms_eps": dbrms_eps,
726
+ "lmin": lmin,
727
+ "l1_weight": l1_weight,
728
+ "use_regularization": use_time_regularization # Apply time domain regularization flag
729
+ }
730
+
731
+ default_spec_params = {
732
+ "name": f"{name}_spec",
733
+ "weight": 1.0, # Will be scaled by the combined loss
734
+ "lambda0": lambda0,
735
+ "delta_lambda": delta_lambda,
736
+ "l1snr_eps": l1snr_eps,
737
+ "dbrms_eps": dbrms_eps,
738
+ "lmin": lmin,
739
+ "n_ffts": n_ffts,
740
+ "hop_lengths": hop_lengths,
741
+ "win_lengths": win_lengths,
742
+ "window_fn": window_fn,
743
+ "min_audio_length": min_audio_length,
744
+ "l1_weight": l1_weight,
745
+
746
+ "use_regularization": use_spec_regularization # Apply spectrogram domain regularization flag
747
+ }
748
+
749
+ # Override with any custom parameters
750
+ if time_loss_params:
751
+ default_time_params.update(time_loss_params)
752
+ if spec_loss_params:
753
+ default_spec_params.update(spec_loss_params)
754
+
755
+ # Create the specialized loss components
756
+ # Note: Component losses handle all optimizations internally based on l1_weight
757
+ # When l1_weight=1.0, they will efficiently bypass SNR and regularization calculations
758
+ self.time_loss = L1SNRDBLoss(**default_time_params)
759
+ self.spec_loss = STFTL1SNRDBLoss(**default_spec_params)
760
+
761
+ # For reference only, indicate if we're in pure L1 mode
762
+ self.pure_l1_mode = (self.l1_weight == 1.0)
763
+
764
+ def forward(self, estimates, actuals, *args, **kwargs):
765
+ """
766
+ Forward pass to compute the combined multi-domain loss.
767
+
768
+ Args:
769
+ estimates: Model output predictions, shape [batch, ...] (batch-first, ..., time-last)
770
+ actuals: Ground truth targets, shape [batch, ...] (batch-first, ..., time-last)
771
+ *args, **kwargs: Additional arguments passed to sub-losses
772
+
773
+ Returns:
774
+ Combined weighted loss from time and spectrogram domains
775
+ """
776
+ # Compute time domain loss
777
+ time_loss = self.time_loss(estimates, actuals, *args, **kwargs)
778
+
779
+ # Compute spectrogram domain loss
780
+ spec_loss = self.spec_loss(estimates, actuals, *args, **kwargs)
781
+
782
+ # Combine with weighting
783
+ combined_loss = (1 - self.spec_weight) * time_loss + self.spec_weight * spec_loss
784
+
785
+ # Apply overall weight
786
+ return combined_loss * self.weight