torch-l1-snr 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_l1_snr-0.0.1/LICENSE +21 -0
- torch_l1_snr-0.0.1/PKG-INFO +201 -0
- torch_l1_snr-0.0.1/README.md +173 -0
- torch_l1_snr-0.0.1/pyproject.toml +3 -0
- torch_l1_snr-0.0.1/setup.cfg +36 -0
- torch_l1_snr-0.0.1/tests/test_losses.py +143 -0
- torch_l1_snr-0.0.1/torch_l1_snr.egg-info/PKG-INFO +201 -0
- torch_l1_snr-0.0.1/torch_l1_snr.egg-info/SOURCES.txt +12 -0
- torch_l1_snr-0.0.1/torch_l1_snr.egg-info/dependency_links.txt +1 -0
- torch_l1_snr-0.0.1/torch_l1_snr.egg-info/requires.txt +3 -0
- torch_l1_snr-0.0.1/torch_l1_snr.egg-info/top_level.txt +1 -0
- torch_l1_snr-0.0.1/torch_l1snr/__init__.py +15 -0
- torch_l1_snr-0.0.1/torch_l1snr/l1snr.py +786 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Christopher Landschoot
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-l1-snr
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: L1-SNR loss functions for audio source separation in PyTorch
|
|
5
|
+
Home-page: https://github.com/crlandsc/torch-l1-snr
|
|
6
|
+
Author: Christopher Landscaping
|
|
7
|
+
Author-email: crlandschoot@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Operating System :: OS Independent
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
|
|
21
|
+
Requires-Python: >=3.8
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: torch
|
|
25
|
+
Requires-Dist: torchaudio
|
|
26
|
+
Requires-Dist: numpy>=1.21.0
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
|
|
29
|
+
 -->
