torch-l1-snr 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-l1-snr
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: L1-SNR loss functions for audio source separation in PyTorch
5
5
  Home-page: https://github.com/crlandsc/torch-l1-snr
6
6
  Author: Christopher Landscaping
@@ -1,6 +1,6 @@
1
1
  [metadata]
2
2
  name = torch-l1-snr
3
- version = 0.0.4
3
+ version = attr: torch_l1snr.__version__
4
4
  author = Christopher Landscaping
5
5
  author_email = crlandschoot@gmail.com
6
6
  description = L1-SNR loss functions for audio source separation in PyTorch
@@ -0,0 +1,545 @@
1
+ import torch
2
+ import pytest
3
+ from typing import Optional
4
+ from torch_l1snr import (
5
+ dbrms,
6
+ L1SNRLoss,
7
+ L1SNRDBLoss,
8
+ STFTL1SNRDBLoss,
9
+ MultiL1SNRDBLoss,
10
+ )
11
+
12
+ # --- Test Helper: Stem Wrapper ---
13
+ class StemWrappedLoss(torch.nn.Module):
14
+ """Test helper matching user's pipL1SNRLoss wrapper pattern."""
15
+ def __init__(self, base_loss, stem_dimension: Optional[int] = None):
16
+ super().__init__()
17
+ self.base_loss = base_loss
18
+ self.stem_dimension = stem_dimension
19
+
20
+ def forward(self, estimates, actuals, *args, **kwargs):
21
+ if self.stem_dimension is not None:
22
+ # Handle both [B,S,T] and [B,S,C,T] shapes
23
+ if estimates.ndim == 3: # [B, S, T]
24
+ est_source = estimates[:, self.stem_dimension, :]
25
+ act_source = actuals[:, self.stem_dimension, :]
26
+ else: # [B, S, C, T]
27
+ est_source = estimates[:, self.stem_dimension, :, :]
28
+ act_source = actuals[:, self.stem_dimension, :, :]
29
+ return self.base_loss(est_source, act_source, *args, **kwargs)
30
+ else:
31
+ return self.base_loss(estimates, actuals, *args, **kwargs)
32
+
33
+ # --- Test Fixtures ---
34
+ @pytest.fixture
35
+ def dummy_audio():
36
+ """Provides a batch of dummy audio signals."""
37
+ estimates = torch.randn(2, 16000)
38
+ actuals = torch.randn(2, 16000)
39
+ # Ensure actuals are not all zero to avoid division by zero in loss
40
+ actuals[0, :100] += 0.1
41
+ return estimates, actuals
42
+
43
+ @pytest.fixture
44
+ def dummy_stems():
45
+ """Provides a batch of dummy multi-stem signals."""
46
+ estimates = torch.randn(2, 4, 1, 16000) # batch, stems, channels, samples
47
+ actuals = torch.randn(2, 4, 1, 16000)
48
+ actuals[:, 0, :, :100] += 0.1 # Ensure not all zero
49
+ return estimates, actuals
50
+
51
+ @pytest.fixture
52
+ def dummy_stems_3d():
53
+ """Multi-stem signals: [B, S, T]"""
54
+ estimates = torch.randn(2, 4, 16000)
55
+ actuals = torch.randn(2, 4, 16000)
56
+ actuals[:, 0, :100] += 0.1 # Ensure not all zero
57
+ return estimates, actuals
58
+
59
+ @pytest.fixture
60
+ def dummy_stems_4d():
61
+ """Multi-stem signals: [B, S, C, T]"""
62
+ estimates = torch.randn(2, 4, 1, 16000)
63
+ actuals = torch.randn(2, 4, 1, 16000)
64
+ actuals[:, 0, :, :100] += 0.1
65
+ return estimates, actuals
66
+
67
+ # --- Test Functions ---
68
+
69
+ def test_dbrms():
70
+ signal = torch.ones(2, 1000) * 0.1
71
+ # RMS of 0.1 is -20 dB
72
+ assert torch.allclose(dbrms(signal), torch.tensor([-20.0, -20.0]), atol=1e-4)
73
+
74
+ zeros = torch.zeros(2, 1000)
75
+ # dbrms of zero should be -80dB with default eps=1e-8
76
+ assert torch.allclose(dbrms(zeros), torch.tensor([-80.0, -80.0]), atol=1e-4)
77
+
78
+ def test_l1snr_loss(dummy_audio):
79
+ estimates, actuals = dummy_audio
80
+ loss_fn = L1SNRLoss(name="test")
81
+ loss = loss_fn(estimates, actuals)
82
+
83
+ assert isinstance(loss, torch.Tensor)
84
+ assert loss.ndim == 0
85
+ assert not torch.isnan(loss)
86
+ assert not torch.isinf(loss)
87
+
88
+ def test_l1snrdb_loss_time(dummy_audio):
89
+ estimates, actuals = dummy_audio
90
+
91
+ # Test with default settings (L1SNR + Regularization)
92
+ loss_fn = L1SNRDBLoss(name="test", use_regularization=True, l1_weight=0.0)
93
+ loss = loss_fn(estimates, actuals)
94
+ assert loss.ndim == 0 and not torch.isnan(loss)
95
+
96
+ # Test without regularization
97
+ loss_fn_no_reg = L1SNRDBLoss(name="test_no_reg", use_regularization=False, l1_weight=0.0)
98
+ loss_no_reg = loss_fn_no_reg(estimates, actuals)
99
+ assert loss_no_reg.ndim == 0 and not torch.isnan(loss_no_reg)
100
+
101
+ # Test with L1 loss component
102
+ loss_fn_l1 = L1SNRDBLoss(name="test_l1", l1_weight=0.2)
103
+ loss_l1 = loss_fn_l1(estimates, actuals)
104
+ assert loss_l1.ndim == 0 and not torch.isnan(loss_l1)
105
+
106
+ # Test pure L1 loss mode
107
+ loss_fn_pure_l1 = L1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
108
+ pure_l1_loss = loss_fn_pure_l1(estimates, actuals)
109
+ # Pure L1 mode uses torch.nn.L1Loss, so compare with manual L1 calculation
110
+ l1_loss_manual = torch.nn.L1Loss()(
111
+ estimates.reshape(estimates.shape[0], -1),
112
+ actuals.reshape(actuals.shape[0], -1)
113
+ )
114
+ assert torch.allclose(pure_l1_loss, l1_loss_manual)
115
+
116
+ def test_stft_l1snrdb_loss(dummy_audio):
117
+ estimates, actuals = dummy_audio
118
+
119
+ # Test with default settings
120
+ loss_fn = STFTL1SNRDBLoss(name="test", l1_weight=0.0)
121
+ loss = loss_fn(estimates, actuals)
122
+ assert loss.ndim == 0 and not torch.isnan(loss) and not torch.isinf(loss)
123
+
124
+ # Test pure L1 mode
125
+ loss_fn_pure_l1 = STFTL1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
126
+ l1_loss = loss_fn_pure_l1(estimates, actuals)
127
+ assert l1_loss.ndim == 0 and not torch.isnan(l1_loss) and not torch.isinf(l1_loss)
128
+
129
+ # Test with very short audio
130
+ short_estimates = estimates[:, :500]
131
+ short_actuals = actuals[:, :500]
132
+ loss_short = loss_fn(short_estimates, short_actuals)
133
+ # min_audio_length is 512, so this should fallback to time-domain loss
134
+ assert loss_short.ndim == 0 and not torch.isnan(loss_short)
135
+
136
+ def test_stem_multi_loss(dummy_stems):
137
+ estimates, actuals = dummy_stems
138
+
139
+ # Test with a specific stem - users now manage stems manually by slicing
140
+ # Extract stem 1 (second stem) manually
141
+ est_stem = estimates[:, 1, ...] # Shape: [batch, channels, samples]
142
+ act_stem = actuals[:, 1, ...]
143
+ loss_fn_stem = MultiL1SNRDBLoss(
144
+ name="test_loss_stem",
145
+ spec_weight=0.5,
146
+ l1_weight=0.1
147
+ )
148
+ loss = loss_fn_stem(est_stem, act_stem)
149
+ assert loss.ndim == 0 and not torch.isnan(loss)
150
+
151
+ # Test with all stems jointly - flatten all stems together
152
+ # Reshape to [batch, -1] to process all stems at once
153
+ est_all = estimates.reshape(estimates.shape[0], -1)
154
+ act_all = actuals.reshape(actuals.shape[0], -1)
155
+ loss_fn_all = MultiL1SNRDBLoss(
156
+ name="test_loss_all",
157
+ spec_weight=0.5,
158
+ l1_weight=0.1
159
+ )
160
+ loss_all = loss_fn_all(est_all, act_all)
161
+ assert loss_all.ndim == 0 and not torch.isnan(loss_all)
162
+
163
+ # Test pure L1 mode on all stems
164
+ loss_fn_l1 = MultiL1SNRDBLoss(name="l1_only", l1_weight=1.0)
165
+ l1_loss = loss_fn_l1(est_all, act_all)
166
+
167
+ # Can't easily compute multi-res STFT L1 here, but can check it's not nan
168
+ assert l1_loss.ndim == 0 and not torch.isnan(l1_loss)
169
+
170
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
171
+ def test_loss_variants(dummy_audio, l1_weight):
172
+ """Test L1SNRDBLoss and STFTL1SNRDBLoss with different l1_weights."""
173
+ estimates, actuals = dummy_audio
174
+
175
+ time_loss_fn = L1SNRDBLoss(name=f"test_time_{l1_weight}", l1_weight=l1_weight)
176
+ time_loss = time_loss_fn(estimates, actuals)
177
+ assert not torch.isnan(time_loss) and not torch.isinf(time_loss)
178
+
179
+ spec_loss_fn = STFTL1SNRDBLoss(name=f"test_spec_{l1_weight}", l1_weight=l1_weight)
180
+ spec_loss = spec_loss_fn(estimates, actuals)
181
+ assert not torch.isnan(spec_loss) and not torch.isinf(spec_loss)
182
+
183
+ # --- Wrapper-Paradigm Tests ---
184
+
185
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
186
+ def test_l1snr_wrapper_all_stems_3d(dummy_stems_3d, l1_weight):
187
+ """Test L1SNRLoss wrapper with stem_dimension=None on [B,S,T]."""
188
+ estimates, actuals = dummy_stems_3d
189
+ base_loss = L1SNRLoss(name="test", weight=1.0, l1_weight=l1_weight)
190
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
191
+
192
+ wrapped_result = wrapped_loss(estimates, actuals)
193
+ direct_result = base_loss(estimates, actuals)
194
+
195
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
196
+ assert wrapped_result.ndim == 0
197
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
198
+
199
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
200
+ def test_l1snr_wrapper_all_stems_4d(dummy_stems_4d, l1_weight):
201
+ """Test L1SNRLoss wrapper with stem_dimension=None on [B,S,C,T]."""
202
+ estimates, actuals = dummy_stems_4d
203
+ base_loss = L1SNRLoss(name="test", weight=1.0, l1_weight=l1_weight)
204
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
205
+
206
+ wrapped_result = wrapped_loss(estimates, actuals)
207
+ direct_result = base_loss(estimates, actuals)
208
+
209
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
210
+ assert wrapped_result.ndim == 0
211
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
212
+
213
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
214
+ @pytest.mark.parametrize("stem_idx", [0, 3])
215
+ def test_l1snr_wrapper_single_stem_3d(dummy_stems_3d, l1_weight, stem_idx):
216
+ """Test L1SNRLoss wrapper with stem_dimension=k on [B,S,T]."""
217
+ estimates, actuals = dummy_stems_3d
218
+ base_loss = L1SNRLoss(name="test", weight=1.0, l1_weight=l1_weight)
219
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
220
+
221
+ wrapped_result = wrapped_loss(estimates, actuals)
222
+ est_slice = estimates[:, stem_idx, :]
223
+ act_slice = actuals[:, stem_idx, :]
224
+ direct_result = base_loss(est_slice, act_slice)
225
+
226
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
227
+ assert wrapped_result.ndim == 0
228
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
229
+
230
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
231
+ @pytest.mark.parametrize("stem_idx", [0, 3])
232
+ def test_l1snr_wrapper_single_stem_4d(dummy_stems_4d, l1_weight, stem_idx):
233
+ """Test L1SNRLoss wrapper with stem_dimension=k on [B,S,C,T]."""
234
+ estimates, actuals = dummy_stems_4d
235
+ base_loss = L1SNRLoss(name="test", weight=1.0, l1_weight=l1_weight)
236
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
237
+
238
+ wrapped_result = wrapped_loss(estimates, actuals)
239
+ est_slice = estimates[:, stem_idx, :, :]
240
+ act_slice = actuals[:, stem_idx, :, :]
241
+ direct_result = base_loss(est_slice, act_slice)
242
+
243
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
244
+ assert wrapped_result.ndim == 0
245
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
246
+
247
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
248
+ @pytest.mark.parametrize("use_reg", [True, False])
249
+ def test_l1snrdb_wrapper_all_stems_3d(dummy_stems_3d, l1_weight, use_reg):
250
+ """Test L1SNRDBLoss wrapper with stem_dimension=None on [B,S,T]."""
251
+ estimates, actuals = dummy_stems_3d
252
+ base_loss = L1SNRDBLoss(name="test", weight=1.0, l1_weight=l1_weight, use_regularization=use_reg)
253
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
254
+
255
+ wrapped_result = wrapped_loss(estimates, actuals)
256
+ direct_result = base_loss(estimates, actuals)
257
+
258
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
259
+ assert wrapped_result.ndim == 0
260
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
261
+
262
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
263
+ @pytest.mark.parametrize("use_reg", [True, False])
264
+ def test_l1snrdb_wrapper_all_stems_4d(dummy_stems_4d, l1_weight, use_reg):
265
+ """Test L1SNRDBLoss wrapper with stem_dimension=None on [B,S,C,T]."""
266
+ estimates, actuals = dummy_stems_4d
267
+ base_loss = L1SNRDBLoss(name="test", weight=1.0, l1_weight=l1_weight, use_regularization=use_reg)
268
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
269
+
270
+ wrapped_result = wrapped_loss(estimates, actuals)
271
+ direct_result = base_loss(estimates, actuals)
272
+
273
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
274
+ assert wrapped_result.ndim == 0
275
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
276
+
277
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
278
+ @pytest.mark.parametrize("use_reg", [True, False])
279
+ @pytest.mark.parametrize("stem_idx", [0, 3])
280
+ def test_l1snrdb_wrapper_single_stem_3d(dummy_stems_3d, l1_weight, use_reg, stem_idx):
281
+ """Test L1SNRDBLoss wrapper with stem_dimension=k on [B,S,T]."""
282
+ estimates, actuals = dummy_stems_3d
283
+ base_loss = L1SNRDBLoss(name="test", weight=1.0, l1_weight=l1_weight, use_regularization=use_reg)
284
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
285
+
286
+ wrapped_result = wrapped_loss(estimates, actuals)
287
+ est_slice = estimates[:, stem_idx, :]
288
+ act_slice = actuals[:, stem_idx, :]
289
+ direct_result = base_loss(est_slice, act_slice)
290
+
291
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
292
+ assert wrapped_result.ndim == 0
293
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
294
+
295
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
296
+ @pytest.mark.parametrize("use_reg", [True, False])
297
+ @pytest.mark.parametrize("stem_idx", [0, 3])
298
+ def test_l1snrdb_wrapper_single_stem_4d(dummy_stems_4d, l1_weight, use_reg, stem_idx):
299
+ """Test L1SNRDBLoss wrapper with stem_dimension=k on [B,S,C,T]."""
300
+ estimates, actuals = dummy_stems_4d
301
+ base_loss = L1SNRDBLoss(name="test", weight=1.0, l1_weight=l1_weight, use_regularization=use_reg)
302
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
303
+
304
+ wrapped_result = wrapped_loss(estimates, actuals)
305
+ est_slice = estimates[:, stem_idx, :, :]
306
+ act_slice = actuals[:, stem_idx, :, :]
307
+ direct_result = base_loss(est_slice, act_slice)
308
+
309
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
310
+ assert wrapped_result.ndim == 0
311
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
312
+
313
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
314
+ def test_stft_wrapper_all_stems_3d(dummy_stems_3d, l1_weight):
315
+ """Test STFTL1SNRDBLoss wrapper with stem_dimension=None on [B,S,T]."""
316
+ estimates, actuals = dummy_stems_3d
317
+ base_loss = STFTL1SNRDBLoss(
318
+ name="test", weight=1.0, l1_weight=l1_weight,
319
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
320
+ )
321
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
322
+
323
+ wrapped_result = wrapped_loss(estimates, actuals)
324
+ direct_result = base_loss(estimates, actuals)
325
+
326
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
327
+ assert wrapped_result.ndim == 0
328
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
329
+
330
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
331
+ def test_stft_wrapper_all_stems_4d(dummy_stems_4d, l1_weight):
332
+ """Test STFTL1SNRDBLoss wrapper with stem_dimension=None on [B,S,C,T]."""
333
+ estimates, actuals = dummy_stems_4d
334
+ base_loss = STFTL1SNRDBLoss(
335
+ name="test", weight=1.0, l1_weight=l1_weight,
336
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
337
+ )
338
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
339
+
340
+ wrapped_result = wrapped_loss(estimates, actuals)
341
+ direct_result = base_loss(estimates, actuals)
342
+
343
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
344
+ assert wrapped_result.ndim == 0
345
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
346
+
347
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
348
+ @pytest.mark.parametrize("stem_idx", [0, 3])
349
+ def test_stft_wrapper_single_stem_3d(dummy_stems_3d, l1_weight, stem_idx):
350
+ """Test STFTL1SNRDBLoss wrapper with stem_dimension=k on [B,S,T]."""
351
+ estimates, actuals = dummy_stems_3d
352
+ base_loss = STFTL1SNRDBLoss(
353
+ name="test", weight=1.0, l1_weight=l1_weight,
354
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
355
+ )
356
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
357
+
358
+ wrapped_result = wrapped_loss(estimates, actuals)
359
+ est_slice = estimates[:, stem_idx, :]
360
+ act_slice = actuals[:, stem_idx, :]
361
+ direct_result = base_loss(est_slice, act_slice)
362
+
363
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
364
+ assert wrapped_result.ndim == 0
365
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
366
+
367
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
368
+ @pytest.mark.parametrize("stem_idx", [0, 3])
369
+ def test_stft_wrapper_single_stem_4d(dummy_stems_4d, l1_weight, stem_idx):
370
+ """Test STFTL1SNRDBLoss wrapper with stem_dimension=k on [B,S,C,T]."""
371
+ estimates, actuals = dummy_stems_4d
372
+ base_loss = STFTL1SNRDBLoss(
373
+ name="test", weight=1.0, l1_weight=l1_weight,
374
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
375
+ )
376
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
377
+
378
+ wrapped_result = wrapped_loss(estimates, actuals)
379
+ est_slice = estimates[:, stem_idx, :, :]
380
+ act_slice = actuals[:, stem_idx, :, :]
381
+ direct_result = base_loss(est_slice, act_slice)
382
+
383
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
384
+ assert wrapped_result.ndim == 0
385
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
386
+
387
+ def test_stft_wrapper_short_audio_3d():
388
+ """Test STFTL1SNRDBLoss wrapper fallback path with short audio [B,S,T]."""
389
+ estimates = torch.randn(2, 4, 400) # Short audio
390
+ actuals = torch.randn(2, 4, 400)
391
+ actuals[:, 0, :100] += 0.1
392
+
393
+ base_loss = STFTL1SNRDBLoss(
394
+ name="test", weight=1.0, l1_weight=0.0,
395
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=512
396
+ )
397
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
398
+
399
+ wrapped_result = wrapped_loss(estimates, actuals)
400
+ direct_result = base_loss(estimates, actuals)
401
+
402
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
403
+ assert wrapped_result.ndim == 0
404
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
405
+
406
+ def test_stft_wrapper_short_audio_4d():
407
+ """Test STFTL1SNRDBLoss wrapper fallback path with short audio [B,S,C,T]."""
408
+ estimates = torch.randn(2, 4, 1, 400) # Short audio
409
+ actuals = torch.randn(2, 4, 1, 400)
410
+ actuals[:, 0, :, :100] += 0.1
411
+
412
+ base_loss = STFTL1SNRDBLoss(
413
+ name="test", weight=1.0, l1_weight=0.0,
414
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=512
415
+ )
416
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
417
+
418
+ wrapped_result = wrapped_loss(estimates, actuals)
419
+ direct_result = base_loss(estimates, actuals)
420
+
421
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
422
+ assert wrapped_result.ndim == 0
423
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
424
+
425
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
426
+ @pytest.mark.parametrize("use_time_reg", [True, False])
427
+ def test_multi_wrapper_all_stems_3d(dummy_stems_3d, l1_weight, use_time_reg):
428
+ """Test MultiL1SNRDBLoss wrapper with stem_dimension=None on [B,S,T]."""
429
+ estimates, actuals = dummy_stems_3d
430
+ base_loss = MultiL1SNRDBLoss(
431
+ name="test", weight=1.0, l1_weight=l1_weight,
432
+ use_time_regularization=use_time_reg, use_spec_regularization=False,
433
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
434
+ )
435
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
436
+
437
+ wrapped_result = wrapped_loss(estimates, actuals)
438
+ direct_result = base_loss(estimates, actuals)
439
+
440
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
441
+ assert wrapped_result.ndim == 0
442
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
443
+
444
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
445
+ @pytest.mark.parametrize("use_time_reg", [True, False])
446
+ def test_multi_wrapper_all_stems_4d(dummy_stems_4d, l1_weight, use_time_reg):
447
+ """Test MultiL1SNRDBLoss wrapper with stem_dimension=None on [B,S,C,T]."""
448
+ estimates, actuals = dummy_stems_4d
449
+ base_loss = MultiL1SNRDBLoss(
450
+ name="test", weight=1.0, l1_weight=l1_weight,
451
+ use_time_regularization=use_time_reg, use_spec_regularization=False,
452
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
453
+ )
454
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
455
+
456
+ wrapped_result = wrapped_loss(estimates, actuals)
457
+ direct_result = base_loss(estimates, actuals)
458
+
459
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
460
+ assert wrapped_result.ndim == 0
461
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
462
+
463
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
464
+ @pytest.mark.parametrize("use_time_reg", [True, False])
465
+ @pytest.mark.parametrize("stem_idx", [0, 3])
466
+ def test_multi_wrapper_single_stem_3d(dummy_stems_3d, l1_weight, use_time_reg, stem_idx):
467
+ """Test MultiL1SNRDBLoss wrapper with stem_dimension=k on [B,S,T]."""
468
+ estimates, actuals = dummy_stems_3d
469
+ base_loss = MultiL1SNRDBLoss(
470
+ name="test", weight=1.0, l1_weight=l1_weight,
471
+ use_time_regularization=use_time_reg, use_spec_regularization=False,
472
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
473
+ )
474
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
475
+
476
+ wrapped_result = wrapped_loss(estimates, actuals)
477
+ est_slice = estimates[:, stem_idx, :]
478
+ act_slice = actuals[:, stem_idx, :]
479
+ direct_result = base_loss(est_slice, act_slice)
480
+
481
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
482
+ assert wrapped_result.ndim == 0
483
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
484
+
485
+ @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
486
+ @pytest.mark.parametrize("use_time_reg", [True, False])
487
+ @pytest.mark.parametrize("stem_idx", [0, 3])
488
+ def test_multi_wrapper_single_stem_4d(dummy_stems_4d, l1_weight, use_time_reg, stem_idx):
489
+ """Test MultiL1SNRDBLoss wrapper with stem_dimension=k on [B,S,C,T]."""
490
+ estimates, actuals = dummy_stems_4d
491
+ base_loss = MultiL1SNRDBLoss(
492
+ name="test", weight=1.0, l1_weight=l1_weight,
493
+ use_time_regularization=use_time_reg, use_spec_regularization=False,
494
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=256
495
+ )
496
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=stem_idx)
497
+
498
+ wrapped_result = wrapped_loss(estimates, actuals)
499
+ est_slice = estimates[:, stem_idx, :, :]
500
+ act_slice = actuals[:, stem_idx, :, :]
501
+ direct_result = base_loss(est_slice, act_slice)
502
+
503
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
504
+ assert wrapped_result.ndim == 0
505
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
506
+
507
+ def test_multi_wrapper_short_audio_3d():
508
+ """Test MultiL1SNRDBLoss wrapper fallback path with short audio [B,S,T]."""
509
+ estimates = torch.randn(2, 4, 400) # Short audio
510
+ actuals = torch.randn(2, 4, 400)
511
+ actuals[:, 0, :100] += 0.1
512
+
513
+ base_loss = MultiL1SNRDBLoss(
514
+ name="test", weight=1.0, l1_weight=0.0,
515
+ use_time_regularization=True, use_spec_regularization=False,
516
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=512
517
+ )
518
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
519
+
520
+ wrapped_result = wrapped_loss(estimates, actuals)
521
+ direct_result = base_loss(estimates, actuals)
522
+
523
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
524
+ assert wrapped_result.ndim == 0
525
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
526
+
527
+ def test_multi_wrapper_short_audio_4d():
528
+ """Test MultiL1SNRDBLoss wrapper fallback path with short audio [B,S,C,T]."""
529
+ estimates = torch.randn(2, 4, 1, 400) # Short audio
530
+ actuals = torch.randn(2, 4, 1, 400)
531
+ actuals[:, 0, :, :100] += 0.1
532
+
533
+ base_loss = MultiL1SNRDBLoss(
534
+ name="test", weight=1.0, l1_weight=0.0,
535
+ use_time_regularization=True, use_spec_regularization=False,
536
+ n_ffts=[256], hop_lengths=[64], win_lengths=[256], min_audio_length=512
537
+ )
538
+ wrapped_loss = StemWrappedLoss(base_loss, stem_dimension=None)
539
+
540
+ wrapped_result = wrapped_loss(estimates, actuals)
541
+ direct_result = base_loss(estimates, actuals)
542
+
543
+ assert torch.allclose(wrapped_result, direct_result, atol=1e-6)
544
+ assert wrapped_result.ndim == 0
545
+ assert not torch.isnan(wrapped_result) and not torch.isinf(wrapped_result)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-l1-snr
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: L1-SNR loss functions for audio source separation in PyTorch
5
5
  Home-page: https://github.com/crlandsc/torch-l1-snr
