flamo 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flamo/optimize/loss.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
import numpy as np
|
|
4
4
|
from flamo.optimize.utils import generate_partitions
|
|
5
|
+
from flamo.processor.dsp import HouseholderMatrix
|
|
5
6
|
from nnAudio import features
|
|
6
7
|
import pyfar as pf
|
|
7
8
|
import torch.nn.functional as F
|
|
@@ -34,23 +35,30 @@ class sparsity_loss(nn.Module):
|
|
|
34
35
|
|
|
35
36
|
def forward(self, y_pred: torch.Tensor, y_target: torch.Tensor, model: nn.Module):
|
|
36
37
|
core = model.get_core()
|
|
38
|
+
# Try to get the mixing matrix from different possible locations
|
|
39
|
+
mixing_matrix = None
|
|
37
40
|
try:
|
|
38
|
-
|
|
41
|
+
mixing_matrix = core.feedback_loop.feedback
|
|
42
|
+
A = mixing_matrix.map(mixing_matrix.param)
|
|
39
43
|
except:
|
|
40
44
|
try:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
)
|
|
45
|
+
mixing_matrix = core.feedback_loop.feedback.mixing_matrix
|
|
46
|
+
A = mixing_matrix.map(mixing_matrix.param)
|
|
44
47
|
except:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
+
mixing_matrix = core.branchA.feedback_loop.feedback.mixing_matrix
|
|
49
|
+
A = mixing_matrix.map(mixing_matrix.param)
|
|
50
|
+
|
|
51
|
+
if isinstance(mixing_matrix, HouseholderMatrix):
|
|
52
|
+
u = A
|
|
53
|
+
A = torch.eye(u.shape[0], device=u.device, dtype=u.dtype) - 2 * u @ u.T
|
|
54
|
+
|
|
48
55
|
N = A.shape[-1]
|
|
49
56
|
if len(A.shape) == 3:
|
|
50
57
|
return torch.mean(
|
|
51
58
|
(torch.sum(torch.abs(A), dim=(-2, -1)) - N * np.sqrt(N))
|
|
52
59
|
/ (N * (1 - np.sqrt(N)))
|
|
53
60
|
)
|
|
61
|
+
|
|
54
62
|
# A = torch.matrix_exp(skew_matrix(A))
|
|
55
63
|
return -(torch.sum(torch.abs(A)) - N * np.sqrt(N)) / (N * (np.sqrt(N) - 1))
|
|
56
64
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: flamo
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.11
|
|
4
4
|
Summary: An Open-Source Library for Frequency-Domain Differentiable Audio Processing
|
|
5
5
|
Project-URL: Homepage, https://github.com/gdalsanto/flamo
|
|
6
6
|
Project-URL: Issues, https://github.com/gdalsanto/flamo/issues
|
|
@@ -11,14 +11,14 @@ flamo/auxiliary/velvet.py,sha256=B4pYEnhaQPkh02pxqiGdAhLRX2g-eWtHezphi0_h4Qs,420
|
|
|
11
11
|
flamo/auxiliary/config/config.py,sha256=sZ3XvqwV6KiIc2n8HRtg7YJE3zhc7Vqblbqs-Z0bsKg,2978
|
|
12
12
|
flamo/optimize/__init__.py,sha256=grgxLmQ7m-c9MvRdIejmEAaaajfBwgeaZAv2qjHIvPw,65
|
|
13
13
|
flamo/optimize/dataset.py,sha256=WPvWDhT-U-gFkPaP1UzvFfB2bxlxdDDQ64zQ2-OcbYY,6789
|
|
14
|
-
flamo/optimize/loss.py,sha256=
|
|
14
|
+
flamo/optimize/loss.py,sha256=sc_E5Dp1QQtQWvTXB6890jUiVJintV6rnkOmjPgr0Ow,33781
|
|
15
15
|
flamo/optimize/surface.py,sha256=sWy1ImwxUh_QLoY6S68LXBa82_HdWJGplFg2ObtpNGc,26655
|
|
16
16
|
flamo/optimize/trainer.py,sha256=LITPVS87mI6bnq4J6GIXqGb4wW7TKWVXeCu4UQ-csxM,12155
|
|
17
17
|
flamo/optimize/utils.py,sha256=R5-KoZagRho3eykY88pC3UB2mc5SsE4Yv9X-ogskXdA,1610
|
|
18
18
|
flamo/processor/__init__.py,sha256=paGdxGVZgA2VAs0tBwRd0bobzGxeyK79DS7ZGO8drkI,41
|
|
19
19
|
flamo/processor/dsp.py,sha256=N9IZNO7taV9zwx67g1rVoQEedF4EVHXqGp23Z0BqohA,147358
|
|
20
20
|
flamo/processor/system.py,sha256=Hct-o6IgF5NQ2xYbX-1j3st94hMoM8dOgAzle2gjDqU,43145
|
|
21
|
-
flamo-0.2.
|
|
22
|
-
flamo-0.2.
|
|
23
|
-
flamo-0.2.
|
|
24
|
-
flamo-0.2.
|
|
21
|
+
flamo-0.2.11.dist-info/METADATA,sha256=NSC11LRxikbEYJiODC7B7rEHjea8GyR2lonUFWxJ0lk,7831
|
|
22
|
+
flamo-0.2.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
23
|
+
flamo-0.2.11.dist-info/licenses/LICENSE,sha256=smMocRH7xdPT5RvFNqSLtbSNzohXJM5G_rX1Qaej6vg,1120
|
|
24
|
+
flamo-0.2.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|