torch-l1-snr 0.0.4__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torch_l1_snr-0.0.4/torch_l1_snr.egg-info → torch_l1_snr-0.0.5}/PKG-INFO +1 -1
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/setup.cfg +1 -1
- torch_l1_snr-0.0.5/tests/test_losses.py +545 -0
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5/torch_l1_snr.egg-info}/PKG-INFO +1 -1
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/torch_l1snr/__init__.py +3 -1
- torch_l1_snr-0.0.4/tests/test_losses.py +0 -143
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/LICENSE +0 -0
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/README.md +0 -0
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/pyproject.toml +0 -0
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/torch_l1_snr.egg-info/SOURCES.txt +0 -0
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/torch_l1_snr.egg-info/dependency_links.txt +0 -0
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/torch_l1_snr.egg-info/requires.txt +0 -0
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/torch_l1_snr.egg-info/top_level.txt +0 -0
- {torch_l1_snr-0.0.4 → torch_l1_snr-0.0.5}/torch_l1snr/l1snr.py +0 -0
|
@@ -0,0 +1,545 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pytest
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from torch_l1snr import (
|
|
5
|
+
dbrms,
|
|
6
|
+
L1SNRLoss,
|
|
7
|
+
L1SNRDBLoss,
|
|
8
|
+
STFTL1SNRDBLoss,
|
|
9
|
+
MultiL1SNRDBLoss,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
# --- Test Helper: Stem Wrapper ---
|
|
13
|
+
class StemWrappedLoss(torch.nn.Module):
|
|
14
|
+
"""Test helper matching user's pipL1SNRLoss wrapper pattern."""
|
|
15
|
+
def __init__(self, base_loss, stem_dimension: Optional[int] = None):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.base_loss = base_loss
|
|
18
|
+
self.stem_dimension = stem_dimension
|
|
19
|
+
|
|
20
|
+
def forward(self, estimates, actuals, *args, **kwargs):
|
|
21
|
+
if self.stem_dimension is not None:
|
|
22
|
+
# Handle both [B,S,T] and [B,S,C,T] shapes
|
|
23
|
+
if estimates.ndim == 3: # [B, S, T]
|
|
24
|
+
est_source = estimates[:, self.stem_dimension, :]
|
|
25
|
+
act_source = actuals[:, self.stem_dimension, :]
|
|
26
|
+
else: # [B, S, C, T]
|
|
27
|
+
est_source = estimates[:, self.stem_dimension, :, :]
|
|
28
|
+
act_source = actuals[:, self.stem_dimension, :, :]
|
|
29
|
+
return self.base_loss(est_source, act_source, *args, **kwargs)
|
|
30
|
+
else:
|
|
31
|
+
return self.base_loss(estimates, actuals, *args, **kwargs)
|
|
32
|
+
|
|
33
|
+
# --- Test Fixtures ---
|
|
34
|
+
@pytest.fixture
|
|
35
|
+
def dummy_audio():
|
|
36
|
+
"""Provides a batch of dummy audio signals."""
|
|
37
|
+
estimates = torch.randn(2, 16000)
|
|
38
|
+
actuals = torch.randn(2, 16000)
|
|
39
|
+
# Ensure actuals are not all zero to avoid division by zero in loss
|
|
40
|
+
actuals[0, :100] += 0.1
|
|
41
|
+
return estimates, actuals
|
|
42
|
+
|
|
43
|
+
@pytest.fixture
|
|
44
|
+
def dummy_stems():
|
|
45
|
+
"""Provides a batch of dummy multi-stem signals."""
|
|
46
|
+
estimates = torch.randn(2, 4, 1, 16000) # batch, stems, channels, samples
|
|
47
|
+
actuals = torch.randn(2, 4, 1, 16000)
|
|
48
|
+
actuals[:, 0, :, :100] += 0.1 # Ensure not all zero
|
|
49
|
+
return estimates, actuals
|
|
50
|
+
|
|
51
|
+
@pytest.fixture
|
|
52
|
+
def dummy_stems_3d():
|
|
53
|
+
"""Multi-stem signals: [B, S, T]"""
|
|
54
|
+
estimates = torch.randn(2, 4, 16000)
|
|
55
|
+
actuals = torch.randn(2, 4, 16000)
|
|
56
|
+
actuals[:, 0, :100] += 0.1 # Ensure not all zero
|
|
57
|
+
return estimates, actuals
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def dummy_stems_4d():
|
|
61
|
+
"""Multi-stem signals: [B, S, C, T]"""
|
|
62
|
+
estimates = torch.randn(2, 4, 1, 16000)
|
|
63
|
+
actuals = torch.randn(2, 4, 1, 16000)
|
|
64
|
+
actuals[:, 0, :, :100] += 0.1
|
|
65
|
+
return estimates, actuals
|
|
66
|
+
|
|
67
|
+
# --- Test Functions ---
|
|
68
|
+
|
|
69
|
+
def test_dbrms():
|
|
70
|
+
signal = torch.ones(2, 1000) * 0.1
|
|
71
|
+
# RMS of 0.1 is -20 dB
|
|
72
|
+
assert torch.allclose(dbrms(signal), torch.tensor([-20.0, -20.0]), atol=1e-4)
|
|
73
|
+
|
|
74
|
+
zeros = torch.zeros(2, 1000)
|
|
75
|
+
# dbrms of zero should be -80dB with default eps=1e-8
|
|
76
|
+
assert torch.allclose(dbrms(zeros), torch.tensor([-80.0, -80.0]), atol=1e-4)
|
|
77
|
+
|
|
78
|
+
def test_l1snr_loss(dummy_audio):
|
|
79
|
+
estimates, actuals = dummy_audio
|
|
80
|
+
loss_fn = L1SNRLoss(name="test")
|
|
81
|
+
loss = loss_fn(estimates, actuals)
|
|
82
|
+
|
|
83
|
+
assert isinstance(loss, torch.Tensor)
|
|
84
|
+
assert loss.ndim == 0
|
|
85
|
+
assert not torch.isnan(loss)
|
|
86
|
+
assert not torch.isinf(loss)
|
|
87
|
+
|
|
88
|
+
def test_l1snrdb_loss_time(dummy_audio):
|
|
89
|
+
estimates, actuals = dummy_audio
|
|
90
|
+
|
|
91
|
+
# Test with default settings (L1SNR + Regularization)
|
|
92
|
+
loss_fn = L1SNRDBLoss(name="test", use_regularization=True, l1_weight=0.0)
|
|
93
|
+
loss = loss_fn(estimates, actuals)
|
|
94
|
+
assert loss.ndim == 0 and not torch.isnan(loss)
|
|
95
|
+
|
|
96
|
+
# Test without regularization
|
|
97
|
+
loss_fn_no_reg = L1SNRDBLoss(name="test_no_reg", use_regularization=False, l1_weight=0.0)
|
|
98
|
+
loss_no_reg = loss_fn_no_reg(estimates, actuals)
|
|
99
|
+
assert loss_no_reg.ndim == 0 and not torch.isnan(loss_no_reg)
|
|
100
|
+
|
|
101
|
+
# Test with L1 loss component
|
|
102
|
+
loss_fn_l1 = L1SNRDBLoss(name="test_l1", l1_weight=0.2)
|
|
103
|
+
loss_l1 = loss_fn_l1(estimates, actuals)
|
|
104
|
+
assert loss_l1.ndim == 0 and not torch.isnan(loss_l1)
|
|
105
|
+
|
|
106
|
+
# Test pure L1 loss mode
|
|
107
|
+
loss_fn_pure_l1 = L1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
|
|
108
|
+
pure_l1_loss = loss_fn_pure_l1(estimates, actuals)
|
|
109
|
+
# Pure L1 mode uses torch.nn.L1Loss, so compare with manual L1 calculation
|
|
110
|
+
l1_loss_manual = torch.nn.L1Loss()(
|
|
111
|
+
estimates.reshape(estimates.shape[0], -1),
|
|
112
|
+
actuals.reshape(actuals.shape[0], -1)
|
|
113
|
+
)
|
|
114
|
+
assert torch.allclose(pure_l1_loss, l1_loss_manual)
|
|
115
|
+
|
|
116
|
+
def test_stft_l1snrdb_loss(dummy_audio):
|
|
117
|
+
estimates, actuals = dummy_audio
|
|
118
|
+
|
|
119
|
+
# Test with default settings
|
|
120
|
+
loss_fn = STFTL1SNRDBLoss(name="test", l1_weight=0.0)
|
|
121
|
+
loss = loss_fn(estimates, actuals)
|
|
122
|
+
assert loss.ndim == 0 and not torch.isnan(loss) and not torch.isinf(loss)
|
|
123
|
+
|
|
124
|
+
# Test pure L1 mode
|
|
125
|
+
loss_fn_pure_l1 = STFTL1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
|
|
126
|
+
l1_loss = loss_fn_pure_l1(estimates, actuals)
|
|
127
|
+
assert l1_loss.ndim == 0 and not torch.isnan(l1_loss) and not torch.isinf(l1_loss)
|
|
128
|
+
|
|
129
|
+
# Test with very short audio
|
|
130
|
+
short_estimates = estimates[:, :500]
|
|
131
|
+
short_actuals = actuals[:, :500]
|
|
132
|
+
loss_short = loss_fn(short_estimates, short_actuals)
|
|
133
|
+
# min_audio_length is 512, so this should fallback to time-domain loss
|
|
134
|
+
assert loss_short.ndim == 0 and not torch.isnan(loss_short)
|
|
135
|
+
|
|
136
|
+
def test_stem_multi_loss(dummy_stems):
|
|
137
|
+
estimates, actuals = dummy_stems
|
|
138
|
+
|
|
139
|
+
# Test with a specific stem - users now manage stems manually by slicing
|
|
140
|
+
# Extract stem 1 (second stem) manually
|
|
141
|
+
est_stem = estimates[:, 1, ...] # Shape: [batch, channels, samples]
|
|
142
|
+
act_stem = actuals[:, 1, ...]
|
|
143
|
+
loss_fn_stem = MultiL1SNRDBLoss(
|
|
144
|
+
name="test_loss_stem",
|
|
145
|
+
spec_weight=0.5,
|
|
146
|
+
l1_weight=0.1
|
|
147
|
+
)
|
|
148
|
+
loss = loss_fn_stem(est_stem, act_stem)
|
|
149
|
+
assert loss.ndim == 0 and not torch.isnan(loss)
|
|
150
|
+
|
|
151
|
+
# Test with all stems jointly - flatten all stems together
|
|
152
|
+
# Reshape to [batch, -1] to process all stems at once
|
|
153
|
+
est_all = estimates.reshape(estimates.shape[0], -1)
|
|
154
|
+
act_all = actuals.reshape(actuals.shape[0], -1)
|
|
155
|
+
loss_fn_all = MultiL1SNRDBLoss(
|
|
156
|
+
name="test_loss_all",
|
|
157
|
+
spec_weight=0.5,
|
|
158
|
+
l1_weight=0.1
|
|
159
|
+
)
|
|
160
|
+
loss_all = loss_fn_all(est_all, act_all)
|
|
161
|
+
assert loss_all.ndim == 0 and not torch.isnan(loss_all)
|
|
162
|
+
|
|
163
|
+
# Test pure L1 mode on all stems
|
|
164
|
+
loss_fn_l1 = MultiL1SNRDBLoss(name="l1_only", l1_weight=1.0)
|
|
165
|
+
l1_loss = loss_fn_l1(est_all, act_all)
|
|
166
|
+
|
|
167
|
+
# Can't easily compute multi-res STFT L1 here, but can check it's not nan
|
|
168
|
+
assert l1_loss.ndim == 0 and not torch.isnan(l1_loss)
|
|
169
|
+
|
|
170
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
171
|
+
def test_loss_variants(dummy_audio, l1_weight):
|
|
172
|
+
"""Test L1SNRDBLoss and STFTL1SNRDBLoss with different l1_weights."""
|
|
173
|
+
estimates, actuals = dummy_audio
|
|
174
|
+
|
|
175
|
+
time_loss_fn = L1SNRDBLoss(name=f"test_time_{l1_weight}", l1_weight=l1_weight)
|
|
176
|
+
time_loss = time_loss_fn(estimates, actuals)
|
|
177
|
+
assert not torch.isnan(time_loss) and not torch.isinf(time_loss)
|
|
178
|
+
|
|
179
|
+
spec_loss_fn = STFTL1SNRDBLoss(name=f"test_spec_{l1_weight}", l1_weight=l1_weight)
|
|
180
|
+
spec_loss = spec_loss_fn(estimates, actuals)
|
|
181
|
+
assert not torch.isnan(spec_loss) and not torch.isinf(spec_loss)
|
|
182
|
+
|
|
183
|
+
# --- Wrapper-Paradigm Tests ---
|
|
184
|
+
|
|
185
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
186
|
+
def test_l1snr_wrapper_all_stems_3d(dummy_stems_3d, l1_weight):
|
|
187
|
+
"""Test L1SNRLoss wrapper with stem_dimension=None on [B,S,T]."""
|
|
188
|
+
estimates, actuals = dummy_stems_3d
|
|
189
|
+
base_loss = L1SNRLoss(name="test", weight=1.0, l1_weight=l1_weight)
|
|
190
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
191
|
+
|
|
192
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
193
|
+
direct_result = base_loss(estimates, actuals)
|
|
194
|
+
|
|
195
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
196
|
+
assert wrapped_result.ndim == 0
|
|
197
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
198
|
+
|
|
199
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
200
|
+
def test_l1snr_wrapper_all_stems_4d(dummy_stems_4d, l1_weight):
|
|
201
|
+
"""Test L1SNRLoss wrapper with stem_dimension=None on [B,S,C,T]."""
|
|
202
|
+
estimates, actuals = dummy_stems_4d
|
|
203
|
+
base_loss = L1SNRLoss(name="test", weight=1.0, l1_weight=l1_weight)
|
|
204
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
205
|
+
|
|
206
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
207
|
+
direct_result = base_loss(estimates, actuals)
|
|
208
|
+
|
|
209
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
210
|
+
assert wrapped_result.ndim == 0
|
|
211
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
212
|
+
|
|
213
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
214
|
+
@pytest.mark.parametrize("stem_idx", [0, 3])
|
|
215
|
+
def test_l1snr_wrapper_single_stem_3d(dummy_stems_3d, l1_weight, stem_idx):
|
|
216
|
+
"""Test L1SNRLoss wrapper with stem_dimension=k on [B,S,T]."""
|
|
217
|
+
estimates, actuals = dummy_stems_3d
|
|
218
|
+
base_loss = L1SNRLoss(name="test", weight=1.0, l1_weight=l1_weight)
|
|
219
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
|
|
220
|
+
|
|
221
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
222
|
+
est_slice = estimates[:, stem_idx, :]
|
|
223
|
+
act_slice = actuals[:, stem_idx, :]
|
|
224
|
+
direct_result = base_loss(est_slice, act_slice)
|
|
225
|
+
|
|
226
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
227
|
+
assert wrapped_result.ndim == 0
|
|
228
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
229
|
+
|
|
230
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
231
|
+
@pytest.mark.parametrize("stem_idx", [0, 3])
|
|
232
|
+
def test_l1snr_wrapper_single_stem_4d(dummy_stems_4d, l1_weight, stem_idx):
|
|
233
|
+
"""Test L1SNRLoss wrapper with stem_dimension=k on [B,S,C,T]."""
|
|
234
|
+
estimates, actuals = dummy_stems_4d
|
|
235
|
+
base_loss = L1SNRLoss(name="test", weight=1.0, l1_weight=l1_weight)
|
|
236
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
|
|
237
|
+
|
|
238
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
239
|
+
est_slice = estimates[:, stem_idx, :, :]
|
|
240
|
+
act_slice = actuals[:, stem_idx, :, :]
|
|
241
|
+
direct_result = base_loss(est_slice, act_slice)
|
|
242
|
+
|
|
243
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
244
|
+
assert wrapped_result.ndim == 0
|
|
245
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
246
|
+
|
|
247
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
248
|
+
@pytest.mark.parametrize("use_reg", [True, False])
|
|
249
|
+
def test_l1snrdb_wrapper_all_stems_3d(dummy_stems_3d, l1_weight, use_reg):
|
|
250
|
+
"""Test L1SNRDBLoss wrapper with stem_dimension=None on [B,S,T]."""
|
|
251
|
+
estimates, actuals = dummy_stems_3d
|
|
252
|
+
base_loss = L1SNRDBLoss(name="test", weight=1.0, l1_weight=l1_weight, use_regularization=use_reg)
|
|
253
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
254
|
+
|
|
255
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
256
|
+
direct_result = base_loss(estimates, actuals)
|
|
257
|
+
|
|
258
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
259
|
+
assert wrapped_result.ndim == 0
|
|
260
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
261
|
+
|
|
262
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
263
|
+
@pytest.mark.parametrize("use_reg", [True, False])
|
|
264
|
+
def test_l1snrdb_wrapper_all_stems_4d(dummy_stems_4d, l1_weight, use_reg):
|
|
265
|
+
"""Test L1SNRDBLoss wrapper with stem_dimension=None on [B,S,C,T]."""
|
|
266
|
+
estimates, actuals = dummy_stems_4d
|
|
267
|
+
base_loss = L1SNRDBLoss(name="test", weight=1.0, l1_weight=l1_weight, use_regularization=use_reg)
|
|
268
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
269
|
+
|
|
270
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
271
|
+
direct_result = base_loss(estimates, actuals)
|
|
272
|
+
|
|
273
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
274
|
+
assert wrapped_result.ndim == 0
|
|
275
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
276
|
+
|
|
277
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
278
|
+
@pytest.mark.parametrize("use_reg", [True, False])
|
|
279
|
+
@pytest.mark.parametrize("stem_idx", [0, 3])
|
|
280
|
+
def test_l1snrdb_wrapper_single_stem_3d(dummy_stems_3d, l1_weight, use_reg, stem_idx):
|
|
281
|
+
"""Test L1SNRDBLoss wrapper with stem_dimension=k on [B,S,T]."""
|
|
282
|
+
estimates, actuals = dummy_stems_3d
|
|
283
|
+
base_loss = L1SNRDBLoss(name="test", weight=1.0, l1_weight=l1_weight, use_regularization=use_reg)
|
|
284
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
|
|
285
|
+
|
|
286
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
287
|
+
est_slice = estimates[:, stem_idx, :]
|
|
288
|
+
act_slice = actuals[:, stem_idx, :]
|
|
289
|
+
direct_result = base_loss(est_slice, act_slice)
|
|
290
|
+
|
|
291
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
292
|
+
assert wrapped_result.ndim == 0
|
|
293
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
294
|
+
|
|
295
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
296
|
+
@pytest.mark.parametrize("use_reg", [True, False])
|
|
297
|
+
@pytest.mark.parametrize("stem_idx", [0, 3])
|
|
298
|
+
def test_l1snrdb_wrapper_single_stem_4d(dummy_stems_4d, l1_weight, use_reg, stem_idx):
|
|
299
|
+
"""Test L1SNRDBLoss wrapper with stem_dimension=k on [B,S,C,T]."""
|
|
300
|
+
estimates, actuals = dummy_stems_4d
|
|
301
|
+
base_loss = L1SNRDBLoss(name="test", weight=1.0, l1_weight=l1_weight, use_regularization=use_reg)
|
|
302
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
|
|
303
|
+
|
|
304
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
305
|
+
est_slice = estimates[:, stem_idx, :, :]
|
|
306
|
+
act_slice = actuals[:, stem_idx, :, :]
|
|
307
|
+
direct_result = base_loss(est_slice, act_slice)
|
|
308
|
+
|
|
309
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
310
|
+
assert wrapped_result.ndim == 0
|
|
311
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
312
|
+
|
|
313
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
314
|
+
def test_stft_wrapper_all_stems_3d(dummy_stems_3d, l1_weight):
|
|
315
|
+
"""Test STFTL1SNRDBLoss wrapper with stem_dimension=None on [B,S,T]."""
|
|
316
|
+
estimates, actuals = dummy_stems_3d
|
|
317
|
+
base_loss = STFTL1SNRDBLoss(
|
|
318
|
+
name="test", weight=1.0, l1_weight=l1_weight,
|
|
319
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
|
|
320
|
+
)
|
|
321
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
322
|
+
|
|
323
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
324
|
+
direct_result = base_loss(estimates, actuals)
|
|
325
|
+
|
|
326
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
327
|
+
assert wrapped_result.ndim == 0
|
|
328
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
329
|
+
|
|
330
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
331
|
+
def test_stft_wrapper_all_stems_4d(dummy_stems_4d, l1_weight):
|
|
332
|
+
"""Test STFTL1SNRDBLoss wrapper with stem_dimension=None on [B,S,C,T]."""
|
|
333
|
+
estimates, actuals = dummy_stems_4d
|
|
334
|
+
base_loss = STFTL1SNRDBLoss(
|
|
335
|
+
name="test", weight=1.0, l1_weight=l1_weight,
|
|
336
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
|
|
337
|
+
)
|
|
338
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
339
|
+
|
|
340
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
341
|
+
direct_result = base_loss(estimates, actuals)
|
|
342
|
+
|
|
343
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
344
|
+
assert wrapped_result.ndim == 0
|
|
345
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
346
|
+
|
|
347
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
348
|
+
@pytest.mark.parametrize("stem_idx", [0, 3])
|
|
349
|
+
def test_stft_wrapper_single_stem_3d(dummy_stems_3d, l1_weight, stem_idx):
|
|
350
|
+
"""Test STFTL1SNRDBLoss wrapper with stem_dimension=k on [B,S,T]."""
|
|
351
|
+
estimates, actuals = dummy_stems_3d
|
|
352
|
+
base_loss = STFTL1SNRDBLoss(
|
|
353
|
+
name="test", weight=1.0, l1_weight=l1_weight,
|
|
354
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
|
|
355
|
+
)
|
|
356
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
|
|
357
|
+
|
|
358
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
359
|
+
est_slice = estimates[:, stem_idx, :]
|
|
360
|
+
act_slice = actuals[:, stem_idx, :]
|
|
361
|
+
direct_result = base_loss(est_slice, act_slice)
|
|
362
|
+
|
|
363
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
364
|
+
assert wrapped_result.ndim == 0
|
|
365
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
366
|
+
|
|
367
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
368
|
+
@pytest.mark.parametrize("stem_idx", [0, 3])
|
|
369
|
+
def test_stft_wrapper_single_stem_4d(dummy_stems_4d, l1_weight, stem_idx):
|
|
370
|
+
"""Test STFTL1SNRDBLoss wrapper with stem_dimension=k on [B,S,C,T]."""
|
|
371
|
+
estimates, actuals = dummy_stems_4d
|
|
372
|
+
base_loss = STFTL1SNRDBLoss(
|
|
373
|
+
name="test", weight=1.0, l1_weight=l1_weight,
|
|
374
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
|
|
375
|
+
)
|
|
376
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
|
|
377
|
+
|
|
378
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
379
|
+
est_slice = estimates[:, stem_idx, :, :]
|
|
380
|
+
act_slice = actuals[:, stem_idx, :, :]
|
|
381
|
+
direct_result = base_loss(est_slice, act_slice)
|
|
382
|
+
|
|
383
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
384
|
+
assert wrapped_result.ndim == 0
|
|
385
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
386
|
+
|
|
387
|
+
def test_stft_wrapper_short_audio_3d():
|
|
388
|
+
"""Test STFTL1SNRDBLoss wrapper fallback path with short audio [B,S,T]."""
|
|
389
|
+
estimates = torch.randn(2, 4, 400) # Short audio
|
|
390
|
+
actuals = torch.randn(2, 4, 400)
|
|
391
|
+
actuals[:, 0, :100] += 0.1
|
|
392
|
+
|
|
393
|
+
base_loss = STFTL1SNRDBLoss(
|
|
394
|
+
name="test", weight=1.0, l1_weight=0.0,
|
|
395
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=512
|
|
396
|
+
)
|
|
397
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
398
|
+
|
|
399
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
400
|
+
direct_result = base_loss(estimates, actuals)
|
|
401
|
+
|
|
402
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
403
|
+
assert wrapped_result.ndim == 0
|
|
404
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
405
|
+
|
|
406
|
+
def test_stft_wrapper_short_audio_4d():
|
|
407
|
+
"""Test STFTL1SNRDBLoss wrapper fallback path with short audio [B,S,C,T]."""
|
|
408
|
+
estimates = torch.randn(2, 4, 1, 400) # Short audio
|
|
409
|
+
actuals = torch.randn(2, 4, 1, 400)
|
|
410
|
+
actuals[:, 0, :, :100] += 0.1
|
|
411
|
+
|
|
412
|
+
base_loss = STFTL1SNRDBLoss(
|
|
413
|
+
name="test", weight=1.0, l1_weight=0.0,
|
|
414
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=512
|
|
415
|
+
)
|
|
416
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
417
|
+
|
|
418
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
419
|
+
direct_result = base_loss(estimates, actuals)
|
|
420
|
+
|
|
421
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
422
|
+
assert wrapped_result.ndim == 0
|
|
423
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
424
|
+
|
|
425
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
426
|
+
@pytest.mark.parametrize("use_time_reg", [True, False])
|
|
427
|
+
def test_multi_wrapper_all_stems_3d(dummy_stems_3d, l1_weight, use_time_reg):
|
|
428
|
+
"""Test MultiL1SNRDBLoss wrapper with stem_dimension=None on [B,S,T]."""
|
|
429
|
+
estimates, actuals = dummy_stems_3d
|
|
430
|
+
base_loss = MultiL1SNRDBLoss(
|
|
431
|
+
name="test", weight=1.0, l1_weight=l1_weight,
|
|
432
|
+
use_time_regularization=use_time_reg, use_spec_regularization=False,
|
|
433
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
|
|
434
|
+
)
|
|
435
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
436
|
+
|
|
437
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
438
|
+
direct_result = base_loss(estimates, actuals)
|
|
439
|
+
|
|
440
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
441
|
+
assert wrapped_result.ndim == 0
|
|
442
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
443
|
+
|
|
444
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
445
|
+
@pytest.mark.parametrize("use_time_reg", [True, False])
|
|
446
|
+
def test_multi_wrapper_all_stems_4d(dummy_stems_4d, l1_weight, use_time_reg):
|
|
447
|
+
"""Test MultiL1SNRDBLoss wrapper with stem_dimension=None on [B,S,C,T]."""
|
|
448
|
+
estimates, actuals = dummy_stems_4d
|
|
449
|
+
base_loss = MultiL1SNRDBLoss(
|
|
450
|
+
name="test", weight=1.0, l1_weight=l1_weight,
|
|
451
|
+
use_time_regularization=use_time_reg, use_spec_regularization=False,
|
|
452
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
|
|
453
|
+
)
|
|
454
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
455
|
+
|
|
456
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
457
|
+
direct_result = base_loss(estimates, actuals)
|
|
458
|
+
|
|
459
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
460
|
+
assert wrapped_result.ndim == 0
|
|
461
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
462
|
+
|
|
463
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
464
|
+
@pytest.mark.parametrize("use_time_reg", [True, False])
|
|
465
|
+
@pytest.mark.parametrize("stem_idx", [0, 3])
|
|
466
|
+
def test_multi_wrapper_single_stem_3d(dummy_stems_3d, l1_weight, use_time_reg, stem_idx):
|
|
467
|
+
"""Test MultiL1SNRDBLoss wrapper with stem_dimension=k on [B,S,T]."""
|
|
468
|
+
estimates, actuals = dummy_stems_3d
|
|
469
|
+
base_loss = MultiL1SNRDBLoss(
|
|
470
|
+
name="test", weight=1.0, l1_weight=l1_weight,
|
|
471
|
+
use_time_regularization=use_time_reg, use_spec_regularization=False,
|
|
472
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
|
|
473
|
+
)
|
|
474
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
|
|
475
|
+
|
|
476
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
477
|
+
est_slice = estimates[:, stem_idx, :]
|
|
478
|
+
act_slice = actuals[:, stem_idx, :]
|
|
479
|
+
direct_result = base_loss(est_slice, act_slice)
|
|
480
|
+
|
|
481
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
482
|
+
assert wrapped_result.ndim == 0
|
|
483
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
484
|
+
|
|
485
|
+
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
486
|
+
@pytest.mark.parametrize("use_time_reg", [True, False])
|
|
487
|
+
@pytest.mark.parametrize("stem_idx", [0, 3])
|
|
488
|
+
def test_multi_wrapper_single_stem_4d(dummy_stems_4d, l1_weight, use_time_reg, stem_idx):
|
|
489
|
+
"""Test MultiL1SNRDBLoss wrapper with stem_dimension=k on [B,S,C,T]."""
|
|
490
|
+
estimates, actuals = dummy_stems_4d
|
|
491
|
+
base_loss = MultiL1SNRDBLoss(
|
|
492
|
+
name="test", weight=1.0, l1_weight=l1_weight,
|
|
493
|
+
use_time_regularization=use_time_reg, use_spec_regularization=False,
|
|
494
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
|
|
495
|
+
)
|
|
496
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
|
|
497
|
+
|
|
498
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
499
|
+
est_slice = estimates[:, stem_idx, :, :]
|
|
500
|
+
act_slice = actuals[:, stem_idx, :, :]
|
|
501
|
+
direct_result = base_loss(est_slice, act_slice)
|
|
502
|
+
|
|
503
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
504
|
+
assert wrapped_result.ndim == 0
|
|
505
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
506
|
+
|
|
507
|
+
def test_multi_wrapper_short_audio_3d():
|
|
508
|
+
"""Test MultiL1SNRDBLoss wrapper fallback path with short audio [B,S,T]."""
|
|
509
|
+
estimates = torch.randn(2, 4, 400) # Short audio
|
|
510
|
+
actuals = torch.randn(2, 4, 400)
|
|
511
|
+
actuals[:, 0, :100] += 0.1
|
|
512
|
+
|
|
513
|
+
base_loss = MultiL1SNRDBLoss(
|
|
514
|
+
name="test", weight=1.0, l1_weight=0.0,
|
|
515
|
+
use_time_regularization=True, use_spec_regularization=False,
|
|
516
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=512
|
|
517
|
+
)
|
|
518
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
519
|
+
|
|
520
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
521
|
+
direct_result = base_loss(estimates, actuals)
|
|
522
|
+
|
|
523
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
524
|
+
assert wrapped_result.ndim == 0
|
|
525
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
526
|
+
|
|
527
|
+
def test_multi_wrapper_short_audio_4d():
|
|
528
|
+
"""Test MultiL1SNRDBLoss wrapper fallback path with short audio [B,S,C,T]."""
|
|
529
|
+
estimates = torch.randn(2, 4, 1, 400) # Short audio
|
|
530
|
+
actuals = torch.randn(2, 4, 1, 400)
|
|
531
|
+
actuals[:, 0, :, :100] += 0.1
|
|
532
|
+
|
|
533
|
+
base_loss = MultiL1SNRDBLoss(
|
|
534
|
+
name="test", weight=1.0, l1_weight=0.0,
|
|
535
|
+
use_time_regularization=True, use_spec_regularization=False,
|
|
536
|
+
n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=512
|
|
537
|
+
)
|
|
538
|
+
wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
|
|
539
|
+
|
|
540
|
+
wrapped_result = wrapped_loss(estimates, actuals)
|
|
541
|
+
direct_result = base_loss(estimates, actuals)
|
|
542
|
+
|
|
543
|
+
assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
|
|
544
|
+
assert wrapped_result.ndim == 0
|
|
545
|
+
assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
|
|
@@ -1,143 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import pytest
|
|
3
|
-
from torch_l1snr import (
|
|
4
|
-
dbrms,
|
|
5
|
-
L1SNRLoss,
|
|
6
|
-
L1SNRDBLoss,
|
|
7
|
-
STFTL1SNRDBLoss,
|
|
8
|
-
MultiL1SNRDBLoss,
|
|
9
|
-
)
|
|
10
|
-
|
|
11
|
-
# --- Test Fixtures ---
|
|
12
|
-
@pytest.fixture
|
|
13
|
-
def dummy_audio():
|
|
14
|
-
"""Provides a batch of dummy audio signals."""
|
|
15
|
-
estimates = torch.randn(2, 16000)
|
|
16
|
-
actuals = torch.randn(2, 16000)
|
|
17
|
-
# Ensure actuals are not all zero to avoid division by zero in loss
|
|
18
|
-
actuals[0, :100] += 0.1
|
|
19
|
-
return estimates, actuals
|
|
20
|
-
|
|
21
|
-
@pytest.fixture
|
|
22
|
-
def dummy_stems():
|
|
23
|
-
"""Provides a batch of dummy multi-stem signals."""
|
|
24
|
-
estimates = torch.randn(2, 4, 1, 16000) # batch, stems, channels, samples
|
|
25
|
-
actuals = torch.randn(2, 4, 1, 16000)
|
|
26
|
-
actuals[:, 0, :, :100] += 0.1 # Ensure not all zero
|
|
27
|
-
return estimates, actuals
|
|
28
|
-
|
|
29
|
-
# --- Test Functions ---
|
|
30
|
-
|
|
31
|
-
def test_dbrms():
|
|
32
|
-
signal = torch.ones(2, 1000) * 0.1
|
|
33
|
-
# RMS of 0.1 is -20 dB
|
|
34
|
-
assert torch.allclose(dbrms(signal), torch.tensor([-20.0, -20.0]), atol=1e-4)
|
|
35
|
-
|
|
36
|
-
zeros = torch.zeros(2, 1000)
|
|
37
|
-
# dbrms of zero should be -80dB with default eps=1e-8
|
|
38
|
-
assert torch.allclose(dbrms(zeros), torch.tensor([-80.0, -80.0]), atol=1e-4)
|
|
39
|
-
|
|
40
|
-
def test_l1snr_loss(dummy_audio):
|
|
41
|
-
estimates, actuals = dummy_audio
|
|
42
|
-
loss_fn = L1SNRLoss(name="test")
|
|
43
|
-
loss = loss_fn(estimates, actuals)
|
|
44
|
-
|
|
45
|
-
assert isinstance(loss, torch.Tensor)
|
|
46
|
-
assert loss.ndim == 0
|
|
47
|
-
assert not torch.isnan(loss)
|
|
48
|
-
assert not torch.isinf(loss)
|
|
49
|
-
|
|
50
|
-
def test_l1snrdb_loss_time(dummy_audio):
|
|
51
|
-
estimates, actuals = dummy_audio
|
|
52
|
-
|
|
53
|
-
# Test with default settings (L1SNR + Regularization)
|
|
54
|
-
loss_fn = L1SNRDBLoss(name="test", use_regularization=True, l1_weight=0.0)
|
|
55
|
-
loss = loss_fn(estimates, actuals)
|
|
56
|
-
assert loss.ndim == 0 and not torch.isnan(loss)
|
|
57
|
-
|
|
58
|
-
# Test without regularization
|
|
59
|
-
loss_fn_no_reg = L1SNRDBLoss(name="test_no_reg", use_regularization=False, l1_weight=0.0)
|
|
60
|
-
loss_no_reg = loss_fn_no_reg(estimates, actuals)
|
|
61
|
-
assert loss_no_reg.ndim == 0 and not torch.isnan(loss_no_reg)
|
|
62
|
-
|
|
63
|
-
# Test with L1 loss component
|
|
64
|
-
loss_fn_l1 = L1SNRDBLoss(name="test_l1", l1_weight=0.2)
|
|
65
|
-
loss_l1 = loss_fn_l1(estimates, actuals)
|
|
66
|
-
assert loss_l1.ndim == 0 and not torch.isnan(loss_l1)
|
|
67
|
-
|
|
68
|
-
# Test pure L1 loss mode
|
|
69
|
-
loss_fn_pure_l1 = L1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
|
|
70
|
-
pure_l1_loss = loss_fn_pure_l1(estimates, actuals)
|
|
71
|
-
# Pure L1 mode uses torch.nn.L1Loss, so compare with manual L1 calculation
|
|
72
|
-
l1_loss_manual = torch.nn.L1Loss()(
|
|
73
|
-
estimates.reshape(estimates.shape[0], -1),
|
|
74
|
-
actuals.reshape(actuals.shape[0], -1)
|
|
75
|
-
)
|
|
76
|
-
assert torch.allclose(pure_l1_loss, l1_loss_manual)
|
|
77
|
-
|
|
78
|
-
def test_stft_l1snrdb_loss(dummy_audio):
|
|
79
|
-
estimates, actuals = dummy_audio
|
|
80
|
-
|
|
81
|
-
# Test with default settings
|
|
82
|
-
loss_fn = STFTL1SNRDBLoss(name="test", l1_weight=0.0)
|
|
83
|
-
loss = loss_fn(estimates, actuals)
|
|
84
|
-
assert loss.ndim == 0 and not torch.isnan(loss) and not torch.isinf(loss)
|
|
85
|
-
|
|
86
|
-
# Test pure L1 mode
|
|
87
|
-
loss_fn_pure_l1 = STFTL1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
|
|
88
|
-
l1_loss = loss_fn_pure_l1(estimates, actuals)
|
|
89
|
-
assert l1_loss.ndim == 0 and not torch.isnan(l1_loss) and not torch.isinf(l1_loss)
|
|
90
|
-
|
|
91
|
-
# Test with very short audio
|
|
92
|
-
short_estimates = estimates[:, :500]
|
|
93
|
-
short_actuals = actuals[:, :500]
|
|
94
|
-
loss_short = loss_fn(short_estimates, short_actuals)
|
|
95
|
-
# min_audio_length is 512, so this should fallback to time-domain loss
|
|
96
|
-
assert loss_short.ndim == 0 and not torch.isnan(loss_short)
|
|
97
|
-
|
|
98
|
-
def test_stem_multi_loss(dummy_stems):
|
|
99
|
-
estimates, actuals = dummy_stems
|
|
100
|
-
|
|
101
|
-
# Test with a specific stem - users now manage stems manually by slicing
|
|
102
|
-
# Extract stem 1 (second stem) manually
|
|
103
|
-
est_stem = estimates[:, 1, ...] # Shape: [batch, channels, samples]
|
|
104
|
-
act_stem = actuals[:, 1, ...]
|
|
105
|
-
loss_fn_stem = MultiL1SNRDBLoss(
|
|
106
|
-
name="test_loss_stem",
|
|
107
|
-
spec_weight=0.5,
|
|
108
|
-
l1_weight=0.1
|
|
109
|
-
)
|
|
110
|
-
loss = loss_fn_stem(est_stem, act_stem)
|
|
111
|
-
assert loss.ndim == 0 and not torch.isnan(loss)
|
|
112
|
-
|
|
113
|
-
# Test with all stems jointly - flatten all stems together
|
|
114
|
-
# Reshape to [batch, -1] to process all stems at once
|
|
115
|
-
est_all = estimates.reshape(estimates.shape[0], -1)
|
|
116
|
-
act_all = actuals.reshape(actuals.shape[0], -1)
|
|
117
|
-
loss_fn_all = MultiL1SNRDBLoss(
|
|
118
|
-
name="test_loss_all",
|
|
119
|
-
spec_weight=0.5,
|
|
120
|
-
l1_weight=0.1
|
|
121
|
-
)
|
|
122
|
-
loss_all = loss_fn_all(est_all, act_all)
|
|
123
|
-
assert loss_all.ndim == 0 and not torch.isnan(loss_all)
|
|
124
|
-
|
|
125
|
-
# Test pure L1 mode on all stems
|
|
126
|
-
loss_fn_l1 = MultiL1SNRDBLoss(name="l1_only", l1_weight=1.0)
|
|
127
|
-
l1_loss = loss_fn_l1(est_all, act_all)
|
|
128
|
-
|
|
129
|
-
# Can't easily compute multi-res STFT L1 here, but can check it's not nan
|
|
130
|
-
assert l1_loss.ndim == 0 and not torch.isnan(l1_loss)
|
|
131
|
-
|
|
132
|
-
@pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
|
|
133
|
-
def test_loss_variants(dummy_audio, l1_weight):
|
|
134
|
-
"""Test L1SNRDBLoss and STFTL1SNRDBLoss with different l1_weights."""
|
|
135
|
-
estimates, actuals = dummy_audio
|
|
136
|
-
|
|
137
|
-
time_loss_fn = L1SNRDBLoss(name=f"test_time_{l1_weight}", l1_weight=l1_weight)
|
|
138
|
-
time_loss = time_loss_fn(estimates, actuals)
|
|
139
|
-
assert not torch.isnan(time_loss) and not torch.isinf(time_loss)
|
|
140
|
-
|
|
141
|
-
spec_loss_fn = STFTL1SNRDBLoss(name=f"test_spec_{l1_weight}", l1_weight=l1_weight)
|
|
142
|
-
spec_loss = spec_loss_fn(estimates, actuals)
|
|
143
|
-
assert not torch.isnan(spec_loss) and not torch.isinf(spec_loss)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|