torch-l1-snr 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torch_l1_snr/__init__.py CHANGED
@@ -14,4 +14,4 @@ __all__ = [
14
14
  "MultiL1SNRDBLoss",
15
15
  ]
16
16
 
17
- __version__ = "0.1.1"
17
+ __version__ = "0.1.2"
torch_l1_snr/l1snr.py CHANGED
@@ -104,9 +104,10 @@ class L1SNRLoss(torch.nn.Module):
104
104
  l1snr_loss = torch.mean(d1)
105
105
 
106
106
  c = 10.0 / math.log(10.0)
107
- inv_mean = torch.mean(1.0 / (l1_error.detach() + self.eps))
108
- # w-independent scaling to match typical gradient magnitudes
109
- scale_time = c * inv_mean
107
+ # Scale by reference signal magnitude (not error) to preserve gradient distinction
108
+ # L1SNR has inverse-error gradients; L1 should have uniform gradients
109
+ inv_ref_mean = torch.mean(1.0 / (l1_true.detach() + self.eps))
110
+ scale_time = c * inv_ref_mean
110
111
  l1_term = torch.mean(l1_error) * scale_time
111
112
 
112
113
  loss = (1.0 - w) * l1snr_loss + w * l1_term
@@ -446,10 +447,11 @@ class STFTL1SNRDBLoss(torch.nn.Module):
446
447
  w = float(self.l1_weight)
447
448
  if 0.0 < w < 1.0:
448
449
  c = 10.0 / math.log(10.0)
449
- inv_mean_comp = torch.mean(0.5 * (1.0 / (err_re.detach() + self.l1snr_eps) +
450
- 1.0 / (err_im.detach() + self.l1snr_eps)))
451
- # w-independent scaling to match typical gradient magnitudes (factor 2.0 for Re/Im symmetry)
452
- scale_spec = 2.0 * c * inv_mean_comp
450
+ # Scale by reference signal magnitude (not error) to preserve gradient distinction
451
+ # L1SNR has inverse-error gradients; L1 should have uniform gradients
452
+ inv_ref_mean_comp = torch.mean(0.5 * (1.0 / (ref_re.detach() + self.l1snr_eps) +
453
+ 1.0 / (ref_im.detach() + self.l1snr_eps)))
454
+ scale_spec = 2.0 * c * inv_ref_mean_comp
453
455
  l1_term = 0.5 * (torch.mean(err_re) + torch.mean(err_im)) * scale_spec
454
456
 
455
457
  loss = (1.0 - w) * d1_sum + w * l1_term
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-l1-snr
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: L1-SNR loss functions for audio source separation in PyTorch
5
5
  Home-page: https://github.com/crlandsc/torch-l1-snr
6
6
  Author: Christopher Landschoot
@@ -237,7 +237,7 @@ While this can potentially reduce the "cleanliness" of separations and slightly
237
237
 
238
238
  The implementation is optimized for efficiency: if `l1_weight` is `0.0` or `1.0`, the unused loss component is not computed, saving computational resources.
239
239
 
240
- **Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), the implementation automatically scales the L1 component to approximately match the gradient magnitudes of the L1SNR component. This helps maintain stable training without manual tuning.
240
+ **Note on Gradient Balancing:** When blending losses (`0.0 < l1_weight < 1.0`), the implementation automatically scales the L1 component to approximately match gradient magnitudes while preserving distinct gradient behaviors. This helps maintain stable training without manual tuning.
241
241
 
242
242
  ## Limitations
243
243
 
@@ -0,0 +1,7 @@
1
+ torch_l1_snr/__init__.py,sha256=eJvtoXxJqrrayRFH1xsx9oYMC-aobnh5VK1SaxjsytU,244
2
+ torch_l1_snr/l1snr.py,sha256=aA3BlPZRNsXGAHWnk3PFY8lDiH4HcWeFuhmM4gAKYiQ,34711
3
+ torch_l1_snr-0.1.2.dist-info/licenses/LICENSE,sha256=JdS2Pv6DDs3jvXHACGdcHYdiFMe9EO1XGeHkEHLTr8Y,1079
4
+ torch_l1_snr-0.1.2.dist-info/METADATA,sha256=hBBGDWtSD31aCu31K6bMZuTRBCqCx5UYg6WPXrxsBn4,15722
5
+ torch_l1_snr-0.1.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
6
+ torch_l1_snr-0.1.2.dist-info/top_level.txt,sha256=VUo0QlGvu7tOF8BKWWDoIiLlhcAcetYwR6c8Ldhhpco,13
7
+ torch_l1_snr-0.1.2.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- torch_l1_snr/__init__.py,sha256=L3Cpdpnhz80gpfcTf6aNM1ROPdOIbdNoN8vO9LxcZEQ,244
2
- torch_l1_snr/l1snr.py,sha256=F1NF3VGodaLWFtHs9xco9MbxfEJ01ip_JSHFS2GgBkU,34520
3
- torch_l1_snr-0.1.1.dist-info/licenses/LICENSE,sha256=JdS2Pv6DDs3jvXHACGdcHYdiFMe9EO1XGeHkEHLTr8Y,1079
4
- torch_l1_snr-0.1.1.dist-info/METADATA,sha256=8zC2S_NgV8B4Wg59QJwoupNfF063TxD6nFEzXJdeZIw,15704
5
- torch_l1_snr-0.1.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
6
- torch_l1_snr-0.1.1.dist-info/top_level.txt,sha256=VUo0QlGvu7tOF8BKWWDoIiLlhcAcetYwR6c8Ldhhpco,13
7
- torch_l1_snr-0.1.1.dist-info/RECORD,,