torch-l1-snr 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-l1-snr
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: L1-SNR loss functions for audio source separation in PyTorch
5
5
  Home-page: https://github.com/crlandsc/torch-l1-snr
6
6
  Author: Christopher Landschoot
@@ -28,11 +28,12 @@ Dynamic: license-file
28
28
 
29
29
  ![torch-l1-snr-logo](https://raw.githubusercontent.com/crlandsc/torch-l1-snr/main/images/logo.png)
30
30
 
31
- [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/stargazers)
31
+ [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/stargazers) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/)
32
32
 
33
- L1 Signal-to-Noise Ratio (SNR) loss functions for audio source separation in PyTorch. This package provides four loss functions that combine implementations from recent academic research with novel extensions, designed to integrate easily into any audio separation training pipeline.
34
33
 
35
- The core `L1SNRLoss` is based on the loss function described in [[1]](https://arxiv.org/abs/2309.02539), while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [[2]](https://arxiv.org/abs/2501.16171). `MultiL1SNRDBLoss` combines both time-domain and spectrogram-domain losses into a single loss function for convenience and flexibility.
34
+ L1 Signal-to-Noise Ratio (SNR) loss functions for audio source separation in PyTorch. This package provides four loss functions that combine implementations from recent academic research with novel extensions, designed to integrate easily into any audio separation or enhancement training pipeline.
35
+
36
+ The core [`L1SNRLoss`](#example-l1snrloss-time-domain) is based on the loss function described in [[1]](https://arxiv.org/abs/2309.02539). [`L1SNRDBLoss`](#example-l1snrdbloss-time-domain-with-regularization) adds adaptive level-matching regularization proposed in [[2]](https://arxiv.org/abs/2501.16171). [`STFTL1SNRDBLoss`](#example-stftl1snrdbloss-spectrogram-domain) provides a spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)). [`MultiL1SNRDBLoss`](#example-multil1snrdbloss-combined-time--spectrogram) combines time-domain and spectrogram-domain losses into a single loss function for convenience and flexibility. Optional novel algorithmic extensions have also been included (such as multi-resolution STFT averaging, spectrogram-domain adaptation of the level-matching regularizer from [[2]](https://arxiv.org/abs/2501.16171), time vs. spectrogram loss balancing, and blending of standard L1 loss) with the goal of increasing flexibility for improved performance depending on the specific task.
36
37
 
37
38
  ## Quick Start
38
39
 
@@ -50,20 +51,24 @@ loss = loss_fn(estimates, targets)
50
51
  loss.backward()
51
52
  ```
52
53
 
53
- ## Features
54
+ ## Loss Functions
55
+
56
+ - [**Time-Domain L1SNR Loss**](#example-l1snrloss-time-domain): A basic, time-domain L1-SNR loss, based on [[1]](https://arxiv.org/abs/2309.02539).
57
+ - [**Regularized Time-Domain L1SNRDBLoss**](#example-l1snrdbloss-time-domain-with-regularization): An extension of the L1SNR loss with adaptive level-matching regularization from [[2]](https://arxiv.org/abs/2501.16171), plus an optional L1 loss component.
58
+ - [**Multi-Resolution STFT L1SNRDBLoss**](#example-stftl1snrdbloss-spectrogram-domain): A spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)), computed over multiple STFT resolutions, with optional spectrogram-domain level-matching regularization inspired by its time-domain counterpart in [[2]](https://arxiv.org/abs/2501.16171).
59
+ - [**Combined Multi-Domain Loss**](#example-multil1snrdbloss-combined-time--spectrogram): `MultiL1SNRDBLoss` combines time-domain and spectrogram-domain losses into a single, weighted objective function.
60
+
61
+ ## Additional Features
54
62
 
55
- - **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [[1]](https://arxiv.org/abs/2309.02539).
56
- - **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [[2]](https://arxiv.org/abs/2501.16171), plus an optional L1 loss component.
57
- - **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [[2]](https://arxiv.org/abs/2501.16171), calculated over multiple STFT resolutions.
58
- - **Combined Multi-Domain Loss**: `MultiL1SNRDBLoss` combines time-domain and spectrogram-domain losses into a single, weighted objective function.
59
63
  - **L1 Loss Blending**: The `l1_weight` parameter allows mixing between L1SNR and standard L1 loss, softening the ["all-or-nothing" behavior](#all-or-nothing-behavior-and-l1_weight) of pure SNR losses for more nuanced separation.
64
+ - **Multi-Resolution STFT Averaging** - Extending an STFT-based loss to multiple resolutions is common in recent literature.
65
+ - **Spectrogram-Domain Adaptation of Level-Matching Regularizer [[2]](https://arxiv.org/abs/2501.16171)** - Options to extend adaptive level-matching regularization to spectrogram-domain. Experimental and not used by default.
66
+ - **Time vs. Spectrogram Loss Balancing.** - Allows fine-tuning the relative contribution of time-domain and spectrogram-domain losses in `MultiL1SNRDBLoss` via the `spec_weight` parameter.
60
67
  - **Numerical Stability**: Robust handling of `NaN` and `inf` values during training.
61
68
  - **Short Audio Fallback**: Graceful fallback to time-domain loss when audio is too short for STFT processing.
62
69
 
63
70
  ## Installation
64
71
 
65
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/)
66
-
67
72
  ### Install from PyPI
68
73
 
69
74
  ```bash
@@ -92,13 +97,13 @@ pip install -e .
92
97
 
93
98
  ## Supported Tensor Shapes
94
99
 
95
- All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)`, `(batch, channels, samples)`, or `(batch, num_sources, channels, samples)`. For 3D & 4D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
100
+ All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)`, `(batch, channels, samples)`, or `(batch, num_sources, channels, samples)`. For the time-domain losses, any 3D/4D input is flattened across all non-batch dimensions (e.g., sources, channels, and samples) into a single vector per example before the loss is computed. For the spectrogram-domain loss, inputs are reshaped to `(batch, streams, samples)` by flattening all non-time dimensions into a “stream” dimension (e.g., `streams = channels` or `streams = num_sources * channels`), and a separate STFT is computed for each stream.
96
101
 
97
102
  ## Usage
98
103
 
99
104
  The loss functions can be imported directly from the `torch_l1_snr` package.
100
105
 
101
- ### Example: `L1SNRLoss` (Time Domain)
106
+ ### `L1SNRLoss` (Time Domain)
102
107
 
103
108
  The simplest loss function - pure L1SNR without regularization.
104
109
 
@@ -120,7 +125,7 @@ loss.backward()
120
125
  print(f"L1SNRLoss: {loss.item()}")
121
126
  ```
122
127
 
123
- ### Example: `L1SNRDBLoss` (Time Domain with Regularization)
128
+ ### `L1SNRDBLoss` (Time Domain with Regularization)
124
129
 
125
130
  Adds adaptive level-matching regularization to prevent silence collapse.
126
131
 
@@ -147,7 +152,7 @@ loss.backward()
147
152
  print(f"L1SNRDBLoss: {loss.item()}")
148
153
  ```
149
154
 
150
- ### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
155
+ ### `STFTL1SNRDBLoss` (Spectrogram Domain)
151
156
 
152
157
  Computes L1SNR loss across multiple STFT resolutions.
153
158
 
@@ -173,7 +178,7 @@ loss.backward()
173
178
  print(f"STFTL1SNRDBLoss: {loss.item()}")
174
179
  ```
175
180
 
176
- ### Example: `MultiL1SNRDBLoss` (Combined Time + Spectrogram)
181
+ ### `MultiL1SNRDBLoss` (Combined Time + Spectrogram)
177
182
 
178
183
  Combines time-domain and spectrogram-domain losses into a single weighted objective.
179
184
 
@@ -208,24 +213,26 @@ The goal of these loss functions is to provide a perceptually-informed and robus
208
213
  - **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
209
214
  - **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
210
215
 
216
+ This package is motivated by, and largely follows, the objectives and regularizers described in the cited papers ([1–3]). Several novel algorithmic extensions have been included with the goal of increasing flexibility for improved performance depending on the specific task.
217
+
211
218
  ### Level-Matching Regularization
212
219
 
213
220
  A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [[2]](https://arxiv.org/abs/2501.16171). This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
214
221
 
215
222
  ### Multi-Resolution Spectrogram Analysis
216
223
 
217
- The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifactsfrom short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training. Using multiple resolutions for an STFT loss is common among many recent source separation works.
224
+ The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency (spectrogram) resolutions. While not mentioned in the cited papers, by analyzing the signal with *multiple different* STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts - from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training. Using multiple resolutions for an STFT loss is common among many recent source separation works, such as the [Band-Split RoPE Transformer](https://arxiv.org/abs/2309.02612).
218
225
 
219
226
  ### "All-or-Nothing" Behavior and `l1_weight`
220
227
 
221
- A characteristic of SNR-style losses (that I experienced in many training experiments) is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
228
+ A characteristic of these SNR-style losses that I experienced in many training experiments is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources (e.g. drums vs vocals), as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept. This poses a tradeoff for sources that may share greater similarities (e.g. speech vs singing vocals).
222
229
 
223
230
  While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
224
231
 
225
- While this can potentially reduce metrics like SDR, I found that re-introducing some standard L1 loss allows for slightly more "smearing" of sound between sources to mask large errors and be more perceptually acceptable. I have no hard numbers on this, just my experience, so I recommend starting with no standard L1 mixed in (`l1_weight=0.0`), and then slowly increasing from there based on your needs.
232
+ While this can potentially reduce the "cleanliness" of separations and slightly harm metrics like SDR, I found that re-introducing some standard L1 loss allows for slightly more "smearing" of sound between sources to mask large errors and be more perceptually acceptable for sources with many similarities. I have no hard numbers to report on this yet, just my experience. So I recommend starting with no standard L1 mixed in (`l1_weight=0.0`), and then slowly increasing from there based on your needs.
226
233
 
227
234
  - `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
228
- - `l1_weight=1.0`: Pure L1 loss.
235
+ - `l1_weight=1.0`: Pure standard L1 loss.
229
236
  - `0.0 < l1_weight < 1.0`: A weighted combination of the two.
230
237
 
231
238
  The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
@@ -239,7 +246,7 @@ The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`
239
246
 
240
247
  ## Contributing
241
248
 
242
- Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
249
+ Contributions are welcome! Please open an issue or submit a pull request if you have any bug fixes, improvements, or new features to suggest.
243
250
 
244
251
  ## License
245
252
 
@@ -1,10 +1,11 @@
1
1
  ![torch-l1-snr-logo](https://raw.githubusercontent.com/crlandsc/torch-l1-snr/main/images/logo.png)
2
2
 
3
- [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/stargazers)
3
+ [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/stargazers) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/)
4
4
 
5
- L1 Signal-to-Noise Ratio (SNR) loss functions for audio source separation in PyTorch. This package provides four loss functions that combine implementations from recent academic research with novel extensions, designed to integrate easily into any audio separation training pipeline.
6
5
 
7
- The core `L1SNRLoss` is based on the loss function described in [[1]](https://arxiv.org/abs/2309.02539), while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [[2]](https://arxiv.org/abs/2501.16171). `MultiL1SNRDBLoss` combines both time-domain and spectrogram-domain losses into a single loss function for convenience and flexibility.
6
+ L1 Signal-to-Noise Ratio (SNR) loss functions for audio source separation in PyTorch. This package provides four loss functions that combine implementations from recent academic research with novel extensions, designed to integrate easily into any audio separation or enhancement training pipeline.
7
+
8
+ The core [`L1SNRLoss`](#example-l1snrloss-time-domain) is based on the loss function described in [[1]](https://arxiv.org/abs/2309.02539). [`L1SNRDBLoss`](#example-l1snrdbloss-time-domain-with-regularization) adds adaptive level-matching regularization proposed in [[2]](https://arxiv.org/abs/2501.16171). [`STFTL1SNRDBLoss`](#example-stftl1snrdbloss-spectrogram-domain) provides a spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)). [`MultiL1SNRDBLoss`](#example-multil1snrdbloss-combined-time--spectrogram) combines time-domain and spectrogram-domain losses into a single loss function for convenience and flexibility. Optional novel algorithmic extensions have also been included (such as multi-resolution STFT averaging, spectrogram-domain adaptation of the level-matching regularizer from [[2]](https://arxiv.org/abs/2501.16171), time vs. spectrogram loss balancing, and blending of standard L1 loss) with the goal of increasing flexibility for improved performance depending on the specific task.
8
9
 
9
10
  ## Quick Start
10
11
 
@@ -22,20 +23,24 @@ loss = loss_fn(estimates, targets)
22
23
  loss.backward()
23
24
  ```
24
25
 
25
- ## Features
26
+ ## Loss Functions
27
+
28
+ - [**Time-Domain L1SNR Loss**](#example-l1snrloss-time-domain): A basic, time-domain L1-SNR loss, based on [[1]](https://arxiv.org/abs/2309.02539).
29
+ - [**Regularized Time-Domain L1SNRDBLoss**](#example-l1snrdbloss-time-domain-with-regularization): An extension of the L1SNR loss with adaptive level-matching regularization from [[2]](https://arxiv.org/abs/2501.16171), plus an optional L1 loss component.
30
+ - [**Multi-Resolution STFT L1SNRDBLoss**](#example-stftl1snrdbloss-spectrogram-domain): A spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)), computed over multiple STFT resolutions, with optional spectrogram-domain level-matching regularization inspired by its time-domain counterpart in [[2]](https://arxiv.org/abs/2501.16171).
31
+ - [**Combined Multi-Domain Loss**](#example-multil1snrdbloss-combined-time--spectrogram): `MultiL1SNRDBLoss` combines time-domain and spectrogram-domain losses into a single, weighted objective function.
32
+
33
+ ## Additional Features
26
34
 
27
- - **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [[1]](https://arxiv.org/abs/2309.02539).
28
- - **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [[2]](https://arxiv.org/abs/2501.16171), plus an optional L1 loss component.
29
- - **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [[2]](https://arxiv.org/abs/2501.16171), calculated over multiple STFT resolutions.
30
- - **Combined Multi-Domain Loss**: `MultiL1SNRDBLoss` combines time-domain and spectrogram-domain losses into a single, weighted objective function.
31
35
  - **L1 Loss Blending**: The `l1_weight` parameter allows mixing between L1SNR and standard L1 loss, softening the ["all-or-nothing" behavior](#all-or-nothing-behavior-and-l1_weight) of pure SNR losses for more nuanced separation.
36
+ - **Multi-Resolution STFT Averaging** - Extending an STFT-based loss to multiple resolutions is common in recent literature.
37
+ - **Spectrogram-Domain Adaptation of Level-Matching Regularizer [[2]](https://arxiv.org/abs/2501.16171)** - Options to extend adaptive level-matching regularization to spectrogram-domain. Experimental and not used by default.
38
+ - **Time vs. Spectrogram Loss Balancing.** - Allows fine-tuning the relative contribution of time-domain and spectrogram-domain losses in `MultiL1SNRDBLoss` via the `spec_weight` parameter.
32
39
  - **Numerical Stability**: Robust handling of `NaN` and `inf` values during training.
33
40
  - **Short Audio Fallback**: Graceful fallback to time-domain loss when audio is too short for STFT processing.
34
41
 
35
42
  ## Installation
36
43
 
37
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/)
38
-
39
44
  ### Install from PyPI
40
45
 
41
46
  ```bash
@@ -64,13 +69,13 @@ pip install -e .
64
69
 
65
70
  ## Supported Tensor Shapes
66
71
 
67
- All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)`, `(batch, channels, samples)`, or `(batch, num_sources, channels, samples)`. For 3D & 4D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
72
+ All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)`, `(batch, channels, samples)`, or `(batch, num_sources, channels, samples)`. For the time-domain losses, any 3D/4D input is flattened across all non-batch dimensions (e.g., sources, channels, and samples) into a single vector per example before the loss is computed. For the spectrogram-domain loss, inputs are reshaped to `(batch, streams, samples)` by flattening all non-time dimensions into a “stream” dimension (e.g., `streams = channels` or `streams = num_sources * channels`), and a separate STFT is computed for each stream.
68
73
 
69
74
  ## Usage
70
75
 
71
76
  The loss functions can be imported directly from the `torch_l1_snr` package.
72
77
 
73
- ### Example: `L1SNRLoss` (Time Domain)
78
+ ### `L1SNRLoss` (Time Domain)
74
79
 
75
80
  The simplest loss function - pure L1SNR without regularization.
76
81
 
@@ -92,7 +97,7 @@ loss.backward()
92
97
  print(f"L1SNRLoss: {loss.item()}")
93
98
  ```
94
99
 
95
- ### Example: `L1SNRDBLoss` (Time Domain with Regularization)
100
+ ### `L1SNRDBLoss` (Time Domain with Regularization)
96
101
 
97
102
  Adds adaptive level-matching regularization to prevent silence collapse.
98
103
 
@@ -119,7 +124,7 @@ loss.backward()
119
124
  print(f"L1SNRDBLoss: {loss.item()}")
120
125
  ```
121
126
 
122
- ### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
127
+ ### `STFTL1SNRDBLoss` (Spectrogram Domain)
123
128
 
124
129
  Computes L1SNR loss across multiple STFT resolutions.
125
130
 
@@ -145,7 +150,7 @@ loss.backward()
145
150
  print(f"STFTL1SNRDBLoss: {loss.item()}")
146
151
  ```
147
152
 
148
- ### Example: `MultiL1SNRDBLoss` (Combined Time + Spectrogram)
153
+ ### `MultiL1SNRDBLoss` (Combined Time + Spectrogram)
149
154
 
150
155
  Combines time-domain and spectrogram-domain losses into a single weighted objective.
151
156
 
@@ -180,24 +185,26 @@ The goal of these loss functions is to provide a perceptually-informed and robus
180
185
  - **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
181
186
  - **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
182
187
 
188
+ This package is motivated by, and largely follows, the objectives and regularizers described in the cited papers ([1–3]). Several novel algorithmic extensions have been included with the goal of increasing flexibility for improved performance depending on the specific task.
189
+
183
190
  ### Level-Matching Regularization
184
191
 
185
192
  A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [[2]](https://arxiv.org/abs/2501.16171). This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
186
193
 
187
194
  ### Multi-Resolution Spectrogram Analysis
188
195
 
189
- The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifactsfrom short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training. Using multiple resolutions for an STFT loss is common among many recent source separation works.
196
+ The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency (spectrogram) resolutions. While not mentioned in the cited papers, by analyzing the signal with *multiple different* STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts - from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training. Using multiple resolutions for an STFT loss is common among many recent source separation works, such as the [Band-Split RoPE Transformer](https://arxiv.org/abs/2309.02612).
190
197
 
191
198
  ### "All-or-Nothing" Behavior and `l1_weight`
192
199
 
193
- A characteristic of SNR-style losses (that I experienced in many training experiments) is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
200
+ A characteristic of these SNR-style losses that I experienced in many training experiments is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources (e.g. drums vs vocals), as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept. This poses a tradeoff for sources that may share greater similarities (e.g. speech vs singing vocals).
194
201
 
195
202
  While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
196
203
 
197
- While this can potentially reduce metrics like SDR, I found that re-introducing some standard L1 loss allows for slightly more "smearing" of sound between sources to mask large errors and be more perceptually acceptable. I have no hard numbers on this, just my experience, so I recommend starting with no standard L1 mixed in (`l1_weight=0.0`), and then slowly increasing from there based on your needs.
204
+ While this can potentially reduce the "cleanliness" of separations and slightly harm metrics like SDR, I found that re-introducing some standard L1 loss allows for slightly more "smearing" of sound between sources to mask large errors and be more perceptually acceptable for sources with many similarities. I have no hard numbers to report on this yet, just my experience. So I recommend starting with no standard L1 mixed in (`l1_weight=0.0`), and then slowly increasing from there based on your needs.
198
205
 
199
206
  - `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
200
- - `l1_weight=1.0`: Pure L1 loss.
207
+ - `l1_weight=1.0`: Pure standard L1 loss.
201
208
  - `0.0 < l1_weight < 1.0`: A weighted combination of the two.
202
209
 
203
210
  The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
@@ -211,7 +218,7 @@ The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`
211
218
 
212
219
  ## Contributing
213
220
 
214
- Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
221
+ Contributions are welcome! Please open an issue or submit a pull request if you have any bug fixes, improvements, or new features to suggest.
215
222
 
216
223
  ## License
217
224
 
@@ -14,4 +14,4 @@ __all__ = [
14
14
  "MultiL1SNRDBLoss",
15
15
  ]
16
16
 
17
- __version__ = "0.1.0"
17
+ __version__ = "0.1.1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-l1-snr
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: L1-SNR loss functions for audio source separation in PyTorch
5
5
  Home-page: https://github.com/crlandsc/torch-l1-snr
6
6
  Author: Christopher Landschoot
@@ -28,11 +28,12 @@ Dynamic: license-file
28
28
 
29
29
  ![torch-l1-snr-logo](https://raw.githubusercontent.com/crlandsc/torch-l1-snr/main/images/logo.png)
30
30
 
31
- [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/stargazers)
31
+ [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1-snr)](https://github.com/crlandsc/torch-l1-snr/stargazers) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/)
32
32
 
33
- L1 Signal-to-Noise Ratio (SNR) loss functions for audio source separation in PyTorch. This package provides four loss functions that combine implementations from recent academic research with novel extensions, designed to integrate easily into any audio separation training pipeline.
34
33
 
35
- The core `L1SNRLoss` is based on the loss function described in [[1]](https://arxiv.org/abs/2309.02539), while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [[2]](https://arxiv.org/abs/2501.16171). `MultiL1SNRDBLoss` combines both time-domain and spectrogram-domain losses into a single loss function for convenience and flexibility.
34
+ L1 Signal-to-Noise Ratio (SNR) loss functions for audio source separation in PyTorch. This package provides four loss functions that combine implementations from recent academic research with novel extensions, designed to integrate easily into any audio separation or enhancement training pipeline.
35
+
36
+ The core [`L1SNRLoss`](#example-l1snrloss-time-domain) is based on the loss function described in [[1]](https://arxiv.org/abs/2309.02539). [`L1SNRDBLoss`](#example-l1snrdbloss-time-domain-with-regularization) adds adaptive level-matching regularization proposed in [[2]](https://arxiv.org/abs/2501.16171). [`STFTL1SNRDBLoss`](#example-stftl1snrdbloss-spectrogram-domain) provides a spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)). [`MultiL1SNRDBLoss`](#example-multil1snrdbloss-combined-time--spectrogram) combines time-domain and spectrogram-domain losses into a single loss function for convenience and flexibility. Optional novel algorithmic extensions have also been included (such as multi-resolution STFT averaging, spectrogram-domain adaptation of the level-matching regularizer from [[2]](https://arxiv.org/abs/2501.16171), time vs. spectrogram loss balancing, and blending of standard L1 loss) with the goal of increasing flexibility for improved performance depending on the specific task.
36
37
 
37
38
  ## Quick Start
38
39
 
@@ -50,20 +51,24 @@ loss = loss_fn(estimates, targets)
50
51
  loss.backward()
51
52
  ```
52
53
 
53
- ## Features
54
+ ## Loss Functions
55
+
56
+ - [**Time-Domain L1SNR Loss**](#example-l1snrloss-time-domain): A basic, time-domain L1-SNR loss, based on [[1]](https://arxiv.org/abs/2309.02539).
57
+ - [**Regularized Time-Domain L1SNRDBLoss**](#example-l1snrdbloss-time-domain-with-regularization): An extension of the L1SNR loss with adaptive level-matching regularization from [[2]](https://arxiv.org/abs/2501.16171), plus an optional L1 loss component.
58
+ - [**Multi-Resolution STFT L1SNRDBLoss**](#example-stftl1snrdbloss-spectrogram-domain): A spectrogram-domain L1SNR-style loss (real/imag STFT components as in [[1]](https://arxiv.org/abs/2309.02539) / [[3]](https://arxiv.org/abs/2406.18747)), computed over multiple STFT resolutions, with optional spectrogram-domain level-matching regularization inspired by its time-domain counterpart in [[2]](https://arxiv.org/abs/2501.16171).
59
+ - [**Combined Multi-Domain Loss**](#example-multil1snrdbloss-combined-time--spectrogram): `MultiL1SNRDBLoss` combines time-domain and spectrogram-domain losses into a single, weighted objective function.
60
+
61
+ ## Additional Features
54
62
 
55
- - **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [[1]](https://arxiv.org/abs/2309.02539).
56
- - **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [[2]](https://arxiv.org/abs/2501.16171), plus an optional L1 loss component.
57
- - **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [[2]](https://arxiv.org/abs/2501.16171), calculated over multiple STFT resolutions.
58
- - **Combined Multi-Domain Loss**: `MultiL1SNRDBLoss` combines time-domain and spectrogram-domain losses into a single, weighted objective function.
59
63
  - **L1 Loss Blending**: The `l1_weight` parameter allows mixing between L1SNR and standard L1 loss, softening the ["all-or-nothing" behavior](#all-or-nothing-behavior-and-l1_weight) of pure SNR losses for more nuanced separation.
64
+ - **Multi-Resolution STFT Averaging** - Extending an STFT-based loss to multiple resolutions is common in recent literature.
65
+ - **Spectrogram-Domain Adaptation of Level-Matching Regularizer [[2]](https://arxiv.org/abs/2501.16171)** - Options to extend adaptive level-matching regularization to spectrogram-domain. Experimental and not used by default.
66
+ - **Time vs. Spectrogram Loss Balancing.** - Allows fine-tuning the relative contribution of time-domain and spectrogram-domain losses in `MultiL1SNRDBLoss` via the `spec_weight` parameter.
60
67
  - **Numerical Stability**: Robust handling of `NaN` and `inf` values during training.
61
68
  - **Short Audio Fallback**: Graceful fallback to time-domain loss when audio is too short for STFT processing.
62
69
 
63
70
  ## Installation
64
71
 
65
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/) [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1-snr)](https://pypi.org/project/torch-l1-snr/)
66
-
67
72
  ### Install from PyPI
68
73
 
69
74
  ```bash
@@ -92,13 +97,13 @@ pip install -e .
92
97
 
93
98
  ## Supported Tensor Shapes
94
99
 
95
- All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)`, `(batch, channels, samples)`, or `(batch, num_sources, channels, samples)`. For 3D & 4D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
100
+ All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)`, `(batch, channels, samples)`, or `(batch, num_sources, channels, samples)`. For the time-domain losses, any 3D/4D input is flattened across all non-batch dimensions (e.g., sources, channels, and samples) into a single vector per example before the loss is computed. For the spectrogram-domain loss, inputs are reshaped to `(batch, streams, samples)` by flattening all non-time dimensions into a “stream” dimension (e.g., `streams = channels` or `streams = num_sources * channels`), and a separate STFT is computed for each stream.
96
101
 
97
102
  ## Usage
98
103
 
99
104
  The loss functions can be imported directly from the `torch_l1_snr` package.
100
105
 
101
- ### Example: `L1SNRLoss` (Time Domain)
106
+ ### `L1SNRLoss` (Time Domain)
102
107
 
103
108
  The simplest loss function - pure L1SNR without regularization.
104
109
 
@@ -120,7 +125,7 @@ loss.backward()
120
125
  print(f"L1SNRLoss: {loss.item()}")
121
126
  ```
122
127
 
123
- ### Example: `L1SNRDBLoss` (Time Domain with Regularization)
128
+ ### `L1SNRDBLoss` (Time Domain with Regularization)
124
129
 
125
130
  Adds adaptive level-matching regularization to prevent silence collapse.
126
131
 
@@ -147,7 +152,7 @@ loss.backward()
147
152
  print(f"L1SNRDBLoss: {loss.item()}")
148
153
  ```
149
154
 
150
- ### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
155
+ ### `STFTL1SNRDBLoss` (Spectrogram Domain)
151
156
 
152
157
  Computes L1SNR loss across multiple STFT resolutions.
153
158
 
@@ -173,7 +178,7 @@ loss.backward()
173
178
  print(f"STFTL1SNRDBLoss: {loss.item()}")
174
179
  ```
175
180
 
176
- ### Example: `MultiL1SNRDBLoss` (Combined Time + Spectrogram)
181
+ ### `MultiL1SNRDBLoss` (Combined Time + Spectrogram)
177
182
 
178
183
  Combines time-domain and spectrogram-domain losses into a single weighted objective.
179
184
 
@@ -208,24 +213,26 @@ The goal of these loss functions is to provide a perceptually-informed and robus
208
213
  - **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
209
214
  - **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
210
215
 
216
+ This package is motivated by, and largely follows, the objectives and regularizers described in the cited papers ([1–3]). Several novel algorithmic extensions have been included with the goal of increasing flexibility for improved performance depending on the specific task.
217
+
211
218
  ### Level-Matching Regularization
212
219
 
213
220
  A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [[2]](https://arxiv.org/abs/2501.16171). This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
214
221
 
215
222
  ### Multi-Resolution Spectrogram Analysis
216
223
 
217
- The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifactsfrom short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training. Using multiple resolutions for an STFT loss is common among many recent source separation works.
224
+ The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency (spectrogram) resolutions. While not mentioned in the cited papers, by analyzing the signal with *multiple different* STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts - from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training. Using multiple resolutions for an STFT loss is common among many recent source separation works, such as the [Band-Split RoPE Transformer](https://arxiv.org/abs/2309.02612).
218
225
 
219
226
  ### "All-or-Nothing" Behavior and `l1_weight`
220
227
 
221
- A characteristic of SNR-style losses (that I experienced in many training experiments) is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
228
+ A characteristic of these SNR-style losses that I experienced in many training experiments is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources (e.g. drums vs vocals), as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept. This poses a tradeoff for sources that may share greater similarities (e.g. speech vs singing vocals).
222
229
 
223
230
  While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
224
231
 
225
- While this can potentially reduce metrics like SDR, I found that re-introducing some standard L1 loss allows for slightly more "smearing" of sound between sources to mask large errors and be more perceptually acceptable. I have no hard numbers on this, just my experience, so I recommend starting with no standard L1 mixed in (`l1_weight=0.0`), and then slowly increasing from there based on your needs.
232
+ While this can potentially reduce the "cleanliness" of separations and slightly harm metrics like SDR, I found that re-introducing some standard L1 loss allows for slightly more "smearing" of sound between sources to mask large errors and be more perceptually acceptable for sources with many similarities. I have no hard numbers to report on this yet, just my experience. So I recommend starting with no standard L1 mixed in (`l1_weight=0.0`), and then slowly increasing from there based on your needs.
226
233
 
227
234
  - `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
228
- - `l1_weight=1.0`: Pure L1 loss.
235
+ - `l1_weight=1.0`: Pure standard L1 loss.
229
236
  - `0.0 < l1_weight < 1.0`: A weighted combination of the two.
230
237
 
231
238
  The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
@@ -239,7 +246,7 @@ The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`
239
246
 
240
247
  ## Contributing
241
248
 
242
- Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
249
+ Contributions are welcome! Please open an issue or submit a pull request if you have any bug fixes, improvements, or new features to suggest.
243
250
 
244
251
  ## License
245
252
 
File without changes
File without changes