flamo 0.1.12__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flamo/auxiliary/config/config.py +3 -1
- flamo/auxiliary/eq.py +22 -18
- flamo/auxiliary/reverb.py +28 -17
- flamo/auxiliary/scattering.py +21 -21
- flamo/functional.py +74 -52
- flamo/optimize/dataset.py +7 -5
- flamo/optimize/surface.py +15 -13
- flamo/processor/dsp.py +158 -99
- flamo/processor/system.py +15 -10
- flamo/utils.py +2 -2
- {flamo-0.1.12.dist-info → flamo-0.2.0.dist-info}/METADATA +1 -1
- flamo-0.2.0.dist-info/RECORD +24 -0
- flamo-0.1.12.dist-info/RECORD +0 -24
- {flamo-0.1.12.dist-info → flamo-0.2.0.dist-info}/WHEEL +0 -0
- {flamo-0.1.12.dist-info → flamo-0.2.0.dist-info}/licenses/LICENSE +0 -0
flamo/functional.py
CHANGED
|
@@ -56,19 +56,20 @@ def skew_matrix(X):
|
|
|
56
56
|
return A - A.transpose(-1, -2)
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def get_frequency_samples(num: int, device: str | torch.device = None):
|
|
59
|
+
def get_frequency_samples(num: int, device: str | torch.device = None, dtype: torch.dtype = torch.float32):
|
|
60
60
|
r"""
|
|
61
61
|
Get frequency samples (in radians) sampled at linearly spaced points along the unit circle.
|
|
62
62
|
|
|
63
63
|
**Arguments**
|
|
64
64
|
- **num** (int): number of frequency samples
|
|
65
65
|
- **device** (torch.device, str): The device of constructed tensors. Default: None.
|
|
66
|
+
- **dtype** (torch.dtype): The dtype of constructed tensors. Default: torch.float32.
|
|
66
67
|
|
|
67
68
|
**Returns**
|
|
68
69
|
- frequency samples in radians between [0, pi]
|
|
69
70
|
"""
|
|
70
|
-
angle = torch.linspace(0, 1, num, device=device)
|
|
71
|
-
abs = torch.ones(num, device=device)
|
|
71
|
+
angle = torch.linspace(0, 1, num, device=device, dtype=dtype)
|
|
72
|
+
abs = torch.ones(num, device=device, dtype=dtype)
|
|
72
73
|
return torch.polar(abs, angle * np.pi)
|
|
73
74
|
|
|
74
75
|
|
|
@@ -77,17 +78,18 @@ class HadamardMatrix(nn.Module):
|
|
|
77
78
|
Generate a Hadamard matrix of size N as a nn.Module.
|
|
78
79
|
"""
|
|
79
80
|
|
|
80
|
-
def __init__(self, N, device: Optional[str] = None):
|
|
81
|
+
def __init__(self, N, device: Optional[str] = None, dtype: torch.dtype = torch.float32):
|
|
81
82
|
super().__init__()
|
|
82
83
|
self.N = N
|
|
83
84
|
self.device = device
|
|
85
|
+
self.dtype = dtype
|
|
84
86
|
|
|
85
87
|
def forward(self, x):
|
|
86
|
-
U = torch.tensor([[1.0]], device=self.device)
|
|
88
|
+
U = torch.tensor([[1.0]], device=self.device, dtype=self.dtype)
|
|
87
89
|
while U.shape[0] < self.N:
|
|
88
90
|
U = torch.kron(
|
|
89
91
|
U, torch.tensor([[1, 1], [1, -1]], dtype=U.dtype, device=U.device)
|
|
90
|
-
) / torch.sqrt(torch.tensor(2.0, device=U.device))
|
|
92
|
+
) / torch.sqrt(torch.tensor(2.0, device=U.device, dtype=U.dtype))
|
|
91
93
|
return U
|
|
92
94
|
|
|
93
95
|
|
|
@@ -103,6 +105,7 @@ class RotationMatrix(nn.Module):
|
|
|
103
105
|
max_angle: float = torch.pi / 4,
|
|
104
106
|
iter: Optional[int] = None,
|
|
105
107
|
device: Optional[str] = None,
|
|
108
|
+
dtype: torch.dtype = torch.float32,
|
|
106
109
|
):
|
|
107
110
|
|
|
108
111
|
super().__init__()
|
|
@@ -111,14 +114,15 @@ class RotationMatrix(nn.Module):
|
|
|
111
114
|
self.max_angle = max_angle
|
|
112
115
|
self.iter = iter
|
|
113
116
|
self.device = device
|
|
117
|
+
self.dtype = dtype
|
|
114
118
|
|
|
115
119
|
def create_submatrix(self, angles: torch.Tensor, iters: int = 1):
|
|
116
120
|
"""Create a submatrix for each group."""
|
|
117
|
-
X = torch.zeros(2, 2, device=self.device)
|
|
121
|
+
X = torch.zeros(2, 2, device=self.device, dtype=self.dtype)
|
|
118
122
|
angles[0] = torch.clamp(angles[0], self.min_angle, self.max_angle)
|
|
119
123
|
X.fill_diagonal_(torch.cos(angles[0]))
|
|
120
|
-
X[1, 0] = -torch.sin(
|
|
121
|
-
X[0, 1] = torch.sin(
|
|
124
|
+
X[1, 0] = -torch.sin(angles[0])
|
|
125
|
+
X[0, 1] = torch.sin(angles[0])
|
|
122
126
|
|
|
123
127
|
if iters is None:
|
|
124
128
|
iters = torch.log2(torch.tensor(self.N)).int().item() - 1
|
|
@@ -166,6 +170,7 @@ def signal_gallery(
|
|
|
166
170
|
rate: float = 1.0,
|
|
167
171
|
reference: torch.Tensor = None,
|
|
168
172
|
device: str | torch.device = None,
|
|
173
|
+
dtype: torch.dtype = torch.float32,
|
|
169
174
|
):
|
|
170
175
|
r"""
|
|
171
176
|
Generate a tensor containing a signal based on the specified signal type.
|
|
@@ -187,6 +192,7 @@ def signal_gallery(
|
|
|
187
192
|
- **fs** (int, optional): The sampling frequency of the signals. Defaults to 48000.
|
|
188
193
|
- **reference** (torch.Tensor, optional): A reference signal to use. Defaults to None.
|
|
189
194
|
- **device** (torch.device, optional): The device of constructed tensors. Defaults to None.
|
|
195
|
+
- **dtype** (torch.dtype, optional): The dtype of constructed tensors. Defaults to torch.float32.
|
|
190
196
|
|
|
191
197
|
**Returns**:
|
|
192
198
|
- torch.Tensor: A tensor of shape (batch_size, n_samples, n) containing the generated signals.
|
|
@@ -207,7 +213,7 @@ def signal_gallery(
|
|
|
207
213
|
raise ValueError(f"Signal type {signal_type} not recognized.")
|
|
208
214
|
match signal_type:
|
|
209
215
|
case "impulse":
|
|
210
|
-
x = torch.zeros(batch_size, n_samples, n)
|
|
216
|
+
x = torch.zeros(batch_size, n_samples, n, dtype=dtype)
|
|
211
217
|
x[:, 0, :] = 1
|
|
212
218
|
return x.to(device)
|
|
213
219
|
case "sine":
|
|
@@ -218,7 +224,7 @@ def signal_gallery(
|
|
|
218
224
|
* np.pi
|
|
219
225
|
* rate
|
|
220
226
|
/ fs
|
|
221
|
-
* torch.linspace(0, n_samples / fs, n_samples)
|
|
227
|
+
* torch.linspace(0, n_samples / fs, n_samples, dtype=dtype)
|
|
222
228
|
)
|
|
223
229
|
.unsqueeze(-1)
|
|
224
230
|
.expand(batch_size, n_samples, n)
|
|
@@ -226,44 +232,45 @@ def signal_gallery(
|
|
|
226
232
|
)
|
|
227
233
|
else:
|
|
228
234
|
return torch.sin(
|
|
229
|
-
torch.linspace(0, 2 * np.pi, n_samples)
|
|
235
|
+
torch.linspace(0, 2 * np.pi, n_samples, dtype=dtype)
|
|
230
236
|
.unsqueeze(-1)
|
|
231
237
|
.expand(batch_size, n_samples, n)
|
|
232
238
|
).to(device)
|
|
233
239
|
case "sweep":
|
|
234
|
-
t = torch.linspace(0, n_samples / fs - 1 / fs, n_samples)
|
|
240
|
+
t = torch.linspace(0, n_samples / fs - 1 / fs, n_samples, dtype=dtype)
|
|
235
241
|
x = torch.tensor(
|
|
236
242
|
scipy.signal.chirp(t, f0=20, f1=20000, t1=t[-1], method="linear"),
|
|
237
243
|
device=device,
|
|
244
|
+
dtype=dtype,
|
|
238
245
|
).unsqueeze(-1)
|
|
239
246
|
return x.expand(batch_size, n_samples, n)
|
|
240
247
|
case "wgn":
|
|
241
|
-
return torch.randn((batch_size, n_samples, n), device=device)
|
|
248
|
+
return torch.randn((batch_size, n_samples, n), device=device, dtype=dtype)
|
|
242
249
|
case "exp":
|
|
243
250
|
return (
|
|
244
|
-
torch.exp(-rate * torch.arange(n_samples) / fs)
|
|
251
|
+
torch.exp(-rate * torch.arange(n_samples, dtype=dtype) / fs)
|
|
245
252
|
.unsqueeze(-1)
|
|
246
253
|
.expand(batch_size, n_samples, n)
|
|
247
254
|
.to(device)
|
|
248
255
|
)
|
|
249
256
|
case "velvet":
|
|
250
|
-
x = torch.empty((batch_size, n_samples, n), device=device)
|
|
257
|
+
x = torch.empty((batch_size, n_samples, n), device=device, dtype=dtype)
|
|
251
258
|
for i_batch in range(batch_size):
|
|
252
259
|
for i_ch in range(n):
|
|
253
|
-
x[i_batch, :, i_ch] = gen_velvet_noise(n_samples, fs, rate, device)
|
|
260
|
+
x[i_batch, :, i_ch] = gen_velvet_noise(n_samples, fs, rate, device, dtype)
|
|
254
261
|
return x
|
|
255
262
|
case "reference":
|
|
256
263
|
if isinstance(reference, torch.Tensor):
|
|
257
264
|
return reference.expand(batch_size, n_samples, n).to(device)
|
|
258
265
|
else:
|
|
259
|
-
return torch.tensor(reference, device=device).expand(
|
|
266
|
+
return torch.tensor(reference, device=device, dtype=dtype).expand(
|
|
260
267
|
batch_size, n_samples, n
|
|
261
268
|
)
|
|
262
269
|
case "noise":
|
|
263
|
-
return torch.randn((batch_size, n_samples, n), device=device)
|
|
270
|
+
return torch.randn((batch_size, n_samples, n), device=device, dtype=dtype)
|
|
264
271
|
|
|
265
272
|
|
|
266
|
-
def gen_velvet_noise(n_samples: int, fs: int, density: float, device: str | torch.device = None) -> torch.Tensor:
|
|
273
|
+
def gen_velvet_noise(n_samples: int, fs: int, density: float, device: str | torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
267
274
|
r"""
|
|
268
275
|
Generate a velvet noise sequence.
|
|
269
276
|
**Arguments**:
|
|
@@ -271,15 +278,16 @@ def gen_velvet_noise(n_samples: int, fs: int, density: float, device: str | torc
|
|
|
271
278
|
- **fs** (int): The sampling frequency of the signal in Hz.
|
|
272
279
|
- **density** (float): The density of impulses in impulses per second.
|
|
273
280
|
- **device** (str | torch.device): The device of constructed tensors.
|
|
281
|
+
- **dtype** (torch.dtype): The dtype of constructed tensors.
|
|
274
282
|
**Returns**:
|
|
275
283
|
- torch.Tensor: A tensor of shape (n_samples,) containing the velvet noise sequence.
|
|
276
284
|
"""
|
|
277
285
|
Td = fs / density # average distance between impulses
|
|
278
286
|
num_impulses = n_samples / Td # expected number of impulses
|
|
279
287
|
floor_impulses = math.floor(num_impulses)
|
|
280
|
-
grid = torch.arange(floor_impulses) * Td
|
|
288
|
+
grid = torch.arange(floor_impulses, dtype=dtype) * Td
|
|
281
289
|
|
|
282
|
-
jitter_factors = torch.rand(floor_impulses)
|
|
290
|
+
jitter_factors = torch.rand(floor_impulses, dtype=dtype)
|
|
283
291
|
impulse_indices = torch.ceil(grid + jitter_factors * (Td - 1)).long()
|
|
284
292
|
|
|
285
293
|
# first impulse is at position 0 and all indices are within bounds
|
|
@@ -290,7 +298,7 @@ def gen_velvet_noise(n_samples: int, fs: int, density: float, device: str | torc
|
|
|
290
298
|
signs = 2 * torch.randint(0, 2, (floor_impulses,)) - 1
|
|
291
299
|
|
|
292
300
|
# Construct sparse signal
|
|
293
|
-
sequence = torch.zeros(n_samples, device=device)
|
|
301
|
+
sequence = torch.zeros(n_samples, device=device, dtype=dtype)
|
|
294
302
|
sequence[impulse_indices] = signs.float()
|
|
295
303
|
|
|
296
304
|
return sequence
|
|
@@ -370,6 +378,7 @@ def lowpass_filter(
|
|
|
370
378
|
gain: float = 0.0,
|
|
371
379
|
fs: int = 48000,
|
|
372
380
|
device: str | torch.device = None,
|
|
381
|
+
dtype: torch.dtype = torch.float32,
|
|
373
382
|
) -> tuple:
|
|
374
383
|
r"""
|
|
375
384
|
Lowpass filter coefficients. It uses the `RBJ cookbook formulas <https://webaudio.github.io/Audio-EQ-Cookbook/Audio-EQ-Cookbook.txt>`_ to map
|
|
@@ -402,12 +411,12 @@ def lowpass_filter(
|
|
|
402
411
|
"""
|
|
403
412
|
|
|
404
413
|
omegaC = hertz2rad(fc, fs).to(device=device)
|
|
405
|
-
two = torch.tensor(2, device=device)
|
|
414
|
+
two = torch.tensor(2, device=device, dtype=dtype)
|
|
406
415
|
alpha = torch.sin(omegaC) / 2 * torch.sqrt(two)
|
|
407
416
|
cosOC = torch.cos(omegaC)
|
|
408
417
|
|
|
409
|
-
a = torch.ones(3, *omegaC.shape, device=device)
|
|
410
|
-
b = torch.ones(3, *omegaC.shape, device=device)
|
|
418
|
+
a = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
|
|
419
|
+
b = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
|
|
411
420
|
|
|
412
421
|
b[0] = (1 - cosOC) / 2
|
|
413
422
|
b[1] = 1 - cosOC
|
|
@@ -424,6 +433,7 @@ def highpass_filter(
|
|
|
424
433
|
gain: float = 0.0,
|
|
425
434
|
fs: int = 48000,
|
|
426
435
|
device: str | torch.device = None,
|
|
436
|
+
dtype: torch.dtype = torch.float32,
|
|
427
437
|
) -> tuple:
|
|
428
438
|
r"""
|
|
429
439
|
Highpass filter coefficients. It uses the `RBJ cookbook formulas <https://webaudio.github.io/Audio-EQ-Cookbook/Audio-EQ-Cookbook.txt>`_ to map
|
|
@@ -455,12 +465,12 @@ def highpass_filter(
|
|
|
455
465
|
"""
|
|
456
466
|
|
|
457
467
|
omegaC = hertz2rad(fc, fs)
|
|
458
|
-
two = torch.tensor(2, device=device)
|
|
468
|
+
two = torch.tensor(2, device=device, dtype=dtype)
|
|
459
469
|
alpha = torch.sin(omegaC) / 2 * torch.sqrt(two)
|
|
460
470
|
cosOC = torch.cos(omegaC)
|
|
461
471
|
|
|
462
|
-
a = torch.ones(3, *omegaC.shape, device=device)
|
|
463
|
-
b = torch.ones(3, *omegaC.shape, device=device)
|
|
472
|
+
a = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
|
|
473
|
+
b = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
|
|
464
474
|
|
|
465
475
|
b[0] = (1 + cosOC) / 2
|
|
466
476
|
b[1] = -(1 + cosOC)
|
|
@@ -478,6 +488,7 @@ def bandpass_filter(
|
|
|
478
488
|
gain: float = 0.0,
|
|
479
489
|
fs: int = 48000,
|
|
480
490
|
device: str | torch.device = None,
|
|
491
|
+
dtype: torch.dtype = torch.float32,
|
|
481
492
|
) -> tuple:
|
|
482
493
|
r"""
|
|
483
494
|
Bandpass filter coefficients. It uses the `RBJ cookbook formulas <https://webaudio.github.io/Audio-EQ-Cookbook/Audio-EQ-Cookbook.txt>`_ to map
|
|
@@ -521,15 +532,15 @@ def bandpass_filter(
|
|
|
521
532
|
|
|
522
533
|
omegaC = (hertz2rad(fc1, fs) + hertz2rad(fc2, fs)) / 2
|
|
523
534
|
BW = torch.log2(fc2 / fc1)
|
|
524
|
-
two = torch.tensor(2, device=device)
|
|
535
|
+
two = torch.tensor(2, device=device, dtype=dtype)
|
|
525
536
|
alpha = torch.sin(omegaC) * torch.sinh(
|
|
526
537
|
torch.log(two) / two * BW * (omegaC / torch.sin(omegaC))
|
|
527
538
|
)
|
|
528
539
|
|
|
529
540
|
cosOC = torch.cos(omegaC)
|
|
530
541
|
|
|
531
|
-
a = torch.ones(3, *omegaC.shape, device=device)
|
|
532
|
-
b = torch.ones(3, *omegaC.shape, device=device)
|
|
542
|
+
a = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
|
|
543
|
+
b = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
|
|
533
544
|
|
|
534
545
|
b[0] = alpha
|
|
535
546
|
b[1] = 0
|
|
@@ -547,7 +558,8 @@ def shelving_filter(
|
|
|
547
558
|
type: str = "low",
|
|
548
559
|
fs: int = 48000,
|
|
549
560
|
device: torch.device | str = None,
|
|
550
|
-
|
|
561
|
+
dtype: torch.dtype = torch.float32,
|
|
562
|
+
) -> tuple:
|
|
551
563
|
r"""
|
|
552
564
|
Shelving filter coefficients.
|
|
553
565
|
Maps the cutoff frequencies and gain to the :math:`\mathbf{b}` and :math:`\mathbf{a}` biquad coefficients.
|
|
@@ -582,8 +594,8 @@ def shelving_filter(
|
|
|
582
594
|
- **b** (torch.Tensor): The numerator coefficients of the filter transfer function.
|
|
583
595
|
- **a** (torch.Tensor): The denominator coefficients of the filter transfer function.
|
|
584
596
|
"""
|
|
585
|
-
b = torch.ones(3, device=device)
|
|
586
|
-
a = torch.ones(3, device=device)
|
|
597
|
+
b = torch.ones(3, device=device, dtype=dtype)
|
|
598
|
+
a = torch.ones(3, device=device, dtype=dtype)
|
|
587
599
|
|
|
588
600
|
omegaC = hertz2rad(fc, fs)
|
|
589
601
|
t = torch.tan(omegaC / 2)
|
|
@@ -591,7 +603,7 @@ def shelving_filter(
|
|
|
591
603
|
g2 = gain**0.5
|
|
592
604
|
g4 = gain**0.25
|
|
593
605
|
|
|
594
|
-
two = torch.tensor(2, device=device)
|
|
606
|
+
two = torch.tensor(2, device=device, dtype=dtype)
|
|
595
607
|
b[0] = g2 * t2 + torch.sqrt(two) * t * g4 + 1
|
|
596
608
|
b[1] = 2 * g2 * t2 - 2
|
|
597
609
|
b[2] = g2 * t2 - torch.sqrt(two) * t * g4 + 1
|
|
@@ -616,6 +628,7 @@ def peak_filter(
|
|
|
616
628
|
Q: torch.Tensor,
|
|
617
629
|
fs: int = 48000,
|
|
618
630
|
device: str | torch.device = None,
|
|
631
|
+
dtype: torch.dtype = torch.float32,
|
|
619
632
|
) -> tuple:
|
|
620
633
|
r"""
|
|
621
634
|
Peak filter coefficients.
|
|
@@ -644,8 +657,8 @@ def peak_filter(
|
|
|
644
657
|
- **b** (torch.Tensor): The numerator coefficients of the filter transfer function.
|
|
645
658
|
- **a** (torch.Tensor): The denominator coefficients of the filter transfer function
|
|
646
659
|
"""
|
|
647
|
-
b = torch.ones(3, device=device)
|
|
648
|
-
a = torch.ones(3, device=device)
|
|
660
|
+
b = torch.ones(3, device=device, dtype=dtype)
|
|
661
|
+
a = torch.ones(3, device=device, dtype=dtype)
|
|
649
662
|
|
|
650
663
|
omegaC = hertz2rad(fc, fs)
|
|
651
664
|
bandWidth = omegaC / Q
|
|
@@ -668,6 +681,7 @@ def prop_shelving_filter(
|
|
|
668
681
|
type: str = "low",
|
|
669
682
|
fs: int = 48000,
|
|
670
683
|
device="cpu",
|
|
684
|
+
dtype: torch.dtype = torch.float32,
|
|
671
685
|
):
|
|
672
686
|
r"""
|
|
673
687
|
Proportional first order Shelving filter coefficients.
|
|
@@ -700,6 +714,7 @@ def prop_shelving_filter(
|
|
|
700
714
|
- **type** (str, optional): The type of shelving filter. Can be 'low' or 'high'. Default: 'low'.
|
|
701
715
|
- **fs** (int, optional): The sampling frequency of the signal in Hz.
|
|
702
716
|
- **device** (torch.device | str, optional): The device of constructed tensors. Default: None.
|
|
717
|
+
- **dtype** (torch.dtype, optional): The dtype of constructed tensors. Default: torch.float32.
|
|
703
718
|
|
|
704
719
|
**Returns**:
|
|
705
720
|
- **b** (torch.Tensor): The numerator coefficients of the filter transfer function.
|
|
@@ -712,7 +727,7 @@ def prop_shelving_filter(
|
|
|
712
727
|
t = torch.tan(torch.pi * fc / fs)
|
|
713
728
|
k = 10 ** (gain / 20)
|
|
714
729
|
|
|
715
|
-
a = torch.zeros((2, *fc.shape), device=device)
|
|
730
|
+
a = torch.zeros((2, *fc.shape), device=device, dtype=dtype)
|
|
716
731
|
b = torch.zeros_like(a)
|
|
717
732
|
|
|
718
733
|
if type == "low":
|
|
@@ -736,6 +751,7 @@ def prop_peak_filter(
|
|
|
736
751
|
gain: torch.Tensor,
|
|
737
752
|
fs: int = 48000,
|
|
738
753
|
device="cpu",
|
|
754
|
+
dtype: torch.dtype = torch.float32,
|
|
739
755
|
):
|
|
740
756
|
r"""
|
|
741
757
|
Proportional Peak (Presence) filter coefficients.
|
|
@@ -761,6 +777,7 @@ def prop_peak_filter(
|
|
|
761
777
|
- **gain** (torch.Tensor): The gain in dB of the filter.
|
|
762
778
|
- **fs** (int, optional): The sampling frequency of the signal in Hz.
|
|
763
779
|
- **device** (torch.device | str, optional): The device of constructed tensors. Default: None.
|
|
780
|
+
- **dtype** (torch.dtype, optional): The dtype of constructed tensors. Default: torch.float32.
|
|
764
781
|
|
|
765
782
|
**Returns**:
|
|
766
783
|
- **b** (torch.Tensor): The numerator coefficients of the filter transfer function.
|
|
@@ -774,7 +791,7 @@ def prop_peak_filter(
|
|
|
774
791
|
c = torch.cos(2 * np.pi * fc / fs)
|
|
775
792
|
k = 10 ** (gain / 20)
|
|
776
793
|
|
|
777
|
-
a = torch.zeros((3, *fc.shape), device=device)
|
|
794
|
+
a = torch.zeros((3, *fc.shape), device=device, dtype=dtype)
|
|
778
795
|
b = torch.zeros_like(a)
|
|
779
796
|
|
|
780
797
|
b[0] = 1 + torch.sqrt(k) * t
|
|
@@ -815,6 +832,7 @@ def svf(
|
|
|
815
832
|
filter_type: str = None,
|
|
816
833
|
fs: int = 48000,
|
|
817
834
|
device: str | torch.device = None,
|
|
835
|
+
dtype: torch.dtype = torch.float32,
|
|
818
836
|
):
|
|
819
837
|
r"""
|
|
820
838
|
Implements a State Variable Filter (SVF) with various filter types.
|
|
@@ -827,6 +845,7 @@ def svf(
|
|
|
827
845
|
- **filter_type** (str, optional): The type of filter to be applied. Can be one of "lowpass", "highpass", "bandpass", "lowshelf", "highshelf", "peaking", "notch", or None. Default: None.
|
|
828
846
|
- **fs** (int, optional): The sampling frequency. Default: 48000.
|
|
829
847
|
- **device** (torch.device, optional): The device of constructed tensors. Default: None.
|
|
848
|
+
- **dtype** (torch.dtype, optional): The dtype of constructed tensors. Default: torch.float32.
|
|
830
849
|
|
|
831
850
|
**Returns**:
|
|
832
851
|
Tuple[torch.Tensor, torch.Tensor]: The numerator and denominator coefficients of the filter transfer function.
|
|
@@ -897,8 +916,8 @@ def svf(
|
|
|
897
916
|
case None:
|
|
898
917
|
print("No filter type specified. Using the given mixing coefficents.")
|
|
899
918
|
|
|
900
|
-
b = torch.zeros((3, *f.shape), device=device)
|
|
901
|
-
a = torch.zeros((3, *f.shape), device=device)
|
|
919
|
+
b = torch.zeros((3, *f.shape), device=device, dtype=dtype)
|
|
920
|
+
a = torch.zeros((3, *f.shape), device=device, dtype=dtype)
|
|
902
921
|
|
|
903
922
|
b[0] = (f**2) * m[..., 0] + f * m[..., 1] + m[..., 2]
|
|
904
923
|
b[1] = 2 * (f**2) * m[..., 0] - 2 * m[..., 2]
|
|
@@ -917,6 +936,7 @@ def probe_sos(
|
|
|
917
936
|
nfft: int,
|
|
918
937
|
fs: int,
|
|
919
938
|
device: str | torch.device = None,
|
|
939
|
+
dtype: torch.dtype = torch.float32,
|
|
920
940
|
):
|
|
921
941
|
r"""
|
|
922
942
|
Probe the frequency / magnitude response of a cascaded SOS filter at the points
|
|
@@ -928,6 +948,7 @@ def probe_sos(
|
|
|
928
948
|
- **nfft** (int): Length of the FFT used for frequency analysis.
|
|
929
949
|
- **fs** (float): Sampling frequency in Hz.
|
|
930
950
|
- **device** (torch.device, optional): The device of constructed tensors. Default: None.
|
|
951
|
+
- **dtype** (torch.dtype, optional): The dtype of constructed tensors. Default: torch.float32.
|
|
931
952
|
|
|
932
953
|
**Returns**:
|
|
933
954
|
tuple: A tuple containing the following:
|
|
@@ -938,8 +959,8 @@ def probe_sos(
|
|
|
938
959
|
n_freqs = sos.shape[-1]
|
|
939
960
|
|
|
940
961
|
H = torch.zeros((nfft // 2 + 1, n_freqs), dtype=torch.cdouble, device=device)
|
|
941
|
-
W = torch.zeros((nfft // 2 + 1, n_freqs), device=device)
|
|
942
|
-
G = torch.zeros((len(control_freqs), n_freqs), device=device)
|
|
962
|
+
W = torch.zeros((nfft // 2 + 1, n_freqs), device=device, dtype=dtype)
|
|
963
|
+
G = torch.zeros((len(control_freqs), n_freqs), device=device, dtype=dtype)
|
|
943
964
|
|
|
944
965
|
for band in range(n_freqs):
|
|
945
966
|
sos[:, band] = sos[:, band] / sos[3, band]
|
|
@@ -1003,7 +1024,7 @@ def find_onset(rir: torch.Tensor):
|
|
|
1003
1024
|
|
|
1004
1025
|
|
|
1005
1026
|
def WGN_reverb(
|
|
1006
|
-
matrix_size: tuple = (1, 1), t60: float = 1.0, samplerate: int = 48000, device=None
|
|
1027
|
+
matrix_size: tuple = (1, 1), t60: float = 1.0, samplerate: int = 48000, device=None, dtype: torch.dtype = torch.float32
|
|
1007
1028
|
) -> torch.Tensor:
|
|
1008
1029
|
r"""
|
|
1009
1030
|
Generates White-Gaussian-Noise-reverb impulse responses.
|
|
@@ -1013,6 +1034,7 @@ def WGN_reverb(
|
|
|
1013
1034
|
- **t60** (float, optional): Reverberation time. Defaults to 1.0.
|
|
1014
1035
|
- **samplerate** (int, optional): Sampling frequency. Defaults to 48000.
|
|
1015
1036
|
- **nfft** (int, optional): Number of frequency bins. Defaults to 2**11.
|
|
1037
|
+
- **dtype** (torch.dtype, optional): The dtype of constructed tensors. Defaults to torch.float32.
|
|
1016
1038
|
|
|
1017
1039
|
**Returns**:
|
|
1018
1040
|
torch.Tensor: Matrix of WGN-reverb impulse responses.
|
|
@@ -1020,10 +1042,10 @@ def WGN_reverb(
|
|
|
1020
1042
|
# Number of samples
|
|
1021
1043
|
n_samples = int(1.5 * t60 * samplerate)
|
|
1022
1044
|
# White Guassian Noise
|
|
1023
|
-
noise = torch.randn(n_samples, *matrix_size, device=device)
|
|
1045
|
+
noise = torch.randn(n_samples, *matrix_size, device=device, dtype=dtype)
|
|
1024
1046
|
# Decay
|
|
1025
|
-
dr = t60 / torch.log(torch.tensor(1000, dtype=
|
|
1026
|
-
decay = torch.exp(-1 / dr * torch.linspace(0, t60, n_samples))
|
|
1047
|
+
dr = t60 / torch.log(torch.tensor(1000, dtype=dtype, device=device))
|
|
1048
|
+
decay = torch.exp(-1 / dr * torch.linspace(0, t60, n_samples, dtype=dtype))
|
|
1027
1049
|
decay = decay.view(-1, *(1,) * (len(matrix_size))).expand(-1, *matrix_size)
|
|
1028
1050
|
# Decaying WGN
|
|
1029
1051
|
IRs = torch.mul(noise, decay)
|
|
@@ -1031,11 +1053,11 @@ def WGN_reverb(
|
|
|
1031
1053
|
TFs = torch.fft.rfft(input=IRs, n=n_samples, dim=0)
|
|
1032
1054
|
|
|
1033
1055
|
# Generate bandpass filter
|
|
1034
|
-
fc_left = torch.tensor([20], dtype=
|
|
1035
|
-
fc_right = torch.tensor([20000], dtype=
|
|
1036
|
-
g = torch.tensor([1], dtype=
|
|
1056
|
+
fc_left = torch.tensor([20], dtype=dtype, device=device)
|
|
1057
|
+
fc_right = torch.tensor([20000], dtype=dtype, device=device)
|
|
1058
|
+
g = torch.tensor([1], dtype=dtype, device=device)
|
|
1037
1059
|
b, a = bandpass_filter(
|
|
1038
|
-
fc1=fc_left, fc2=fc_right, gain=g, fs=samplerate, device=device
|
|
1060
|
+
fc1=fc_left, fc2=fc_right, gain=g, fs=samplerate, device=device, dtype=dtype
|
|
1039
1061
|
)
|
|
1040
1062
|
sos = torch.cat((b.reshape(1, 3), a.reshape(1, 3)), dim=1)
|
|
1041
1063
|
bp_H = sosfreqz(sos=sos, nfft=n_samples).squeeze()
|
flamo/optimize/dataset.py
CHANGED
|
@@ -29,9 +29,10 @@ class Dataset(torch.utils.data.Dataset):
|
|
|
29
29
|
target: torch.Tensor = torch.randn(1, 1),
|
|
30
30
|
expand: int = 1,
|
|
31
31
|
device: str = "cpu",
|
|
32
|
+
dtype: torch.dtype = torch.float32,
|
|
32
33
|
):
|
|
33
|
-
self.input = input.to(device)
|
|
34
|
-
self.target = target.to(device)
|
|
34
|
+
self.input = input.to(device).to(dtype)
|
|
35
|
+
self.target = target.to(device).to(dtype)
|
|
35
36
|
self.expand = expand
|
|
36
37
|
self.device = device
|
|
37
38
|
self.input = self.input.expand(tuple([expand] + [d for d in input.shape[1:]]))
|
|
@@ -69,11 +70,12 @@ class DatasetColorless(Dataset):
|
|
|
69
70
|
target_shape: tuple,
|
|
70
71
|
expand: int = 1000,
|
|
71
72
|
device: str = "cpu",
|
|
73
|
+
dtype: torch.dtype = torch.float32,
|
|
72
74
|
):
|
|
73
|
-
input = torch.zeros(input_shape, device=device)
|
|
75
|
+
input = torch.zeros(input_shape, device=device, dtype=dtype)
|
|
74
76
|
input[:, 0, :] = 1
|
|
75
|
-
target = torch.ones(target_shape, device=device)
|
|
76
|
-
super().__init__(input=input, target=target, expand=expand, device=device)
|
|
77
|
+
target = torch.ones(target_shape, device=device, dtype=dtype)
|
|
78
|
+
super().__init__(input=input, target=target, expand=expand, device=device, dtype=dtype)
|
|
77
79
|
|
|
78
80
|
def __getitem__(self, index):
|
|
79
81
|
return self.input[index], self.target[index]
|
flamo/optimize/surface.py
CHANGED
|
@@ -58,7 +58,7 @@ class LossProfile:
|
|
|
58
58
|
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
|
-
def __init__(self, net: Shell, loss_config: LossConfig, device: str = "cpu"):
|
|
61
|
+
def __init__(self, net: Shell, loss_config: LossConfig, device: str = "cpu", dtype: torch.dtype = torch.float32):
|
|
62
62
|
|
|
63
63
|
super().__init__()
|
|
64
64
|
self.net = net
|
|
@@ -68,6 +68,7 @@ class LossProfile:
|
|
|
68
68
|
self.n_runs = loss_config.n_runs
|
|
69
69
|
self.output_dir = loss_config.output_dir
|
|
70
70
|
self.device = device
|
|
71
|
+
self.dtype = dtype
|
|
71
72
|
self.register_steps()
|
|
72
73
|
|
|
73
74
|
def compute_loss(self, input: torch.Tensor, target: torch.Tensor):
|
|
@@ -101,9 +102,9 @@ class LossProfile:
|
|
|
101
102
|
if type(self.param_config.lower_bound) == list:
|
|
102
103
|
# interpolate between the lower and upper bound
|
|
103
104
|
new_value = (1 - steps[i_step]) * torch.tensor(
|
|
104
|
-
self.param_config.lower_bound, device=self.device
|
|
105
|
+
self.param_config.lower_bound, device=self.device, dtype=self.dtype
|
|
105
106
|
) + steps[i_step] * torch.tensor(
|
|
106
|
-
self.param_config.upper_bound, device=self.device
|
|
107
|
+
self.param_config.upper_bound, device=self.device, dtype=self.dtype
|
|
107
108
|
)
|
|
108
109
|
else:
|
|
109
110
|
new_value = steps[i_step]
|
|
@@ -230,12 +231,14 @@ class LossProfile:
|
|
|
230
231
|
upper_bound = param_upper_bound
|
|
231
232
|
|
|
232
233
|
if scale == "linear":
|
|
233
|
-
steps = torch.linspace(lower_bound, upper_bound, n_steps)
|
|
234
|
+
steps = torch.linspace(lower_bound, upper_bound, n_steps, device=self.device, dtype=self.dtype)
|
|
234
235
|
elif scale == "log":
|
|
235
236
|
steps = torch.logspace(
|
|
236
|
-
torch.log10(torch.tensor(lower_bound, device=self.device)),
|
|
237
|
-
torch.log10(torch.tensor(upper_bound, device=self.device)),
|
|
237
|
+
torch.log10(torch.tensor(lower_bound, device=self.device, dtype=self.dtype)),
|
|
238
|
+
torch.log10(torch.tensor(upper_bound, device=self.device, dtype=self.dtype)),
|
|
238
239
|
n_steps,
|
|
240
|
+
device=self.device,
|
|
241
|
+
dtype=self.dtype,
|
|
239
242
|
)
|
|
240
243
|
else:
|
|
241
244
|
raise ValueError("Scale must be either 'linear' or 'log'")
|
|
@@ -336,9 +339,9 @@ class LossSurface(LossProfile):
|
|
|
336
339
|
- **steps** (dict): Dictionary of steps between the lower and upper bound of the parameters.
|
|
337
340
|
"""
|
|
338
341
|
|
|
339
|
-
def __init__(self, net: Shell, loss_config: LossConfig, device: str = "cpu"):
|
|
342
|
+
def __init__(self, net: Shell, loss_config: LossConfig, device: str = "cpu", dtype: torch.dtype = torch.float32):
|
|
340
343
|
|
|
341
|
-
super().__init__(net, loss_config, device)
|
|
344
|
+
super().__init__(net, loss_config, device, dtype)
|
|
342
345
|
|
|
343
346
|
assert (
|
|
344
347
|
len(loss_config.param_config) == 2
|
|
@@ -403,9 +406,9 @@ class LossSurface(LossProfile):
|
|
|
403
406
|
if type(self.param_config[0].lower_bound) == list:
|
|
404
407
|
# interpolate between the lower and upper bound
|
|
405
408
|
new_value = (1 - steps_0[i_step_0]) * torch.tensor(
|
|
406
|
-
self.param_config[0].lower_bound, device=self.device
|
|
409
|
+
self.param_config[0].lower_bound, device=self.device, dtype=self.dtype
|
|
407
410
|
) + steps_0[i_step_0] * torch.tensor(
|
|
408
|
-
self.param_config[0].upper_bound, device=self.device
|
|
411
|
+
self.param_config[0].upper_bound, device=self.device, dtype=self.dtype
|
|
409
412
|
)
|
|
410
413
|
else:
|
|
411
414
|
new_value = steps_0[i_step_0]
|
|
@@ -419,9 +422,9 @@ class LossSurface(LossProfile):
|
|
|
419
422
|
if type(self.param_config[1].lower_bound) == list:
|
|
420
423
|
# interpolate between the lower and upper bound
|
|
421
424
|
new_value = (1 - steps_1[i_step_1]) * torch.tensor(
|
|
422
|
-
self.param_config[1].lower_bound, device=self.device
|
|
425
|
+
self.param_config[1].lower_bound, device=self.device, dtype=self.dtype
|
|
423
426
|
) + steps_1[i_step_1] * torch.tensor(
|
|
424
|
-
self.param_config[1].upper_bound, device=self.device
|
|
427
|
+
self.param_config[1].upper_bound, device=self.device, dtype=self.dtype
|
|
425
428
|
)
|
|
426
429
|
else:
|
|
427
430
|
new_value = steps_1[i_step_1]
|
|
@@ -621,7 +624,6 @@ class LossSurface(LossProfile):
|
|
|
621
624
|
plt.savefig(
|
|
622
625
|
f"{self.output_dir}/{title[i_plot]}_{self.param_config[0].key}_{self.param_config[1].key}.png"
|
|
623
626
|
)
|
|
624
|
-
plt.clf()
|
|
625
627
|
return fig, ax
|
|
626
628
|
|
|
627
629
|
def compute_accuracy(self, loss):
|