6
6
  Author: Christopher Landscaping
@@ -12,4 +12,6 @@ __all__ = [
12
12
  "L1SNRDBLoss",
13
13
  "STFTL1SNRDBLoss",
14
14
  "MultiL1SNRDBLoss",
15
- ]
15
+ ]
16
+
17
+ __version__ = "0.0.5"
@@ -1,143 +0,0 @@
1
- import torch
2
- import pytest
3
- from torch_l1snr import (
4
- dbrms,
5
- L1SNRLoss,
6
- L1SNRDBLoss,
7
- STFTL1SNRDBLoss,
8
- MultiL1SNRDBLoss,
9
- )
10
-
11
- # --- Test Fixtures ---
12
- @pytest.fixture
13
- def dummy_audio():
14
- """Provides a batch of dummy audio signals."""
15
- estimates = torch.randn(2, 16000)
16
- actuals = torch.randn(2, 16000)
17
- # Ensure actuals are not all zero to avoid division by zero in loss
18
- actuals[0, :100] += 0.1
19
- return estimates, actuals
20
-
21
- @pytest.fixture
22
- def dummy_stems():
23
- """Provides a batch of dummy multi-stem signals."""
24
- estimates = torch.randn(2, 4, 1, 16000) # batch, stems, channels, samples
25
- actuals = torch.randn(2, 4, 1, 16000)
26
- actuals[:, 0, :, :100] += 0.1 # Ensure not all zero
27
- return estimates, actuals
28
-
29
- # --- Test Functions ---
30
-
31
- def test_dbrms():
32
- signal = torch.ones(2, 1000) * 0.1
33
- # RMS of 0.1 is -20 dB
34
- assert torch.allclose(dbrms(signal), torch.tensor([-20.0, -20.0]), atol=1e-4)
35
-
36
- zeros = torch.zeros(2, 1000)
37
- # dbrms of zero should be -80dB with default eps=1e-8
38
- assert torch.allclose(dbrms(zeros), torch.tensor([-80.0, -80.0]), atol=1e-4)
39
-
40
- def test_l1snr_loss(dummy_audio):
41
- estimates, actuals = dummy_audio
42
- loss_fn = L1SNRLoss(name="test")
43
- loss = loss_fn(estimates, actuals)
44
-
45
- assert isinstance(loss, torch.Tensor)
46
- assert loss.ndim == 0
47
- assert not torch.isnan(loss)
48
- assert not torch.isinf(loss)
49
-
50
- def test_l1snrdb_loss_time(dummy_audio):
51
- estimates, actuals = dummy_audio
52
-
53
- # Test with default settings (L1SNR + Regularization)
54
- loss_fn = L1SNRDBLoss(name="test", use_regularization=True, l1_weight=0.0)
55
- loss = loss_fn(estimates, actuals)
56
- assert loss.ndim == 0 and not torch.isnan(loss)
57
-
58
- # Test without regularization
59
- loss_fn_no_reg = L1SNRDBLoss(name="test_no_reg", use_regularization=False, l1_weight=0.0)
60
- loss_no_reg = loss_fn_no_reg(estimates, actuals)
61
- assert loss_no_reg.ndim == 0 and not torch.isnan(loss_no_reg)
62
-
63
- # Test with L1 loss component
64
- loss_fn_l1 = L1SNRDBLoss(name="test_l1", l1_weight=0.2)
65
- loss_l1 = loss_fn_l1(estimates, actuals)
66
- assert loss_l1.ndim == 0 and not torch.isnan(loss_l1)
67
-
68
- # Test pure L1 loss mode
69
- loss_fn_pure_l1 = L1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
70
- pure_l1_loss = loss_fn_pure_l1(estimates, actuals)
71
- # Pure L1 mode uses torch.nn.L1Loss, so compare with manual L1 calculation
72
- l1_loss_manual = torch.nn.L1Loss()(
73
- estimates.reshape(estimates.shape[0], -1),
74
- actuals.reshape(actuals.shape[0], -1)
75
- )
76
- assert torch.allclose(pure_l1_loss, l1_loss_manual)
77
-
78
- def test_stft_l1snrdb_loss(dummy_audio):
79
- estimates, actuals = dummy_audio
80
-
81
- # Test with default settings
82
- loss_fn = STFTL1SNRDBLoss(name="test", l1_weight=0.0)
83
- loss = loss_fn(estimates, actuals)
84
- assert loss.ndim == 0 and not torch.isnan(loss) and not torch.isinf(loss)
85
-
86
- # Test pure L1 mode
87
- loss_fn_pure_l1 = STFTL1SNRDBLoss(name="test_pure_l1", l1_weight=1.0)
88
- l1_loss = loss_fn_pure_l1(estimates, actuals)
89
- assert l1_loss.ndim == 0 and not torch.isnan(l1_loss) and not torch.isinf(l1_loss)
90
-
91
- # Test with very short audio
92
- short_estimates = estimates[:, :500]
93
- short_actuals = actuals[:, :500]
94
- loss_short = loss_fn(short_estimates, short_actuals)
95
- # min_audio_length is 512, so this should fallback to time-domain loss
96
- assert loss_short.ndim == 0 and not torch.isnan(loss_short)
97
-
98
- def test_stem_multi_loss(dummy_stems):
99
- estimates, actuals = dummy_stems
100
-
101
- # Test with a specific stem - users now manage stems manually by slicing
102
- # Extract stem 1 (second stem) manually
103
- est_stem = estimates[:, 1, ...] # Shape: [batch, channels, samples]
104
- act_stem = actuals[:, 1, ...]
105
- loss_fn_stem = MultiL1SNRDBLoss(
106
- name="test_loss_stem",
107
- spec_weight=0.5,
108
- l1_weight=0.1
109
- )
110
- loss = loss_fn_stem(est_stem, act_stem)
111
- assert loss.ndim == 0 and not torch.isnan(loss)
112
-
113
- # Test with all stems jointly - flatten all stems together
114
- # Reshape to [batch, -1] to process all stems at once
115
- est_all = estimates.reshape(estimates.shape[0], -1)
116
- act_all = actuals.reshape(actuals.shape[0], -1)
117
- loss_fn_all = MultiL1SNRDBLoss(
118
- name="test_loss_all",
119
- spec_weight=0.5,
120
- l1_weight=0.1
121
- )
122
- loss_all = loss_fn_all(est_all, act_all)
123
- assert loss_all.ndim == 0 and not torch.isnan(loss_all)
124
-
125
- # Test pure L1 mode on all stems
126
- loss_fn_l1 = MultiL1SNRDBLoss(name="l1_only", l1_weight=1.0)
127
- l1_loss = loss_fn_l1(est_all, act_all)
128
-
129
- # Can't easily compute multi-res STFT L1 here, but can check it's not nan
130
- assert l1_loss.ndim == 0 and not torch.isnan(l1_loss)
131
-
132
- @pytest.mark.parametrize("l1_weight", [0.0, 0.5, 1.0])
133
- def test_loss_variants(dummy_audio, l1_weight):
134
- """Test L1SNRDBLoss and STFTL1SNRDBLoss with different l1_weights."""
135
- estimates, actuals = dummy_audio
136
-
137
- time_loss_fn = L1SNRDBLoss(name=f"test_time_{l1_weight}", l1_weight=l1_weight)
138
- time_loss = time_loss_fn(estimates, actuals)
139
- assert not torch.isnan(time_loss) and not torch.isinf(time_loss)
140
-
141
- spec_loss_fn = STFTL1SNRDBLoss(name=f"test_spec_{l1_weight}", l1_weight=l1_weight)
142
- spec_loss = spec_loss_fn(estimates, actuals)
143
- assert not torch.isnan(spec_loss) and not torch.isinf(spec_loss)
File without changes
File without changes