torch-l1-snr 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_l1_snr/__init__.py +1 -1
- torch_l1_snr/l1snr.py +9 -7
- {torch_l1_snr-0.1.1.dist-info → torch_l1_snr-0.1.2.dist-info}/METADATA +2 -2
- torch_l1_snr-0.1.2.dist-info/RECORD +7 -0
- torch_l1_snr-0.1.1.dist-info/RECORD +0 -7
- {torch_l1_snr-0.1.1.dist-info → torch_l1_snr-0.1.2.dist-info}/WHEEL +0 -0
- {torch_l1_snr-0.1.1.dist-info → torch_l1_snr-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {torch_l1_snr-0.1.1.dist-info → torch_l1_snr-0.1.2.dist-info}/top_level.txt +0 -0
torch_l1_snr/__init__.py
CHANGED
torch_l1_snr/l1snr.py
CHANGED
|
@@ -104,9 +104,10 @@ class L1SNRLoss(torch.nn.Module):
|
|
|
104
104
|
l1snr_loss = torch.mean(d1)
|
|
105
105
|
|
|
106
106
|
c = 10.0 / math.log(10.0)
|
|
107
|
-
|
|
108
|
-
#
|
|
109
|
-
|
|
107
|
+
# Scale by reference signal magnitude (not error) to preserve gradient distinction
|
|
108
|
+
# L1SNR has inverse-error gradients; L1 should have uniform gradients
|
|
109
|
+
inv_ref_mean = torch.mean(1.0 / (l1_true.detach() + self.eps))
|
|
110
|
+
scale_time = c * inv_ref_mean
|
|
110
111
|
l1_term = torch.mean(l1_error) * scale_time
|
|
111
112
|
|
|
112
113
|
loss = (1.0 - w) * l1snr_loss + w * l1_term
|
|
@@ -446,10 +447,11 @@ class STFTL1SNRDBLoss(torch.nn.Module):
|
|
|
446
447
|
w = float(self.l1_weight)
|
|
447
448
|
if 0.0 < w < 1.0:
|
|
448
449
|
c = 10.0 / math.log(10.0)
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
450
|
+
# Scale by reference signal magnitude (not error) to preserve gradient distinction
|
|
451
|
+
# L1SNR has inverse-error gradients; L1 should have uniform gradients
|
|
452
|
+
inv_ref_mean_comp = torch.mean(0.5 * (1.0 / (ref_re.detach() + self.l1snr_eps) +
|
|
453
|
+
1.0 / (ref_im.detach() + self.l1snr_eps)))
|
|
454
|
+
scale_spec = 2.0 * c * inv_ref_mean_comp
|
|
453
455
|
l1_term = 0.5 * (torch.mean(err_re) + torch.mean(err_im)) * scale_spec
|
|
454
456
|
|
|
455
457
|
loss = (1.0 - w) * d1_sum + w * l1_term
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-l1-snr
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: L1-SNR loss functions for audio source separation in PyTorch
|
|
5
5
|
Home-page: https://github.com/crlandsc/torch-l1-snr
|
|
6
6
|
Author: Christopher Landschoot
|
|
@@ -237,7 +237,7 @@ While this can potentially reduce the "cleanliness" of separations and slightly
|
|
|
237
237
|
|
|
238
238
|
The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
|
|
239
239
|
|
|
240
|
-
**Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), the implementation automatically scales the L1 component to approximately match
|
|
240
|
+
**Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), the implementation automatically scales the L1 component to approximately match gradient magnitudes while preserving distinct gradient behaviors. This helps maintain stable training without manual tuning.
|
|
241
241
|
|
|
242
242
|
## Limitations
|
|
243
243
|
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
torch_l1_snr/__init__.py,sha256=eJvtoXxJqrrayRFH1xsx9oYMC-aobnh5VK1SaxjsytU,244
|
|
2
|
+
torch_l1_snr/l1snr.py,sha256=aA3BlPZRNsXGAHWnk3PFY8lDiH4HcWeFuhmM4gAKYiQ,34711
|
|
3
|
+
torch_l1_snr-0.1.2.dist-info/licenses/LICENSE,sha256=JdS2Pv6DDs3jvXHACGdcHYdiFMe9EO1XGeHkEHLTr8Y,1079
|
|
4
|
+
torch_l1_snr-0.1.2.dist-info/METADATA,sha256=hBBGDWtSD31aCu31K6bMZuTRBCqCx5UYg6WPXrxsBn4,15722
|
|
5
|
+
torch_l1_snr-0.1.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
6
|
+
torch_l1_snr-0.1.2.dist-info/top_level.txt,sha256=VUo0QlGvu7tOF8BKWWDoIiLlhcAcetYwR6c8Ldhhpco,13
|
|
7
|
+
torch_l1_snr-0.1.2.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
torch_l1_snr/__init__.py,sha256=L3Cpdpnhz80gpfcTf6aNM1ROPdOIbdNoN8vO9LxcZEQ,244
|
|
2
|
-
torch_l1_snr/l1snr.py,sha256=F1NF3VGodaLWFtHs9xco9MbxfEJ01ip_JSHFS2GgBkU,34520
|
|
3
|
-
torch_l1_snr-0.1.1.dist-info/licenses/LICENSE,sha256=JdS2Pv6DDs3jvXHACGdcHYdiFMe9EO1XGeHkEHLTr8Y,1079
|
|
4
|
-
torch_l1_snr-0.1.1.dist-info/METADATA,sha256=8zC2S_NgV8B4Wg59QJwoupNfF063TxD6nFEzXJdeZIw,15704
|
|
5
|
-
torch_l1_snr-0.1.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
6
|
-
torch_l1_snr-0.1.1.dist-info/top_level.txt,sha256=VUo0QlGvu7tOF8BKWWDoIiLlhcAcetYwR6c8Ldhhpco,13
|
|
7
|
-
torch_l1_snr-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|