flamo 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flamo/optimize/loss.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
2
  import torch.nn as nn
3
3
  import numpy as np
4
4
  from flamo.optimize.utils import generate_partitions
5
+ from flamo.processor.dsp import HouseholderMatrix
5
6
  from nnAudio import features
6
7
  import pyfar as pf
7
8
  import torch.nn.functional as F
@@ -34,23 +35,30 @@ class sparsity_loss(nn.Module):
34
35
 
35
36
  def forward(self, y_pred: torch.Tensor, y_target: torch.Tensor, model: nn.Module):
36
37
  core = model.get_core()
38
+ # Try to get the mixing matrix from different possible locations
39
+ mixing_matrix = None
37
40
  try:
38
- A = core.feedback_loop.feedback.map(core.feedback_loop.feedback.param)
41
+ mixing_matrix = core.feedback_loop.feedback
42
+ A = mixing_matrix.map(mixing_matrix.param)
39
43
  except:
40
44
  try:
41
- A = core.feedback_loop.feedback.mixing_matrix.map(
42
- core.feedback_loop.feedback.mixing_matrix.param
43
- )
45
+ mixing_matrix = core.feedback_loop.feedback.mixing_matrix
46
+ A = mixing_matrix.map(mixing_matrix.param)
44
47
  except:
45
- A = core.branchA.feedback_loop.feedback.mixing_matrix.map(
46
- core.branchA.feedback_loop.feedback.mixing_matrix.param
47
- )
48
+ mixing_matrix = core.branchA.feedback_loop.feedback.mixing_matrix
49
+ A = mixing_matrix.map(mixing_matrix.param)
50
+
51
+ if isinstance(mixing_matrix, HouseholderMatrix):
52
+ u = A
53
+ A = torch.eye(u.shape[0], device=u.device, dtype=u.dtype) - 2 * u @ u.T
54
+
48
55
  N = A.shape[-1]
49
56
  if len(A.shape) == 3:
50
57
  return torch.mean(
51
58
  (torch.sum(torch.abs(A), dim=(-2, -1)) - N * np.sqrt(N))
52
59
  / (N * (1 - np.sqrt(N)))
53
60
  )
61
+
54
62
  # A = torch.matrix_exp(skew_matrix(A))
55
63
  return -(torch.sum(torch.abs(A)) - N * np.sqrt(N)) / (N * (np.sqrt(N) - 1))
56
64
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flamo
3
- Version: 0.2.10
3
+ Version: 0.2.11
4
4
  Summary: An Open-Source Library for Frequency-Domain Differentiable Audio Processing
5
5
  Project-URL: Homepage, https://github.com/gdalsanto/flamo
6
6
  Project-URL: Issues, https://github.com/gdalsanto/flamo/issues
@@ -11,14 +11,14 @@ flamo/auxiliary/velvet.py,sha256=B4pYEnhaQPkh02pxqiGdAhLRX2g-eWtHezphi0_h4Qs,420
11
11
  flamo/auxiliary/config/config.py,sha256=sZ3XvqwV6KiIc2n8HRtg7YJE3zhc7Vqblbqs-Z0bsKg,2978
12
12
  flamo/optimize/__init__.py,sha256=grgxLmQ7m-c9MvRdIejmEAaaajfBwgeaZAv2qjHIvPw,65
13
13
  flamo/optimize/dataset.py,sha256=WPvWDhT-U-gFkPaP1UzvFfB2bxlxdDDQ64zQ2-OcbYY,6789
14
- flamo/optimize/loss.py,sha256=h6EeqjdX5P1SqDBKBavSxV25VBgnYK8tuX91wk6lw_g,33466
14
+ flamo/optimize/loss.py,sha256=sc_E5Dp1QQtQWvTXB6890jUiVJintV6rnkOmjPgr0Ow,33781
15
15
  flamo/optimize/surface.py,sha256=sWy1ImwxUh_QLoY6S68LXBa82_HdWJGplFg2ObtpNGc,26655
16
16
  flamo/optimize/trainer.py,sha256=LITPVS87mI6bnq4J6GIXqGb4wW7TKWVXeCu4UQ-csxM,12155
17
17
  flamo/optimize/utils.py,sha256=R5-KoZagRho3eykY88pC3UB2mc5SsE4Yv9X-ogskXdA,1610
18
18
  flamo/processor/__init__.py,sha256=paGdxGVZgA2VAs0tBwRd0bobzGxeyK79DS7ZGO8drkI,41
19
19
  flamo/processor/dsp.py,sha256=N9IZNO7taV9zwx67g1rVoQEedF4EVHXqGp23Z0BqohA,147358
20
20
  flamo/processor/system.py,sha256=Hct-o6IgF5NQ2xYbX-1j3st94hMoM8dOgAzle2gjDqU,43145
21
- flamo-0.2.10.dist-info/METADATA,sha256=E3OE6I1j_xJH8IM-BGZZhkzPfu5beXI2jpYQDz98Cko,7831
22
- flamo-0.2.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
23
- flamo-0.2.10.dist-info/licenses/LICENSE,sha256=smMocRH7xdPT5RvFNqSLtbSNzohXJM5G_rX1Qaej6vg,1120
24
- flamo-0.2.10.dist-info/RECORD,,
21
+ flamo-0.2.11.dist-info/METADATA,sha256=NSC11LRxikbEYJiODC7B7rEHjea8GyR2lonUFWxJ0lk,7831
22
+ flamo-0.2.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
23
+ flamo-0.2.11.dist-info/licenses/LICENSE,sha256=smMocRH7xdPT5RvFNqSLtbSNzohXJM5G_rX1Qaej6vg,1120
24
+ flamo-0.2.11.dist-info/RECORD,,
File without changes