flamo 0.1.12__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flamo/auxiliary/config/config.py +3 -1
- flamo/auxiliary/eq.py +22 -18
- flamo/auxiliary/reverb.py +28 -17
- flamo/auxiliary/scattering.py +21 -21
- flamo/functional.py +74 -52
- flamo/optimize/dataset.py +7 -5
- flamo/optimize/surface.py +15 -13
- flamo/processor/dsp.py +158 -99
- flamo/processor/system.py +15 -10
- flamo/utils.py +2 -2
- {flamo-0.1.12.dist-info → flamo-0.2.0.dist-info}/METADATA +1 -1
- flamo-0.2.0.dist-info/RECORD +24 -0
- flamo-0.1.12.dist-info/RECORD +0 -24
- {flamo-0.1.12.dist-info → flamo-0.2.0.dist-info}/WHEEL +0 -0
- {flamo-0.1.12.dist-info → flamo-0.2.0.dist-info}/licenses/LICENSE +0 -0
flamo/processor/dsp.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import math
|
|
3
|
+
import abc
|
|
4
|
+
import warnings
|
|
3
5
|
from typing import Optional
|
|
4
6
|
import torch.nn as nn
|
|
5
7
|
import torch.nn.functional as F
|
|
@@ -41,11 +43,11 @@ class Transform(nn.Module):
|
|
|
41
43
|
tensor([1, 4, 9])
|
|
42
44
|
"""
|
|
43
45
|
|
|
44
|
-
def __init__(self, transform: callable = lambda x: x, device: Optional[str] = None):
|
|
45
|
-
|
|
46
|
+
def __init__(self, transform: callable = lambda x: x, device: Optional[str] = None, dtype: torch.dtype = torch.float32):
|
|
46
47
|
super().__init__()
|
|
47
48
|
self.transform = transform
|
|
48
49
|
self.device = device
|
|
50
|
+
self.dtype = dtype
|
|
49
51
|
|
|
50
52
|
def forward(self, x: torch.Tensor):
|
|
51
53
|
r"""
|
|
@@ -75,12 +77,12 @@ class FFT(Transform):
|
|
|
75
77
|
For details on the real FFT function, see `torch.fft.rfft documentation <https://pytorch.org/docs/stable/generated/torch.fft.rfft.html>`_.
|
|
76
78
|
"""
|
|
77
79
|
|
|
78
|
-
def __init__(self, nfft: int = 2**11, norm: str = "backward"):
|
|
79
|
-
|
|
80
|
+
def __init__(self, nfft: int = 2**11, norm: str = "backward", dtype: torch.dtype = torch.float32):
|
|
80
81
|
self.nfft = nfft
|
|
81
82
|
self.norm = norm
|
|
83
|
+
self.dtype = dtype
|
|
82
84
|
transform = lambda x: torch.fft.rfft(x, n=self.nfft, dim=1, norm=self.norm)
|
|
83
|
-
super().__init__(transform=transform)
|
|
85
|
+
super().__init__(transform=transform, dtype=self.dtype)
|
|
84
86
|
|
|
85
87
|
|
|
86
88
|
class iFFT(Transform):
|
|
@@ -97,12 +99,12 @@ class iFFT(Transform):
|
|
|
97
99
|
For details on the inverse real FFT function, see `torch.fft.irfft documentation <https://pytorch.org/docs/stable/generated/torch.fft.irfft.html>`_.
|
|
98
100
|
"""
|
|
99
101
|
|
|
100
|
-
def __init__(self, nfft: int = 2**11, norm: str = "backward"):
|
|
101
|
-
|
|
102
|
+
def __init__(self, nfft: int = 2**11, norm: str = "backward", dtype: torch.dtype = torch.float32):
|
|
102
103
|
self.nfft = nfft
|
|
103
104
|
self.norm = norm
|
|
105
|
+
self.dtype = dtype
|
|
104
106
|
transform = lambda x: torch.fft.irfft(x, n=self.nfft, dim=1, norm=self.norm)
|
|
105
|
-
super().__init__(transform=transform)
|
|
107
|
+
super().__init__(transform=transform, dtype=self.dtype)
|
|
106
108
|
|
|
107
109
|
|
|
108
110
|
class FFTAntiAlias(Transform):
|
|
@@ -119,7 +121,7 @@ class FFTAntiAlias(Transform):
|
|
|
119
121
|
- **norm** (str): The normalization mode for the FFT.
|
|
120
122
|
- **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. Default: 0.0.
|
|
121
123
|
- **device** (str): The device of the constructed tensors. Default: None.
|
|
122
|
-
|
|
124
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
123
125
|
|
|
124
126
|
For details on the FFT function, see `torch.fft.rfft documentation <https://pytorch.org/docs/stable/generated/torch.fft.rfft.html>`_.
|
|
125
127
|
"""
|
|
@@ -130,24 +132,23 @@ class FFTAntiAlias(Transform):
|
|
|
130
132
|
norm: str = "backward",
|
|
131
133
|
alias_decay_db: float = 0.0,
|
|
132
134
|
device: Optional[str] = None,
|
|
135
|
+
dtype: torch.dtype = torch.float32,
|
|
133
136
|
):
|
|
134
|
-
|
|
135
137
|
self.nfft = nfft
|
|
136
138
|
self.norm = norm
|
|
137
139
|
self.device = device
|
|
138
|
-
|
|
140
|
+
self.dtype = dtype
|
|
139
141
|
gamma = 10 ** (
|
|
140
|
-
-torch.abs(torch.tensor(alias_decay_db, device=self.device))
|
|
142
|
+
-torch.abs(torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype))
|
|
141
143
|
/ (self.nfft)
|
|
142
144
|
/ 20
|
|
143
145
|
)
|
|
144
146
|
self.alias_envelope = gamma ** torch.arange(
|
|
145
|
-
0, -self.nfft, -1, device=self.device
|
|
147
|
+
0, -self.nfft, -1, device=self.device, dtype=self.dtype
|
|
146
148
|
)
|
|
147
|
-
|
|
148
149
|
fft = lambda x: torch.fft.rfft(x, n=self.nfft, dim=1, norm=self.norm)
|
|
149
150
|
transform = lambda x: fft(torch.einsum("btm, t->btm", x, self.alias_envelope))
|
|
150
|
-
super().__init__(transform=transform)
|
|
151
|
+
super().__init__(transform=transform, dtype=self.dtype)
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
class iFFTAntiAlias(Transform):
|
|
@@ -163,6 +164,7 @@ class iFFTAntiAlias(Transform):
|
|
|
163
164
|
- **norm** (str): The normalization mode. Default: "backward".
|
|
164
165
|
- **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. Default: 0.0.
|
|
165
166
|
- **device** (str): The device of the constructed tensors. Default: None.
|
|
167
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
166
168
|
|
|
167
169
|
For details on the inverse FFT function, see `torch.fft.irfft documentation <https://pytorch.org/docs/stable/generated/torch.fft.irfft.html>`_.
|
|
168
170
|
"""
|
|
@@ -173,24 +175,23 @@ class iFFTAntiAlias(Transform):
|
|
|
173
175
|
norm: str = "backward",
|
|
174
176
|
alias_decay_db: float = 0.0,
|
|
175
177
|
device: Optional[str] = None,
|
|
178
|
+
dtype: torch.dtype = torch.float32,
|
|
176
179
|
):
|
|
177
|
-
|
|
178
180
|
self.nfft = nfft
|
|
179
181
|
self.norm = norm
|
|
180
182
|
self.device = device
|
|
181
|
-
|
|
183
|
+
self.dtype = dtype
|
|
182
184
|
gamma = 10 ** (
|
|
183
|
-
-torch.abs(torch.tensor(alias_decay_db, device=self.device))
|
|
185
|
+
-torch.abs(torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype))
|
|
184
186
|
/ (self.nfft)
|
|
185
187
|
/ 20
|
|
186
188
|
)
|
|
187
189
|
self.alias_envelope = gamma ** torch.arange(
|
|
188
|
-
0, -self.nfft, -1, device=self.device
|
|
190
|
+
0, -self.nfft, -1, device=self.device, dtype=self.dtype
|
|
189
191
|
)
|
|
190
|
-
|
|
191
192
|
ifft = lambda x: torch.fft.irfft(x, n=self.nfft, dim=1, norm=self.norm)
|
|
192
193
|
transform = lambda x: torch.einsum("btm, t->btm", ifft(x), self.alias_envelope)
|
|
193
|
-
super().__init__(transform=transform)
|
|
194
|
+
super().__init__(transform=transform, dtype=self.dtype)
|
|
194
195
|
|
|
195
196
|
|
|
196
197
|
# ============================= CORE ================================
|
|
@@ -216,6 +217,7 @@ class DSP(nn.Module):
|
|
|
216
217
|
- **requires_grad** (bool, optional): Whether the parameters require gradients. Default: False.
|
|
217
218
|
- **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
218
219
|
- **device** (str): The device of the constructed tensor, if any. Default: None.
|
|
220
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
219
221
|
|
|
220
222
|
**Attributes**:
|
|
221
223
|
- **param** (nn.Parameter): The parameters of the DSP module.
|
|
@@ -234,6 +236,7 @@ class DSP(nn.Module):
|
|
|
234
236
|
requires_grad: bool = False,
|
|
235
237
|
alias_decay_db: float = 0.0,
|
|
236
238
|
device: Optional[str] = None,
|
|
239
|
+
dtype: torch.dtype = torch.float32,
|
|
237
240
|
):
|
|
238
241
|
|
|
239
242
|
super().__init__()
|
|
@@ -244,25 +247,31 @@ class DSP(nn.Module):
|
|
|
244
247
|
self.new_value = 0 # flag indicating if new values have been assigned
|
|
245
248
|
self.requires_grad = requires_grad
|
|
246
249
|
self.device = device
|
|
250
|
+
self.dtype = dtype
|
|
247
251
|
self.param = nn.Parameter(
|
|
248
|
-
torch.empty(self.size, device=self.device), requires_grad=self.requires_grad
|
|
252
|
+
torch.empty(self.size, device=self.device, dtype=self.dtype), requires_grad=self.requires_grad
|
|
249
253
|
)
|
|
250
254
|
self.fft = lambda x: torch.fft.rfft(x, n=self.nfft, dim=0)
|
|
251
255
|
self.ifft = lambda x: torch.fft.irfft(x, n=self.nfft, dim=0)
|
|
252
256
|
# initialize time anti-aliasing envelope function
|
|
253
|
-
self.alias_decay_db = torch.tensor(alias_decay_db, device=self.device)
|
|
257
|
+
self.alias_decay_db = torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype)
|
|
254
258
|
self.init_param()
|
|
255
259
|
self.get_gamma()
|
|
256
260
|
|
|
257
|
-
|
|
261
|
+
@abc.abstractmethod
|
|
262
|
+
def forward(self, x, **kwArguments):
|
|
258
263
|
r"""
|
|
259
264
|
Forward method.
|
|
260
265
|
|
|
261
266
|
Input is returned. Forward method is to be implemented by the child class.
|
|
262
267
|
|
|
263
268
|
"""
|
|
264
|
-
|
|
269
|
+
warnings.warn(
|
|
270
|
+
"Forward method not implemented. Input is returned.",
|
|
271
|
+
UserWarning
|
|
272
|
+
)
|
|
265
273
|
return x
|
|
274
|
+
|
|
266
275
|
|
|
267
276
|
def init_param(self):
|
|
268
277
|
r"""
|
|
@@ -334,6 +343,7 @@ class Gain(DSP):
|
|
|
334
343
|
- **requires_grad** (bool): Whether the parameters requires gradients. Default: False.
|
|
335
344
|
- **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
336
345
|
- **device** (str): The device of the constructed tensors. Default: None.
|
|
346
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
337
347
|
|
|
338
348
|
**Attributes**:
|
|
339
349
|
- **param** (nn.Parameter): The parameters of the Gain module.
|
|
@@ -354,6 +364,7 @@ class Gain(DSP):
|
|
|
354
364
|
requires_grad: bool = False,
|
|
355
365
|
alias_decay_db: float = 0.0,
|
|
356
366
|
device: Optional[str] = None,
|
|
367
|
+
dtype: torch.dtype = torch.float32,
|
|
357
368
|
):
|
|
358
369
|
super().__init__(
|
|
359
370
|
size=size,
|
|
@@ -362,6 +373,7 @@ class Gain(DSP):
|
|
|
362
373
|
requires_grad=requires_grad,
|
|
363
374
|
alias_decay_db=alias_decay_db,
|
|
364
375
|
device=device,
|
|
376
|
+
dtype=dtype,
|
|
365
377
|
)
|
|
366
378
|
self.initialize_class()
|
|
367
379
|
|
|
@@ -461,6 +473,7 @@ class parallelGain(Gain):
|
|
|
461
473
|
requires_grad: bool = False,
|
|
462
474
|
alias_decay_db: float = 0.0,
|
|
463
475
|
device: Optional[str] = None,
|
|
476
|
+
dtype: torch.dtype = torch.float32,
|
|
464
477
|
):
|
|
465
478
|
super().__init__(
|
|
466
479
|
size=size,
|
|
@@ -469,6 +482,7 @@ class parallelGain(Gain):
|
|
|
469
482
|
requires_grad=requires_grad,
|
|
470
483
|
alias_decay_db=alias_decay_db,
|
|
471
484
|
device=device,
|
|
485
|
+
dtype=dtype,
|
|
472
486
|
)
|
|
473
487
|
|
|
474
488
|
def check_param_shape(self):
|
|
@@ -520,6 +534,7 @@ class Matrix(Gain):
|
|
|
520
534
|
- **requires_grad** (bool, optional): Whether the matrix requires gradient computation. Default: False.
|
|
521
535
|
- **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
522
536
|
- **device** (str, optional): The device of the constructed tensors. Default: None.
|
|
537
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
523
538
|
|
|
524
539
|
**Attributes**:
|
|
525
540
|
- **param** (nn.Parameter): The parameters of the Matrix module.
|
|
@@ -543,6 +558,7 @@ class Matrix(Gain):
|
|
|
543
558
|
requires_grad: bool = False,
|
|
544
559
|
alias_decay_db: float = 0.0,
|
|
545
560
|
device: Optional[str] = None,
|
|
561
|
+
dtype: torch.dtype = torch.float32,
|
|
546
562
|
):
|
|
547
563
|
self.matrix_type = matrix_type
|
|
548
564
|
self.iter = iter # iterations number for the rotation matrix
|
|
@@ -553,6 +569,7 @@ class Matrix(Gain):
|
|
|
553
569
|
requires_grad=requires_grad,
|
|
554
570
|
alias_decay_db=alias_decay_db,
|
|
555
571
|
device=device,
|
|
572
|
+
dtype=dtype,
|
|
556
573
|
)
|
|
557
574
|
|
|
558
575
|
def matrix_gallery(self):
|
|
@@ -579,7 +596,7 @@ class Matrix(Gain):
|
|
|
579
596
|
assert (
|
|
580
597
|
N % 2 == 0
|
|
581
598
|
), "Matrix must have even dimensions to be Hadamard"
|
|
582
|
-
self.map = lambda x: HadamardMatrix(self.size[0], device=self.device)(x)
|
|
599
|
+
self.map = lambda x: HadamardMatrix(self.size[0], device=self.device, dtype=self.dtype)(x)
|
|
583
600
|
case "rotation":
|
|
584
601
|
assert (
|
|
585
602
|
N == self.size[1]
|
|
@@ -587,7 +604,7 @@ class Matrix(Gain):
|
|
|
587
604
|
assert (
|
|
588
605
|
N % 2 == 0
|
|
589
606
|
), "Matrix must have even dimensions to be a rotation matrix"
|
|
590
|
-
self.map = lambda x: RotationMatrix(self.size[0], self.iter, device=self.device)([x[0][0]])
|
|
607
|
+
self.map = lambda x: RotationMatrix(self.size[0], self.iter, device=self.device, dtype=self.dtype)([x[0][0]])
|
|
591
608
|
|
|
592
609
|
def initialize_class(self):
|
|
593
610
|
r"""
|
|
@@ -619,6 +636,7 @@ class HouseholderMatrix(Gain):
|
|
|
619
636
|
- **requires_grad** (bool, optional): If True, gradients will be computed for the parameters. Defaults to False.
|
|
620
637
|
- **alias_decay_db** (float, optional): Alias decay in decibels. Defaults to 0.0.
|
|
621
638
|
- **device** (optional): Device on which to perform computations. Defaults to None.
|
|
639
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
622
640
|
|
|
623
641
|
**Attributes**:
|
|
624
642
|
- **param** (nn.Parameter): The parameters `u`` used to construct the Householder matrix.
|
|
@@ -640,6 +658,7 @@ class HouseholderMatrix(Gain):
|
|
|
640
658
|
requires_grad: bool = False,
|
|
641
659
|
alias_decay_db: float = 0.0,
|
|
642
660
|
device: Optional[str] = None,
|
|
661
|
+
dtype: torch.dtype = torch.float32,
|
|
643
662
|
):
|
|
644
663
|
assert size[0] == size[1], "Matrix must be square"
|
|
645
664
|
size = (size[0], 1)
|
|
@@ -651,6 +670,7 @@ class HouseholderMatrix(Gain):
|
|
|
651
670
|
requires_grad=requires_grad,
|
|
652
671
|
alias_decay_db=alias_decay_db,
|
|
653
672
|
device=device,
|
|
673
|
+
dtype=dtype,
|
|
654
674
|
)
|
|
655
675
|
|
|
656
676
|
def forward(self, x, ext_param=None):
|
|
@@ -735,6 +755,7 @@ class Filter(DSP):
|
|
|
735
755
|
- **requires_grad** (bool): Whether the filter parameters require gradients. Default: False.
|
|
736
756
|
- **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
737
757
|
- **device** (str): The device of the constructed tensors. Default: None.
|
|
758
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
738
759
|
|
|
739
760
|
**Attributes**:
|
|
740
761
|
- **param** (nn.Parameter): The parameters of the Filter module.
|
|
@@ -758,6 +779,7 @@ class Filter(DSP):
|
|
|
758
779
|
requires_grad: bool = False,
|
|
759
780
|
alias_decay_db: float = 0.0,
|
|
760
781
|
device: Optional[str] = None,
|
|
782
|
+
dtype: torch.dtype = torch.float32,
|
|
761
783
|
):
|
|
762
784
|
super().__init__(
|
|
763
785
|
size=size,
|
|
@@ -766,6 +788,7 @@ class Filter(DSP):
|
|
|
766
788
|
requires_grad=requires_grad,
|
|
767
789
|
alias_decay_db=alias_decay_db,
|
|
768
790
|
device=device,
|
|
791
|
+
dtype=dtype,
|
|
769
792
|
)
|
|
770
793
|
self.initialize_class()
|
|
771
794
|
|
|
@@ -886,6 +909,7 @@ class parallelFilter(Filter):
|
|
|
886
909
|
requires_grad: bool = False,
|
|
887
910
|
alias_decay_db: float = 0.0,
|
|
888
911
|
device: Optional[str] = None,
|
|
912
|
+
dtype: torch.dtype = torch.float32,
|
|
889
913
|
):
|
|
890
914
|
super().__init__(
|
|
891
915
|
size=size,
|
|
@@ -894,6 +918,7 @@ class parallelFilter(Filter):
|
|
|
894
918
|
requires_grad=requires_grad,
|
|
895
919
|
alias_decay_db=alias_decay_db,
|
|
896
920
|
device=device,
|
|
921
|
+
dtype=dtype,
|
|
897
922
|
)
|
|
898
923
|
|
|
899
924
|
def check_param_shape(self):
|
|
@@ -973,7 +998,8 @@ class ScatteringMatrix(Filter):
|
|
|
973
998
|
- **requires_grad** (bool): Whether the filter parameters require gradients. Default: False.
|
|
974
999
|
- **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
975
1000
|
- **device** (str): The device of the constructed tensors. Default: None.
|
|
976
|
-
|
|
1001
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
1002
|
+
|
|
977
1003
|
**Attributes**:
|
|
978
1004
|
- **param** (nn.Parameter): The parameters of the Filter module.
|
|
979
1005
|
- **map** (function): Mapping function to ensure orthogonality of :math:`\mathbf{U}_k`.
|
|
@@ -1002,6 +1028,7 @@ class ScatteringMatrix(Filter):
|
|
|
1002
1028
|
requires_grad: bool = False,
|
|
1003
1029
|
alias_decay_db: float = 0.0,
|
|
1004
1030
|
device: Optional[str] = None,
|
|
1031
|
+
dtype: torch.dtype = torch.float32,
|
|
1005
1032
|
):
|
|
1006
1033
|
self.sparsity = sparsity
|
|
1007
1034
|
self.gain_per_sample = gain_per_sample
|
|
@@ -1017,6 +1044,7 @@ class ScatteringMatrix(Filter):
|
|
|
1017
1044
|
requires_grad=requires_grad,
|
|
1018
1045
|
alias_decay_db=alias_decay_db,
|
|
1019
1046
|
device=device,
|
|
1047
|
+
dtype=dtype,
|
|
1020
1048
|
)
|
|
1021
1049
|
|
|
1022
1050
|
def get_freq_convolve(self):
|
|
@@ -1071,6 +1099,7 @@ class ScatteringMatrix(Filter):
|
|
|
1071
1099
|
m_L=self.m_L,
|
|
1072
1100
|
m_R=self.m_R,
|
|
1073
1101
|
device=self.device,
|
|
1102
|
+
dtype=self.dtype,
|
|
1074
1103
|
)
|
|
1075
1104
|
self.check_param_shape()
|
|
1076
1105
|
self.get_io()
|
|
@@ -1110,6 +1139,7 @@ class VelvetNoiseMatrix(Filter):
|
|
|
1110
1139
|
- **requires_grad** (bool): Whether the filter parameters require gradients. Default: False.
|
|
1111
1140
|
- **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
1112
1141
|
- **device** (str): The device of the constructed tensors. Default: None.
|
|
1142
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
1113
1143
|
|
|
1114
1144
|
**Attributes**:
|
|
1115
1145
|
- **param** (nn.Parameter): The parameters of the Filter module.
|
|
@@ -1139,6 +1169,7 @@ class VelvetNoiseMatrix(Filter):
|
|
|
1139
1169
|
m_R: torch.tensor = None,
|
|
1140
1170
|
alias_decay_db: float = 0.0,
|
|
1141
1171
|
device: Optional[str] = None,
|
|
1172
|
+
dtype: torch.dtype = torch.float32,
|
|
1142
1173
|
):
|
|
1143
1174
|
self.sparsity = 1/density
|
|
1144
1175
|
self.gain_per_sample = gain_per_sample
|
|
@@ -1155,8 +1186,9 @@ class VelvetNoiseMatrix(Filter):
|
|
|
1155
1186
|
requires_grad=False,
|
|
1156
1187
|
alias_decay_db=alias_decay_db,
|
|
1157
1188
|
device=device,
|
|
1189
|
+
dtype=dtype,
|
|
1158
1190
|
)
|
|
1159
|
-
self.assign_value(torch.tensor(hadamard_matrix(self.size[-1]), device=self.device).unsqueeze(0).repeat(self.size[0], 1, 1))
|
|
1191
|
+
self.assign_value(torch.tensor(hadamard_matrix(self.size[-1]), device=self.device, dtype=self.dtype).unsqueeze(0).repeat(self.size[0], 1, 1))
|
|
1160
1192
|
|
|
1161
1193
|
def get_freq_convolve(self):
|
|
1162
1194
|
r"""
|
|
@@ -1210,6 +1242,7 @@ class VelvetNoiseMatrix(Filter):
|
|
|
1210
1242
|
m_L=self.m_L,
|
|
1211
1243
|
m_R=self.m_R,
|
|
1212
1244
|
device=self.device,
|
|
1245
|
+
dtype=self.dtype,
|
|
1213
1246
|
)
|
|
1214
1247
|
self.check_param_shape()
|
|
1215
1248
|
self.get_io()
|
|
@@ -1253,6 +1286,7 @@ class Biquad(Filter):
|
|
|
1253
1286
|
- **requires_grad** (bool, optional): Whether the filter parameters require gradient computation. Default: True.
|
|
1254
1287
|
- **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
1255
1288
|
- **device** (str, optional): The device of the constructed tensors. Default: None.
|
|
1289
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
1256
1290
|
|
|
1257
1291
|
**Attributes**:
|
|
1258
1292
|
- **param** (nn.Parameter): The parameters of the Filter module.
|
|
@@ -1279,17 +1313,19 @@ class Biquad(Filter):
|
|
|
1279
1313
|
requires_grad: bool = False,
|
|
1280
1314
|
alias_decay_db: float = 0.0,
|
|
1281
1315
|
device: Optional[str] = None,
|
|
1316
|
+
dtype: torch.dtype = torch.float32,
|
|
1282
1317
|
):
|
|
1283
1318
|
assert filter_type in ["lowpass", "highpass", "bandpass"], "Invalid filter type"
|
|
1284
1319
|
self.n_sections = n_sections
|
|
1285
1320
|
self.filter_type = filter_type
|
|
1286
1321
|
self.fs = fs
|
|
1287
1322
|
self.device = device
|
|
1323
|
+
self.dtype = dtype
|
|
1288
1324
|
self.get_map()
|
|
1289
1325
|
gamma = 10 ** (
|
|
1290
|
-
-torch.abs(torch.tensor(alias_decay_db, device=self.device)) / (nfft) / 20
|
|
1326
|
+
-torch.abs(torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype)) / (nfft) / 20
|
|
1291
1327
|
)
|
|
1292
|
-
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=self.device)
|
|
1328
|
+
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=self.device, dtype=self.dtype)
|
|
1293
1329
|
super().__init__(
|
|
1294
1330
|
size=(n_sections, *self.get_size(), *size),
|
|
1295
1331
|
nfft=nfft,
|
|
@@ -1297,6 +1333,7 @@ class Biquad(Filter):
|
|
|
1297
1333
|
requires_grad=requires_grad,
|
|
1298
1334
|
alias_decay_db=alias_decay_db,
|
|
1299
1335
|
device=device,
|
|
1336
|
+
dtype=dtype,
|
|
1300
1337
|
)
|
|
1301
1338
|
|
|
1302
1339
|
def get_size(self):
|
|
@@ -1361,6 +1398,7 @@ class Biquad(Filter):
|
|
|
1361
1398
|
gain=param[:, 1, :, :],
|
|
1362
1399
|
fs=self.fs,
|
|
1363
1400
|
device=self.device,
|
|
1401
|
+
dtype=self.dtype,
|
|
1364
1402
|
)
|
|
1365
1403
|
case "highpass":
|
|
1366
1404
|
b, a = highpass_filter(
|
|
@@ -1368,6 +1406,7 @@ class Biquad(Filter):
|
|
|
1368
1406
|
gain=param[:, 1, :, :],
|
|
1369
1407
|
fs=self.fs,
|
|
1370
1408
|
device=self.device,
|
|
1409
|
+
dtype=self.dtype,
|
|
1371
1410
|
)
|
|
1372
1411
|
case "bandpass":
|
|
1373
1412
|
b, a = bandpass_filter(
|
|
@@ -1376,15 +1415,15 @@ class Biquad(Filter):
|
|
|
1376
1415
|
gain=param[:, 2, :, :],
|
|
1377
1416
|
fs=self.fs,
|
|
1378
1417
|
device=self.device,
|
|
1418
|
+
dtype=self.dtype,
|
|
1379
1419
|
)
|
|
1380
|
-
b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy
|
|
1381
|
-
a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy
|
|
1420
|
+
b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, b)
|
|
1421
|
+
a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, a)
|
|
1382
1422
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
1383
1423
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
1384
1424
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
1385
1425
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
1386
|
-
|
|
1387
|
-
return H.to(H_type), B, A
|
|
1426
|
+
return H, B, A
|
|
1388
1427
|
|
|
1389
1428
|
def get_map(self):
|
|
1390
1429
|
r"""
|
|
@@ -1398,10 +1437,10 @@ class Biquad(Filter):
|
|
|
1398
1437
|
(x[:, 0, :, :], 20 * torch.log10(torch.abs(x[:, 1, :, :]))),
|
|
1399
1438
|
dim=1,
|
|
1400
1439
|
),
|
|
1401
|
-
min=torch.tensor([0, -60], device=self.device)
|
|
1440
|
+
min=torch.tensor([0, -60], device=self.device, dtype=self.dtype)
|
|
1402
1441
|
.view(-1, 1, 1)
|
|
1403
1442
|
.expand_as(x),
|
|
1404
|
-
max=torch.tensor([1, 60], device=self.device)
|
|
1443
|
+
max=torch.tensor([1, 60], device=self.device, dtype=self.dtype)
|
|
1405
1444
|
.view(-1, 1, 1)
|
|
1406
1445
|
.expand_as(x),
|
|
1407
1446
|
)
|
|
@@ -1415,10 +1454,10 @@ class Biquad(Filter):
|
|
|
1415
1454
|
),
|
|
1416
1455
|
dim=1,
|
|
1417
1456
|
),
|
|
1418
|
-
min=torch.tensor([0 + torch.finfo(
|
|
1457
|
+
min=torch.tensor([0 + torch.finfo(self.dtype).eps, 0 + torch.finfo(self.dtype).eps, -60], device=self.device, dtype=self.dtype)
|
|
1419
1458
|
.view(-1, 1, 1)
|
|
1420
1459
|
.expand_as(x),
|
|
1421
|
-
max=torch.tensor([1 - torch.finfo(
|
|
1460
|
+
max=torch.tensor([1 - torch.finfo(self.dtype).eps, 1 - torch.finfo(self.dtype).eps, 60], device=self.device, dtype=self.dtype)
|
|
1422
1461
|
.view(-1, 1, 1)
|
|
1423
1462
|
.expand_as(x),
|
|
1424
1463
|
)
|
|
@@ -1491,6 +1530,7 @@ class parallelBiquad(Biquad):
|
|
|
1491
1530
|
requires_grad: bool = False,
|
|
1492
1531
|
alias_decay_db: float = 0.0,
|
|
1493
1532
|
device: Optional[str] = None,
|
|
1533
|
+
dtype: torch.dtype = torch.float32,
|
|
1494
1534
|
):
|
|
1495
1535
|
super().__init__(
|
|
1496
1536
|
size=size,
|
|
@@ -1501,6 +1541,7 @@ class parallelBiquad(Biquad):
|
|
|
1501
1541
|
requires_grad=requires_grad,
|
|
1502
1542
|
alias_decay_db=alias_decay_db,
|
|
1503
1543
|
device=device,
|
|
1544
|
+
dtype=dtype,
|
|
1504
1545
|
)
|
|
1505
1546
|
|
|
1506
1547
|
def check_param_shape(self):
|
|
@@ -1519,10 +1560,10 @@ class parallelBiquad(Biquad):
|
|
|
1519
1560
|
torch.stack(
|
|
1520
1561
|
(x[:, 0, :], 20 * torch.log10(torch.abs(x[:, -1, :]))), dim=1
|
|
1521
1562
|
),
|
|
1522
|
-
min=torch.tensor([0, -60], device=self.device)
|
|
1563
|
+
min=torch.tensor([0, -60], device=self.device, dtype=self.dtype)
|
|
1523
1564
|
.view(-1, 1)
|
|
1524
1565
|
.expand_as(x),
|
|
1525
|
-
max=torch.tensor([1, 60], device=self.device)
|
|
1566
|
+
max=torch.tensor([1, 60], device=self.device, dtype=self.dtype)
|
|
1526
1567
|
.view(-1, 1)
|
|
1527
1568
|
.expand_as(x),
|
|
1528
1569
|
)
|
|
@@ -1536,10 +1577,10 @@ class parallelBiquad(Biquad):
|
|
|
1536
1577
|
),
|
|
1537
1578
|
dim=1,
|
|
1538
1579
|
),
|
|
1539
|
-
min=torch.tensor([0 + torch.finfo(
|
|
1580
|
+
min=torch.tensor([0 + torch.finfo(self.dtype).eps, 0 + torch.finfo(self.dtype).eps, -60], device=self.device, dtype=self.dtype)
|
|
1540
1581
|
.view(-1, 1)
|
|
1541
1582
|
.expand_as(x),
|
|
1542
|
-
max=torch.tensor([1 - torch.finfo(
|
|
1583
|
+
max=torch.tensor([1 - torch.finfo(self.dtype).eps, 1 - torch.finfo(self.dtype).eps, 60], device=self.device, dtype=self.dtype)
|
|
1543
1584
|
.view(-1, 1)
|
|
1544
1585
|
.expand_as(x),
|
|
1545
1586
|
)
|
|
@@ -1582,6 +1623,7 @@ class parallelBiquad(Biquad):
|
|
|
1582
1623
|
gain=param[:, 1, :],
|
|
1583
1624
|
fs=self.fs,
|
|
1584
1625
|
device=self.device,
|
|
1626
|
+
dtype=self.dtype,
|
|
1585
1627
|
)
|
|
1586
1628
|
case "highpass":
|
|
1587
1629
|
b, a = highpass_filter(
|
|
@@ -1589,6 +1631,7 @@ class parallelBiquad(Biquad):
|
|
|
1589
1631
|
gain=param[:, 1, :],
|
|
1590
1632
|
fs=self.fs,
|
|
1591
1633
|
device=self.device,
|
|
1634
|
+
dtype=self.dtype,
|
|
1592
1635
|
)
|
|
1593
1636
|
case "bandpass":
|
|
1594
1637
|
b, a = bandpass_filter(
|
|
@@ -1597,15 +1640,15 @@ class parallelBiquad(Biquad):
|
|
|
1597
1640
|
gain=param[:, 2, :],
|
|
1598
1641
|
fs=self.fs,
|
|
1599
1642
|
device=self.device,
|
|
1643
|
+
dtype=self.dtype,
|
|
1600
1644
|
)
|
|
1601
|
-
b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy
|
|
1602
|
-
a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy
|
|
1645
|
+
b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, b)
|
|
1646
|
+
a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, a)
|
|
1603
1647
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
1604
1648
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
1605
1649
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
1606
1650
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
1607
|
-
|
|
1608
|
-
return H.to(H_type), B, A
|
|
1651
|
+
return H, B, A
|
|
1609
1652
|
|
|
1610
1653
|
def get_freq_convolve(self):
|
|
1611
1654
|
self.freq_convolve = lambda x, param: torch.einsum(
|
|
@@ -1673,7 +1716,8 @@ class SVF(Filter):
|
|
|
1673
1716
|
- **requires_grad** (bool, optional): Whether the filter parameters require gradients. Default: False.
|
|
1674
1717
|
- **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
1675
1718
|
- **device** (str, optional): The device of the constructed tensors. Default: None.
|
|
1676
|
-
|
|
1719
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
1720
|
+
|
|
1677
1721
|
**Attributes**:
|
|
1678
1722
|
- **alias_envelope_dcy** (torch.Tensor): The anti time-aliasing decaying envelope.
|
|
1679
1723
|
- **fft** (function): The FFT function. Calls the torch.fft.rfft function.
|
|
@@ -1700,6 +1744,7 @@ class SVF(Filter):
|
|
|
1700
1744
|
requires_grad: bool = False,
|
|
1701
1745
|
alias_decay_db: float = 0.0,
|
|
1702
1746
|
device: Optional[str] = None,
|
|
1747
|
+
dtype: torch.dtype = torch.float32,
|
|
1703
1748
|
):
|
|
1704
1749
|
self.fs = fs
|
|
1705
1750
|
self.n_sections = n_sections
|
|
@@ -1715,9 +1760,9 @@ class SVF(Filter):
|
|
|
1715
1760
|
], "Invalid filter type"
|
|
1716
1761
|
self.filter_type = filter_type
|
|
1717
1762
|
gamma = 10 ** (
|
|
1718
|
-
-torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20
|
|
1763
|
+
-torch.abs(torch.tensor(alias_decay_db, device=device, dtype=dtype)) / (nfft) / 20
|
|
1719
1764
|
)
|
|
1720
|
-
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device)
|
|
1765
|
+
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device, dtype=dtype)
|
|
1721
1766
|
super().__init__(
|
|
1722
1767
|
size=(5, self.n_sections, *size),
|
|
1723
1768
|
nfft=nfft,
|
|
@@ -1725,6 +1770,7 @@ class SVF(Filter):
|
|
|
1725
1770
|
requires_grad=requires_grad,
|
|
1726
1771
|
alias_decay_db=alias_decay_db,
|
|
1727
1772
|
device=device,
|
|
1773
|
+
dtype=dtype,
|
|
1728
1774
|
)
|
|
1729
1775
|
|
|
1730
1776
|
def check_param_shape(self):
|
|
@@ -1767,14 +1813,13 @@ class SVF(Filter):
|
|
|
1767
1813
|
a[2] = (f**2) - 2 * R * f + 1
|
|
1768
1814
|
|
|
1769
1815
|
# apply anti-aliasing
|
|
1770
|
-
b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy
|
|
1771
|
-
a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy
|
|
1816
|
+
b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, b)
|
|
1817
|
+
a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, a)
|
|
1772
1818
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
1773
1819
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
1774
1820
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
1775
1821
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
1776
|
-
|
|
1777
|
-
return H.to(H_type), B, A
|
|
1822
|
+
return H, B, A
|
|
1778
1823
|
|
|
1779
1824
|
def param2freq(self, param):
|
|
1780
1825
|
r"""
|
|
@@ -1801,8 +1846,8 @@ class SVF(Filter):
|
|
|
1801
1846
|
|
|
1802
1847
|
"""
|
|
1803
1848
|
return torch.div(
|
|
1804
|
-
torch.log(torch.ones(1, device=self.device) + torch.exp(param)),
|
|
1805
|
-
torch.log(torch.tensor(2, device=self.device)),
|
|
1849
|
+
torch.log(torch.ones(1, device=self.device, dtype=self.dtype) + torch.exp(param)),
|
|
1850
|
+
torch.log(torch.tensor(2, device=self.device, dtype=self.dtype)),
|
|
1806
1851
|
)
|
|
1807
1852
|
|
|
1808
1853
|
def param2mix(self, param, R=None):
|
|
@@ -1887,8 +1932,8 @@ class SVF(Filter):
|
|
|
1887
1932
|
)
|
|
1888
1933
|
case None:
|
|
1889
1934
|
# general SVF filter
|
|
1890
|
-
bias = torch.ones((param.shape), device=self.device)
|
|
1891
|
-
bias[1] = 2 * torch.ones((param.shape[1:]), device=self.device)
|
|
1935
|
+
bias = torch.ones((param.shape), device=self.device, dtype=self.dtype)
|
|
1936
|
+
bias[1] = 2 * torch.ones((param.shape[1:]), device=self.device, dtype=self.dtype)
|
|
1892
1937
|
return param + bias
|
|
1893
1938
|
|
|
1894
1939
|
def map_param2svf(self, param):
|
|
@@ -1899,7 +1944,7 @@ class SVF(Filter):
|
|
|
1899
1944
|
r = self.param2R(param[1])
|
|
1900
1945
|
if self.filter_type == "lowshelf" or self.filter_type == "highshelf":
|
|
1901
1946
|
# R = r + torch.sqrt(torch.tensor(2))
|
|
1902
|
-
R = torch.tensor(1, device=self.device)
|
|
1947
|
+
R = torch.tensor(1, device=self.device, dtype=self.dtype)
|
|
1903
1948
|
if self.filter_type == "peaking":
|
|
1904
1949
|
R = 1 / r # temporary fix for peaking filter
|
|
1905
1950
|
m = self.param2mix(param[2:], r)
|
|
@@ -1945,6 +1990,7 @@ class parallelSVF(SVF):
|
|
|
1945
1990
|
requires_grad: bool = False,
|
|
1946
1991
|
alias_decay_db: float = 0.0,
|
|
1947
1992
|
device: Optional[str] = None,
|
|
1993
|
+
dtype: torch.dtype = torch.float32,
|
|
1948
1994
|
):
|
|
1949
1995
|
super().__init__(
|
|
1950
1996
|
size=size,
|
|
@@ -1955,6 +2001,7 @@ class parallelSVF(SVF):
|
|
|
1955
2001
|
requires_grad=requires_grad,
|
|
1956
2002
|
alias_decay_db=alias_decay_db,
|
|
1957
2003
|
device=device,
|
|
2004
|
+
dtype=dtype,
|
|
1958
2005
|
)
|
|
1959
2006
|
|
|
1960
2007
|
def check_param_shape(self):
|
|
@@ -1985,14 +2032,13 @@ class parallelSVF(SVF):
|
|
|
1985
2032
|
a[2] = (f**2) - 2 * R * f + 1
|
|
1986
2033
|
|
|
1987
2034
|
# apply anti-aliasing
|
|
1988
|
-
b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy
|
|
1989
|
-
a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy
|
|
2035
|
+
b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, b)
|
|
2036
|
+
a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, a)
|
|
1990
2037
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
1991
2038
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
1992
2039
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
1993
2040
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
1994
|
-
|
|
1995
|
-
return H.to(H_type), B, A
|
|
2041
|
+
return H, B, A
|
|
1996
2042
|
|
|
1997
2043
|
def get_freq_convolve(self):
|
|
1998
2044
|
self.freq_convolve = lambda x, param: torch.einsum(
|
|
@@ -2038,6 +2084,7 @@ class GEQ(Filter):
|
|
|
2038
2084
|
- **requires_grad** (bool, optional): Whether the filter parameters require gradients. Default: False.
|
|
2039
2085
|
- **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
|
|
2040
2086
|
- **device** (str, optional): The device of the constructed tensors. Default: None.
|
|
2087
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
2041
2088
|
|
|
2042
2089
|
**Attributes**:
|
|
2043
2090
|
- **alias_envelope_dcy** (torch.Tensor): The anti time-aliasing decaying envelope.
|
|
@@ -2072,6 +2119,7 @@ class GEQ(Filter):
|
|
|
2072
2119
|
requires_grad: bool = False,
|
|
2073
2120
|
alias_decay_db: float = 0.0,
|
|
2074
2121
|
device: Optional[str] = None,
|
|
2122
|
+
dtype: torch.dtype = torch.float32,
|
|
2075
2123
|
):
|
|
2076
2124
|
self.octave_interval = octave_interval
|
|
2077
2125
|
self.fs = fs
|
|
@@ -2080,9 +2128,9 @@ class GEQ(Filter):
|
|
|
2080
2128
|
)
|
|
2081
2129
|
self.n_gains = len(self.center_freq) + 3
|
|
2082
2130
|
gamma = 10 ** (
|
|
2083
|
-
-torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20
|
|
2131
|
+
-torch.abs(torch.tensor(alias_decay_db, device=device, dtype=dtype)) / (nfft) / 20
|
|
2084
2132
|
)
|
|
2085
|
-
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device)
|
|
2133
|
+
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device, dtype=dtype)
|
|
2086
2134
|
super().__init__(
|
|
2087
2135
|
size=(self.n_gains, *size),
|
|
2088
2136
|
nfft=nfft,
|
|
@@ -2090,6 +2138,7 @@ class GEQ(Filter):
|
|
|
2090
2138
|
requires_grad=requires_grad,
|
|
2091
2139
|
alias_decay_db=alias_decay_db,
|
|
2092
2140
|
device=device,
|
|
2141
|
+
dtype=dtype,
|
|
2093
2142
|
)
|
|
2094
2143
|
|
|
2095
2144
|
def init_param(self):
|
|
@@ -2124,14 +2173,13 @@ class GEQ(Filter):
|
|
|
2124
2173
|
device=self.device,
|
|
2125
2174
|
)
|
|
2126
2175
|
|
|
2127
|
-
b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy
|
|
2128
|
-
a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy
|
|
2176
|
+
b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, a)
|
|
2177
|
+
a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, b)
|
|
2129
2178
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
2130
2179
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
2131
2180
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
2132
2181
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
2133
|
-
|
|
2134
|
-
return H.to(H_type), B, A
|
|
2182
|
+
return H, B, A
|
|
2135
2183
|
|
|
2136
2184
|
def initialize_class(self):
|
|
2137
2185
|
self.check_param_shape()
|
|
@@ -2178,6 +2226,7 @@ class parallelGEQ(GEQ):
|
|
|
2178
2226
|
requires_grad: bool = False,
|
|
2179
2227
|
alias_decay_db: float = 0.0,
|
|
2180
2228
|
device: Optional[str] = None,
|
|
2229
|
+
dtype: torch.dtype = torch.float32,
|
|
2181
2230
|
):
|
|
2182
2231
|
super().__init__(
|
|
2183
2232
|
size=size,
|
|
@@ -2188,6 +2237,7 @@ class parallelGEQ(GEQ):
|
|
|
2188
2237
|
requires_grad=requires_grad,
|
|
2189
2238
|
alias_decay_db=alias_decay_db,
|
|
2190
2239
|
device=device,
|
|
2240
|
+
dtype=dtype,
|
|
2191
2241
|
)
|
|
2192
2242
|
|
|
2193
2243
|
def check_param_shape(self):
|
|
@@ -2210,14 +2260,13 @@ class parallelGEQ(GEQ):
|
|
|
2210
2260
|
device=self.device,
|
|
2211
2261
|
)
|
|
2212
2262
|
|
|
2213
|
-
b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy
|
|
2214
|
-
a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy
|
|
2263
|
+
b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, a)
|
|
2264
|
+
a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, b)
|
|
2215
2265
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
2216
2266
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
2217
2267
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
2218
2268
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
2219
|
-
|
|
2220
|
-
return H.to(H_type), B, A
|
|
2269
|
+
return H, B, A
|
|
2221
2270
|
|
|
2222
2271
|
def get_freq_convolve(self):
|
|
2223
2272
|
self.freq_convolve = lambda x, param: torch.einsum(
|
|
@@ -2248,6 +2297,7 @@ class PEQ(Filter):
|
|
|
2248
2297
|
requires_grad: bool = False,
|
|
2249
2298
|
alias_decay_db: float = 0.0,
|
|
2250
2299
|
device: Optional[str] = None,
|
|
2300
|
+
dtype: torch.dtype = torch.float32,
|
|
2251
2301
|
):
|
|
2252
2302
|
self.n_bands = n_bands
|
|
2253
2303
|
self.design = design
|
|
@@ -2255,11 +2305,11 @@ class PEQ(Filter):
|
|
|
2255
2305
|
self.f_min = f_min
|
|
2256
2306
|
self.f_max = f_max
|
|
2257
2307
|
gamma = 10 ** (
|
|
2258
|
-
-torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20
|
|
2308
|
+
-torch.abs(torch.tensor(alias_decay_db, device=device, dtype=dtype)) / (nfft) / 20
|
|
2259
2309
|
)
|
|
2260
|
-
k = torch.arange(1, self.n_bands + 1, dtype=
|
|
2310
|
+
k = torch.arange(1, self.n_bands + 1, dtype=dtype)
|
|
2261
2311
|
self.center_freq_bias = f_min * (f_max / f_min) ** ((k - 1) / (self.n_bands - 1))
|
|
2262
|
-
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device)
|
|
2312
|
+
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device, dtype=dtype)
|
|
2263
2313
|
super().__init__(
|
|
2264
2314
|
size=(self.n_bands, 3, *size),
|
|
2265
2315
|
nfft=nfft,
|
|
@@ -2267,6 +2317,7 @@ class PEQ(Filter):
|
|
|
2267
2317
|
requires_grad=requires_grad,
|
|
2268
2318
|
alias_decay_db=alias_decay_db,
|
|
2269
2319
|
device=device,
|
|
2320
|
+
dtype=dtype,
|
|
2270
2321
|
)
|
|
2271
2322
|
|
|
2272
2323
|
def init_param(self):
|
|
@@ -2316,14 +2367,13 @@ class PEQ(Filter):
|
|
|
2316
2367
|
G=G[1:-1],
|
|
2317
2368
|
type='peaking',
|
|
2318
2369
|
)
|
|
2319
|
-
b_aa = torch.einsum("p, opmn -> opmn", self.alias_envelope_dcy
|
|
2320
|
-
a_aa = torch.einsum("p, opmn -> opmn", self.alias_envelope_dcy
|
|
2370
|
+
b_aa = torch.einsum("p, opmn -> opmn", self.alias_envelope_dcy, b)
|
|
2371
|
+
a_aa = torch.einsum("p, opmn -> opmn", self.alias_envelope_dcy, a)
|
|
2321
2372
|
B = torch.fft.rfft(b_aa, self.nfft, dim=1)
|
|
2322
2373
|
A = torch.fft.rfft(a_aa, self.nfft, dim=1)
|
|
2323
2374
|
H_temp = torch.prod(B, dim=0) / (torch.prod(A, dim=0))
|
|
2324
2375
|
H = torch.where(torch.abs(torch.prod(A, dim=0)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
2325
|
-
|
|
2326
|
-
return H.to(H_type), B, A
|
|
2376
|
+
return H, B, A
|
|
2327
2377
|
|
|
2328
2378
|
def compute_biquad_coeff(self, f, R, G, type='peaking'):
|
|
2329
2379
|
# f : freq, R : resonance, G : gain in dB
|
|
@@ -2432,6 +2482,7 @@ class parallelPEQ(PEQ):
|
|
|
2432
2482
|
requires_grad: bool = False,
|
|
2433
2483
|
alias_decay_db: float = 0.0,
|
|
2434
2484
|
device: Optional[str] = None,
|
|
2485
|
+
dtype: torch.dtype = torch.float32,
|
|
2435
2486
|
):
|
|
2436
2487
|
super().__init__(
|
|
2437
2488
|
size=size,
|
|
@@ -2445,6 +2496,7 @@ class parallelPEQ(PEQ):
|
|
|
2445
2496
|
requires_grad=requires_grad,
|
|
2446
2497
|
alias_decay_db=alias_decay_db,
|
|
2447
2498
|
device=device,
|
|
2499
|
+
dtype=dtype,
|
|
2448
2500
|
)
|
|
2449
2501
|
|
|
2450
2502
|
def init_param(self):
|
|
@@ -2487,14 +2539,13 @@ class parallelPEQ(PEQ):
|
|
|
2487
2539
|
G=G[1:-1],
|
|
2488
2540
|
type='peaking'
|
|
2489
2541
|
)
|
|
2490
|
-
b_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy
|
|
2491
|
-
a_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy
|
|
2542
|
+
b_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy, b)
|
|
2543
|
+
a_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy, a)
|
|
2492
2544
|
B = torch.fft.rfft(b_aa, self.nfft, dim=1)
|
|
2493
2545
|
A = torch.fft.rfft(a_aa, self.nfft, dim=1)
|
|
2494
2546
|
H_temp = torch.prod(B, dim=0) / (torch.prod(A, dim=0))
|
|
2495
2547
|
H = torch.where(torch.abs(torch.prod(A, dim=0)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
2496
|
-
|
|
2497
|
-
return H.to(H_type), B, A
|
|
2548
|
+
return H, B, A
|
|
2498
2549
|
|
|
2499
2550
|
|
|
2500
2551
|
def map_eq(self, param):
|
|
@@ -2566,6 +2617,7 @@ class AccurateGEQ(Filter):
|
|
|
2566
2617
|
- start_freq (float, optional): The starting frequency for the filter bands. Default: 31.25.
|
|
2567
2618
|
- end_freq (float, optional): The ending frequency for the filter bands. Default: 16000.0.
|
|
2568
2619
|
- device (str, optional): The device of the constructed tensors. Default: None.
|
|
2620
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
2569
2621
|
|
|
2570
2622
|
**Attributes**:
|
|
2571
2623
|
- fs (int): The sampling frequency.
|
|
@@ -2598,22 +2650,24 @@ class AccurateGEQ(Filter):
|
|
|
2598
2650
|
alias_decay_db: float = 0.0,
|
|
2599
2651
|
start_freq: float = 31.25,
|
|
2600
2652
|
end_freq: float = 16000.0,
|
|
2601
|
-
device=None
|
|
2653
|
+
device=None,
|
|
2654
|
+
dtype: torch.dtype = torch.float32,
|
|
2602
2655
|
):
|
|
2603
2656
|
self.octave_interval = octave_interval
|
|
2604
2657
|
self.fs = fs
|
|
2605
2658
|
self.center_freq, self.shelving_crossover = eq_freqs(
|
|
2606
2659
|
interval=self.octave_interval, start_freq=start_freq, end_freq=end_freq)
|
|
2607
2660
|
self.n_gains = len(self.center_freq) + 2
|
|
2608
|
-
gamma = 10 ** (-torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20)
|
|
2609
|
-
self.alias_envelope_dcy = (gamma ** torch.arange(0, 3, 1, device=device))
|
|
2661
|
+
gamma = 10 ** (-torch.abs(torch.tensor(alias_decay_db, device=device, dtype=dtype)) / (nfft) / 20)
|
|
2662
|
+
self.alias_envelope_dcy = (gamma ** torch.arange(0, 3, 1, device=device, dtype=dtype))
|
|
2610
2663
|
super().__init__(
|
|
2611
2664
|
size=(self.n_gains, *size),
|
|
2612
2665
|
nfft=nfft,
|
|
2613
2666
|
map=map,
|
|
2614
2667
|
requires_grad=False,
|
|
2615
2668
|
alias_decay_db=alias_decay_db,
|
|
2616
|
-
device=device
|
|
2669
|
+
device=device,
|
|
2670
|
+
dtype=dtype,
|
|
2617
2671
|
)
|
|
2618
2672
|
|
|
2619
2673
|
def init_param(self):
|
|
@@ -2646,14 +2700,13 @@ class AccurateGEQ(Filter):
|
|
|
2646
2700
|
device=self.device
|
|
2647
2701
|
)
|
|
2648
2702
|
|
|
2649
|
-
b_aa = torch.einsum('p, pomn -> pomn', self.alias_envelope_dcy
|
|
2650
|
-
a_aa = torch.einsum('p, pomn -> pomn', self.alias_envelope_dcy
|
|
2703
|
+
b_aa = torch.einsum('p, pomn -> pomn', self.alias_envelope_dcy, a)
|
|
2704
|
+
a_aa = torch.einsum('p, pomn -> pomn', self.alias_envelope_dcy, b)
|
|
2651
2705
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
2652
2706
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
2653
2707
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
2654
2708
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
2655
|
-
|
|
2656
|
-
return H.to(H_type), B, A
|
|
2709
|
+
return H, B, A
|
|
2657
2710
|
|
|
2658
2711
|
def initialize_class(self):
|
|
2659
2712
|
self.check_param_shape()
|
|
@@ -2695,7 +2748,8 @@ class parallelAccurateGEQ(AccurateGEQ):
|
|
|
2695
2748
|
alias_decay_db: float = 0.0,
|
|
2696
2749
|
start_freq: float = 31.25,
|
|
2697
2750
|
end_freq: float = 16000.0,
|
|
2698
|
-
device=None
|
|
2751
|
+
device=None,
|
|
2752
|
+
dtype: torch.dtype = torch.float32,
|
|
2699
2753
|
):
|
|
2700
2754
|
super().__init__(
|
|
2701
2755
|
size=size,
|
|
@@ -2706,7 +2760,8 @@ class parallelAccurateGEQ(AccurateGEQ):
|
|
|
2706
2760
|
alias_decay_db=alias_decay_db,
|
|
2707
2761
|
start_freq=start_freq,
|
|
2708
2762
|
end_freq=end_freq,
|
|
2709
|
-
device=device
|
|
2763
|
+
device=device,
|
|
2764
|
+
dtype=dtype,
|
|
2710
2765
|
)
|
|
2711
2766
|
|
|
2712
2767
|
def check_param_shape(self):
|
|
@@ -2729,14 +2784,13 @@ class parallelAccurateGEQ(AccurateGEQ):
|
|
|
2729
2784
|
device=self.device
|
|
2730
2785
|
)
|
|
2731
2786
|
|
|
2732
|
-
b_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy
|
|
2733
|
-
a_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy
|
|
2787
|
+
b_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy, a)
|
|
2788
|
+
a_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy, b)
|
|
2734
2789
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
2735
2790
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
2736
2791
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
2737
2792
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
2738
|
-
|
|
2739
|
-
return H.to(H_type), B, A
|
|
2793
|
+
return H, B, A
|
|
2740
2794
|
|
|
2741
2795
|
def get_freq_convolve(self):
|
|
2742
2796
|
self.freq_convolve = lambda x, param: torch.einsum(
|
|
@@ -2788,6 +2842,7 @@ class Delay(DSP):
|
|
|
2788
2842
|
- **requires_grad** (bool, optional): Flag indicating whether the module parameters require gradients. Default: False.
|
|
2789
2843
|
- **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Defaults to 0.
|
|
2790
2844
|
- **device** (str, optional): The device of the constructed tensors. Default: None.
|
|
2845
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
2791
2846
|
|
|
2792
2847
|
**Attributes**:
|
|
2793
2848
|
- **alias_envelope_dcy** (torch.Tensor): The anti time-aliasing decaying envelope.
|
|
@@ -2815,6 +2870,7 @@ class Delay(DSP):
|
|
|
2815
2870
|
requires_grad: bool = False,
|
|
2816
2871
|
alias_decay_db: float = 0.0,
|
|
2817
2872
|
device: Optional[str] = None,
|
|
2873
|
+
dtype: torch.dtype = torch.float32,
|
|
2818
2874
|
):
|
|
2819
2875
|
self.fs = fs
|
|
2820
2876
|
self.max_len = max_len
|
|
@@ -2826,6 +2882,7 @@ class Delay(DSP):
|
|
|
2826
2882
|
requires_grad=requires_grad,
|
|
2827
2883
|
alias_decay_db=alias_decay_db,
|
|
2828
2884
|
device=device,
|
|
2885
|
+
dtype=dtype,
|
|
2829
2886
|
)
|
|
2830
2887
|
self.initialize_class()
|
|
2831
2888
|
|
|
@@ -2949,7 +3006,7 @@ class Delay(DSP):
|
|
|
2949
3006
|
self.omega = (
|
|
2950
3007
|
2
|
|
2951
3008
|
* torch.pi
|
|
2952
|
-
* torch.arange(0, self.nfft // 2 + 1, device=self.device)
|
|
3009
|
+
* torch.arange(0, self.nfft // 2 + 1, device=self.device, dtype=self.dtype)
|
|
2953
3010
|
/ self.nfft
|
|
2954
3011
|
).unsqueeze(1)
|
|
2955
3012
|
self.get_freq_response()
|
|
@@ -2989,6 +3046,7 @@ class parallelDelay(Delay):
|
|
|
2989
3046
|
requires_grad: bool = False,
|
|
2990
3047
|
alias_decay_db: float = 0.0,
|
|
2991
3048
|
device: Optional[str] = None,
|
|
3049
|
+
dtype: torch.dtype = torch.float32,
|
|
2992
3050
|
):
|
|
2993
3051
|
super().__init__(
|
|
2994
3052
|
size=size,
|
|
@@ -3000,6 +3058,7 @@ class parallelDelay(Delay):
|
|
|
3000
3058
|
requires_grad=requires_grad,
|
|
3001
3059
|
alias_decay_db=alias_decay_db,
|
|
3002
3060
|
device=device,
|
|
3061
|
+
dtype=dtype,
|
|
3003
3062
|
)
|
|
3004
3063
|
|
|
3005
3064
|
def check_param_shape(self):
|