flamo 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flamo/auxiliary/config/config.py +2 -1
- flamo/auxiliary/eq.py +2 -1
- flamo/auxiliary/reverb.py +1 -1
- flamo/functional.py +68 -1
- flamo/optimize/loss.py +2 -1
- flamo/processor/dsp.py +41 -10
- {flamo-0.1.3.dist-info → flamo-0.1.5.dist-info}/METADATA +2 -1
- {flamo-0.1.3.dist-info → flamo-0.1.5.dist-info}/RECORD +10 -10
- {flamo-0.1.3.dist-info → flamo-0.1.5.dist-info}/WHEEL +0 -0
- {flamo-0.1.3.dist-info → flamo-0.1.5.dist-info}/licenses/LICENSE +0 -0
flamo/auxiliary/config/config.py
CHANGED
flamo/auxiliary/eq.py
CHANGED
|
@@ -86,7 +86,8 @@ def geq(
|
|
|
86
86
|
|
|
87
87
|
for band in range(num_bands):
|
|
88
88
|
if band == 0:
|
|
89
|
-
b = torch.
|
|
89
|
+
b = torch.zeros(3, device=device)
|
|
90
|
+
b[0] = db2mag(gain_db[band])
|
|
90
91
|
a = torch.tensor([1, 0, 0], device=device)
|
|
91
92
|
elif band == 1:
|
|
92
93
|
b, a = shelving_filter(
|
flamo/auxiliary/reverb.py
CHANGED
|
@@ -129,7 +129,7 @@ class HomogeneousFDN:
|
|
|
129
129
|
size=(self.N,),
|
|
130
130
|
max_len=delay_lines.max(),
|
|
131
131
|
nfft=self.config_dict.nfft,
|
|
132
|
-
isint=
|
|
132
|
+
isint=self.config_dict.is_delay_int,
|
|
133
133
|
requires_grad=self.config_dict.delays_grad,
|
|
134
134
|
alias_decay_db=self.config_dict.alias_decay_db,
|
|
135
135
|
device=self.config_dict.device,
|
flamo/functional.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
import torch.nn as nn
|
|
2
3
|
import numpy as np
|
|
3
4
|
import scipy.signal
|
|
5
|
+
from typing import Optional
|
|
4
6
|
from flamo.utils import RegularGridInterpolator
|
|
5
7
|
|
|
6
8
|
|
|
@@ -69,6 +71,68 @@ def get_frequency_samples(num: int, device: str | torch.device = None):
|
|
|
69
71
|
return torch.polar(abs, angle * np.pi)
|
|
70
72
|
|
|
71
73
|
|
|
74
|
+
class HadamardMatrix(nn.Module):
|
|
75
|
+
"""
|
|
76
|
+
Generate a Hadamard matrix of size N as a nn.Module.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, N, device: Optional[str] = None):
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.N = N
|
|
82
|
+
self.device = device
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
U = torch.tensor([[1.0]], device=self.device)
|
|
86
|
+
while U.shape[0] < self.N:
|
|
87
|
+
U = torch.kron(
|
|
88
|
+
U, torch.tensor([[1, 1], [1, -1]], dtype=U.dtype, device=U.device)
|
|
89
|
+
) / torch.sqrt(torch.tensor(2.0, device=U.device))
|
|
90
|
+
return U
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class RotationMatrix(nn.Module):
|
|
94
|
+
"""
|
|
95
|
+
Generate a rotation matrix of size N as a nn.Module from a given angle.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
N: int,
|
|
101
|
+
min_angle: float = 0,
|
|
102
|
+
max_angle: float = torch.pi / 4,
|
|
103
|
+
iter: Optional[int] = None,
|
|
104
|
+
device: Optional[str] = None,
|
|
105
|
+
):
|
|
106
|
+
|
|
107
|
+
super().__init__()
|
|
108
|
+
self.N = N
|
|
109
|
+
self.min_angle = min_angle
|
|
110
|
+
self.max_angle = max_angle
|
|
111
|
+
self.iter = iter
|
|
112
|
+
self.device = device
|
|
113
|
+
|
|
114
|
+
def create_submatrix(self, angles: torch.Tensor, iters: int = 1):
|
|
115
|
+
"""Create a submatrix for each group."""
|
|
116
|
+
X = torch.zeros(2, 2, device=self.device)
|
|
117
|
+
angles[0] = torch.clamp(angles[0], self.min_angle, self.max_angle)
|
|
118
|
+
X.fill_diagonal_(torch.cos(angles[0]))
|
|
119
|
+
X[1, 0] = -torch.sin(torch.tensor(angles[0], device=self.device))
|
|
120
|
+
X[0, 1] = torch.sin(torch.tensor(angles[0], device=self.device))
|
|
121
|
+
|
|
122
|
+
if iters is None:
|
|
123
|
+
iters = torch.log2(torch.tensor(self.N)).int().item() - 1
|
|
124
|
+
for i in range(iters):
|
|
125
|
+
if len(angles) > 1:
|
|
126
|
+
X = torch.kron(X, self.create_submatrix([angles[i]]))
|
|
127
|
+
else:
|
|
128
|
+
X = torch.kron(X, X)
|
|
129
|
+
return X
|
|
130
|
+
|
|
131
|
+
def forward(self, theta):
|
|
132
|
+
|
|
133
|
+
return self.create_submatrix(theta, self.iter)
|
|
134
|
+
|
|
135
|
+
|
|
72
136
|
def biquad2tf(b: torch.Tensor, a: torch.Tensor, nfft: int):
|
|
73
137
|
r"""
|
|
74
138
|
Converts a biquad filter representation to a transfer function.
|
|
@@ -133,10 +197,11 @@ def signal_gallery(
|
|
|
133
197
|
"wgn",
|
|
134
198
|
"exp",
|
|
135
199
|
"reference",
|
|
200
|
+
"noise",
|
|
136
201
|
}
|
|
137
202
|
|
|
138
203
|
if signal_type not in signal_types:
|
|
139
|
-
raise ValueError(f"
|
|
204
|
+
raise ValueError(f"Signal type {signal_type} not recognized.")
|
|
140
205
|
match signal_type:
|
|
141
206
|
case "impulse":
|
|
142
207
|
x = torch.zeros(batch_size, n_samples, n)
|
|
@@ -185,6 +250,8 @@ def signal_gallery(
|
|
|
185
250
|
return torch.tensor(reference, device=device).expand(
|
|
186
251
|
batch_size, n_samples, n
|
|
187
252
|
)
|
|
253
|
+
case "noise":
|
|
254
|
+
return torch.randn((batch_size, n_samples, n), device=device)
|
|
188
255
|
|
|
189
256
|
|
|
190
257
|
def hertz2rad(hertz: torch.Tensor, fs: int):
|
flamo/optimize/loss.py
CHANGED
|
@@ -38,7 +38,7 @@ class sparsity_loss(nn.Module):
|
|
|
38
38
|
A = core.feedback_loop.feedback.map(core.feedback_loop.feedback.param)
|
|
39
39
|
except:
|
|
40
40
|
try:
|
|
41
|
-
A = core.feedback_loop.feedback.map(
|
|
41
|
+
A = core.feedback_loop.feedback.mixing_matrix.map(
|
|
42
42
|
core.feedback_loop.feedback.mixing_matrix.param
|
|
43
43
|
)
|
|
44
44
|
except:
|
|
@@ -77,6 +77,7 @@ class mse_loss(nn.Module):
|
|
|
77
77
|
self.nfft = nfft
|
|
78
78
|
self.device = device
|
|
79
79
|
self.mse_loss = nn.MSELoss()
|
|
80
|
+
self.name = "MSE"
|
|
80
81
|
|
|
81
82
|
def forward(self, y_pred, y_true):
|
|
82
83
|
"""
|
flamo/processor/dsp.py
CHANGED
|
@@ -8,7 +8,9 @@ from flamo.functional import (
|
|
|
8
8
|
lowpass_filter,
|
|
9
9
|
highpass_filter,
|
|
10
10
|
bandpass_filter,
|
|
11
|
-
rad2hertz
|
|
11
|
+
rad2hertz,
|
|
12
|
+
HadamardMatrix,
|
|
13
|
+
RotationMatrix)
|
|
12
14
|
from flamo.auxiliary.eq import (
|
|
13
15
|
eq_freqs,
|
|
14
16
|
geq,
|
|
@@ -535,11 +537,13 @@ class Matrix(Gain):
|
|
|
535
537
|
nfft: int = 2**11,
|
|
536
538
|
map: callable = lambda x: x,
|
|
537
539
|
matrix_type: str = "random",
|
|
540
|
+
iter: int = 1,
|
|
538
541
|
requires_grad: bool = False,
|
|
539
542
|
alias_decay_db: float = 0.0,
|
|
540
543
|
device: Optional[str] = None,
|
|
541
544
|
):
|
|
542
545
|
self.matrix_type = matrix_type
|
|
546
|
+
self.iter = iter # iterations number for the rotation matrix
|
|
543
547
|
super().__init__(
|
|
544
548
|
size=size,
|
|
545
549
|
nfft=nfft,
|
|
@@ -557,14 +561,31 @@ class Matrix(Gain):
|
|
|
557
561
|
Warning(
|
|
558
562
|
f"you asked for {self.matrix_type} matrix type, map will be overwritten"
|
|
559
563
|
)
|
|
564
|
+
N = self.size[0]
|
|
560
565
|
match self.matrix_type:
|
|
561
566
|
case "random":
|
|
562
567
|
self.map = lambda x: x
|
|
563
568
|
case "orthogonal":
|
|
564
569
|
assert (
|
|
565
|
-
|
|
570
|
+
N == self.size[1]
|
|
566
571
|
), "Matrix must be square to be orthogonal"
|
|
567
572
|
self.map = lambda x: torch.matrix_exp(skew_matrix(x))
|
|
573
|
+
case "hadamard":
|
|
574
|
+
assert (
|
|
575
|
+
N == self.size[1]
|
|
576
|
+
), "Matrix must be square to be Hadamard"
|
|
577
|
+
assert (
|
|
578
|
+
N % 2 == 0
|
|
579
|
+
), "Matrix must have even dimensions to be Hadamard"
|
|
580
|
+
self.map = lambda x: HadamardMatrix(self.size[0], device=self.device)(x)
|
|
581
|
+
case "rotation":
|
|
582
|
+
assert (
|
|
583
|
+
N == self.size[1]
|
|
584
|
+
), "Matrix must be square to be a rotation matrix"
|
|
585
|
+
assert (
|
|
586
|
+
N % 2 == 0
|
|
587
|
+
), "Matrix must have even dimensions to be a rotation matrix"
|
|
588
|
+
self.map = lambda x: RotationMatrix(self.size[0], self.iter, device=self.device)([x[0][0]])
|
|
568
589
|
|
|
569
590
|
def initialize_class(self):
|
|
570
591
|
r"""
|
|
@@ -2722,7 +2743,7 @@ class Delay(DSP):
|
|
|
2722
2743
|
"""
|
|
2723
2744
|
m = self.get_delays()
|
|
2724
2745
|
if self.isint:
|
|
2725
|
-
self.freq_response = lambda param: (self.gamma ** m(param)) * torch.exp(
|
|
2746
|
+
self.freq_response = lambda param: (self.gamma ** m(param).round()) * torch.exp(
|
|
2726
2747
|
-1j
|
|
2727
2748
|
* torch.einsum(
|
|
2728
2749
|
"fo, omn -> fmn",
|
|
@@ -2859,14 +2880,24 @@ class parallelDelay(Delay):
|
|
|
2859
2880
|
Computes the frequency response of the delay module.
|
|
2860
2881
|
"""
|
|
2861
2882
|
m = self.get_delays()
|
|
2862
|
-
self.
|
|
2863
|
-
|
|
2864
|
-
|
|
2865
|
-
|
|
2866
|
-
|
|
2867
|
-
|
|
2883
|
+
if self.isint:
|
|
2884
|
+
self.freq_response = lambda param: (self.gamma ** m(param).round()) * torch.exp(
|
|
2885
|
+
-1j
|
|
2886
|
+
* torch.einsum(
|
|
2887
|
+
"fo, on -> fn",
|
|
2888
|
+
self.omega,
|
|
2889
|
+
m(param).round().unsqueeze(0),
|
|
2890
|
+
)
|
|
2891
|
+
)
|
|
2892
|
+
else:
|
|
2893
|
+
self.freq_response = lambda param: (self.gamma ** m(param)) * torch.exp(
|
|
2894
|
+
-1j
|
|
2895
|
+
* torch.einsum(
|
|
2896
|
+
"fo, on -> fn",
|
|
2897
|
+
self.omega,
|
|
2898
|
+
m(param).unsqueeze(0),
|
|
2899
|
+
)
|
|
2868
2900
|
)
|
|
2869
|
-
)
|
|
2870
2901
|
|
|
2871
2902
|
def get_io(self):
|
|
2872
2903
|
r"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: flamo
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.5
|
|
4
4
|
Summary: An Open-Source Library for Frequency-Domain Differentiable Audio Processing
|
|
5
5
|
Project-URL: Homepage, https://github.com/gdalsanto/flamo
|
|
6
6
|
Project-URL: Issues, https://github.com/gdalsanto/flamo/issues
|
|
@@ -38,6 +38,7 @@ Requires-Dist: numpy
|
|
|
38
38
|
Requires-Dist: pydantic
|
|
39
39
|
Requires-Dist: pyfar
|
|
40
40
|
Requires-Dist: pysoundfile
|
|
41
|
+
Requires-Dist: pyyaml
|
|
41
42
|
Requires-Dist: scipy
|
|
42
43
|
Requires-Dist: torch
|
|
43
44
|
Requires-Dist: torchaudio
|
|
@@ -1,23 +1,23 @@
|
|
|
1
1
|
flamo/__init__.py,sha256=ujezWOJfD7DUoj4q1meeMUnB97rOEtNR7mYw_PE9LMg,49
|
|
2
|
-
flamo/functional.py,sha256=
|
|
2
|
+
flamo/functional.py,sha256=9wl6fHkc8KMB5IMvbd_K7-z8Z2Miw0qOsNxWPItliPU,35138
|
|
3
3
|
flamo/utils.py,sha256=ypGKSABZMphgIrjCKgCH-zgR7BaupRbyzuUhsZFqAAM,3350
|
|
4
4
|
flamo/auxiliary/__init__.py,sha256=7lVNh8OxHloZ4KPmp-iTUJnUbi8XbuRzGaQ3Z-NKXio,42
|
|
5
|
-
flamo/auxiliary/eq.py,sha256=
|
|
5
|
+
flamo/auxiliary/eq.py,sha256=eIWMIq0ggizXLhTdeWWbgBXWUFXCJyoEbkBH7Gzasao,6779
|
|
6
6
|
flamo/auxiliary/filterbank.py,sha256=02w8dI8HoNDtKpdVhSJkIkd-h-KNXvZtivf3l4_ozzU,9866
|
|
7
7
|
flamo/auxiliary/minimize.py,sha256=fMTAAAk9yD7Y4luKS4XA1-HTq44xo2opq_dRPRrhlIY,2474
|
|
8
|
-
flamo/auxiliary/reverb.py,sha256=
|
|
8
|
+
flamo/auxiliary/reverb.py,sha256=9iKSuyuqRiHGGvaj0eizqVpu2V7plsX13OWiB6o1whU,31636
|
|
9
9
|
flamo/auxiliary/scattering.py,sha256=ITPT0TTOAROy3G0_kpykffRSqjoA9dFJ2LnaLxtUMF4,9482
|
|
10
|
-
flamo/auxiliary/config/config.py,sha256=
|
|
10
|
+
flamo/auxiliary/config/config.py,sha256=CxXj-8sLq0_m9KyLg1a6NwLoK1UvTz3i0jZOLraq14I,2893
|
|
11
11
|
flamo/optimize/__init__.py,sha256=grgxLmQ7m-c9MvRdIejmEAaaajfBwgeaZAv2qjHIvPw,65
|
|
12
12
|
flamo/optimize/dataset.py,sha256=2mfzsnyX_bzavXouII9ee_pd6ti4lv215ieGJHscceI,5829
|
|
13
|
-
flamo/optimize/loss.py,sha256=
|
|
13
|
+
flamo/optimize/loss.py,sha256=h6EeqjdX5P1SqDBKBavSxV25VBgnYK8tuX91wk6lw_g,33466
|
|
14
14
|
flamo/optimize/surface.py,sha256=uvsgxLFSvJ18s8kPcb22G3W1rgycXP1nNX0q48Pda2g,26135
|
|
15
15
|
flamo/optimize/trainer.py,sha256=he4nUjLC-3RTlxxBIw33r5k8mQfgAGvN1wpPBAWCjVo,12045
|
|
16
16
|
flamo/optimize/utils.py,sha256=R5-KoZagRho3eykY88pC3UB2mc5SsE4Yv9X-ogskXdA,1610
|
|
17
17
|
flamo/processor/__init__.py,sha256=paGdxGVZgA2VAs0tBwRd0bobzGxeyK79DS7ZGO8drkI,41
|
|
18
|
-
flamo/processor/dsp.py,sha256=
|
|
18
|
+
flamo/processor/dsp.py,sha256=w5BT3lWAJ0kOC7rjVl1dgYFcGdQ6214pa_NyAL-K9QI,119983
|
|
19
19
|
flamo/processor/system.py,sha256=9XwLtaGEVs9glVOFvyiPnQpsnR_Wjrv6k1i1qCs8D1Q,42516
|
|
20
|
-
flamo-0.1.
|
|
21
|
-
flamo-0.1.
|
|
22
|
-
flamo-0.1.
|
|
23
|
-
flamo-0.1.
|
|
20
|
+
flamo-0.1.5.dist-info/METADATA,sha256=hTn11LzrhyQUm4JbHD_JiSmDfpOLbvgrGgzPezvegVo,7825
|
|
21
|
+
flamo-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
22
|
+
flamo-0.1.5.dist-info/licenses/LICENSE,sha256=smMocRH7xdPT5RvFNqSLtbSNzohXJM5G_rX1Qaej6vg,1120
|
|
23
|
+
flamo-0.1.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|