|
|
30
|
+
|
|
31
|
+
# NOTE: Repo is currently a work-in-progress and not ready for installation & use.
|
|
32
|
+
|
|
33
|
+
[](https://github.com/crlandsc/torch-l1snr/blob/main/LICENSE) [](https://github.com/crlandsc/torch-l1snr/stargazers)
|
|
34
|
+
|
|
35
|
+
# torch-l1-snr
|
|
36
|
+
|
|
37
|
+
A PyTorch implementation of L1-based Signal-to-Noise Ratio (SNR) loss functions for audio source separation. This package provides implementations and novel extensions based on concepts from recent academic papers, offering flexible and robust loss functions that can be easily integrated into any PyTorch-based audio separation pipeline.
|
|
38
|
+
|
|
39
|
+
The core `L1SNRLoss` is based on the loss function described in [1], while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [2].
|
|
40
|
+
|
|
41
|
+
## Features
|
|
42
|
+
|
|
43
|
+
- **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [1].
|
|
44
|
+
- **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [2], plus an optional L1 loss component.
|
|
45
|
+
- **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [2], calculated over multiple STFT resolutions.
|
|
46
|
+
- **Modular Stem-based Loss**: A wrapper that combines time and spectrogram domain losses and can be configured to run on specific stems.
|
|
47
|
+
- **Efficient & Robust**: Includes optimizations for pure L1 loss calculation and robust handling of `NaN`/`inf` values and short audio segments.
|
|
48
|
+
|
|
49
|
+
## Installation
|
|
50
|
+
|
|
51
|
+
<!-- Add PyPI badges once the package is published -->
|
|
52
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
53
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
54
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
55
|
+
|
|
56
|
+
You can install the package directly from GitHub:
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
pip install git+https://github.com/crlandsc/torch-l1snr.git
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
Or, you can clone the repository and install it in editable mode for development:
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
git clone https://github.com/crlandsc/torch-l1snr.git
|
|
66
|
+
cd torch-l1snr
|
|
67
|
+
pip install -e .
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
## Dependencies
|
|
71
|
+
|
|
72
|
+
- [PyTorch](https://pytorch.org/)
|
|
73
|
+
- [torchaudio](https://pytorch.org/audio/stable/index.html)
|
|
74
|
+
|
|
75
|
+
## Supported Tensor Shapes
|
|
76
|
+
|
|
77
|
+
All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)` or `(batch, channels, samples)`. For 3D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
|
|
78
|
+
|
|
79
|
+
## Usage
|
|
80
|
+
|
|
81
|
+
The loss functions can be imported directly from the `torch_l1snr` package.
|
|
82
|
+
|
|
83
|
+
### Example: `L1SNRDBLoss` (Time Domain)
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
import torch
|
|
87
|
+
from torch_l1snr import L1SNRDBLoss
|
|
88
|
+
|
|
89
|
+
# Create dummy audio signals
|
|
90
|
+
estimates = torch.randn(4, 32000) # Batch of 4, 32000 samples
|
|
91
|
+
actuals = torch.randn(4, 32000)
|
|
92
|
+
|
|
93
|
+
# Initialize the loss function
|
|
94
|
+
# l1_weight=0.1 blends L1SNR with 10% L1 loss
|
|
95
|
+
loss_fn = L1SNRDBLoss(l1_weight=0.1)
|
|
96
|
+
|
|
97
|
+
# Calculate loss
|
|
98
|
+
loss = loss_fn(estimates, actuals)
|
|
99
|
+
loss.backward()
|
|
100
|
+
|
|
101
|
+
print(f"L1SNRDBLoss: {loss.item()}")
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
import torch
|
|
108
|
+
from torch_l1snr import STFTL1SNRDBLoss
|
|
109
|
+
|
|
110
|
+
# Create dummy audio signals
|
|
111
|
+
estimates = torch.randn(4, 32000)
|
|
112
|
+
actuals = torch.randn(4, 32000)
|
|
113
|
+
|
|
114
|
+
# Initialize the loss function
|
|
115
|
+
# Uses multiple STFT resolutions by default
|
|
116
|
+
loss_fn = STFTL1SNRDBLoss(l1_weight=0.0) # Pure L1SNR + Regularization
|
|
117
|
+
|
|
118
|
+
# Calculate loss
|
|
119
|
+
loss = loss_fn(estimates, actuals)
|
|
120
|
+
loss.backward()
|
|
121
|
+
|
|
122
|
+
print(f"STFTL1SNRDBLoss: {loss.item()}")
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
### Example: `MultiL1SNRDBLoss` for a Combined Time+Spectrogram Loss
|
|
126
|
+
|
|
127
|
+
This loss combines the time-domain and spectrogram-domain losses into a single, weighted objective function.
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
import torch
|
|
131
|
+
from torch_l1snr import MultiL1SNRDBLoss
|
|
132
|
+
|
|
133
|
+
# Create dummy audio signals
|
|
134
|
+
# Shape: (batch, channels, samples)
|
|
135
|
+
estimates = torch.randn(2, 2, 44100) # Batch of 2, stereo
|
|
136
|
+
actuals = torch.randn(2, 2, 44100)
|
|
137
|
+
|
|
138
|
+
# --- Configuration ---
|
|
139
|
+
loss_fn = MultiL1SNRDBLoss(
|
|
140
|
+
weight=1.0, # Overall weight for this loss
|
|
141
|
+
spec_weight=0.7, # 70% spectrogram loss, 30% time-domain loss
|
|
142
|
+
l1_weight=0.1, # Use 10% L1, 90% L1SNR+Reg
|
|
143
|
+
)
|
|
144
|
+
loss = loss_fn(estimates, actuals)
|
|
145
|
+
print(f"Multi-domain Loss: {loss.item()}")
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
## Motivation
|
|
149
|
+
|
|
150
|
+
The goal of these loss functions is to provide a perceptually-informed and robust alternative to common audio losses like L1, L2 (MSE), and SI-SDR for training audio source separation models.
|
|
151
|
+
|
|
152
|
+
- **Robustness**: The L1 norm is less sensitive to large outliers than the L2 norm, making it more suitable for audio signals which can have sharp transients.
|
|
153
|
+
- **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
|
|
154
|
+
- **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
|
|
155
|
+
|
|
156
|
+
#### Level-Matching Regularization
|
|
157
|
+
|
|
158
|
+
A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [2]. This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
|
|
159
|
+
|
|
160
|
+
#### Multi-Resolution Spectrogram Analysis
|
|
161
|
+
|
|
162
|
+
The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts—from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training.
|
|
163
|
+
|
|
164
|
+
#### "All-or-Nothing" Behavior and `l1_weight`
|
|
165
|
+
|
|
166
|
+
A characteristic of SNR-style losses is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
|
|
167
|
+
|
|
168
|
+
While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
|
|
169
|
+
|
|
170
|
+
- `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
|
|
171
|
+
- `l1_weight=1.0`: Pure L1 loss.
|
|
172
|
+
- `0.0 < l1_weight < 1.0`: A weighted combination of the two.
|
|
173
|
+
|
|
174
|
+
The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
|
|
175
|
+
|
|
176
|
+
**Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), you may need to tune `l1_scale_time` and `l1_scale_spec`. This is to ensure the gradients of the L1 and L1SNR components are balanced, which is crucial for stable training. The default values provide a reasonable starting point, but monitoring the loss components is recommended to ensure they are scaled appropriately.
|
|
177
|
+
|
|
178
|
+
## Limitations
|
|
179
|
+
|
|
180
|
+
- The L1SNR loss is not scale-invariant. Unlike SI-SNR, it requires the model's output to be correctly scaled relative to the target.
|
|
181
|
+
- While the dB scaling and regularization are psychoacoustically motivated, the loss does not model more complex perceptual phenomena like auditory masking.
|
|
182
|
+
|
|
183
|
+
## Contributing
|
|
184
|
+
|
|
185
|
+
Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
|
|
186
|
+
|
|
187
|
+
## License
|
|
188
|
+
|
|
189
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
190
|
+
|
|
191
|
+
## Acknowledgments
|
|
192
|
+
|
|
193
|
+
The loss functions implemented here are based on the work of the authors of the referenced papers.
|
|
194
|
+
|
|
195
|
+
## References
|
|
196
|
+
|
|
197
|
+
[1] K. N. Watcharasupat, C.-W. Wu, Y. Ding, I. Orife, A. J. Hipple, P. A. Williams, S. Kramer, A. Lerch, and W. Wolcott, "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation," IEEE Open Journal of Signal Processing, 2023. (arXiv:2309.02539)
|
|
198
|
+
|
|
199
|
+
[2] K. N. Watcharasupat and A. Lerch, "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries," arXiv:2501.16171.
|
|
200
|
+
|
|
201
|
+
[3] K. N. Watcharasupat and A. Lerch, "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems," Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024. (arXiv:2406.18747)
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
 -->
|
|
2
|
+
|
|
3
|
+
# NOTE: Repo is currently a work-in-progress and not ready for installation & use.
|
|
4
|
+
|
|
5
|
+
[](https://github.com/crlandsc/torch-l1snr/blob/main/LICENSE) [](https://github.com/crlandsc/torch-l1snr/stargazers)
|
|
6
|
+
|
|
7
|
+
# torch-l1-snr
|
|
8
|
+
|
|
9
|
+
A PyTorch implementation of L1-based Signal-to-Noise Ratio (SNR) loss functions for audio source separation. This package provides implementations and novel extensions based on concepts from recent academic papers, offering flexible and robust loss functions that can be easily integrated into any PyTorch-based audio separation pipeline.
|
|
10
|
+
|
|
11
|
+
The core `L1SNRLoss` is based on the loss function described in [1], while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [2].
|
|
12
|
+
|
|
13
|
+
## Features
|
|
14
|
+
|
|
15
|
+
- **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [1].
|
|
16
|
+
- **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [2], plus an optional L1 loss component.
|
|
17
|
+
- **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [2], calculated over multiple STFT resolutions.
|
|
18
|
+
- **Modular Stem-based Loss**: A wrapper that combines time and spectrogram domain losses and can be configured to run on specific stems.
|
|
19
|
+
- **Efficient & Robust**: Includes optimizations for pure L1 loss calculation and robust handling of `NaN`/`inf` values and short audio segments.
|
|
20
|
+
|
|
21
|
+
## Installation
|
|
22
|
+
|
|
23
|
+
<!-- Add PyPI badges once the package is published -->
|
|
24
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
25
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
26
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
27
|
+
|
|
28
|
+
You can install the package directly from GitHub:
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
pip install git+https://github.com/crlandsc/torch-l1snr.git
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Or, you can clone the repository and install it in editable mode for development:
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
git clone https://github.com/crlandsc/torch-l1snr.git
|
|
38
|
+
cd torch-l1snr
|
|
39
|
+
pip install -e .
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## Dependencies
|
|
43
|
+
|
|
44
|
+
- [PyTorch](https://pytorch.org/)
|
|
45
|
+
- [torchaudio](https://pytorch.org/audio/stable/index.html)
|
|
46
|
+
|
|
47
|
+
## Supported Tensor Shapes
|
|
48
|
+
|
|
49
|
+
All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)` or `(batch, channels, samples)`. For 3D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
|
|
50
|
+
|
|
51
|
+
## Usage
|
|
52
|
+
|
|
53
|
+
The loss functions can be imported directly from the `torch_l1snr` package.
|
|
54
|
+
|
|
55
|
+
### Example: `L1SNRDBLoss` (Time Domain)
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
import torch
|
|
59
|
+
from torch_l1snr import L1SNRDBLoss
|
|
60
|
+
|
|
61
|
+
# Create dummy audio signals
|
|
62
|
+
estimates = torch.randn(4, 32000) # Batch of 4, 32000 samples
|
|
63
|
+
actuals = torch.randn(4, 32000)
|
|
64
|
+
|
|
65
|
+
# Initialize the loss function
|
|
66
|
+
# l1_weight=0.1 blends L1SNR with 10% L1 loss
|
|
67
|
+
loss_fn = L1SNRDBLoss(l1_weight=0.1)
|
|
68
|
+
|
|
69
|
+
# Calculate loss
|
|
70
|
+
loss = loss_fn(estimates, actuals)
|
|
71
|
+
loss.backward()
|
|
72
|
+
|
|
73
|
+
print(f"L1SNRDBLoss: {loss.item()}")
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
import torch
|
|
80
|
+
from torch_l1snr import STFTL1SNRDBLoss
|
|
81
|
+
|
|
82
|
+
# Create dummy audio signals
|
|
83
|
+
estimates = torch.randn(4, 32000)
|
|
84
|
+
actuals = torch.randn(4, 32000)
|
|
85
|
+
|
|
86
|
+
# Initialize the loss function
|
|
87
|
+
# Uses multiple STFT resolutions by default
|
|
88
|
+
loss_fn = STFTL1SNRDBLoss(l1_weight=0.0) # Pure L1SNR + Regularization
|
|
89
|
+
|
|
90
|
+
# Calculate loss
|
|
91
|
+
loss = loss_fn(estimates, actuals)
|
|
92
|
+
loss.backward()
|
|
93
|
+
|
|
94
|
+
print(f"STFTL1SNRDBLoss: {loss.item()}")
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
### Example: `MultiL1SNRDBLoss` for a Combined Time+Spectrogram Loss
|
|
98
|
+
|
|
99
|
+
This loss combines the time-domain and spectrogram-domain losses into a single, weighted objective function.
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
import torch
|
|
103
|
+
from torch_l1snr import MultiL1SNRDBLoss
|
|
104
|
+
|
|
105
|
+
# Create dummy audio signals
|
|
106
|
+
# Shape: (batch, channels, samples)
|
|
107
|
+
estimates = torch.randn(2, 2, 44100) # Batch of 2, stereo
|
|
108
|
+
actuals = torch.randn(2, 2, 44100)
|
|
109
|
+
|
|
110
|
+
# --- Configuration ---
|
|
111
|
+
loss_fn = MultiL1SNRDBLoss(
|
|
112
|
+
weight=1.0, # Overall weight for this loss
|
|
113
|
+
spec_weight=0.7, # 70% spectrogram loss, 30% time-domain loss
|
|
114
|
+
l1_weight=0.1, # Use 10% L1, 90% L1SNR+Reg
|
|
115
|
+
)
|
|
116
|
+
loss = loss_fn(estimates, actuals)
|
|
117
|
+
print(f"Multi-domain Loss: {loss.item()}")
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
## Motivation
|
|
121
|
+
|
|
122
|
+
The goal of these loss functions is to provide a perceptually-informed and robust alternative to common audio losses like L1, L2 (MSE), and SI-SDR for training audio source separation models.
|
|
123
|
+
|
|
124
|
+
- **Robustness**: The L1 norm is less sensitive to large outliers than the L2 norm, making it more suitable for audio signals which can have sharp transients.
|
|
125
|
+
- **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
|
|
126
|
+
- **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
|
|
127
|
+
|
|
128
|
+
#### Level-Matching Regularization
|
|
129
|
+
|
|
130
|
+
A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [2]. This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
|
|
131
|
+
|
|
132
|
+
#### Multi-Resolution Spectrogram Analysis
|
|
133
|
+
|
|
134
|
+
The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts—from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training.
|
|
135
|
+
|
|
136
|
+
#### "All-or-Nothing" Behavior and `l1_weight`
|
|
137
|
+
|
|
138
|
+
A characteristic of SNR-style losses is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
|
|
139
|
+
|
|
140
|
+
While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
|
|
141
|
+
|
|
142
|
+
- `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
|
|
143
|
+
- `l1_weight=1.0`: Pure L1 loss.
|
|
144
|
+
- `0.0 < l1_weight < 1.0`: A weighted combination of the two.
|
|
145
|
+
|
|
146
|
+
The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
|
|
147
|
+
|
|
148
|
+
**Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), you may need to tune `l1_scale_time` and `l1_scale_spec`. This is to ensure the gradients of the L1 and L1SNR components are balanced, which is crucial for stable training. The default values provide a reasonable starting point, but monitoring the loss components is recommended to ensure they are scaled appropriately.
|
|
149
|
+
|
|
150
|
+
## Limitations
|
|
151
|
+
|
|
152
|
+
- The L1SNR loss is not scale-invariant. Unlike SI-SNR, it requires the model's output to be correctly scaled relative to the target.
|
|
153
|
+
- While the dB scaling and regularization are psychoacoustically motivated, the loss does not model more complex perceptual phenomena like auditory masking.
|
|
154
|
+
|
|
155
|
+
## Contributing
|
|
156
|
+
|
|
157
|
+
Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
|
|
158
|
+
|
|
159
|
+
## License
|
|
160
|
+
|
|
161
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
162
|
+
|
|
163
|
+
## Acknowledgments
|
|
164
|
+
|
|
165
|
+
The loss functions implemented here are based on the work of the authors of the referenced papers.
|
|
166
|
+
|
|
167
|
+
## References
|
|
168
|
+
|
|
169
|
+
[1] K. N. Watcharasupat, C.-W. Wu, Y. Ding, I. Orife, A. J. Hipple, P. A. Williams, S. Kramer, A. Lerch, and W. Wolcott, "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation," IEEE Open Journal of Signal Processing, 2023. (arXiv:2309.02539)
|
|
170
|
+
|
|
171
|
+
[2] K. N. Watcharasupat and A. Lerch, "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries," arXiv:2501.16171.
|
|
172
|
+
|
|
173
|
+
[3] K. N. Watcharasupat and A. Lerch, "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems," Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024. (arXiv:2406.18747)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
[metadata]
|
|
2
|
+
name = torch-l1-snr
|
|
3
|
+
version = 0.0.1
|
|
4
|
+
author = Christopher Landscaping
|
|
5
|
+
author_email = crlandschoot@gmail.com
|
|
6
|
+
description = L1-SNR loss functions for audio source separation in PyTorch
|
|
7
|
+
long_description = file: README.md
|
|
8
|
+
long_description_content_type = text/markdown
|
|
9
|
+
url = https://github.com/crlandsc/torch-l1-snr
|
|
10
|
+
license = MIT
|
|
11
|
+
classifiers =
|
|
12
|
+
Intended Audience :: Developers
|
|
13
|
+
Intended Audience :: Science/Research
|
|
14
|
+
License :: OSI Approved :: MIT License
|
|
15
|
+
Programming Language :: Python :: 3
|
|
16
|
+
Programming Language :: Python :: 3.8
|
|
17
|
+
Programming Language :: Python :: 3.9
|
|
18
|
+
Programming Language :: Python :: 3.10
|
|
19
|
+
Programming Language :: Python :: 3.11
|
|
20
|
+
Programming Language :: Python :: 3.12
|
|
21
|
+
Operating System :: OS Independent
|
|
22
|
+
Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
|
+
Topic :: Multimedia :: Sound/Audio :: Analysis
|
|
24
|
+
|
|
25
|
+
[options]
|
|
26
|
+
packages = find:
|
|
27
|
+
python_requires = >=3.8
|
|
28
|
+
install_requires =
|
|
29
|
+
torch
|
|
30
|
+
torchaudio
|
|
31
|
+
numpy>=1.21.0
|
|
32
|
+
|
|
33
|
+
[egg_info]
|
|
34
|
+
tag_build =
|
|
35
|
+
tag_date = 0
|
|
36
|
+
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pytest
|
|
3
|
+
from torch_l1snr import (
|
|
4
|
+
dbrms,
|
|
5
|
+
L1SNRLoss,
|
|
6
|
+
L1SNRDBLoss,
|
|
7
|
+
STFTL1SNRDBLoss,
|
|
8
|
+
MultiL1SNRDBLoss,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
# --- Test Fixtures ---
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def dummy_audio():
|
|
14
|
+
"""Provides a batch of dummy audio signals."""
|
|
15
|
+
estimates = torch.randn(2, 16000)
|
|
16
|
+
actuals = torch.randn(2, 16000)
|
|
17
|
+
# Ensure actuals are not all zero to avoid division by zero in loss
|
|
18
|
+
actuals[0, :100] += 0.1
|
|
19
|
+
return estimates, actuals
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def dummy_stems():
|
|
23
|
+
"""Provides a batch of dummy multi-stem signals."""
|
|
24
|
+
estimates = torch.randn(2, 4, 1, 16000) # batch, stems, channels, samples
|
|
25
|
+
actuals = torch.randn(2, 4, 1, 16000)
|
|
26
|
+
actuals[:, 0, :, :100] += 0.1 # Ensure not all zero
|
|
27
|
+
return estimates, actuals
|
|
28
|
+
|
|
29
|
+
# --- Test Functions ---
|
|
30
|
+
|
|
31
|
+
def test_dbrms():
|
|
32
|
+
signal = torch.ones(2, 1000) * 0.1
|
|
33
|
+
# RMS of 0.1 is -20 dB
|
|
34
|
+
assert torch.allclose(dbrms(signal), torch.tensor([-20.0, -20.0]), atol=1e-4)
|
|
35
|
+
|
|
36
|
+
zeros = torch.zeros(2, 1000)
|
|
37
|
+
# dbrms of zero should be -80dB with default eps=1e-8
|
|
38
|
+
assert torch.allclose(dbrms(zeros), torch.tensor([-80.0, -80.0]), atol=1e-4)
|
|
39
|
+
|
|
40
|
+
def test_l1snr_loss(dummy_audio):
|
|
41
|
+
estimates, actuals = dummy_audio
|
|
42
|
+
loss_fn = L1SNRLoss(name="test")
|
|
43
|
+
loss = loss_fn(estimates, actuals)
|
|
44
|
+
|
|
45
|
+
assert isinstance(loss, torch.Tensor)
|
|
46
|
+
assert loss.ndim == 0
|
|
47
|
+
assert not torch.isnan(loss)
|
|
48
|
+
assert not torch.isinf(loss)
|
|
49
|
+
|
|
50
|
+
def test_l1snrdb_loss_time(dummy_audio):
|
|
51
|
+
estimates, actuals = dummy_audio
|
|
52
|
+
|
|
53
|
+
# Test with default settings (L1SNR + Regularization)
|
|
54
|
+
loss_fn = L1SNRDBLoss(name="test", use_regularization=True, l1_weight=0.0)
|
|
55
|
+
loss = loss_fn(estimates, actuals)
|
|
56
|
+
assert loss.ndim == 0 and not torch.isnan(loss)
|
|
57
|
+
|
|
58
|
+
# Test without regularization
|
|
59
|
+
loss_fn_no_reg = L1SNRDBLoss(name="test_no_reg", use_regularization=False, l1_weight=0.0)
|
|
60
|
+
loss_no_reg = loss_fn_no_reg(estimates, actuals)
|
|
61
|
+
assert loss_no_reg.ndim == 0 and not torch.isnan(loss_no_reg)
|
|
62
|
+
|
|
63
|
+
# Test with L1 loss component
|
|
64
|
+
loss_fn_l1 = L1SNRDBLoss(name="test_l1", l1_weight=0.2)
|
|
65
|
+
loss_l1 = loss_fn_l1(estimates, actuals)
|
|
66
|
+
assert loss_l1.ndim == 0 and not torch.isnan(loss_l1)
|
|
67
|
+
|
|
68
|
+
# Test pure L1 loss mode
|
|
69
|
+
loss_fn_pure_l1 = L1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
|
|
70
|
+
pure_l1_loss = loss_fn_pure_l1(estimates, actuals)
|
|
71
|
+
# Pure L1 mode uses torch.nn.L1Loss, so compare with manual L1 calculation
|
|
72
|
+
l1_loss_manual = torch.nn.L1Loss()(
|
|
73
|
+
estimates.reshape(estimates.shape[0], -1),
|
|
74
|
+
actuals.reshape(actuals.shape[0], -1)
|
|
75
|
+
)
|
|
76
|
+
assert torch.allclose(pure_l1_loss, l1_loss_manual)
|
|
77
|
+
|
|
78
|
+
def test_stft_l1snrdb_loss(dummy_audio):
|
|
79
|
+
estimates, actuals = dummy_audio
|
|
80
|
+
|
|
81
|
+
# Test with default settings
|
|
82
|
+
loss_fn = STFTL1SNRDBLoss(name="test", l1_weight=0.0)
|
|
83
|
+
loss = loss_fn(estimates, actuals)
|
|
84
|
+
assert loss.ndim == 0 and not torch.isnan(loss) and not torch.isinf(loss)
|
|
85
|
+
|
|
86
|
+
# Test pure L1 mode
|
|
87
|
+
loss_fn_pure_l1 = STFTL1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
|
|
88
|
+
l1_loss = loss_fn_pure_l1(estimates, actuals)
|
|
89
|
+
assert l1_loss.ndim == 0 and not torch.isnan(l1_loss) and not torch.isinf(l1_loss)
|
|
90
|
+
|
|
91
|
+
# Test with very short audio
|
|
92
|
+
short_estimates = estimates[:, :500]
|
|
93
|
+
short_actuals = actuals[:, :500]
|
|
94
|
+
loss_short = loss_fn(short_estimates, short_actuals)
|
|
95
|
+
# min_audio_length is 512, so this should fallback to time-domain loss
|
|
96
|
+
assert loss_short.ndim == 0 and not torch.isnan(loss_short)
|
|
97
|
+
|
|
98
|
+
def test_stem_multi_loss(dummy_stems):
|
|
99
|
+
estimates, actuals = dummy_stems
|
|
100
|
+
|
|
101
|
+
# Test with a specific stem - users now manage stems manually by slicing
|
|
102
|
+
# Extract stem 1 (second stem) manually
|
|
103
|
+
est_stem = estimates[:, 1, ...] # Shape: [batch, channels, samples]
|
|
104
|
+
act_stem = actuals[:, 1, ...]
|
|
105
|
+
loss_fn_stem = MultiL1SNRDBLoss(
|
|
106
|
+
name="test_loss_stem",
|
|
107
|
+
spec_weight=0.5,
|
|
108
|
+
l1_weight=0.1
|
|
109
|
+
)
|
|
110
|
+
loss = loss_fn_stem(est_stem, act_stem)
|
|
111
|
+
assert loss.ndim == 0 and not torch.isnan(loss)
|
|
112
|
+
|
|
113
|
+
# Test with all stems jointly - flatten all stems together
|
|
114
|
+
# Reshape to [batch, -1] to process all stems at once
|
|
115
|
+
est_all = estimates.reshape(estimates.shape[0], -1)
|
|
116
|
+
act_all = actuals.reshape(actuals.shape[0], -1)
|
|
117
|
+
loss_fn_all = MultiL1SNRDBLoss(
|
|
118
|
+
name="test_loss_all",
|
|
119
|
+
spec_weight=0.5,
|
|
120
|
+
l1_weight=0.1
|
|
121
|
+
)
|
|
122
|
+
loss_all = loss_fn_all(est_all, act_all)
|
|
123
|
+
assert loss_all.ndim == 0 and not torch.isnan(loss_all)
|
|
124
|
+
|
|
125
|
+
# Test pure L1 mode on all stems
|
|
126
|
+
loss_fn_l1 = MultiL1SNRDBLoss(name="l1_only", l1_weight=1.0)
|
|
127
|
+
l1_loss = loss_fn_l1(est_all, act_all)
|
|
128
|
+
|
|
129
|
+
# Can't easily compute multi-res STFT L1 here, but can check it's not nan
|
|
130
|
+
assert l1_loss.ndim == 0 and not torch.isnan(l1_loss)
|
|
131
|
+
|
|
132
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
133
|
+
def test_loss_variants(dummy_audio, l1_weight):
|
|
134
|
+
"""Test L1SNRDBLoss and STFTL1SNRDBLoss with different l1_weights."""
|
|
135
|
+
estimates, actuals = dummy_audio
|
|
136
|
+
|
|
137
|
+
time_loss_fn = L1SNRDBLoss(name=f"test_time_{l1_weight}", l1_weight=l1_weight)
|
|
138
|
+
time_loss = time_loss_fn(estimates, actuals)
|
|
139
|
+
assert not torch.isnan(time_loss) and not torch.isinf(time_loss)
|
|
140
|
+
|
|
141
|
+
spec_loss_fn = STFTL1SNRDBLoss(name=f"test_spec_{l1_weight}", l1_weight=l1_weight)
|
|
142
|
+
spec_loss = spec_loss_fn(estimates, actuals)
|
|
143
|
+
assert not torch.isnan(spec_loss) and not torch.isinf(spec_loss)
|