torch-l1-snr 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Christopher Landschoot
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,201 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-l1-snr
3
+ Version: 0.0.1
4
+ Summary: L1-SNR loss functions for audio source separation in PyTorch
5
+ Home-page: https://github.com/crlandsc/torch-l1-snr
6
+ Author: Christopher Landscaping
7
+ Author-email: crlandschoot@gmail.com
8
+ License: MIT
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Operating System :: OS Independent
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
21
+ Requires-Python: >=3.8
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch
25
+ Requires-Dist: torchaudio
26
+ Requires-Dist: numpy>=1.21.0
27
+ Dynamic: license-file
28
+
29
+ ![torch-l1-snr-logo](https://raw.githubusercontent.com/crlandsc/torch-l1-snr/main/images/logo.png) -->
30
+
31
+ # NOTE: Repo is currently a work-in-progress and not ready for installation & use.
32
+
33
+ [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1snr)](https://github.com/crlandsc/torch-l1snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1snr)](https://github.com/crlandsc/torch-l1snr/stargazers)
34
+
35
+ # torch-l1-snr
36
+
37
+ A PyTorch implementation of L1-based Signal-to-Noise Ratio (SNR) loss functions for audio source separation. This package provides implementations and novel extensions based on concepts from recent academic papers, offering flexible and robust loss functions that can be easily integrated into any PyTorch-based audio separation pipeline.
38
+
39
+ The core `L1SNRLoss` is based on the loss function described in [1], while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [2].
40
+
41
+ ## Features
42
+
43
+ - **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [1].
44
+ - **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [2], plus an optional L1 loss component.
45
+ - **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [2], calculated over multiple STFT resolutions.
46
+ - **Modular Stem-based Loss**: A wrapper that combines time and spectrogram domain losses and can be configured to run on specific stems.
47
+ - **Efficient & Robust**: Includes optimizations for pure L1 loss calculation and robust handling of `NaN`/`inf` values and short audio segments.
48
+
49
+ ## Installation
50
+
51
+ <!-- Add PyPI badges once the package is published -->
52
+ <!-- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
53
+ <!-- [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
54
+ <!-- [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
55
+
56
+ You can install the package directly from GitHub:
57
+
58
+ ```bash
59
+ pip install git+https://github.com/crlandsc/torch-l1snr.git
60
+ ```
61
+
62
+ Or, you can clone the repository and install it in editable mode for development:
63
+
64
+ ```bash
65
+ git clone https://github.com/crlandsc/torch-l1snr.git
66
+ cd torch-l1snr
67
+ pip install -e .
68
+ ```
69
+
70
+ ## Dependencies
71
+
72
+ - [PyTorch](https://pytorch.org/)
73
+ - [torchaudio](https://pytorch.org/audio/stable/index.html)
74
+
75
+ ## Supported Tensor Shapes
76
+
77
+ All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)` or `(batch, channels, samples)`. For 3D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
78
+
79
+ ## Usage
80
+
81
+ The loss functions can be imported directly from the `torch_l1snr` package.
82
+
83
+ ### Example: `L1SNRDBLoss` (Time Domain)
84
+
85
+ ```python
86
+ import torch
87
+ from torch_l1snr import L1SNRDBLoss
88
+
89
+ # Create dummy audio signals
90
+ estimates = torch.randn(4, 32000) # Batch of 4, 32000 samples
91
+ actuals = torch.randn(4, 32000)
92
+
93
+ # Initialize the loss function
94
+ # l1_weight=0.1 blends L1SNR with 10% L1 loss
95
+ loss_fn = L1SNRDBLoss(l1_weight=0.1)
96
+
97
+ # Calculate loss
98
+ loss = loss_fn(estimates, actuals)
99
+ loss.backward()
100
+
101
+ print(f"L1SNRDBLoss: {loss.item()}")
102
+ ```
103
+
104
+ ### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
105
+
106
+ ```python
107
+ import torch
108
+ from torch_l1snr import STFTL1SNRDBLoss
109
+
110
+ # Create dummy audio signals
111
+ estimates = torch.randn(4, 32000)
112
+ actuals = torch.randn(4, 32000)
113
+
114
+ # Initialize the loss function
115
+ # Uses multiple STFT resolutions by default
116
+ loss_fn = STFTL1SNRDBLoss(l1_weight=0.0) # Pure L1SNR + Regularization
117
+
118
+ # Calculate loss
119
+ loss = loss_fn(estimates, actuals)
120
+ loss.backward()
121
+
122
+ print(f"STFTL1SNRDBLoss: {loss.item()}")
123
+ ```
124
+
125
+ ### Example: `MultiL1SNRDBLoss` for a Combined Time+Spectrogram Loss
126
+
127
+ This loss combines the time-domain and spectrogram-domain losses into a single, weighted objective function.
128
+
129
+ ```python
130
+ import torch
131
+ from torch_l1snr import MultiL1SNRDBLoss
132
+
133
+ # Create dummy audio signals
134
+ # Shape: (batch, channels, samples)
135
+ estimates = torch.randn(2, 2, 44100) # Batch of 2, stereo
136
+ actuals = torch.randn(2, 2, 44100)
137
+
138
+ # --- Configuration ---
139
+ loss_fn = MultiL1SNRDBLoss(
140
+ weight=1.0, # Overall weight for this loss
141
+ spec_weight=0.7, # 70% spectrogram loss, 30% time-domain loss
142
+ l1_weight=0.1, # Use 10% L1, 90% L1SNR+Reg
143
+ )
144
+ loss = loss_fn(estimates, actuals)
145
+ print(f"Multi-domain Loss: {loss.item()}")
146
+ ```
147
+
148
+ ## Motivation
149
+
150
+ The goal of these loss functions is to provide a perceptually-informed and robust alternative to common audio losses like L1, L2 (MSE), and SI-SDR for training audio source separation models.
151
+
152
+ - **Robustness**: The L1 norm is less sensitive to large outliers than the L2 norm, making it more suitable for audio signals which can have sharp transients.
153
+ - **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
154
+ - **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
155
+
156
+ #### Level-Matching Regularization
157
+
158
+ A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [2]. This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
159
+
160
+ #### Multi-Resolution Spectrogram Analysis
161
+
162
+ The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts—from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training.
163
+
164
+ #### "All-or-Nothing" Behavior and `l1_weight`
165
+
166
+ A characteristic of SNR-style losses is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
167
+
168
+ While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
169
+
170
+ - `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
171
+ - `l1_weight=1.0`: Pure L1 loss.
172
+ - `0.0 < l1_weight < 1.0`: A weighted combination of the two.
173
+
174
+ The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
175
+
176
+ **Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), you may need to tune `l1_scale_time` and `l1_scale_spec`. This is to ensure the gradients of the L1 and L1SNR components are balanced, which is crucial for stable training. The default values provide a reasonable starting point, but monitoring the loss components is recommended to ensure they are scaled appropriately.
177
+
178
+ ## Limitations
179
+
180
+ - The L1SNR loss is not scale-invariant. Unlike SI-SNR, it requires the model's output to be correctly scaled relative to the target.
181
+ - While the dB scaling and regularization are psychoacoustically motivated, the loss does not model more complex perceptual phenomena like auditory masking.
182
+
183
+ ## Contributing
184
+
185
+ Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
186
+
187
+ ## License
188
+
189
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
190
+
191
+ ## Acknowledgments
192
+
193
+ The loss functions implemented here are based on the work of the authors of the referenced papers.
194
+
195
+ ## References
196
+
197
+ [1] K. N. Watcharasupat, C.-W. Wu, Y. Ding, I. Orife, A. J. Hipple, P. A. Williams, S. Kramer, A. Lerch, and W. Wolcott, "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation," IEEE Open Journal of Signal Processing, 2023. (arXiv:2309.02539)
198
+
199
+ [2] K. N. Watcharasupat and A. Lerch, "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries," arXiv:2501.16171.
200
+
201
+ [3] K. N. Watcharasupat and A. Lerch, "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems," Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024. (arXiv:2406.18747)
@@ -0,0 +1,173 @@
1
+ ![torch-l1-snr-logo](https://raw.githubusercontent.com/crlandsc/torch-l1-snr/main/images/logo.png) -->
2
+
3
+ # NOTE: Repo is currently a work-in-progress and not ready for installation & use.
4
+
5
+ [![LICENSE](https://img.shields.io/github/license/crlandsc/torch-l1snr)](https://github.com/crlandsc/torch-l1snr/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/torch-l1snr)](https://github.com/crlandsc/torch-l1snr/stargazers)
6
+
7
+ # torch-l1-snr
8
+
9
+ A PyTorch implementation of L1-based Signal-to-Noise Ratio (SNR) loss functions for audio source separation. This package provides implementations and novel extensions based on concepts from recent academic papers, offering flexible and robust loss functions that can be easily integrated into any PyTorch-based audio separation pipeline.
10
+
11
+ The core `L1SNRLoss` is based on the loss function described in [1], while `L1SNRDBLoss` and `STFTL1SNRDBLoss` are extensions of the adaptive level-matching regularization technique proposed in [2].
12
+
13
+ ## Features
14
+
15
+ - **Time-Domain L1SNR Loss**: A basic, time-domain L1-SNR loss, based on [1].
16
+ - **Regularized Time-Domain L1SNRDBLoss**: An extension of the L1SNR loss with adaptive level-matching regularization from [2], plus an optional L1 loss component.
17
+ - **Multi-Resolution STFT L1SNRDBLoss**: A spectrogram-domain version of the loss from [2], calculated over multiple STFT resolutions.
18
+ - **Modular Stem-based Loss**: A wrapper that combines time and spectrogram domain losses and can be configured to run on specific stems.
19
+ - **Efficient & Robust**: Includes optimizations for pure L1 loss calculation and robust handling of `NaN`/`inf` values and short audio segments.
20
+
21
+ ## Installation
22
+
23
+ <!-- Add PyPI badges once the package is published -->
24
+ <!-- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
25
+ <!-- [![PyPI - Version](https://img.shields.io/pypi/v/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
26
+ <!-- [![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-l1snr)](https://pypi.org/project/torch-l1snr/) -->
27
+
28
+ You can install the package directly from GitHub:
29
+
30
+ ```bash
31
+ pip install git+https://github.com/crlandsc/torch-l1snr.git
32
+ ```
33
+
34
+ Or, you can clone the repository and install it in editable mode for development:
35
+
36
+ ```bash
37
+ git clone https://github.com/crlandsc/torch-l1snr.git
38
+ cd torch-l1snr
39
+ pip install -e .
40
+ ```
41
+
42
+ ## Dependencies
43
+
44
+ - [PyTorch](https://pytorch.org/)
45
+ - [torchaudio](https://pytorch.org/audio/stable/index.html)
46
+
47
+ ## Supported Tensor Shapes
48
+
49
+ All loss functions in this package (`L1SNRLoss`, `L1SNRDBLoss`, `STFTL1SNRDBLoss`, and `MultiL1SNRDBLoss`) accept standard audio tensors of shape `(batch, samples)` or `(batch, channels, samples)`. For 3D tensors, the channel and sample dimensions are flattened before the time-domain losses are calculated. For the spectrogram-domain loss, a separate STFT is computed for each channel.
50
+
51
+ ## Usage
52
+
53
+ The loss functions can be imported directly from the `torch_l1snr` package.
54
+
55
+ ### Example: `L1SNRDBLoss` (Time Domain)
56
+
57
+ ```python
58
+ import torch
59
+ from torch_l1snr import L1SNRDBLoss
60
+
61
+ # Create dummy audio signals
62
+ estimates = torch.randn(4, 32000) # Batch of 4, 32000 samples
63
+ actuals = torch.randn(4, 32000)
64
+
65
+ # Initialize the loss function
66
+ # l1_weight=0.1 blends L1SNR with 10% L1 loss
67
+ loss_fn = L1SNRDBLoss(l1_weight=0.1)
68
+
69
+ # Calculate loss
70
+ loss = loss_fn(estimates, actuals)
71
+ loss.backward()
72
+
73
+ print(f"L1SNRDBLoss: {loss.item()}")
74
+ ```
75
+
76
+ ### Example: `STFTL1SNRDBLoss` (Spectrogram Domain)
77
+
78
+ ```python
79
+ import torch
80
+ from torch_l1snr import STFTL1SNRDBLoss
81
+
82
+ # Create dummy audio signals
83
+ estimates = torch.randn(4, 32000)
84
+ actuals = torch.randn(4, 32000)
85
+
86
+ # Initialize the loss function
87
+ # Uses multiple STFT resolutions by default
88
+ loss_fn = STFTL1SNRDBLoss(l1_weight=0.0) # Pure L1SNR + Regularization
89
+
90
+ # Calculate loss
91
+ loss = loss_fn(estimates, actuals)
92
+ loss.backward()
93
+
94
+ print(f"STFTL1SNRDBLoss: {loss.item()}")
95
+ ```
96
+
97
+ ### Example: `MultiL1SNRDBLoss` for a Combined Time+Spectrogram Loss
98
+
99
+ This loss combines the time-domain and spectrogram-domain losses into a single, weighted objective function.
100
+
101
+ ```python
102
+ import torch
103
+ from torch_l1snr import MultiL1SNRDBLoss
104
+
105
+ # Create dummy audio signals
106
+ # Shape: (batch, channels, samples)
107
+ estimates = torch.randn(2, 2, 44100) # Batch of 2, stereo
108
+ actuals = torch.randn(2, 2, 44100)
109
+
110
+ # --- Configuration ---
111
+ loss_fn = MultiL1SNRDBLoss(
112
+ weight=1.0, # Overall weight for this loss
113
+ spec_weight=0.7, # 70% spectrogram loss, 30% time-domain loss
114
+ l1_weight=0.1, # Use 10% L1, 90% L1SNR+Reg
115
+ )
116
+ loss = loss_fn(estimates, actuals)
117
+ print(f"Multi-domain Loss: {loss.item()}")
118
+ ```
119
+
120
+ ## Motivation
121
+
122
+ The goal of these loss functions is to provide a perceptually-informed and robust alternative to common audio losses like L1, L2 (MSE), and SI-SDR for training audio source separation models.
123
+
124
+ - **Robustness**: The L1 norm is less sensitive to large outliers than the L2 norm, making it more suitable for audio signals which can have sharp transients.
125
+ - **Perceptual Relevance**: The loss is scaled to decibels (dB), which more closely aligns with human perception of loudness.
126
+ - **Adaptive Regularization**: Prevents the model from collapsing to silent outputs by penalizing mismatches in the overall loudness (dBRMS) between the estimate and the target.
127
+
128
+ #### Level-Matching Regularization
129
+
130
+ A key feature of `L1SNRDBLoss` is the adaptive regularization term, as described in [2]. This component calculates the difference in decibel-scaled root-mean-square (dBRMS) levels between the estimated and actual signals. An adaptive weight (`lambda`) is applied to this difference, which increases when the model incorrectly silences a non-silent target. This encourages the model to learn the correct output level and specifically avoids the model collapsing to a trivial silent solution when uncertain.
131
+
132
+ #### Multi-Resolution Spectrogram Analysis
133
+
134
+ The `STFTL1SNRDBLoss` module applies the L1SNRDB loss across multiple time-frequency resolutions. By analyzing the signal with different STFT window sizes and hop lengths, the loss function can capture a wider range of artifacts—from short, transient errors to longer, tonal discrepancies. This provides a more comprehensive error signal to the model during training.
135
+
136
+ #### "All-or-Nothing" Behavior and `l1_weight`
137
+
138
+ A characteristic of SNR-style losses is that they encourage the model to make definitive, "all-or-nothing" separation decisions. This can be highly effective for well-defined sources, as it pushes the model to be confident in its estimations. However, this can also lead to "confident errors," where the model completely removes a signal component it should have kept.
139
+
140
+ While the Level-Matching Regularization prevents a *total collapse to silence*, it does not by itself solve this issue of overly confident, hard-boundary separation. To provide a tunable solution, this implementation introduces a novel `l1_weight` hyperparameter. This allows you to create a hybrid loss, blending the decisive L1SNR objective with a standard L1 loss to soften its "all-or-nothing"-style behavior and allow for more nuanced separation.
141
+
142
+ - `l1_weight=0.0` (Default): Pure L1SNR (+ regularization).
143
+ - `l1_weight=1.0`: Pure L1 loss.
144
+ - `0.0 < l1_weight < 1.0`: A weighted combination of the two.
145
+
146
+ The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
147
+
148
+ **Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), you may need to tune `l1_scale_time` and `l1_scale_spec`. This is to ensure the gradients of the L1 and L1SNR components are balanced, which is crucial for stable training. The default values provide a reasonable starting point, but monitoring the loss components is recommended to ensure they are scaled appropriately.
149
+
150
+ ## Limitations
151
+
152
+ - The L1SNR loss is not scale-invariant. Unlike SI-SNR, it requires the model's output to be correctly scaled relative to the target.
153
+ - While the dB scaling and regularization are psychoacoustically motivated, the loss does not model more complex perceptual phenomena like auditory masking.
154
+
155
+ ## Contributing
156
+
157
+ Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.
158
+
159
+ ## License
160
+
161
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
162
+
163
+ ## Acknowledgments
164
+
165
+ The loss functions implemented here are based on the work of the authors of the referenced papers.
166
+
167
+ ## References
168
+
169
+ [1] K. N. Watcharasupat, C.-W. Wu, Y. Ding, I. Orife, A. J. Hipple, P. A. Williams, S. Kramer, A. Lerch, and W. Wolcott, "A Generalized Bandsplit Neural Network for Cinematic Audio Source Separation," IEEE Open Journal of Signal Processing, 2023. (arXiv:2309.02539)
170
+
171
+ [2] K. N. Watcharasupat and A. Lerch, "Separate This, and All of these Things Around It: Music Source Separation via Hyperellipsoidal Queries," arXiv:2501.16171.
172
+
173
+ [3] K. N. Watcharasupat and A. Lerch, "A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems," Proceedings of the 25th International Society for Music Information Retrieval Conference, 2024. (arXiv:2406.18747)
@@ -0,0 +1,3 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,36 @@
1
+ [metadata]
2
+ name = torch-l1-snr
3
+ version = 0.0.1
4
+ author = Christopher Landscaping
5
+ author_email = crlandschoot@gmail.com
6
+ description = L1-SNR loss functions for audio source separation in PyTorch
7
+ long_description = file: README.md
8
+ long_description_content_type = text/markdown
9
+ url = https://github.com/crlandsc/torch-l1-snr
10
+ license = MIT
11
+ classifiers =
12
+ Intended Audience :: Developers
13
+ Intended Audience :: Science/Research
14
+ License :: OSI Approved :: MIT License
15
+ Programming Language :: Python :: 3
16
+ Programming Language :: Python :: 3.8
17
+ Programming Language :: Python :: 3.9
18
+ Programming Language :: Python :: 3.10
19
+ Programming Language :: Python :: 3.11
20
+ Programming Language :: Python :: 3.12
21
+ Operating System :: OS Independent
22
+ Topic :: Scientific/Engineering :: Artificial Intelligence
23
+ Topic :: Multimedia :: Sound/Audio :: Analysis
24
+
25
+ [options]
26
+ packages = find:
27
+ python_requires = >=3.8
28
+ install_requires =
29
+ torch
30
+ torchaudio
31
+ numpy>=1.21.0
32
+
33
+ [egg_info]
34
+ tag_build =
35
+ tag_date = 0
36
+
@@ -0,0 +1,143 @@
1
+ import torch
2
+ import pytest
3
+ from torch_l1snr import (
4
+ dbrms,
5
+ L1SNRLoss,
6
+ L1SNRDBLoss,
7
+ STFTL1SNRDBLoss,
8
+ MultiL1SNRDBLoss,
9
+ )
10
+
11
+ # --- Test Fixtures ---
12
+ @pytest.fixture
13
+ def dummy_audio():
14
+ """Provides a batch of dummy audio signals."""
15
+ estimates = torch.randn(2, 16000)
16
+ actuals = torch.randn(2, 16000)
17
+ # Ensure actuals are not all zero to avoid division by zero in loss
18
+ actuals[0, :100] += 0.1
19
+ return estimates, actuals
20
+
21
+ @pytest.fixture
22
+ def dummy_stems():
23
+ """Provides a batch of dummy multi-stem signals."""
24
+ estimates = torch.randn(2, 4, 1, 16000) # batch, stems, channels, samples
25
+ actuals = torch.randn(2, 4, 1, 16000)
26
+ actuals[:, 0, :, :100] += 0.1 # Ensure not all zero
27
+ return estimates, actuals
28
+
29
+ # --- Test Functions ---
30
+
31
+ def test_dbrms():
32
+ signal = torch.ones(2, 1000) * 0.1
33
+ # RMS of 0.1 is -20 dB
34
+ assert torch.allclose(dbrms(signal), torch.tensor([-20.0, -20.0]), atol=1e-4)
35
+
36
+ zeros = torch.zeros(2, 1000)
37
+ # dbrms of zero should be -80dB with default eps=1e-8
38
+ assert torch.allclose(dbrms(zeros), torch.tensor([-80.0, -80.0]), atol=1e-4)
39
+
40
+ def test_l1snr_loss(dummy_audio):
41
+ estimates, actuals = dummy_audio
42
+ loss_fn = L1SNRLoss(name="test")
43
+ loss = loss_fn(estimates, actuals)
44
+
45
+ assert isinstance(loss, torch.Tensor)
46
+ assert loss.ndim == 0
47
+ assert not torch.isnan(loss)
48
+ assert not torch.isinf(loss)
49
+
50
+ def test_l1snrdb_loss_time(dummy_audio):
51
+ estimates, actuals = dummy_audio
52
+
53
+ # Test with default settings (L1SNR + Regularization)
54
+ loss_fn = L1SNRDBLoss(name="test", use_regularization=True, l1_weight=0.0)
55
+ loss = loss_fn(estimates, actuals)
56
+ assert loss.ndim == 0 and not torch.isnan(loss)
57
+
58
+ # Test without regularization
59
+ loss_fn_no_reg = L1SNRDBLoss(name="test_no_reg", use_regularization=False, l1_weight=0.0)
60
+ loss_no_reg = loss_fn_no_reg(estimates, actuals)
61
+ assert loss_no_reg.ndim == 0 and not torch.isnan(loss_no_reg)
62
+
63
+ # Test with L1 loss component
64
+ loss_fn_l1 = L1SNRDBLoss(name="test_l1", l1_weight=0.2)
65
+ loss_l1 = loss_fn_l1(estimates, actuals)
66
+ assert loss_l1.ndim == 0 and not torch.isnan(loss_l1)
67
+
68
+ # Test pure L1 loss mode
69
+ loss_fn_pure_l1 = L1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
70
+ pure_l1_loss = loss_fn_pure_l1(estimates, actuals)
71
+ # Pure L1 mode uses torch.nn.L1Loss, so compare with manual L1 calculation
72
+ l1_loss_manual = torch.nn.L1Loss()(
73
+ estimates.reshape(estimates.shape[0], -1),
74
+ actuals.reshape(actuals.shape[0], -1)
75
+ )
76
+ assert torch.allclose(pure_l1_loss, l1_loss_manual)
77
+
78
+ def test_stft_l1snrdb_loss(dummy_audio):
79
+ estimates, actuals = dummy_audio
80
+
81
+ # Test with default settings
82
+ loss_fn = STFTL1SNRDBLoss(name="test", l1_weight=0.0)
83
+ loss = loss_fn(estimates, actuals)
84
+ assert loss.ndim == 0 and not torch.isnan(loss) and not torch.isinf(loss)
85
+
86
+ # Test pure L1 mode
87
+ loss_fn_pure_l1 = STFTL1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
88
+ l1_loss = loss_fn_pure_l1(estimates, actuals)
89
+ assert l1_loss.ndim == 0 and not torch.isnan(l1_loss) and not torch.isinf(l1_loss)
90
+
91
+ # Test with very short audio
92
+ short_estimates = estimates[:, :500]
93
+ short_actuals = actuals[:, :500]
94
+ loss_short = loss_fn(short_estimates, short_actuals)
95
+ # min_audio_length is 512, so this should fallback to time-domain loss
96
+ assert loss_short.ndim == 0 and not torch.isnan(loss_short)
97
+
98
+ def test_stem_multi_loss(dummy_stems):
99
+ estimates, actuals = dummy_stems
100
+
101
+ # Test with a specific stem - users now manage stems manually by slicing
102
+ # Extract stem 1 (second stem) manually
103
+ est_stem = estimates[:, 1, ...] # Shape: [batch, channels, samples]
104
+ act_stem = actuals[:, 1, ...]
105
+ loss_fn_stem = MultiL1SNRDBLoss(
106
+ name="test_loss_stem",
107
+ spec_weight=0.5,
108
+ l1_weight=0.1
109
+ )
110
+ loss = loss_fn_stem(est_stem, act_stem)
111
+ assert loss.ndim == 0 and not torch.isnan(loss)
112
+
113
+ # Test with all stems jointly - flatten all stems together
114
+ # Reshape to [batch, -1] to process all stems at once
115
+ est_all = estimates.reshape(estimates.shape[0], -1)
116
+ act_all = actuals.reshape(actuals.shape[0], -1)
117
+ loss_fn_all = MultiL1SNRDBLoss(
118
+ name="test_loss_all",
119
+ spec_weight=0.5,
120
+ l1_weight=0.1
121
+ )
122
+ loss_all = loss_fn_all(est_all, act_all)
123
+ assert loss_all.ndim == 0 and not torch.isnan(loss_all)
124
+
125
+ # Test pure L1 mode on all stems
126
+ loss_fn_l1 = MultiL1SNRDBLoss(name="l1_only", l1_weight=1.0)
127
+ l1_loss = loss_fn_l1(est_all, act_all)
128
+
129
+ # Can't easily compute multi-res STFT L1 here, but can check it's not nan
130
+ assert l1_loss.ndim == 0 and not torch.isnan(l1_loss)
131
+
132
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
133
+ def test_loss_variants(dummy_audio, l1_weight):
134
+ """Test L1SNRDBLoss and STFTL1SNRDBLoss with different l1_weights."""
135
+ estimates, actuals = dummy_audio
136
+
137
+ time_loss_fn = L1SNRDBLoss(name=f"test_time_{l1_weight}", l1_weight=l1_weight)
138
+ time_loss = time_loss_fn(estimates, actuals)
139
+ assert not torch.isnan(time_loss) and not torch.isinf(time_loss)
140
+
141
+ spec_loss_fn = STFTL1SNRDBLoss(name=f"test_spec_{l1_weight}", l1_weight=l1_weight)
142
+ spec_loss = spec_loss_fn(estimates, actuals)
143
+ assert not torch.isnan(spec_loss) and not torch.isinf(spec_loss)