torch-l1-snr 0.0.5__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,265 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-l1-snr
3
+ Version: 0.1.1
4
+ Summary: L1-SNR loss functions for audio source separation in PyTorch
5
+ Home-page: https://github.com/crlandsc/torch-l1-snr
6
+ Author: Christopher Landschoot
7
+ Author-email: crlandschoot@gmail.com
8
+ License: MIT
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Operating System :: OS Independent
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
21
+ Requires-Python: >=3.8
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch
25
+ Requires-Dist: torchaudio
26
+ Requires-Dist: numpy>=1.21.0
27
+ Dynamic: license-file
28
+
29
+ ![torch-l1-snr-logo](https://raw.githubusercontent.com/crlandsc/torch-l1-snr/main/images/logo.png)
30
+
31
+ [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/stargazers) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/)
32
+
33
+
34
+ L1 Signal-to-Noise Ratio (SNR) loss functions for audio source separation in PyTorch. This package provides four loss functions that combine implementations from recent academic research with novel extensions, designed to integrate easily into any audio separation or enhancement training pipeline.
35
+
36
+ The core [`L1SNRLoss`](#example-l1snrloss-time-domain) is based on the loss function described in [[1]](https://arxiv.org/abs/2309.02539). [`L1SNRDBLoss`](#example-l1snrdbloss-time-domain-with-regularization) adds adaptive level-matching regularization proposed in [[2]](https://arxiv.org/abs/2501.16171). [`STFTL1SNRDBLoss`](#example-stftl1snrdbloss-spectrogram-domain) provides a spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)). [`MultiL1SNRDBLoss`](#example-multil1snrdbloss-combined-time--spectrogram) combines time-domain and spectrogram-domain losses into a single loss function for convenience and flexibility. Optional novel algorithmic extensions have also been included (such as multi-resolution STFT averaging, spectrogram-domain adaptation of the level-matching regularizer from [[2]](https://arxiv.org/abs/2501.16171), time vs. spectrogram loss balancing, and blending of standard L1 loss) with the goal of increasing flexibility for improved performance depending on the specific task.
37
+
38
+ ## Quick Start
39
+
40
+ ```python
41
+ import torch
42
+ from torch_l1_snr import MultiL1SNRDBLoss
43
+
44
+ # Create combined time + spectrogram domain loss function with adaptive regularization
45
+ loss_fn = MultiL1SNRDBLoss(name="multi_l1_snr_db_loss")
46
+
47
+ # Calculate loss between model output and target
48
+ estimates = torch.randn(4, 32000) # (batch, samples)
49
+ targets = torch.randn(4, 32000)
50
+ loss = loss_fn(estimates, targets)
51
+ loss.backward()
52
+ ```
53
+
54
+ ## Loss Functions
55
+
56
+ - [**Time-Domain L1SNR Loss**](#example-l1snrloss-time-domain): A basic, time-domain L1-SNR loss, based on [[1]](https://arxiv.org/abs/2309.02539).
57
+ - [**Regularized Time-Domain L1SNRDBLoss**](#example-l1snrdbloss-time-domain-with-regularization): An extension of the L1SNR loss with adaptive level-matching regularization from [[2]](https://arxiv.org/abs/2501.16171), plus an optional L1 loss component.
58
+ - [**Multi-Resolution STFT L1SNRDBLoss**](#example-stftl1snrdbloss-spectrogram-domain): A spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)), computed over multiple STFT resolutions, with optional spectrogram-domain level-matching regularization inspired by its time-domain counterpart in [[2]](https://arxiv.org/abs/2501.16171).
59
+ - [**Combined Multi-Domain Loss**](#example-multil1snrdbloss-combined-time--spectrogram): `MultiL1SNRDBLoss` combines time-domain and spectrogram-domain losses into a single, weighted objective function.
60
+
61
+ ## Additional Features
62
+
63
+ - **L1 Loss Blending**: The `l1_weight` parameter allows mixing between L1SNR and standard L1 loss, softening the ["all-or-nothing" behavior](#all-or-nothing-behavior-and-l1_weight) of pure SNR losses for more nuanced separation.
64
+ - **Multi-Resolution STFT Averaging** - Extending an STFT-based loss to multiple resolutions is common in recent literature.
65
+ - **Spectrogram-Domain Adaptation of Level-Matching Regularizer [[2]](https://arxiv.org/abs/2501.16171)** - Options to extend adaptive level-matching regularization to spectrogram-domain. Experimental and not used by default.
66
+ - **Time vs. Spectrogram Loss Balancing.** - Allows fine-tuning the relative contribution of time-domain and spectrogram-domain losses in `MultiL1SNRDBLoss` via the `spec_weight` parameter.
67
+ - **Numerical Stability**: Robust handling of `NaN` and `inf` values during training.
68
+ - **Short Audio Fallback**: Graceful fallback to time-domain loss when audio is too short for STFT processing.
69
+
70
+ ## Installation
71
+
72
+ ### Install from PyPI
73
+
74
+ ```bash
75
+ pip install torch-l1-snr
76
+ ```
77
+
78
+ ### Install from GitHub
79
+
80
+ ```bash
81
+ pip install git+https://github.com/crlandsc/torch-l1-snr.git
82
+ ```
83
+
84
+ Or, you can clone the repository and install it in editable mode for development:
85
+
86
+ ```bash
87
+ git clone https://github.com/crlandsc/torch-l1-snr.git
88
+ cd torch-l1-snr
89
+ pip install -e .
90
+ ```
91
+
92
+ ## Dependencies
93
+
94
+ - [PyTorch](https://pytorch.org/)
95
+ - [torchaudio](https://pytorch.org/audio/stable/index.html)
96
+ - [NumPy](https://numpy.org/) (>=1.21.0)
97
+
98
+ ## Supported Tensor Shapes
99
+
100
+ All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)`, `(batch, channels, samples)`, or `(batch, num_sources, channels, samples)`. For the time-domain losses, any 3D/4D input is flattened across all non-batch dimensions (e.g., sources, channels, and samples) into a single vector per example before the loss is computed. For the spectrogram-domain loss, inputs are reshaped to `(batch, streams, samples)` by flattening all non-time dimensions into a “stream” dimension (e.g., `streams = channels` or `streams = num_sources * channels`), and a separate STFT is computed for each stream.
101
+
102
+ ## Usage
103
+
104
+ The loss functions can be imported directly from the `torch_l1_snr` package.
105
+
106
+ ### `L1SNRLoss` (Time Domain)
107
+
108
+ The simplest loss function - pure L1SNR without regularization.
109
+
110
+ ```python
111
+ import torch
112
+ from torch_l1_snr import L1SNRLoss
113
+
114
+ # Create dummy audio signals
115
+ estimates = torch.randn(4, 2, 44100) # Batch of 4, stereo, 44100 samples
116
+ actuals = torch.randn(4, 2, 44100)
117
+
118
+ # Basic L1SNR loss
119
+ loss_fn = L1SNRLoss(name="l1_snr_loss")
120
+
121
+ # Calculate loss
122
+ loss = loss_fn(estimates, actuals)
123
+ loss.backward()
124
+
125
+ print(f"L1SNRLoss: {loss.item()}")
126
+ ```
127
+
128
+ ### `L1SNRDBLoss` (Time Domain with Regularization)
129
+
130
+ Adds adaptive level-matching regularization to prevent silence collapse.
131
+
132
+ ```python
133
+ import torch
134
+ from torch_l1_snr import L1SNRDBLoss
135
+
136
+ # Create dummy audio signals
137
+ estimates = torch.randn(4, 2, 44100) # Batch of 4, stereo, 44100 samples
138
+ actuals = torch.randn(4, 2, 44100)
139
+
140
+ # Initialize the loss function with regularization enabled
141
+ # l1_weight=0.1 blends 90% L1SNR+Regularization with 10% L1 loss
142
+ loss_fn = L1SNRDBLoss(
143
+ name="l1_snr_db_loss",
144
+ use_regularization=True, # Enable adaptive level-matching regularization
145
+ l1_weight=0.1 # 10% L1 loss, 90% L1SNR + regularization
146
+ )
147
+
148
+ # Calculate loss
149
+ loss = loss_fn(estimates, actuals)
150
+ loss.backward()
151
+
152
+ print(f"L1SNRDBLoss: {loss.item()}")
153
+ ```
154
+
155
+ ### `STFTL1SNRDBLoss` (Spectrogram Domain)
156
+
157
+ Computes L1SNR loss across multiple STFT resolutions.
158
+
159
+ ```python
160
+ import torch
161
+ from torch_l1_snr import STFTL1SNRDBLoss
162
+
163
+ # Create dummy audio signals
164
+ estimates = torch.randn(4, 2, 44100) # Batch of 4, stereo, 44100 samples
165
+ actuals = torch.randn(4, 2, 44100)
166
+
167
+ # Initialize the loss function without regularization or traditional L1
168
+ # Uses multiple STFT resolutions by default: [512, 1024, 2048] FFT sizes
169
+ loss_fn = STFTL1SNRDBLoss(
170
+ name="stft_l1_snr_db_loss",
171
+ l1_weight=0.0 # Pure L1SNR (no regularization, no L1)
172
+ )
173
+
174
+ # Calculate loss
175
+ loss = loss_fn(estimates, actuals)
176
+ loss.backward()
177
+
178
+ print(f"STFTL1SNRDBLoss: {loss.item()}")
179
+ ```
180
+
181
+ ### `MultiL1SNRDBLoss` (Combined Time + Spectrogram)
182
+
183
+ Combines time-domain and spectrogram-domain losses into a single weighted objective.
184
+
185
+ ```python
186
+ import torch
187
+ from torch_l1_snr import MultiL1SNRDBLoss
188
+
189
+ # Create dummy audio signals
190
+ estimates = torch.randn(4, 2, 44100) # Batch of 4, stereo, 44100 samples
191
+ actuals = torch.randn(4, 2, 44100)
192
+
193
+ # Initialize the multi-domain loss function
194
+ loss_fn = MultiL1SNRDBLoss(
195
+ name="multi_l1_snr_db_loss",
196
+ weight=1.0, # Overall weight for this loss
197
+ spec_weight=0.6, # 60% spectrogram loss, 40% time-domain loss
198
+ l1_weight=0.1, # Use 10% L1, 90% L1SNR+Reg in both domains
199
+ use_time_regularization=True, # Enable regularization in time domain
200
+ use_spec_regularization=False # Disable regularization in spec domain
201
+ )
202
+
203
+ # Calculate loss
204
+ loss = loss_fn(estimates, actuals)
205
+ print(f"Multi-domain Loss: {loss.item()}")
206
+ ```
207
+
208
+ ## Motivation
209
+
210
+ The goal of these loss functions is to provide a perceptually-informed and robust alternative to common audio losses like L1, L2 (MSE), and SI-SDR for training audio source separation models.
211
+
212
+ - **Robustness**: The L1 norm is less sensitive to large outliers than the L2 norm, making it more suitable for audio signals which can have sharp transients.
213
+ - **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
214
+ - **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
215
+
216
+ This package is motivated by, and largely follows, the objectives and regularizers described in the cited papers ([1–3]). Several novel algorithmic extensions have been included with the goal of increasing flexibility for improved performance depending on the specific task.
217
+
218
+ ### Level-Matching Regularization
219
+
220
+ A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [[2]](https://arxiv.org/abs/2501.16171). This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
221
+
222
+ ### Multi-Resolution Spectrogram Analysis
223
+
224
+ The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency (spectrogram) resolutions. While not mentioned in the cited papers, by analyzing the signal with *multiple different* STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts - from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training. Using multiple resolutions for an STFT loss is common among many recent source separation works, such as the [Band-Split RoPE Transformer](https://arxiv.org/abs/2309.02612).
225
+
226
+ ### "All-or-Nothing" Behavior and `l1_weight`
227
+
228
+ A characteristic of these SNR-style losses that I experienced in many training experiments is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources (e.g. drums vs vocals), as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept. This poses a tradeoff for sources that may share greater similarities (e.g. speech vs singing vocals).
229
+
230
+ While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
231
+
232
+ While this can potentially reduce the "cleanliness" of separations and slightly harm metrics like SDR, I found that re-introducing some standard L1 loss allows for slightly more "smearing" of sound between sources to mask large errors and be more perceptually acceptable for sources with many similarities. I have no hard numbers to report on this yet, just my experience. So I recommend starting with no standard L1 mixed in (`l1_weight=0.0`), and then slowly increasing from there based on your needs.
233
+
234
+ - `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
235
+ - `l1_weight=1.0`: Pure standard L1 loss.
236
+ - `0.0 < l1_weight < 1.0`: A weighted combination of the two.
237
+
238
+ The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
239
+
240
+ **Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), the implementation automatically scales the L1 component to approximately match the gradient magnitudes of the L1SNR component. This helps maintain stable training without manual tuning.
241
+
242
+ ## Limitations
243
+
244
+ - The L1SNR loss is not scale-invariant. Unlike SI-SNR, it requires the model's output to be correctly scaled relative to the target.
245
+ - While the dB scaling and regularization are psychoacoustically motivated, the loss does not model more complex perceptual phenomena like auditory masking.
246
+
247
+ ## Contributing
248
+
249
+ Contributions are welcome! Please open an issue or submit a pull request if you have any bug fixes, improvements, or new features to suggest.
250
+
251
+ ## License
252
+
253
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
254
+
255
+ ## Acknowledgments
256
+
257
+ The loss functions implemented here are largely based on the work of the authors of the referenced papers. Thank you for your research!
258
+
259
+ ## References
260
+
261
+ [1] K. N. Watcharasupat, C.-W. Wu, Y. Ding, I. Orife, A. J. Hipple, P. A. Williams, S. Kramer, A. Lerch, and W. Wolcott, "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation," IEEE Open Journal of Signal Processing, 2023. [arXiv:2309.02539](https://arxiv.org/abs/2309.02539)
262
+
263
+ [2] K. N. Watcharasupat and A. Lerch, "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries," [arXiv:2501.16171](https://arxiv.org/abs/2501.16171).
264
+
265
+ [3] K. N. Watcharasupat and A. Lerch, "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems," Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024. [arXiv:2406.18747](https://arxiv.org/abs/2406.18747)
@@ -0,0 +1,237 @@
1
+ ![torch-l1-snr-logo](https://raw.githubusercontent.com/crlandsc/torch-l1-snr/main/images/logo.png)
2
+
3
+ [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/stargazers) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/)
4
+
5
+
6
+ L1 Signal-to-Noise Ratio (SNR) loss functions for audio source separation in PyTorch. This package provides four loss functions that combine implementations from recent academic research with novel extensions, designed to integrate easily into any audio separation or enhancement training pipeline.
7
+
8
+ The core [`L1SNRLoss`](#example-l1snrloss-time-domain) is based on the loss function described in [[1]](https://arxiv.org/abs/2309.02539). [`L1SNRDBLoss`](#example-l1snrdbloss-time-domain-with-regularization) adds adaptive level-matching regularization proposed in [[2]](https://arxiv.org/abs/2501.16171). [`STFTL1SNRDBLoss`](#example-stftl1snrdbloss-spectrogram-domain) provides a spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)). [`MultiL1SNRDBLoss`](#example-multil1snrdbloss-combined-time--spectrogram) combines time-domain and spectrogram-domain losses into a single loss function for convenience and flexibility. Optional novel algorithmic extensions have also been included (such as multi-resolution STFT averaging, spectrogram-domain adaptation of the level-matching regularizer from [[2]](https://arxiv.org/abs/2501.16171), time vs. spectrogram loss balancing, and blending of standard L1 loss) with the goal of increasing flexibility for improved performance depending on the specific task.
9
+
10
+ ## Quick Start
11
+
12
+ ```python
13
+ import torch
14
+ from torch_l1_snr import MultiL1SNRDBLoss
15
+
16
+ # Create combined time + spectrogram domain loss function with adaptive regularization
17
+ loss_fn = MultiL1SNRDBLoss(name="multi_l1_snr_db_loss")
18
+
19
+ # Calculate loss between model output and target
20
+ estimates = torch.randn(4, 32000) # (batch, samples)
21
+ targets = torch.randn(4, 32000)
22
+ loss = loss_fn(estimates, targets)
23
+ loss.backward()
24
+ ```
25
+
26
+ ## Loss Functions
27
+
28
+ - [**Time-Domain L1SNR Loss**](#example-l1snrloss-time-domain): A basic, time-domain L1-SNR loss, based on [[1]](https://arxiv.org/abs/2309.02539).
29
+ - [**Regularized Time-Domain L1SNRDBLoss**](#example-l1snrdbloss-time-domain-with-regularization): An extension of the L1SNR loss with adaptive level-matching regularization from [[2]](https://arxiv.org/abs/2501.16171), plus an optional L1 loss component.
30
+ - [**Multi-Resolution STFT L1SNRDBLoss**](#example-stftl1snrdbloss-spectrogram-domain): A spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)), computed over multiple STFT resolutions, with optional spectrogram-domain level-matching regularization inspired by its time-domain counterpart in [[2]](https://arxiv.org/abs/2501.16171).
31
+ - [**Combined Multi-Domain Loss**](#example-multil1snrdbloss-combined-time--spectrogram): `MultiL1SNRDBLoss` combines time-domain and spectrogram-domain losses into a single, weighted objective function.
32
+
33
+ ## Additional Features
34
+
35
+ - **L1 Loss Blending**: The `l1_weight` parameter allows mixing between L1SNR and standard L1 loss, softening the ["all-or-nothing" behavior](#all-or-nothing-behavior-and-l1_weight) of pure SNR losses for more nuanced separation.
36
+ - **Multi-Resolution STFT Averaging** - Extending an STFT-based loss to multiple resolutions is common in recent literature.
37
+ - **Spectrogram-Domain Adaptation of Level-Matching Regularizer [[2]](https://arxiv.org/abs/2501.16171)** - Options to extend adaptive level-matching regularization to spectrogram-domain. Experimental and not used by default.
38
+ - **Time vs. Spectrogram Loss Balancing.** - Allows fine-tuning the relative contribution of time-domain and spectrogram-domain losses in `MultiL1SNRDBLoss` via the `spec_weight` parameter.
39
+ - **Numerical Stability**: Robust handling of `NaN` and `inf` values during training.
40
+ - **Short Audio Fallback**: Graceful fallback to time-domain loss when audio is too short for STFT processing.
41
+
42
+ ## Installation
43
+
44
+ ### Install from PyPI
45
+
46
+ ```bash
47
+ pip install torch-l1-snr
48
+ ```
49
+
50
+ ### Install from GitHub
51
+
52
+ ```bash
53
+ pip install git+https://github.com/crlandsc/torch-l1-snr.git
54
+ ```
55
+
56
+ Or, you can clone the repository and install it in editable mode for development:
57
+
58
+ ```bash
59
+ git clone https://github.com/crlandsc/torch-l1-snr.git
60
+ cd torch-l1-snr
61
+ pip install -e .
62
+ ```
63
+
64
+ ## Dependencies
65
+
66
+ - [PyTorch](https://pytorch.org/)
67
+ - [torchaudio](https://pytorch.org/audio/stable/index.html)
68
+ - [NumPy](https://numpy.org/) (>=1.21.0)
69
+
70
+ ## Supported Tensor Shapes
71
+
72
+ All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)`, `(batch, channels, samples)`, or `(batch, num_sources, channels, samples)`. For the time-domain losses, any 3D/4D input is flattened across all non-batch dimensions (e.g., sources, channels, and samples) into a single vector per example before the loss is computed. For the spectrogram-domain loss, inputs are reshaped to `(batch, streams, samples)` by flattening all non-time dimensions into a “stream” dimension (e.g., `streams = channels` or `streams = num_sources * channels`), and a separate STFT is computed for each stream.
73
+
74
+ ## Usage
75
+
76
+ The loss functions can be imported directly from the `torch_l1_snr` package.
77
+
78
+ ### `L1SNRLoss` (Time Domain)
79
+
80
+ The simplest loss function - pure L1SNR without regularization.
81
+
82
+ ```python
83
+ import torch
84
+ from torch_l1_snr import L1SNRLoss
85
+
86
+ # Create dummy audio signals
87
+ estimates = torch.randn(4, 2, 44100) # Batch of 4, stereo, 44100 samples
88
+ actuals = torch.randn(4, 2, 44100)
89
+
90
+ # Basic L1SNR loss
91
+ loss_fn = L1SNRLoss(name="l1_snr_loss")
92
+
93
+ # Calculate loss
94
+ loss = loss_fn(estimates, actuals)
95
+ loss.backward()
96
+
97
+ print(f"L1SNRLoss: {loss.item()}")
98
+ ```
99
+
100
+ ### `L1SNRDBLoss` (Time Domain with Regularization)
101
+
102
+ Adds adaptive level-matching regularization to prevent silence collapse.
103
+
104
+ ```python
105
+ import torch
106
+ from torch_l1_snr import L1SNRDBLoss
107
+
108
+ # Create dummy audio signals
109
+ estimates = torch.randn(4, 2, 44100) # Batch of 4, stereo, 44100 samples
110
+ actuals = torch.randn(4, 2, 44100)
111
+
112
+ # Initialize the loss function with regularization enabled
113
+ # l1_weight=0.1 blends 90% L1SNR+Regularization with 10% L1 loss
114
+ loss_fn = L1SNRDBLoss(
115
+ name="l1_snr_db_loss",
116
+ use_regularization=True, # Enable adaptive level-matching regularization
117
+ l1_weight=0.1 # 10% L1 loss, 90% L1SNR + regularization
118
+ )
119
+
120
+ # Calculate loss
121
+ loss = loss_fn(estimates, actuals)
122
+ loss.backward()
123
+
124
+ print(f"L1SNRDBLoss: {loss.item()}")
125
+ ```
126
+
127
+ ### `STFTL1SNRDBLoss` (Spectrogram Domain)
128
+
129
+ Computes L1SNR loss across multiple STFT resolutions.
130
+
131
+ ```python
132
+ import torch
133
+ from torch_l1_snr import STFTL1SNRDBLoss
134
+
135
+ # Create dummy audio signals
136
+ estimates = torch.randn(4, 2, 44100) # Batch of 4, stereo, 44100 samples
137
+ actuals = torch.randn(4, 2, 44100)
138
+
139
+ # Initialize the loss function without regularization or traditional L1
140
+ # Uses multiple STFT resolutions by default: [512, 1024, 2048] FFT sizes
141
+ loss_fn = STFTL1SNRDBLoss(
142
+ name="stft_l1_snr_db_loss",
143
+ l1_weight=0.0 # Pure L1SNR (no regularization, no L1)
144
+ )
145
+
146
+ # Calculate loss
147
+ loss = loss_fn(estimates, actuals)
148
+ loss.backward()
149
+
150
+ print(f"STFTL1SNRDBLoss: {loss.item()}")
151
+ ```
152
+
153
+ ### `MultiL1SNRDBLoss` (Combined Time + Spectrogram)
154
+
155
+ Combines time-domain and spectrogram-domain losses into a single weighted objective.
156
+
157
+ ```python
158
+ import torch
159
+ from torch_l1_snr import MultiL1SNRDBLoss
160
+
161
+ # Create dummy audio signals
162
+ estimates = torch.randn(4, 2, 44100) # Batch of 4, stereo, 44100 samples
163
+ actuals = torch.randn(4, 2, 44100)
164
+
165
+ # Initialize the multi-domain loss function
166
+ loss_fn = MultiL1SNRDBLoss(
167
+ name="multi_l1_snr_db_loss",
168
+ weight=1.0, # Overall weight for this loss
169
+ spec_weight=0.6, # 60% spectrogram loss, 40% time-domain loss
170
+ l1_weight=0.1, # Use 10% L1, 90% L1SNR+Reg in both domains
171
+ use_time_regularization=True, # Enable regularization in time domain
172
+ use_spec_regularization=False # Disable regularization in spec domain
173
+ )
174
+
175
+ # Calculate loss
176
+ loss = loss_fn(estimates, actuals)
177
+ print(f"Multi-domain Loss: {loss.item()}")
178
+ ```
179
+
180
+ ## Motivation
181
+
182
+ The goal of these loss functions is to provide a perceptually-informed and robust alternative to common audio losses like L1, L2 (MSE), and SI-SDR for training audio source separation models.
183
+
184
+ - **Robustness**: The L1 norm is less sensitive to large outliers than the L2 norm, making it more suitable for audio signals which can have sharp transients.
185
+ - **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
186
+ - **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
187
+
188
+ This package is motivated by, and largely follows, the objectives and regularizers described in the cited papers ([1–3]). Several novel algorithmic extensions have been included with the goal of increasing flexibility for improved performance depending on the specific task.
189
+
190
+ ### Level-Matching Regularization
191
+
192
+ A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [[2]](https://arxiv.org/abs/2501.16171). This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
193
+
194
+ ### Multi-Resolution Spectrogram Analysis
195
+
196
+ The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency (spectrogram) resolutions. While not mentioned in the cited papers, by analyzing the signal with *multiple different* STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts - from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training. Using multiple resolutions for an STFT loss is common among many recent source separation works, such as the [Band-Split RoPE Transformer](https://arxiv.org/abs/2309.02612).
197
+
198
+ ### "All-or-Nothing" Behavior and `l1_weight`
199
+
200
+ A characteristic of these SNR-style losses that I experienced in many training experiments is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources (e.g. drums vs vocals), as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept. This poses a tradeoff for sources that may share greater similarities (e.g. speech vs singing vocals).
201
+
202
+ While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
203
+
204
+ While this can potentially reduce the "cleanliness" of separations and slightly harm metrics like SDR, I found that re-introducing some standard L1 loss allows for slightly more "smearing" of sound between sources to mask large errors and be more perceptually acceptable for sources with many similarities. I have no hard numbers to report on this yet, just my experience. So I recommend starting with no standard L1 mixed in (`l1_weight=0.0`), and then slowly increasing from there based on your needs.
205
+
206
+ - `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
207
+ - `l1_weight=1.0`: Pure standard L1 loss.
208
+ - `0.0 < l1_weight < 1.0`: A weighted combination of the two.
209
+
210
+ The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
211
+
212
+ **Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), the implementation automatically scales the L1 component to approximately match the gradient magnitudes of the L1SNR component. This helps maintain stable training without manual tuning.
213
+
214
+ ## Limitations
215
+
216
+ - The L1SNR loss is not scale-invariant. Unlike SI-SNR, it requires the model's output to be correctly scaled relative to the target.
217
+ - While the dB scaling and regularization are psychoacoustically motivated, the loss does not model more complex perceptual phenomena like auditory masking.
218
+
219
+ ## Contributing
220
+
221
+ Contributions are welcome! Please open an issue or submit a pull request if you have any bug fixes, improvements, or new features to suggest.
222
+
223
+ ## License
224
+
225
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
226
+
227
+ ## Acknowledgments
228
+
229
+ The loss functions implemented here are largely based on the work of the authors of the referenced papers. Thank you for your research!
230
+
231
+ ## References
232
+
233
+ [1] K. N. Watcharasupat, C.-W. Wu, Y. Ding, I. Orife, A. J. Hipple, P. A. Williams, S. Kramer, A. Lerch, and W. Wolcott, "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation," IEEE Open Journal of Signal Processing, 2023. [arXiv:2309.02539](https://arxiv.org/abs/2309.02539)
234
+
235
+ [2] K. N. Watcharasupat and A. Lerch, "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries," [arXiv:2501.16171](https://arxiv.org/abs/2501.16171).
236
+
237
+ [3] K. N. Watcharasupat and A. Lerch, "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems," Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024. [arXiv:2406.18747](https://arxiv.org/abs/2406.18747)
@@ -1,7 +1,7 @@
1
1
  [metadata]
2
2
  name = torch-l1-snr
3
- version = attr: torch_l1snr.__version__
4
- author = Christopher Landscaping
3
+ version = attr: torch_l1_snr.__version__
4
+ author = Christopher Landschoot
5
5
  author_email = crlandschoot@gmail.com
6
6
  description = L1-SNR loss functions for audio source separation in PyTorch
7
7
  long_description = file: README.md
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import pytest
3
3
  from typing import Optional
4
- from torch_l1snr import (
4
+ from torch_l1_snr import (
5
5
  dbrms,
6
6
  L1SNRLoss,
7
7
  L1SNRDBLoss,
@@ -14,4 +14,4 @@ __all__ = [
14
14
  "MultiL1SNRDBLoss",
15
15
  ]
16
16
 
17
- __version__ = "0.0.5"
17
+ __version__ = "0.1.1"
@@ -16,6 +16,8 @@
16
16
  # Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024
17
17
  # arXiv:2406.18747
18
18
 
19
+ import warnings
20
+
19
21
  import torch
20
22
  import torch.nn as nn
21
23
  from torchaudio.transforms import Spectrogram
@@ -107,20 +109,6 @@ class L1SNRLoss(torch.nn.Module):
107
109
  scale_time = c * inv_mean
108
110
  l1_term = torch.mean(l1_error) * scale_time
109
111
 
110
- if getattr(self, "balance_per_sample", False):
111
- # per-sample w-independent scaling
112
- bal = c / (l1_error.detach() + self.eps)
113
- l1_term = torch.mean(l1_error * bal)
114
-
115
- if getattr(self, "debug_balance", False):
116
- g_d1 = (1.0 - w) * c * inv_mean
117
- if getattr(self, "balance_per_sample", False):
118
- g_l1 = w * torch.mean(c / (l1_error.detach() + self.eps))
119
- else:
120
- g_l1 = w * c * inv_mean
121
- ratio = (g_l1 / (g_d1 + 1e-12)).item()
122
- setattr(self, "last_balance_ratio", ratio)
123
-
124
112
  loss = (1.0 - w) * l1snr_loss + w * l1_term
125
113
  return loss * self.weight
126
114
 
@@ -464,11 +452,6 @@ class STFTL1SNRDBLoss(torch.nn.Module):
464
452
  scale_spec = 2.0 * c * inv_mean_comp
465
453
  l1_term = 0.5 * (torch.mean(err_re) + torch.mean(err_im)) * scale_spec
466
454
 
467
- if getattr(self, "balance_per_sample", False):
468
- bal_re = c / (err_re.detach() + self.l1snr_eps)
469
- bal_im = c / (err_im.detach() + self.l1snr_eps)
470
- l1_term = 0.5 * (torch.mean(err_re * bal_re) + torch.mean(err_im * bal_im))
471
-
472
455
  loss = (1.0 - w) * d1_sum + w * l1_term
473
456
  return loss
474
457
  elif w >= 1.0:
@@ -563,8 +546,10 @@ class STFTL1SNRDBLoss(torch.nn.Module):
563
546
  est_spec = transform(est_source)
564
547
  act_spec = transform(act_source)
565
548
  except RuntimeError as e:
566
- print(f"Error computing spectrogram for resolution {i}: {e}")
567
- print(f"Parameters: n_fft={self.n_ffts[i]}, hop_length={self.hop_lengths[i]}, win_length={self.win_lengths[i]}")
549
+ warnings.warn(
550
+ f"Error computing spectrogram for resolution {i}: {e}. "
551
+ f"Parameters: n_fft={self.n_ffts[i]}, hop_length={self.hop_lengths[i]}, win_length={self.win_lengths[i]}"
552
+ )
568
553
  continue
569
554
 
570
555
  # Ensure same (B, C, F, T); crop only (F, T) if needed
@@ -578,7 +563,7 @@ class STFTL1SNRDBLoss(torch.nn.Module):
578
563
  try:
579
564
  spec_loss = self._compute_complex_spec_l1snr_loss(est_spec, act_spec)
580
565
  except RuntimeError as e:
581
- print(f"Error computing complex spectral loss for resolution {i}: {e}")
566
+ warnings.warn(f"Error computing complex spectral loss for resolution {i}: {e}")
582
567
  continue
583
568
 
584
569
  # Check for numerical issues
@@ -599,19 +584,19 @@ class STFTL1SNRDBLoss(torch.nn.Module):
599
584
  # Accumulate regularization loss
600
585
  total_spec_reg_loss += spec_reg_loss
601
586
  except RuntimeError as e:
602
- print(f"Error computing spectral level-matching for resolution {i}: {e}")
587
+ warnings.warn(f"Error computing spectral level-matching for resolution {i}: {e}")
603
588
 
604
589
  # Accumulate loss
605
590
  total_spec_loss += spec_loss
606
591
  valid_transforms += 1
607
592
 
608
593
  except RuntimeError as e:
609
- print(f"Runtime error in spectrogram transform {i}: {e}")
594
+ warnings.warn(f"Runtime error in spectrogram transform {i}: {e}")
610
595
  continue
611
596
 
612
597
  # If all transforms failed, return zero loss
613
598
  if valid_transforms == 0:
614
- print("Warning: All spectrogram transforms failed. Returning zero loss.")
599
+ warnings.warn("All spectrogram transforms failed. Returning zero loss.")
615
600
  return torch.tensor(0.0, device=device)
616
601
 
617
602
  # Average losses across valid transforms