torch-l1-snr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-l1-snr
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: L1-SNR loss functions for audio source separation in PyTorch
|
|
5
|
+
Home-page: https://github.com/crlandsc/torch-l1-snr
|
|
6
|
+
Author: Christopher Landscaping
|
|
7
|
+
Author-email: crlandschoot@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Operating System :: OS Independent
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
|
|
21
|
+
Requires-Python: >=3.8
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: torch
|
|
25
|
+
Requires-Dist: torchaudio
|
|
26
|
+
Requires-Dist: numpy>=1.21.0
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
|
|
29
|
+
 -->
|
|
30
|
+
|
|
31
|
+
# NOTE: Repo is currently a work-in-progress and not ready for installation & use.
|
|
32
|
+
|
|
33
|
+
[](https://github.com/crlandsc/torch-l1snr/blob/main/LICENSE) [](https://github.com/crlandsc/torch-l1snr/stargazers)
|
|
34
|
+
|
|
35
|
+
# torch-l1-snr
|
|
36
|
+
|
|
37
|
+
A PyTorch implementation of L1-based Signal-to-Noise Ratio (SNR) loss functions for audio source separation. This package provides implementations and novel extensions based on concepts from recent academic papers, offering flexible and robust loss functions that can be easily integrated into any PyTorch-based audio separation pipeline.
|
|
38
|
+
|
|
39
|
+
The core `L1SNRLoss` is based on the loss function described in [1], while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [2].
|
|
40
|
+
|
|
41
|
+
## Features
|
|
42
|
+
|
|
43
|
+
- **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [1].
|
|
44
|
+
- **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [2], plus an optional L1 loss component.
|
|
45
|
+
- **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [2], calculated over multiple STFT resolutions.
|
|
46
|
+
- **Modular Stem-based Loss**: A wrapper that combines time and spectrogram domain losses and can be configured to run on specific stems.
|
|
47
|
+
- **Efficient & Robust**: Includes optimizations for pure L1 loss calculation and robust handling of `NaN`/`inf` values and short audio segments.
|
|
48
|
+
|
|
49
|
+
## Installation
|
|
50
|
+
|
|
51
|
+
<!-- Add PyPI badges once the package is published -->
|
|
52
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
53
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
54
|
+
<!-- [](https://pypi.org/project/torch-l1snr/) -->
|
|
55
|
+
|
|
56
|
+
You can install the package directly from GitHub:
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
pip install git+https://github.com/crlandsc/torch-l1snr.git
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
Or, you can clone the repository and install it in editable mode for development:
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
git clone https://github.com/crlandsc/torch-l1snr.git
|
|
66
|
+
cd torch-l1snr
|
|
67
|
+
pip install -e .
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
## Dependencies
|
|
71
|
+
|
|
72
|
+
- [PyTorch](https://pytorch.org/)
|
|
73
|
+
- [torchaudio](https://pytorch.org/audio/stable/index.html)
|
|
74
|
+
|
|
75
|
+
## Supported Tensor Shapes
|
|
76
|
+
|
|
77
|
+
All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)` or `(batch, channels, samples)`. For 3D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
|
|
78
|
+
|
|
79
|
+
## Usage
|
|
80
|
+
|
|
81
|
+
The loss functions can be imported directly from the `torch_l1snr` package.
|
|
82
|
+
|
|
83
|
+
### Example: `L1SNRDBLoss` (Time Domain)
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
import torch
|
|
87
|
+
from torch_l1snr import L1SNRDBLoss
|
|
88
|
+
|
|
89
|
+
# Create dummy audio signals
|
|
90
|
+
estimates = torch.randn(4, 32000) # Batch of 4, 32000 samples
|
|
91
|
+
actuals = torch.randn(4, 32000)
|
|
92
|
+
|
|
93
|
+
# Initialize the loss function
|
|
94
|
+
# l1_weight=0.1 blends L1SNR with 10% L1 loss
|
|
95
|
+
loss_fn = L1SNRDBLoss(l1_weight=0.1)
|
|
96
|
+
|
|
97
|
+
# Calculate loss
|
|
98
|
+
loss = loss_fn(estimates, actuals)
|
|
99
|
+
loss.backward()
|
|
100
|
+
|
|
101
|
+
print(f"L1SNRDBLoss: {loss.item()}")
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
import torch
|
|
108
|
+
from torch_l1snr import STFTL1SNRDBLoss
|
|
109
|
+
|
|
110
|
+
# Create dummy audio signals
|
|
111
|
+
estimates = torch.randn(4, 32000)
|
|
112
|
+
actuals = torch.randn(4, 32000)
|
|
113
|
+
|
|
114
|
+
# Initialize the loss function
|
|
115
|
+
# Uses multiple STFT resolutions by default
|
|
116
|
+
loss_fn = STFTL1SNRDBLoss(l1_weight=0.0) # Pure L1SNR + Regularization
|
|
117
|
+
|
|
118
|
+
# Calculate loss
|
|
119
|
+
loss = loss_fn(estimates, actuals)
|
|
120
|
+
loss.backward()
|
|
121
|
+
|
|
122
|
+
print(f"STFTL1SNRDBLoss: {loss.item()}")
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
### Example: `MultiL1SNRDBLoss` for a Combined Time+Spectrogram Loss
|
|
126
|
+
|
|
127
|
+
This loss combines the time-domain and spectrogram-domain losses into a single, weighted objective function.
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
import torch
|
|
131
|
+
from torch_l1snr import MultiL1SNRDBLoss
|
|
132
|
+
|
|
133
|
+
# Create dummy audio signals
|
|
134
|
+
# Shape: (batch, channels, samples)
|
|
135
|
+
estimates = torch.randn(2, 2, 44100) # Batch of 2, stereo
|
|
136
|
+
actuals = torch.randn(2, 2, 44100)
|
|
137
|
+
|
|
138
|
+
# --- Configuration ---
|
|
139
|
+
loss_fn = MultiL1SNRDBLoss(
|
|
140
|
+
weight=1.0, # Overall weight for this loss
|
|
141
|
+
spec_weight=0.7, # 70% spectrogram loss, 30% time-domain loss
|
|
142
|
+
l1_weight=0.1, # Use 10% L1, 90% L1SNR+Reg
|
|
143
|
+
)
|
|
144
|
+
loss = loss_fn(estimates, actuals)
|
|
145
|
+
print(f"Multi-domain Loss: {loss.item()}")
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
## Motivation
|
|
149
|
+
|
|
150
|
+
The goal of these loss functions is to provide a perceptually-informed and robust alternative to common audio losses like L1, L2 (MSE), and SI-SDR for training audio source separation models.
|
|
151
|
+
|
|
152
|
+
- **Robustness**: The L1 norm is less sensitive to large outliers than the L2 norm, making it more suitable for audio signals which can have sharp transients.
|
|
153
|
+
- **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
|
|
154
|
+
- **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
|
|
155
|
+
|
|
156
|
+
#### Level-Matching Regularization
|
|
157
|
+
|
|
158
|
+
A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [2]. This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
|
|
159
|
+
|
|
160
|
+
#### Multi-Resolution Spectrogram Analysis
|
|
161
|
+
|
|
162
|
+
The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts—from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training.
|
|
163
|
+
|
|
164
|
+
#### "All-or-Nothing" Behavior and `l1_weight`
|
|
165
|
+
|
|
166
|
+
A characteristic of SNR-style losses is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
|
|
167
|
+
|
|
168
|
+
While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
|
|
169
|
+
|
|
170
|
+
- `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
|
|
171
|
+
- `l1_weight=1.0`: Pure L1 loss.
|
|
172
|
+
- `0.0 < l1_weight < 1.0`: A weighted combination of the two.
|
|
173
|
+
|
|
174
|
+
The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
|
|
175
|
+
|
|
176
|
+
**Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), you may need to tune `l1_scale_time` and `l1_scale_spec`. This is to ensure the gradients of the L1 and L1SNR components are balanced, which is crucial for stable training. The default values provide a reasonable starting point, but monitoring the loss components is recommended to ensure they are scaled appropriately.
|
|
177
|
+
|
|
178
|
+
## Limitations
|
|
179
|
+
|
|
180
|
+
- The L1SNR loss is not scale-invariant. Unlike SI-SNR, it requires the model's output to be correctly scaled relative to the target.
|
|
181
|
+
- While the dB scaling and regularization are psychoacoustically motivated, the loss does not model more complex perceptual phenomena like auditory masking.
|
|
182
|
+
|
|
183
|
+
## Contributing
|
|
184
|
+
|
|
185
|
+
Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
|
|
186
|
+
|
|
187
|
+
## License
|
|
188
|
+
|
|
189
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
190
|
+
|
|
191
|
+
## Acknowledgments
|
|
192
|
+
|
|
193
|
+
The loss functions implemented here are based on the work of the authors of the referenced papers.
|
|
194
|
+
|
|
195
|
+
## References
|
|
196
|
+
|
|
197
|
+
[1] K. N. Watcharasupat, C.-W. Wu, Y. Ding, I. Orife, A. J. Hipple, P. A. Williams, S. Kramer, A. Lerch, and W. Wolcott, "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation," IEEE Open Journal of Signal Processing, 2023. (arXiv:2309.02539)
|
|
198
|
+
|
|
199
|
+
[2] K. N. Watcharasupat and A. Lerch, "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries," arXiv:2501.16171.
|
|
200
|
+
|
|
201
|
+
[3] K. N. Watcharasupat and A. Lerch, "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems," Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024. (arXiv:2406.18747)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
torch_l1_snr-0.0.1.dist-info/licenses/LICENSE,sha256=gBRgAD6TvJXVLS9LVkO0V1GzNHKX2poZLFGtyc6hwq0,1079
|
|
2
|
+
torch_l1snr/__init__.py,sha256=pR9jg3fjTKt_suZoVDC67tqB7EWRkbfaXaPP7pYQrlQ,220
|
|
3
|
+
torch_l1snr/l1snr.py,sha256=aqmtNfT_8A0IRI9jiVGwNse3igBvelQGKnjfe23Xh7w,35304
|
|
4
|
+
torch_l1_snr-0.0.1.dist-info/METADATA,sha256=0EuezNJH0APVDEWIsmqW_6Wc9B1J-Zk3ksg8VknU7kY,10374
|
|
5
|
+
torch_l1_snr-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
+
torch_l1_snr-0.0.1.dist-info/top_level.txt,sha256=NfaRND6pcjZ7-035d4XAg8xJuz31EEU210Y9xWeFOxc,12
|
|
7
|
+
torch_l1_snr-0.0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Christopher Landschoot
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torch_l1snr
|
torch_l1snr/__init__.py
ADDED
torch_l1snr/l1snr.py
ADDED
|
@@ -0,0 +1,786 @@
|
|
|
1
|
+
# PyTorch implementation of L1SNR loss functions for audio source separation
|
|
2
|
+
# https://github.com/crlandsc/torch-l1-snr
|
|
3
|
+
# Copyright (c) 2026 crlandsc
|
|
4
|
+
# MIT License
|
|
5
|
+
#
|
|
6
|
+
# This implementation is based on and extends the loss functions described in:
|
|
7
|
+
# [1] "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries"
|
|
8
|
+
# Karn N. Watcharasupat, Alexander Lerch
|
|
9
|
+
# arXiv:2501.16171
|
|
10
|
+
# [2] "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation"
|
|
11
|
+
# Karn N. Watcharasupat, Chih-Wei Wu, Yiwei Ding, Iroro Orife, Aaron J. Hipple, Phillip A. Williams, Scott Kramer, Alexander Lerch, William Wolcott
|
|
12
|
+
# IEEE Open Journal of Signal Processing, 2023
|
|
13
|
+
# arXiv:2309.02539
|
|
14
|
+
# [3] "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems"
|
|
15
|
+
# Karn N. Watcharasupat, Alexander Lerch
|
|
16
|
+
# Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024
|
|
17
|
+
# arXiv:2406.18747
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
from torchaudio.transforms import Spectrogram
|
|
22
|
+
import math
|
|
23
|
+
from typing import Union, Dict, Tuple, Optional, List
|
|
24
|
+
|
|
25
|
+
def dbrms(x, eps=1e-8):
|
|
26
|
+
"""
|
|
27
|
+
Compute RMS level in decibels for a batch of signals.
|
|
28
|
+
Args:
|
|
29
|
+
x: (batch, time) or (batch, ...) tensor
|
|
30
|
+
eps: stability constant
|
|
31
|
+
Returns:
|
|
32
|
+
(batch,) tensor of dB RMS
|
|
33
|
+
"""
|
|
34
|
+
x = x.reshape(x.shape[0], -1)
|
|
35
|
+
rms = torch.sqrt(torch.mean(x ** 2, dim=-1) + eps)
|
|
36
|
+
return 20.0 * torch.log10(rms + eps)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class L1SNRLoss(torch.nn.Module):
|
|
40
|
+
"""
|
|
41
|
+
Implements the L1 Signal-to-Noise Ratio (SNR) loss with optional weighted L1 loss
|
|
42
|
+
component to balance "all-or-nothing" behavior.
|
|
43
|
+
|
|
44
|
+
Paper-aligned D1(ŷ; y) form:
|
|
45
|
+
D1 = 10 * log10( (||ŷ - y||_1 + eps) / (||y||_1 + eps) )
|
|
46
|
+
L1SNR_loss = mean(D1)
|
|
47
|
+
|
|
48
|
+
When l1_weight > 0, the loss combines L1SNR with scaled L1:
|
|
49
|
+
loss = (1 - l1_weight) * L1SNR_loss + l1_weight * L1_auto_scaled
|
|
50
|
+
|
|
51
|
+
Input Shape:
|
|
52
|
+
Accepts waveform tensors (time-domain audio) of any shape as long as they are batch-first.
|
|
53
|
+
Recommended shapes:
|
|
54
|
+
- [batch, time] for single-source audio
|
|
55
|
+
- [batch, num_sources, time] for multi-source audio
|
|
56
|
+
- [batch, num_sources, channels, time] for multi-channel multi-source audio
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
name (str): Name identifier for the loss.
|
|
60
|
+
weight (float): Global weight multiplier for the loss.
|
|
61
|
+
eps (float): Small epsilon for numerical stability in D1 (default 1e-3 per the papers).
|
|
62
|
+
l1_weight (float): Weight for the L1 term mixed into L1SNR.
|
|
63
|
+
"""
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
name,
|
|
67
|
+
weight: float = 1.0,
|
|
68
|
+
eps: float = 1e-3,
|
|
69
|
+
l1_weight: float = 0.0,
|
|
70
|
+
):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.name = name
|
|
73
|
+
self.weight = weight
|
|
74
|
+
self.eps = eps
|
|
75
|
+
self.l1_weight = l1_weight
|
|
76
|
+
|
|
77
|
+
def forward(self, estimates, actuals, *args, **kwargs):
|
|
78
|
+
batch_size = estimates.shape[0]
|
|
79
|
+
|
|
80
|
+
est_source = estimates.reshape(batch_size, -1)
|
|
81
|
+
act_source = actuals.reshape(batch_size, -1)
|
|
82
|
+
|
|
83
|
+
# L1 errors and reference
|
|
84
|
+
l1_error = torch.mean(torch.abs(est_source - act_source), dim=-1)
|
|
85
|
+
l1_true = torch.mean(torch.abs(act_source), dim=-1)
|
|
86
|
+
|
|
87
|
+
# Auto-balanced L1/SNR mixing
|
|
88
|
+
w = float(self.l1_weight)
|
|
89
|
+
|
|
90
|
+
# Pure-L1 shortcut: avoid D1 computation
|
|
91
|
+
if w >= 1.0:
|
|
92
|
+
return torch.mean(l1_error) * self.weight
|
|
93
|
+
|
|
94
|
+
# If pure SNR (w == 0) we can skip L1 scaling math
|
|
95
|
+
if w <= 0.0:
|
|
96
|
+
d1 = 10.0 * torch.log10((l1_error + self.eps) / (l1_true + self.eps))
|
|
97
|
+
l1snr_loss = torch.mean(d1)
|
|
98
|
+
return l1snr_loss * self.weight
|
|
99
|
+
|
|
100
|
+
# Mixed path
|
|
101
|
+
d1 = 10.0 * torch.log10((l1_error + self.eps) / (l1_true + self.eps))
|
|
102
|
+
l1snr_loss = torch.mean(d1)
|
|
103
|
+
|
|
104
|
+
c = 10.0 / math.log(10.0)
|
|
105
|
+
inv_mean = torch.mean(1.0 / (l1_error.detach() + self.eps))
|
|
106
|
+
# w-independent scaling to match typical gradient magnitudes
|
|
107
|
+
scale_time = c * inv_mean
|
|
108
|
+
l1_term = torch.mean(l1_error) * scale_time
|
|
109
|
+
|
|
110
|
+
if getattr(self, "balance_per_sample", False):
|
|
111
|
+
# per-sample w-independent scaling
|
|
112
|
+
bal = c / (l1_error.detach() + self.eps)
|
|
113
|
+
l1_term = torch.mean(l1_error * bal)
|
|
114
|
+
|
|
115
|
+
if getattr(self, "debug_balance", False):
|
|
116
|
+
g_d1 = (1.0 - w) * c * inv_mean
|
|
117
|
+
if getattr(self, "balance_per_sample", False):
|
|
118
|
+
g_l1 = w * torch.mean(c / (l1_error.detach() + self.eps))
|
|
119
|
+
else:
|
|
120
|
+
g_l1 = w * c * inv_mean
|
|
121
|
+
ratio = (g_l1 / (g_d1 + 1e-12)).item()
|
|
122
|
+
setattr(self, "last_balance_ratio", ratio)
|
|
123
|
+
|
|
124
|
+
loss = (1.0 - w) * l1snr_loss + w * l1_term
|
|
125
|
+
return loss * self.weight
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class L1SNRDBLoss(torch.nn.Module):
|
|
129
|
+
"""
|
|
130
|
+
Implements L1SNR plus adaptive level-matching regularization in the time domain
|
|
131
|
+
as described in arXiv:2501.16171, with optional L1 loss component to balance
|
|
132
|
+
"all-or-nothing" behavior.
|
|
133
|
+
|
|
134
|
+
The loss combines three components:
|
|
135
|
+
1. L1SNR loss: mean(10*log10((l1_error + eps) / (l1_true + eps)))
|
|
136
|
+
2. Level-matching regularization: λ*|L_pred - L_true|
|
|
137
|
+
Where λ is adaptively computed based on the signal levels
|
|
138
|
+
3. Optional L1 loss: mean(l1_error)
|
|
139
|
+
|
|
140
|
+
The complete loss is structured as:
|
|
141
|
+
When l1_weight < 1.0: total_loss = l1snr_loss + (1-l1_weight) * mean(reg_loss)
|
|
142
|
+
When l1_weight = 1.0: total_loss = l1_loss (pure L1, bypassing all other computations)
|
|
143
|
+
|
|
144
|
+
The adaptive weighting λ for regularization increases when loud parts of a stem aren't
|
|
145
|
+
reconstructed properly, helping balance between quality and level preservation.
|
|
146
|
+
|
|
147
|
+
When l1_weight=1.0, this loss efficiently switches to a pure L1 loss calculation,
|
|
148
|
+
bypassing all SNR and regularization computations for standard L1 behavior.
|
|
149
|
+
This is useful when you want to avoid the "all-or-nothing" behavior of the SNR-style loss.
|
|
150
|
+
|
|
151
|
+
Input Shape:
|
|
152
|
+
Accepts waveform tensors (time-domain audio) of any shape as long as they are batch-first.
|
|
153
|
+
Recommended shapes:
|
|
154
|
+
- [batch, time] for single-source audio
|
|
155
|
+
- [batch, num_sources, time] for multi-source audio
|
|
156
|
+
- [batch, num_sources, channels, time] for multi-channel multi-source audio
|
|
157
|
+
|
|
158
|
+
Attributes:
|
|
159
|
+
name (str): The name identifier for the loss.
|
|
160
|
+
weight (float): The overall weight multiplier for the loss.
|
|
161
|
+
lambda0 (float): Minimum regularization weight (λ_min).
|
|
162
|
+
delta_lambda (float): Range of extra weight for regularization (Δλ).
|
|
163
|
+
l1snr_eps (float): Epsilon value for the L1SNR component to avoid log(0).
|
|
164
|
+
dbrms_eps (float): Epsilon value for dBRMS calculation to avoid log(0).
|
|
165
|
+
lmin (float): Minimum dBRMS considered non-silent for adaptive weighting.
|
|
166
|
+
use_regularization (bool): Whether to use level-matching regularization.
|
|
167
|
+
If False, only the L1SNR (and optional L1) components are used.
|
|
168
|
+
l1_weight (float): Weight for the L1 loss component. Default 0 (disabled).
|
|
169
|
+
As this increases, the regularization term is also scaled down proportionally.
|
|
170
|
+
When set to 1.0, efficiently computes only L1 loss.
|
|
171
|
+
"""
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
name,
|
|
175
|
+
weight: float = 1.0,
|
|
176
|
+
lambda0: float = 0.1,
|
|
177
|
+
delta_lambda: float = 0.9,
|
|
178
|
+
l1snr_eps: float = 1e-3,
|
|
179
|
+
dbrms_eps: float = 1e-8,
|
|
180
|
+
lmin: float = -60.0,
|
|
181
|
+
use_regularization: bool = True,
|
|
182
|
+
l1_weight: float = 0.0,
|
|
183
|
+
):
|
|
184
|
+
super().__init__()
|
|
185
|
+
self.name = name
|
|
186
|
+
self.weight = weight
|
|
187
|
+
self.lambda0 = lambda0 # minimum regularization weight
|
|
188
|
+
self.delta_lambda = delta_lambda # range of extra weight
|
|
189
|
+
self.l1snr_eps = l1snr_eps
|
|
190
|
+
self.dbrms_eps = dbrms_eps
|
|
191
|
+
self.lmin = lmin
|
|
192
|
+
self.use_regularization = use_regularization
|
|
193
|
+
|
|
194
|
+
# Validate l1_weight is between 0.0 and 1.0 inclusive
|
|
195
|
+
assert 0.0 <= l1_weight <= 1.0, "l1_weight must be between 0.0 and 1.0 inclusive"
|
|
196
|
+
self.l1_weight = l1_weight
|
|
197
|
+
|
|
198
|
+
# Initialize component losses based on l1_weight
|
|
199
|
+
if self.l1_weight == 1.0:
|
|
200
|
+
# Pure L1 mode - only need L1 loss
|
|
201
|
+
self.l1snr_loss = None
|
|
202
|
+
self.l1_loss = torch.nn.L1Loss()
|
|
203
|
+
else:
|
|
204
|
+
# Standard mode with L1SNR (and optional weighted L1 if l1_weight > 0)
|
|
205
|
+
self.l1snr_loss = L1SNRLoss(
|
|
206
|
+
name="l1_snr",
|
|
207
|
+
weight=1.0, # We'll apply the weight at the end
|
|
208
|
+
eps=l1snr_eps,
|
|
209
|
+
l1_weight=l1_weight,
|
|
210
|
+
)
|
|
211
|
+
self.l1_loss = None
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def compute_adaptive_weight(L_pred, L_true, L_min, lambda0, delta_lambda, R):
|
|
215
|
+
"""
|
|
216
|
+
Implements the adaptive weighting of the level-matching regularization term, per arXiv:2501.16171.
|
|
217
|
+
Args:
|
|
218
|
+
L_pred: predicted dBRMS, shape (batch,)
|
|
219
|
+
L_true: reference dBRMS, shape (batch,)
|
|
220
|
+
L_min: minimum dBRMS considered non-silent (float)
|
|
221
|
+
lambda0: minimum weight for regularization
|
|
222
|
+
delta_lambda: range of extra weight for regularization
|
|
223
|
+
R: |L_pred - L_true|, shape (batch,)
|
|
224
|
+
Returns:
|
|
225
|
+
lambda_weight: shape (batch,)
|
|
226
|
+
"""
|
|
227
|
+
# Compute eta: 1 if L_true > max(L_pred, L_min), else 0
|
|
228
|
+
max_val = torch.max(L_pred, torch.full_like(L_true, L_min))
|
|
229
|
+
eta = (L_true > max_val).float()
|
|
230
|
+
denom = (L_true - L_min).clamp(min=1e-6)
|
|
231
|
+
clamp_arg = (R / denom).clamp(0.0, 1.0)
|
|
232
|
+
lam = lambda0 + eta * delta_lambda * clamp_arg
|
|
233
|
+
return lam.detach() # Stop-gradient
|
|
234
|
+
|
|
235
|
+
def forward(self, estimates, actuals, *args, **kwargs):
|
|
236
|
+
batch_size = estimates.shape[0]
|
|
237
|
+
|
|
238
|
+
est_source = estimates.reshape(batch_size, -1)
|
|
239
|
+
act_source = actuals.reshape(batch_size, -1)
|
|
240
|
+
|
|
241
|
+
# Pure L1 mode - efficient path that bypasses SNR and regularization
|
|
242
|
+
if self.l1_loss is not None:
|
|
243
|
+
l1_loss = self.l1_loss(est_source, act_source)
|
|
244
|
+
return l1_loss * self.weight
|
|
245
|
+
|
|
246
|
+
# Standard mode with L1SNR, regularization, and optional weighted L1
|
|
247
|
+
# 1. L1SNR reconstruction loss (with L1 component if l1_weight > 0)
|
|
248
|
+
l1snr_loss = self.l1snr_loss(estimates, actuals, *args, **kwargs)
|
|
249
|
+
|
|
250
|
+
# Only compute and apply regularization if enabled
|
|
251
|
+
if self.use_regularization:
|
|
252
|
+
# 2. Level-matching regularization
|
|
253
|
+
L_true = dbrms(act_source, self.dbrms_eps) # (batch,)
|
|
254
|
+
L_pred = dbrms(est_source, self.dbrms_eps) # (batch,)
|
|
255
|
+
R = torch.abs(L_pred - L_true) # (batch,)
|
|
256
|
+
|
|
257
|
+
lambda_weight = self.compute_adaptive_weight(L_pred, L_true, self.lmin, self.lambda0, self.delta_lambda, R) # (batch,)
|
|
258
|
+
|
|
259
|
+
reg_loss = lambda_weight * R
|
|
260
|
+
|
|
261
|
+
# Scale regularization by the same factor as L1SNR loss
|
|
262
|
+
l1snr_weight = 1.0 - self.l1_weight
|
|
263
|
+
total_loss = l1snr_loss + (l1snr_weight * torch.mean(reg_loss))
|
|
264
|
+
else:
|
|
265
|
+
# Skip regularization calculation entirely
|
|
266
|
+
total_loss = l1snr_loss
|
|
267
|
+
|
|
268
|
+
return total_loss * self.weight
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class STFTL1SNRDBLoss(torch.nn.Module):
|
|
272
|
+
"""
|
|
273
|
+
Implements L1SNR plus adaptive level-matching regularization in the spectrogram domain
|
|
274
|
+
as described in arXiv:2501.16171, with multi-resolution STFT and optional L1 loss component
|
|
275
|
+
to balance "all-or-nothing" behavior.
|
|
276
|
+
|
|
277
|
+
This loss applies the same principles as L1SNRDBLoss but operates in the complex
|
|
278
|
+
spectrogram domain across multiple time-frequency resolutions. For each resolution:
|
|
279
|
+
|
|
280
|
+
1. L1SNR loss: Computed on the complex STFT representation (real/imaginary parts)
|
|
281
|
+
2. Level-matching regularization: Applied to STFT magnitudes with adaptive weighting
|
|
282
|
+
3. Optional L1 loss: Direct L1 penalty on STFT differences
|
|
283
|
+
|
|
284
|
+
Multi-resolution processing helps capture both fine temporal details and frequency
|
|
285
|
+
characteristics. The loss averages results across all valid STFT resolutions.
|
|
286
|
+
|
|
287
|
+
The complete loss structure is similar to L1SNRDBLoss:
|
|
288
|
+
When l1_weight < 1.0: total_loss = l1snr_loss + (1-l1_weight) * spec_reg_coef * mean(reg_loss)
|
|
289
|
+
When l1_weight = 1.0: total_loss = l1_loss (pure L1 in spectrogram domain, bypassing all other computations)
|
|
290
|
+
|
|
291
|
+
When l1_weight=1.0, this loss efficiently switches to a pure L1 loss calculation in the
|
|
292
|
+
spectrogram domain, bypassing all SNR and regularization computations for standard L1 behavior.
|
|
293
|
+
This is useful when you want to avoid the "all-or-nothing" behavior of the SNR-style loss.
|
|
294
|
+
|
|
295
|
+
Input Shape:
|
|
296
|
+
Accepts waveform tensors (time-domain audio) of any shape as long as they are batch-first
|
|
297
|
+
and time-last. Recommended shapes:
|
|
298
|
+
- [batch, time] for single-source audio
|
|
299
|
+
- [batch, num_sources, time] for multi-source audio
|
|
300
|
+
- [batch, num_sources, channels, time] for multi-channel multi-source audio
|
|
301
|
+
|
|
302
|
+
Attributes:
|
|
303
|
+
name (str): The name identifier for the loss.
|
|
304
|
+
weight (float): The overall weight multiplier for the loss.
|
|
305
|
+
lambda0 (float): Minimum regularization weight (λ_min).
|
|
306
|
+
delta_lambda (float): Range of extra weight for regularization (Δλ).
|
|
307
|
+
l1snr_eps (float): Epsilon value for the L1SNR component to avoid log(0).
|
|
308
|
+
dbrms_eps (float): Epsilon value for dBRMS calculation to avoid log(0).
|
|
309
|
+
lmin (float): Minimum dBRMS considered non-silent for adaptive weighting.
|
|
310
|
+
n_ffts (List[int]): List of FFT sizes for multi-resolution STFT analysis.
|
|
311
|
+
hop_lengths (List[int]): List of hop lengths (STFT time steps) for each resolution.
|
|
312
|
+
win_lengths (List[int]): List of window lengths for each resolution.
|
|
313
|
+
window_fn (str): Window function for the STFT ('hann', 'hamming', etc.)
|
|
314
|
+
min_audio_length (int): Minimum audio length required for processing.
|
|
315
|
+
If audio is shorter, returns zero loss to avoid errors.
|
|
316
|
+
use_regularization (bool): Whether to use level-matching regularization.
|
|
317
|
+
If False, only the L1SNR (and optional L1) components are used.
|
|
318
|
+
l1_weight (float): Weight for the L1 loss component. Default 0 (disabled).
|
|
319
|
+
As this increases, the regularization term is also scaled down proportionally.
|
|
320
|
+
When set to 1.0, efficiently computes only L1 loss.
|
|
321
|
+
"""
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
name,
|
|
325
|
+
weight: float = 1.0,
|
|
326
|
+
lambda0: float = 0.1,
|
|
327
|
+
delta_lambda: float = 0.9,
|
|
328
|
+
l1snr_eps: float = 1e-3,
|
|
329
|
+
dbrms_eps: float = 1e-8,
|
|
330
|
+
lmin: float = -60.0,
|
|
331
|
+
n_ffts: List[int] = [512, 1024, 2048],
|
|
332
|
+
hop_lengths: List[int] = [128, 256, 512],
|
|
333
|
+
win_lengths: List[int] = [512, 1024, 2048],
|
|
334
|
+
window_fn: str = 'hann',
|
|
335
|
+
min_audio_length: int = 512,
|
|
336
|
+
use_regularization: bool = False,
|
|
337
|
+
spec_reg_coef: float = 0.1,
|
|
338
|
+
l1_weight: float = 0.0,
|
|
339
|
+
):
|
|
340
|
+
super().__init__()
|
|
341
|
+
self.name = name
|
|
342
|
+
self.weight = weight
|
|
343
|
+
self.min_audio_length = min_audio_length
|
|
344
|
+
|
|
345
|
+
# Validate STFT parameters
|
|
346
|
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths), "All STFT parameter lists must have the same length"
|
|
347
|
+
|
|
348
|
+
# Store STFT parameters for validation during forward pass
|
|
349
|
+
self.n_ffts = n_ffts
|
|
350
|
+
self.hop_lengths = hop_lengths
|
|
351
|
+
self.win_lengths = win_lengths
|
|
352
|
+
self.window_fn_name = window_fn
|
|
353
|
+
|
|
354
|
+
# Validate window sizes
|
|
355
|
+
for n_fft, win_length in zip(n_ffts, win_lengths):
|
|
356
|
+
assert n_fft >= win_length, f"FFT size ({n_fft}) must be greater than or equal to window length ({win_length})"
|
|
357
|
+
|
|
358
|
+
# Pre-initialize Spectrogram transforms for maximum efficiency
|
|
359
|
+
self.spectrogram_transforms = nn.ModuleList()
|
|
360
|
+
window_fn_callable = getattr(torch, f"{window_fn}_window")
|
|
361
|
+
|
|
362
|
+
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
|
|
363
|
+
# Create a spectrogram transform for each resolution
|
|
364
|
+
transform = Spectrogram(
|
|
365
|
+
n_fft=n_fft,
|
|
366
|
+
win_length=win_length,
|
|
367
|
+
hop_length=hop_length,
|
|
368
|
+
pad_mode="reflect",
|
|
369
|
+
center=True,
|
|
370
|
+
window_fn=window_fn_callable,
|
|
371
|
+
normalized=True,
|
|
372
|
+
power=None, # This ensures the output is complex
|
|
373
|
+
)
|
|
374
|
+
self.spectrogram_transforms.append(transform)
|
|
375
|
+
|
|
376
|
+
# Parameters for spectrogram domain level-matching
|
|
377
|
+
self.lambda0 = lambda0
|
|
378
|
+
self.delta_lambda = delta_lambda
|
|
379
|
+
self.lmin = lmin
|
|
380
|
+
self.dbrms_eps = dbrms_eps
|
|
381
|
+
self.l1snr_eps = l1snr_eps
|
|
382
|
+
|
|
383
|
+
# Add L1 loss parameters and validate
|
|
384
|
+
assert 0.0 <= l1_weight <= 1.0, "l1_weight must be between 0.0 and 1.0 inclusive"
|
|
385
|
+
self.l1_weight = l1_weight
|
|
386
|
+
|
|
387
|
+
# Flag for pure L1 mode
|
|
388
|
+
self.pure_l1_mode = (self.l1_weight == 1.0)
|
|
389
|
+
# Create L1 loss function for pure L1 mode
|
|
390
|
+
if self.pure_l1_mode:
|
|
391
|
+
self.l1_loss = torch.nn.L1Loss()
|
|
392
|
+
else:
|
|
393
|
+
self.l1_loss = None
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
# Add this parameter to control regularization
|
|
397
|
+
self.use_regularization = use_regularization
|
|
398
|
+
# Coefficient to scale spectral regularization (disabled by default)
|
|
399
|
+
self.spec_reg_coef = spec_reg_coef
|
|
400
|
+
|
|
401
|
+
# Fallback time-domain loss (used when audio is too short for TF processing)
|
|
402
|
+
self.fallback_time_loss = L1SNRDBLoss(
|
|
403
|
+
name=f"{name}_fallback_time",
|
|
404
|
+
weight=1.0,
|
|
405
|
+
lambda0=self.lambda0,
|
|
406
|
+
delta_lambda=self.delta_lambda,
|
|
407
|
+
l1snr_eps=self.l1snr_eps,
|
|
408
|
+
dbrms_eps=self.dbrms_eps,
|
|
409
|
+
lmin=self.lmin,
|
|
410
|
+
use_regularization=False, # regularizer belongs to TF for this class
|
|
411
|
+
l1_weight=self.l1_weight,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Simplified tracking
|
|
415
|
+
self.nan_inf_counts = {"inputs": 0, "spec_loss": 0}
|
|
416
|
+
|
|
417
|
+
def _compute_complex_spec_l1snr_loss(self, est_spec, act_spec):
|
|
418
|
+
"""
|
|
419
|
+
Compute TF-domain loss as per the papers:
|
|
420
|
+
- D1 on real part + D1 on imaginary part, summed.
|
|
421
|
+
- Optional L1 mixing applied symmetrically to Re/Im.
|
|
422
|
+
est_spec, act_spec: complex tensors with shape (B, C, F, T)
|
|
423
|
+
"""
|
|
424
|
+
# Ensure same shape (assert to avoid silent mismatches)
|
|
425
|
+
assert est_spec.shape == act_spec.shape, f"Spec shapes must match: {est_spec.shape} vs {act_spec.shape}"
|
|
426
|
+
|
|
427
|
+
# Split real/imag
|
|
428
|
+
est_re, est_im = est_spec.real, est_spec.imag
|
|
429
|
+
act_re, act_im = act_spec.real, act_spec.imag
|
|
430
|
+
|
|
431
|
+
B = est_spec.shape[0]
|
|
432
|
+
|
|
433
|
+
# Flatten to (B, -1)
|
|
434
|
+
est_re = est_re.reshape(B, -1)
|
|
435
|
+
act_re = act_re.reshape(B, -1)
|
|
436
|
+
est_im = est_im.reshape(B, -1)
|
|
437
|
+
act_im = act_im.reshape(B, -1)
|
|
438
|
+
|
|
439
|
+
# L1 errors and refs
|
|
440
|
+
err_re = torch.mean(torch.abs(est_re - act_re), dim=1)
|
|
441
|
+
ref_re = torch.mean(torch.abs(act_re), dim=1)
|
|
442
|
+
err_im = torch.mean(torch.abs(est_im - act_im), dim=1)
|
|
443
|
+
ref_im = torch.mean(torch.abs(act_im), dim=1)
|
|
444
|
+
|
|
445
|
+
# Paper-aligned D1 = 10*log10((||e||_1 + eps)/(||y||_1 + eps))
|
|
446
|
+
d1_re = 10.0 * torch.log10((err_re + self.l1snr_eps) / (ref_re + self.l1snr_eps))
|
|
447
|
+
d1_im = 10.0 * torch.log10((err_im + self.l1snr_eps) / (ref_im + self.l1snr_eps))
|
|
448
|
+
d1_sum = torch.mean(d1_re + d1_im) # mean over batch
|
|
449
|
+
|
|
450
|
+
# Pure L1 mode
|
|
451
|
+
if self.pure_l1_mode:
|
|
452
|
+
l1_re = torch.mean(err_re)
|
|
453
|
+
l1_im = torch.mean(err_im)
|
|
454
|
+
l1_term = 0.5 * (l1_re + l1_im)
|
|
455
|
+
return l1_term
|
|
456
|
+
|
|
457
|
+
# Mixed mode (auto-balanced L1/SNR) with per-batch scaling
|
|
458
|
+
w = float(self.l1_weight)
|
|
459
|
+
if 0.0 < w < 1.0:
|
|
460
|
+
c = 10.0 / math.log(10.0)
|
|
461
|
+
inv_mean_comp = torch.mean(0.5 * (1.0 / (err_re.detach() + self.l1snr_eps) +
|
|
462
|
+
1.0 / (err_im.detach() + self.l1snr_eps)))
|
|
463
|
+
# w-independent scaling to match typical gradient magnitudes (factor 2.0 for Re/Im symmetry)
|
|
464
|
+
scale_spec = 2.0 * c * inv_mean_comp
|
|
465
|
+
l1_term = 0.5 * (torch.mean(err_re) + torch.mean(err_im)) * scale_spec
|
|
466
|
+
|
|
467
|
+
if getattr(self, "balance_per_sample", False):
|
|
468
|
+
bal_re = c / (err_re.detach() + self.l1snr_eps)
|
|
469
|
+
bal_im = c / (err_im.detach() + self.l1snr_eps)
|
|
470
|
+
l1_term = 0.5 * (torch.mean(err_re * bal_re) + torch.mean(err_im * bal_im))
|
|
471
|
+
|
|
472
|
+
loss = (1.0 - w) * d1_sum + w * l1_term
|
|
473
|
+
return loss
|
|
474
|
+
elif w >= 1.0:
|
|
475
|
+
# Pure L1
|
|
476
|
+
l1_term = 0.5 * (torch.mean(err_re) + torch.mean(err_im))
|
|
477
|
+
return l1_term
|
|
478
|
+
else:
|
|
479
|
+
# Pure SNR (D1)
|
|
480
|
+
return d1_sum
|
|
481
|
+
|
|
482
|
+
def _compute_spec_level_matching(self, est_spec, act_spec):
|
|
483
|
+
"""
|
|
484
|
+
Compute the level matching regularization term for a spectrogram.
|
|
485
|
+
"""
|
|
486
|
+
batch_size = est_spec.shape[0]
|
|
487
|
+
|
|
488
|
+
# Make sure dimensions match before operations
|
|
489
|
+
if est_spec.shape != act_spec.shape:
|
|
490
|
+
# Resize to match the smaller of the two
|
|
491
|
+
min_freq = min(est_spec.shape[-2], act_spec.shape[-2])
|
|
492
|
+
min_time = min(est_spec.shape[-1], act_spec.shape[-1])
|
|
493
|
+
est_spec = est_spec[..., :min_freq, :min_time]
|
|
494
|
+
act_spec = act_spec[..., :min_freq, :min_time]
|
|
495
|
+
|
|
496
|
+
# For level-matching regularization, we use magnitude information
|
|
497
|
+
est_mag = torch.abs(est_spec)
|
|
498
|
+
act_mag = torch.abs(act_spec)
|
|
499
|
+
|
|
500
|
+
# Reshape once for efficiency
|
|
501
|
+
est_mag_flat = est_mag.reshape(batch_size, -1)
|
|
502
|
+
act_mag_flat = act_mag.reshape(batch_size, -1)
|
|
503
|
+
|
|
504
|
+
# Calculate dB levels
|
|
505
|
+
L_true = dbrms(act_mag_flat, self.dbrms_eps)
|
|
506
|
+
L_pred = dbrms(est_mag_flat, self.dbrms_eps)
|
|
507
|
+
|
|
508
|
+
R = torch.abs(L_pred - L_true)
|
|
509
|
+
|
|
510
|
+
# Use the adaptive weighting function
|
|
511
|
+
lambda_weight = L1SNRDBLoss.compute_adaptive_weight(
|
|
512
|
+
L_pred, L_true, self.lmin, self.lambda0, self.delta_lambda, R
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
return torch.mean(lambda_weight * R)
|
|
516
|
+
|
|
517
|
+
def _validate_audio_length(self, audio_length):
|
|
518
|
+
"""
|
|
519
|
+
Validates that the audio is long enough for the STFT parameters.
|
|
520
|
+
"""
|
|
521
|
+
if audio_length < self.min_audio_length:
|
|
522
|
+
return False
|
|
523
|
+
|
|
524
|
+
for n_fft, hop_length in zip(self.n_ffts, self.hop_lengths):
|
|
525
|
+
n_frames = (audio_length // hop_length) + 1
|
|
526
|
+
if n_frames < 2:
|
|
527
|
+
return False
|
|
528
|
+
|
|
529
|
+
return True
|
|
530
|
+
|
|
531
|
+
def forward(self, estimates, actuals, *args, **kwargs):
|
|
532
|
+
device = estimates.device
|
|
533
|
+
batch_size = estimates.shape[0]
|
|
534
|
+
|
|
535
|
+
# Basic NaN/Inf handling (simplified)
|
|
536
|
+
if torch.isnan(estimates).any() or torch.isinf(estimates).any() or torch.isnan(actuals).any() or torch.isinf(actuals).any():
|
|
537
|
+
self.nan_inf_counts["inputs"] += 1
|
|
538
|
+
estimates = torch.nan_to_num(estimates, nan=0.0, posinf=1.0, neginf=-1.0)
|
|
539
|
+
actuals = torch.nan_to_num(actuals, nan=0.0, posinf=1.0, neginf=-1.0)
|
|
540
|
+
|
|
541
|
+
est_source = estimates.reshape(batch_size, -1, estimates.shape[-1])
|
|
542
|
+
act_source = actuals.reshape(batch_size, -1, actuals.shape[-1])
|
|
543
|
+
|
|
544
|
+
# Validate audio length
|
|
545
|
+
audio_length = est_source.shape[-1]
|
|
546
|
+
if not self._validate_audio_length(audio_length):
|
|
547
|
+
# Fallback to time-domain L1SNR-style loss instead of zero
|
|
548
|
+
return self.fallback_time_loss(estimates, actuals, *args, **kwargs) * self.weight
|
|
549
|
+
|
|
550
|
+
# Track losses (initialize as tensors on the correct device for stability)
|
|
551
|
+
total_spec_loss = torch.tensor(0.0, device=device)
|
|
552
|
+
total_spec_reg_loss = torch.tensor(0.0, device=device)
|
|
553
|
+
valid_transforms = 0
|
|
554
|
+
|
|
555
|
+
# Ensure transforms are on the correct device
|
|
556
|
+
self.spectrogram_transforms.to(device)
|
|
557
|
+
|
|
558
|
+
# Process each resolution
|
|
559
|
+
for i, transform in enumerate(self.spectrogram_transforms):
|
|
560
|
+
try:
|
|
561
|
+
# Compute spectrograms using pre-initialized transforms
|
|
562
|
+
try:
|
|
563
|
+
est_spec = transform(est_source)
|
|
564
|
+
act_spec = transform(act_source)
|
|
565
|
+
except RuntimeError as e:
|
|
566
|
+
print(f"Error computing spectrogram for resolution {i}: {e}")
|
|
567
|
+
print(f"Parameters: n_fft={self.n_ffts[i]}, hop_length={self.hop_lengths[i]}, win_length={self.win_lengths[i]}")
|
|
568
|
+
continue
|
|
569
|
+
|
|
570
|
+
# Ensure same (B, C, F, T); crop only (F, T) if needed
|
|
571
|
+
if est_spec.shape != act_spec.shape:
|
|
572
|
+
min_f = min(est_spec.shape[-2], act_spec.shape[-2])
|
|
573
|
+
min_t = min(est_spec.shape[-1], act_spec.shape[-1])
|
|
574
|
+
est_spec = est_spec[..., :min_f, :min_t]
|
|
575
|
+
act_spec = act_spec[..., :min_f, :min_t]
|
|
576
|
+
|
|
577
|
+
# Compute complex spectral loss (either L1 or L1SNR based on self.pure_l1_mode)
|
|
578
|
+
try:
|
|
579
|
+
spec_loss = self._compute_complex_spec_l1snr_loss(est_spec, act_spec)
|
|
580
|
+
except RuntimeError as e:
|
|
581
|
+
print(f"Error computing complex spectral loss for resolution {i}: {e}")
|
|
582
|
+
continue
|
|
583
|
+
|
|
584
|
+
# Check for numerical issues
|
|
585
|
+
if torch.isnan(spec_loss) or torch.isinf(spec_loss):
|
|
586
|
+
self.nan_inf_counts["spec_loss"] += 1
|
|
587
|
+
continue
|
|
588
|
+
|
|
589
|
+
# Only compute regularization if not in pure L1 mode and regularization is enabled
|
|
590
|
+
if not self.pure_l1_mode and self.use_regularization:
|
|
591
|
+
try:
|
|
592
|
+
spec_reg_loss = self._compute_spec_level_matching(est_spec, act_spec)
|
|
593
|
+
|
|
594
|
+
# Check for numerical issues
|
|
595
|
+
if torch.isnan(spec_reg_loss) or torch.isinf(spec_reg_loss):
|
|
596
|
+
self.nan_inf_counts["spec_loss"] += 1
|
|
597
|
+
spec_reg_loss = 0.0 # Use zero reg_loss if there are issues
|
|
598
|
+
|
|
599
|
+
# Accumulate regularization loss
|
|
600
|
+
total_spec_reg_loss += spec_reg_loss
|
|
601
|
+
except RuntimeError as e:
|
|
602
|
+
print(f"Error computing spectral level-matching for resolution {i}: {e}")
|
|
603
|
+
|
|
604
|
+
# Accumulate loss
|
|
605
|
+
total_spec_loss += spec_loss
|
|
606
|
+
valid_transforms += 1
|
|
607
|
+
|
|
608
|
+
except RuntimeError as e:
|
|
609
|
+
print(f"Runtime error in spectrogram transform {i}: {e}")
|
|
610
|
+
continue
|
|
611
|
+
|
|
612
|
+
# If all transforms failed, return zero loss
|
|
613
|
+
if valid_transforms == 0:
|
|
614
|
+
print("Warning: All spectrogram transforms failed. Returning zero loss.")
|
|
615
|
+
return torch.tensor(0.0, device=device)
|
|
616
|
+
|
|
617
|
+
# Average losses across valid transforms
|
|
618
|
+
avg_spec_loss = total_spec_loss / valid_transforms
|
|
619
|
+
|
|
620
|
+
# For standard mode, apply regularization if enabled
|
|
621
|
+
if not self.pure_l1_mode and self.use_regularization:
|
|
622
|
+
avg_spec_reg_loss = total_spec_reg_loss / valid_transforms
|
|
623
|
+
# Scale spectral regularization by both (1 - l1_weight) and spec_reg_coef
|
|
624
|
+
l1snr_weight = 1.0 - self.l1_weight
|
|
625
|
+
final_loss = avg_spec_loss + l1snr_weight * (self.spec_reg_coef * avg_spec_reg_loss)
|
|
626
|
+
else:
|
|
627
|
+
final_loss = avg_spec_loss
|
|
628
|
+
|
|
629
|
+
return final_loss * self.weight
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
class MultiL1SNRDBLoss(torch.nn.Module):
|
|
633
|
+
"""
|
|
634
|
+
A modular loss function that combines time-domain and spectrogram-domain L1SNR and
|
|
635
|
+
adaptive level-matching losses, as described in arXiv:2501.16171, with optional
|
|
636
|
+
L1 loss component to balance "all-or-nothing" behavior.
|
|
637
|
+
|
|
638
|
+
This implementation uses separate specialized components:
|
|
639
|
+
- L1SNRDBLoss for time domain processing
|
|
640
|
+
- STFTL1SNRDBLoss for spectrogram domain processing
|
|
641
|
+
|
|
642
|
+
The loss combines time-domain and spectrogram-domain losses:
|
|
643
|
+
Loss = weight * [(1-spec_weight) * time_loss + spec_weight * spec_loss]
|
|
644
|
+
|
|
645
|
+
Where time_loss and spec_loss are computed by L1SNRDBLoss and STFTL1SNRDBLoss respectively,
|
|
646
|
+
each handling their own L1SNR, regularization, and optional L1 components as described
|
|
647
|
+
in their individual docstrings.
|
|
648
|
+
|
|
649
|
+
When l1_weight=1.0, this loss efficiently switches to a pure L1 loss calculation in both
|
|
650
|
+
domains, bypassing all SNR and regularization computations for standard L1 behavior.
|
|
651
|
+
This is useful when you want to avoid the "all-or-nothing" behavior of the SNR-style loss.
|
|
652
|
+
|
|
653
|
+
The regularization components use adaptive weighting based on level differences
|
|
654
|
+
between estimated and target signals, with weighting controlled by lambda0 and delta_lambda.
|
|
655
|
+
|
|
656
|
+
Input Shape:
|
|
657
|
+
Accepts waveform tensors (time-domain audio) of any shape as long as they are batch-first
|
|
658
|
+
and time-last. Recommended shapes:
|
|
659
|
+
- [batch, time] for single-source audio
|
|
660
|
+
- [batch, num_sources, time] for multi-source audio
|
|
661
|
+
- [batch, num_sources, channels, time] for multi-channel multi-source audio
|
|
662
|
+
|
|
663
|
+
Attributes:
|
|
664
|
+
name (str): The name identifier for the loss.
|
|
665
|
+
weight (float): The overall weight multiplier for the loss.
|
|
666
|
+
spec_weight (float): The weight for spectrogram domain loss relative to time domain.
|
|
667
|
+
Default 0.5 (equal weighting). Set higher to emphasize spectral accuracy.
|
|
668
|
+
use_time_regularization (bool): Whether to use level-matching regularization in time domain.
|
|
669
|
+
use_spec_regularization (bool): Whether to use level-matching regularization in spectogram domain.
|
|
670
|
+
l1_weight (float): Weight for the L1 loss component vs the L1SNR+reg components.
|
|
671
|
+
Default 0 (disabled). As this increases, the regularization term is also scaled down.
|
|
672
|
+
When set to 1.0, efficiently computes only L1 loss in both domains.
|
|
673
|
+
lambda0 (float): Minimum regularization weight for both domains.
|
|
674
|
+
delta_lambda (float): Range of extra weight for regularization in both domains.
|
|
675
|
+
time_loss_params (dict): Optional additional parameters to pass to time domain loss.
|
|
676
|
+
spec_loss_params (dict): Optional additional parameters to pass to spectrogram domain loss.
|
|
677
|
+
"""
|
|
678
|
+
def __init__(
|
|
679
|
+
self,
|
|
680
|
+
name,
|
|
681
|
+
weight: float = 1.0,
|
|
682
|
+
spec_weight: float = 0.5, # Balance between time and frequency domain
|
|
683
|
+
# L1 component parameters
|
|
684
|
+
l1_weight: float = 0.0, # Weight for the L1 loss component vs (L1SNR + Regularization).
|
|
685
|
+
# Note: Regularization term is also scaled by (1.0 - l1_weight).
|
|
686
|
+
# When set to 1.0, efficiently computes only L1 loss in both domains.
|
|
687
|
+
# auto-balanced mixing used
|
|
688
|
+
# Regularization on/off flags
|
|
689
|
+
use_time_regularization: bool = True,
|
|
690
|
+
use_spec_regularization: bool = False, # likely redundant if already using in time domain
|
|
691
|
+
# Default parameters for both loss components
|
|
692
|
+
lambda0: float = 0.1,
|
|
693
|
+
delta_lambda: float = 0.9,
|
|
694
|
+
l1snr_eps: float = 1e-3,
|
|
695
|
+
dbrms_eps: float = 1e-8,
|
|
696
|
+
lmin: float = -60.0,
|
|
697
|
+
# STFT parameters
|
|
698
|
+
n_ffts: List[int] = [512, 1024, 2048],
|
|
699
|
+
hop_lengths: List[int] = [128, 256, 512],
|
|
700
|
+
win_lengths: List[int] = [512, 1024, 2048],
|
|
701
|
+
window_fn: str = 'hann',
|
|
702
|
+
min_audio_length: int = 512,
|
|
703
|
+
# Allow for separate parameter overrides (e.g. different delta_lambda for time and spec)
|
|
704
|
+
time_loss_params: dict = None,
|
|
705
|
+
spec_loss_params: dict = None,
|
|
706
|
+
):
|
|
707
|
+
super().__init__()
|
|
708
|
+
self.name = name
|
|
709
|
+
self.weight = weight
|
|
710
|
+
self.spec_weight = spec_weight
|
|
711
|
+
|
|
712
|
+
# Validate l1_weight is in valid range
|
|
713
|
+
assert 0.0 <= l1_weight <= 1.0, "l1_weight must be between 0.0 and 1.0 inclusive"
|
|
714
|
+
self.l1_weight = l1_weight
|
|
715
|
+
self.use_time_regularization = use_time_regularization
|
|
716
|
+
self.use_spec_regularization = use_spec_regularization
|
|
717
|
+
|
|
718
|
+
# Set up default parameters
|
|
719
|
+
default_time_params = {
|
|
720
|
+
"name": f"{name}_time",
|
|
721
|
+
"weight": 1.0, # Will be scaled by the combined loss
|
|
722
|
+
"lambda0": lambda0,
|
|
723
|
+
"delta_lambda": delta_lambda,
|
|
724
|
+
"l1snr_eps": l1snr_eps,
|
|
725
|
+
"dbrms_eps": dbrms_eps,
|
|
726
|
+
"lmin": lmin,
|
|
727
|
+
"l1_weight": l1_weight,
|
|
728
|
+
"use_regularization": use_time_regularization # Apply time domain regularization flag
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
default_spec_params = {
|
|
732
|
+
"name": f"{name}_spec",
|
|
733
|
+
"weight": 1.0, # Will be scaled by the combined loss
|
|
734
|
+
"lambda0": lambda0,
|
|
735
|
+
"delta_lambda": delta_lambda,
|
|
736
|
+
"l1snr_eps": l1snr_eps,
|
|
737
|
+
"dbrms_eps": dbrms_eps,
|
|
738
|
+
"lmin": lmin,
|
|
739
|
+
"n_ffts": n_ffts,
|
|
740
|
+
"hop_lengths": hop_lengths,
|
|
741
|
+
"win_lengths": win_lengths,
|
|
742
|
+
"window_fn": window_fn,
|
|
743
|
+
"min_audio_length": min_audio_length,
|
|
744
|
+
"l1_weight": l1_weight,
|
|
745
|
+
|
|
746
|
+
"use_regularization": use_spec_regularization # Apply spectrogram domain regularization flag
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
# Override with any custom parameters
|
|
750
|
+
if time_loss_params:
|
|
751
|
+
default_time_params.update(time_loss_params)
|
|
752
|
+
if spec_loss_params:
|
|
753
|
+
default_spec_params.update(spec_loss_params)
|
|
754
|
+
|
|
755
|
+
# Create the specialized loss components
|
|
756
|
+
# Note: Component losses handle all optimizations internally based on l1_weight
|
|
757
|
+
# When l1_weight=1.0, they will efficiently bypass SNR and regularization calculations
|
|
758
|
+
self.time_loss = L1SNRDBLoss(**default_time_params)
|
|
759
|
+
self.spec_loss = STFTL1SNRDBLoss(**default_spec_params)
|
|
760
|
+
|
|
761
|
+
# For reference only, indicate if we're in pure L1 mode
|
|
762
|
+
self.pure_l1_mode = (self.l1_weight == 1.0)
|
|
763
|
+
|
|
764
|
+
def forward(self, estimates, actuals, *args, **kwargs):
|
|
765
|
+
"""
|
|
766
|
+
Forward pass to compute the combined multi-domain loss.
|
|
767
|
+
|
|
768
|
+
Args:
|
|
769
|
+
estimates: Model output predictions, shape [batch, ...] (batch-first, ..., time-last)
|
|
770
|
+
actuals: Ground truth targets, shape [batch, ...] (batch-first, ..., time-last)
|
|
771
|
+
*args, **kwargs: Additional arguments passed to sub-losses
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
Combined weighted loss from time and spectrogram domains
|
|
775
|
+
"""
|
|
776
|
+
# Compute time domain loss
|
|
777
|
+
time_loss = self.time_loss(estimates, actuals, *args, **kwargs)
|
|
778
|
+
|
|
779
|
+
# Compute spectrogram domain loss
|
|
780
|
+
spec_loss = self.spec_loss(estimates, actuals, *args, **kwargs)
|
|
781
|
+
|
|
782
|
+
# Combine with weighting
|
|
783
|
+
combined_loss = (1 - self.spec_weight) * time_loss + self.spec_weight * spec_loss
|
|
784
|
+
|
|
785
|
+
# Apply overall weight
|
|
786
|
+
return combined_loss * self.weight
|