flamo 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,7 +33,8 @@ class HomogeneousFDNConfig(BaseModel):
33
33
  delays_grad: bool = False
34
34
  mixing_matrix_grad: bool = True
35
35
  attenuation_grad: bool = True
36
-
36
+ is_delay_int: bool = True
37
+
37
38
  def __init__(self, **data):
38
39
  super().__init__(**data)
39
40
  if self.delays is None:
flamo/auxiliary/eq.py CHANGED
@@ -86,7 +86,8 @@ def geq(
86
86
 
87
87
  for band in range(num_bands):
88
88
  if band == 0:
89
- b = torch.tensor([db2mag(gain_db[band]), 0, 0], device=device)
89
+ b = torch.zeros(3, device=device)
90
+ b[0] = db2mag(gain_db[band])
90
91
  a = torch.tensor([1, 0, 0], device=device)
91
92
  elif band == 1:
92
93
  b, a = shelving_filter(
flamo/auxiliary/reverb.py CHANGED
@@ -129,7 +129,7 @@ class HomogeneousFDN:
129
129
  size=(self.N,),
130
130
  max_len=delay_lines.max(),
131
131
  nfft=self.config_dict.nfft,
132
- isint=True,
132
+ isint=self.config_dict.is_delay_int,
133
133
  requires_grad=self.config_dict.delays_grad,
134
134
  alias_decay_db=self.config_dict.alias_decay_db,
135
135
  device=self.config_dict.device,
flamo/functional.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import torch
2
+ import torch.nn as nn
2
3
  import numpy as np
3
4
  import scipy.signal
5
+ from typing import Optional
4
6
  from flamo.utils import RegularGridInterpolator
5
7
 
6
8
 
@@ -69,6 +71,68 @@ def get_frequency_samples(num: int, device: str | torch.device = None):
69
71
  return torch.polar(abs, angle * np.pi)
70
72
 
71
73
 
74
+ class HadamardMatrix(nn.Module):
75
+ """
76
+ Generate a Hadamard matrix of size N as a nn.Module.
77
+ """
78
+
79
+ def __init__(self, N, device: Optional[str] = None):
80
+ super().__init__()
81
+ self.N = N
82
+ self.device = device
83
+
84
+ def forward(self, x):
85
+ U = torch.tensor([[1.0]], device=self.device)
86
+ while U.shape[0] < self.N:
87
+ U = torch.kron(
88
+ U, torch.tensor([[1, 1], [1, -1]], dtype=U.dtype, device=U.device)
89
+ ) / torch.sqrt(torch.tensor(2.0, device=U.device))
90
+ return U
91
+
92
+
93
+ class RotationMatrix(nn.Module):
94
+ """
95
+ Generate a rotation matrix of size N as a nn.Module from a given angle.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ N: int,
101
+ min_angle: float = 0,
102
+ max_angle: float = torch.pi / 4,
103
+ iter: Optional[int] = None,
104
+ device: Optional[str] = None,
105
+ ):
106
+
107
+ super().__init__()
108
+ self.N = N
109
+ self.min_angle = min_angle
110
+ self.max_angle = max_angle
111
+ self.iter = iter
112
+ self.device = device
113
+
114
+ def create_submatrix(self, angles: torch.Tensor, iters: int = 1):
115
+ """Create a submatrix for each group."""
116
+ X = torch.zeros(2, 2, device=self.device)
117
+ angles[0] = torch.clamp(angles[0], self.min_angle, self.max_angle)
118
+ X.fill_diagonal_(torch.cos(angles[0]))
119
+ X[1, 0] = -torch.sin(torch.tensor(angles[0], device=self.device))
120
+ X[0, 1] = torch.sin(torch.tensor(angles[0], device=self.device))
121
+
122
+ if iters is None:
123
+ iters = torch.log2(torch.tensor(self.N)).int().item() - 1
124
+ for i in range(iters):
125
+ if len(angles) > 1:
126
+ X = torch.kron(X, self.create_submatrix([angles[i]]))
127
+ else:
128
+ X = torch.kron(X, X)
129
+ return X
130
+
131
+ def forward(self, theta):
132
+
133
+ return self.create_submatrix(theta, self.iter)
134
+
135
+
72
136
  def biquad2tf(b: torch.Tensor, a: torch.Tensor, nfft: int):
73
137
  r"""
74
138
  Converts a biquad filter representation to a transfer function.
@@ -133,10 +197,11 @@ def signal_gallery(
133
197
  "wgn",
134
198
  "exp",
135
199
  "reference",
200
+ "noise",
136
201
  }
137
202
 
138
203
  if signal_type not in signal_types:
139
- raise ValueError(f"Matrix type {signal_type} not recognized.")
204
+ raise ValueError(f"Signal type {signal_type} not recognized.")
140
205
  match signal_type:
141
206
  case "impulse":
142
207
  x = torch.zeros(batch_size, n_samples, n)
@@ -185,6 +250,8 @@ def signal_gallery(
185
250
  return torch.tensor(reference, device=device).expand(
186
251
  batch_size, n_samples, n
187
252
  )
253
+ case "noise":
254
+ return torch.randn((batch_size, n_samples, n), device=device)
188
255
 
189
256
 
190
257
  def hertz2rad(hertz: torch.Tensor, fs: int):
flamo/optimize/loss.py CHANGED
@@ -38,7 +38,7 @@ class sparsity_loss(nn.Module):
38
38
  A = core.feedback_loop.feedback.map(core.feedback_loop.feedback.param)
39
39
  except:
40
40
  try:
41
- A = core.feedback_loop.feedback.map(
41
+ A = core.feedback_loop.feedback.mixing_matrix.map(
42
42
  core.feedback_loop.feedback.mixing_matrix.param
43
43
  )
44
44
  except:
@@ -77,6 +77,7 @@ class mse_loss(nn.Module):
77
77
  self.nfft = nfft
78
78
  self.device = device
79
79
  self.mse_loss = nn.MSELoss()
80
+ self.name = "MSE"
80
81
 
81
82
  def forward(self, y_pred, y_true):
82
83
  """
flamo/processor/dsp.py CHANGED
@@ -8,7 +8,9 @@ from flamo.functional import (
8
8
  lowpass_filter,
9
9
  highpass_filter,
10
10
  bandpass_filter,
11
- rad2hertz )
11
+ rad2hertz,
12
+ HadamardMatrix,
13
+ RotationMatrix)
12
14
  from flamo.auxiliary.eq import (
13
15
  eq_freqs,
14
16
  geq,
@@ -535,11 +537,13 @@ class Matrix(Gain):
535
537
  nfft: int = 2**11,
536
538
  map: callable = lambda x: x,
537
539
  matrix_type: str = "random",
540
+ iter: int = 1,
538
541
  requires_grad: bool = False,
539
542
  alias_decay_db: float = 0.0,
540
543
  device: Optional[str] = None,
541
544
  ):
542
545
  self.matrix_type = matrix_type
546
+ self.iter = iter # iterations number for the rotation matrix
543
547
  super().__init__(
544
548
  size=size,
545
549
  nfft=nfft,
@@ -557,14 +561,31 @@ class Matrix(Gain):
557
561
  Warning(
558
562
  f"you asked for {self.matrix_type} matrix type, map will be overwritten"
559
563
  )
564
+ N = self.size[0]
560
565
  match self.matrix_type:
561
566
  case "random":
562
567
  self.map = lambda x: x
563
568
  case "orthogonal":
564
569
  assert (
565
- self.size[0] == self.size[1]
570
+ N == self.size[1]
566
571
  ), "Matrix must be square to be orthogonal"
567
572
  self.map = lambda x: torch.matrix_exp(skew_matrix(x))
573
+ case "hadamard":
574
+ assert (
575
+ N == self.size[1]
576
+ ), "Matrix must be square to be Hadamard"
577
+ assert (
578
+ N % 2 == 0
579
+ ), "Matrix must have even dimensions to be Hadamard"
580
+ self.map = lambda x: HadamardMatrix(self.size[0], device=self.device)(x)
581
+ case "rotation":
582
+ assert (
583
+ N == self.size[1]
584
+ ), "Matrix must be square to be a rotation matrix"
585
+ assert (
586
+ N % 2 == 0
587
+ ), "Matrix must have even dimensions to be a rotation matrix"
588
+ self.map = lambda x: RotationMatrix(self.size[0], self.iter, device=self.device)([x[0][0]])
568
589
 
569
590
  def initialize_class(self):
570
591
  r"""
@@ -2722,7 +2743,7 @@ class Delay(DSP):
2722
2743
  """
2723
2744
  m = self.get_delays()
2724
2745
  if self.isint:
2725
- self.freq_response = lambda param: (self.gamma ** m(param)) * torch.exp(
2746
+ self.freq_response = lambda param: (self.gamma ** m(param).round()) * torch.exp(
2726
2747
  -1j
2727
2748
  * torch.einsum(
2728
2749
  "fo, omn -> fmn",
@@ -2859,14 +2880,24 @@ class parallelDelay(Delay):
2859
2880
  Computes the frequency response of the delay module.
2860
2881
  """
2861
2882
  m = self.get_delays()
2862
- self.freq_response = lambda param: (self.gamma ** m(param)) * torch.exp(
2863
- -1j
2864
- * torch.einsum(
2865
- "fo, on -> fn",
2866
- self.omega,
2867
- m(param).unsqueeze(0),
2883
+ if self.isint:
2884
+ self.freq_response = lambda param: (self.gamma ** m(param).round()) * torch.exp(
2885
+ -1j
2886
+ * torch.einsum(
2887
+ "fo, on -> fn",
2888
+ self.omega,
2889
+ m(param).round().unsqueeze(0),
2890
+ )
2891
+ )
2892
+ else:
2893
+ self.freq_response = lambda param: (self.gamma ** m(param)) * torch.exp(
2894
+ -1j
2895
+ * torch.einsum(
2896
+ "fo, on -> fn",
2897
+ self.omega,
2898
+ m(param).unsqueeze(0),
2899
+ )
2868
2900
  )
2869
- )
2870
2901
 
2871
2902
  def get_io(self):
2872
2903
  r"""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flamo
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: An Open-Source Library for Frequency-Domain Differentiable Audio Processing
5
5
  Project-URL: Homepage, https://github.com/gdalsanto/flamo
6
6
  Project-URL: Issues, https://github.com/gdalsanto/flamo/issues
@@ -38,6 +38,7 @@ Requires-Dist: numpy
38
38
  Requires-Dist: pydantic
39
39
  Requires-Dist: pyfar
40
40
  Requires-Dist: pysoundfile
41
+ Requires-Dist: pyyaml
41
42
  Requires-Dist: scipy
42
43
  Requires-Dist: torch
43
44
  Requires-Dist: torchaudio
@@ -1,23 +1,23 @@
1
1
  flamo/__init__.py,sha256=ujezWOJfD7DUoj4q1meeMUnB97rOEtNR7mYw_PE9LMg,49
2
- flamo/functional.py,sha256=oFVgab3uqXw2bKwwUOzWUCXjVLmLbZjBfL1wJ7PYgGQ,33100
2
+ flamo/functional.py,sha256=9wl6fHkc8KMB5IMvbd_K7-z8Z2Miw0qOsNxWPItliPU,35138
3
3
  flamo/utils.py,sha256=ypGKSABZMphgIrjCKgCH-zgR7BaupRbyzuUhsZFqAAM,3350
4
4
  flamo/auxiliary/__init__.py,sha256=7lVNh8OxHloZ4KPmp-iTUJnUbi8XbuRzGaQ3Z-NKXio,42
5
- flamo/auxiliary/eq.py,sha256=dkULcVlQrL3LKi4ejFnWb6VSWSmEb4PYSNLrOQMvGws,6767
5
+ flamo/auxiliary/eq.py,sha256=eIWMIq0ggizXLhTdeWWbgBXWUFXCJyoEbkBH7Gzasao,6779
6
6
  flamo/auxiliary/filterbank.py,sha256=02w8dI8HoNDtKpdVhSJkIkd-h-KNXvZtivf3l4_ozzU,9866
7
7
  flamo/auxiliary/minimize.py,sha256=fMTAAAk9yD7Y4luKS4XA1-HTq44xo2opq_dRPRrhlIY,2474
8
- flamo/auxiliary/reverb.py,sha256=Rmv5oCW49MsfuJnM7ujZnJRQB6y1hQa1KAn1Hki2Bwk,31611
8
+ flamo/auxiliary/reverb.py,sha256=9iKSuyuqRiHGGvaj0eizqVpu2V7plsX13OWiB6o1whU,31636
9
9
  flamo/auxiliary/scattering.py,sha256=ITPT0TTOAROy3G0_kpykffRSqjoA9dFJ2LnaLxtUMF4,9482
10
- flamo/auxiliary/config/config.py,sha256=7WYQsk3rfzb-OOY5JyRv6GzXPv8deLL_Viv1EbAUwu4,2859
10
+ flamo/auxiliary/config/config.py,sha256=CxXj-8sLq0_m9KyLg1a6NwLoK1UvTz3i0jZOLraq14I,2893
11
11
  flamo/optimize/__init__.py,sha256=grgxLmQ7m-c9MvRdIejmEAaaajfBwgeaZAv2qjHIvPw,65
12
12
  flamo/optimize/dataset.py,sha256=2mfzsnyX_bzavXouII9ee_pd6ti4lv215ieGJHscceI,5829
13
- flamo/optimize/loss.py,sha256=d5SJLlhXOhSSMWFRQFTyPkgHCSnLdInrvPZJ0VIVNQQ,33426
13
+ flamo/optimize/loss.py,sha256=h6EeqjdX5P1SqDBKBavSxV25VBgnYK8tuX91wk6lw_g,33466
14
14
  flamo/optimize/surface.py,sha256=uvsgxLFSvJ18s8kPcb22G3W1rgycXP1nNX0q48Pda2g,26135
15
15
  flamo/optimize/trainer.py,sha256=he4nUjLC-3RTlxxBIw33r5k8mQfgAGvN1wpPBAWCjVo,12045
16
16
  flamo/optimize/utils.py,sha256=R5-KoZagRho3eykY88pC3UB2mc5SsE4Yv9X-ogskXdA,1610
17
17
  flamo/processor/__init__.py,sha256=paGdxGVZgA2VAs0tBwRd0bobzGxeyK79DS7ZGO8drkI,41
18
- flamo/processor/dsp.py,sha256=Adb6pWyLHYO1mlMKPt8QnwC2oXkhjO4VrW-34jvUpyY,118696
18
+ flamo/processor/dsp.py,sha256=w5BT3lWAJ0kOC7rjVl1dgYFcGdQ6214pa_NyAL-K9QI,119983
19
19
  flamo/processor/system.py,sha256=9XwLtaGEVs9glVOFvyiPnQpsnR_Wjrv6k1i1qCs8D1Q,42516
20
- flamo-0.1.3.dist-info/METADATA,sha256=4by0jhNq-p5kPZxs3WXFREJk25J7KuOFgNVPgsArb6w,7803
21
- flamo-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
22
- flamo-0.1.3.dist-info/licenses/LICENSE,sha256=smMocRH7xdPT5RvFNqSLtbSNzohXJM5G_rX1Qaej6vg,1120
23
- flamo-0.1.3.dist-info/RECORD,,
20
+ flamo-0.1.5.dist-info/METADATA,sha256=hTn11LzrhyQUm4JbHD_JiSmDfpOLbvgrGgzPezvegVo,7825
21
+ flamo-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
22
+ flamo-0.1.5.dist-info/licenses/LICENSE,sha256=smMocRH7xdPT5RvFNqSLtbSNzohXJM5G_rX1Qaej6vg,1120
23
+ flamo-0.1.5.dist-info/RECORD,,
File without changes