flamo 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flamo/processor/dsp.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import torch
2
2
  import math
3
+ import abc
4
+ import warnings
3
5
  from typing import Optional
4
6
  import torch.nn as nn
5
7
  import torch.nn.functional as F
@@ -41,11 +43,11 @@ class Transform(nn.Module):
41
43
  tensor([1, 4, 9])
42
44
  """
43
45
 
44
- def __init__(self, transform: callable = lambda x: x, device: Optional[str] = None):
45
-
46
+ def __init__(self, transform: callable = lambda x: x, device: Optional[str] = None, dtype: torch.dtype = torch.float32):
46
47
  super().__init__()
47
48
  self.transform = transform
48
49
  self.device = device
50
+ self.dtype = dtype
49
51
 
50
52
  def forward(self, x: torch.Tensor):
51
53
  r"""
@@ -75,12 +77,12 @@ class FFT(Transform):
75
77
  For details on the real FFT function, see `torch.fft.rfft documentation <https://pytorch.org/docs/stable/generated/torch.fft.rfft.html>`_.
76
78
  """
77
79
 
78
- def __init__(self, nfft: int = 2**11, norm: str = "backward"):
79
-
80
+ def __init__(self, nfft: int = 2**11, norm: str = "backward", dtype: torch.dtype = torch.float32):
80
81
  self.nfft = nfft
81
82
  self.norm = norm
83
+ self.dtype = dtype
82
84
  transform = lambda x: torch.fft.rfft(x, n=self.nfft, dim=1, norm=self.norm)
83
- super().__init__(transform=transform)
85
+ super().__init__(transform=transform, dtype=self.dtype)
84
86
 
85
87
 
86
88
  class iFFT(Transform):
@@ -97,12 +99,12 @@ class iFFT(Transform):
97
99
  For details on the inverse real FFT function, see `torch.fft.irfft documentation <https://pytorch.org/docs/stable/generated/torch.fft.irfft.html>`_.
98
100
  """
99
101
 
100
- def __init__(self, nfft: int = 2**11, norm: str = "backward"):
101
-
102
+ def __init__(self, nfft: int = 2**11, norm: str = "backward", dtype: torch.dtype = torch.float32):
102
103
  self.nfft = nfft
103
104
  self.norm = norm
105
+ self.dtype = dtype
104
106
  transform = lambda x: torch.fft.irfft(x, n=self.nfft, dim=1, norm=self.norm)
105
- super().__init__(transform=transform)
107
+ super().__init__(transform=transform, dtype=self.dtype)
106
108
 
107
109
 
108
110
  class FFTAntiAlias(Transform):
@@ -119,7 +121,7 @@ class FFTAntiAlias(Transform):
119
121
  - **norm** (str): The normalization mode for the FFT.
120
122
  - **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. Default: 0.0.
121
123
  - **device** (str): The device of the constructed tensors. Default: None.
122
-
124
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
123
125
 
124
126
  For details on the FFT function, see `torch.fft.rfft documentation <https://pytorch.org/docs/stable/generated/torch.fft.rfft.html>`_.
125
127
  """
@@ -130,24 +132,23 @@ class FFTAntiAlias(Transform):
130
132
  norm: str = "backward",
131
133
  alias_decay_db: float = 0.0,
132
134
  device: Optional[str] = None,
135
+ dtype: torch.dtype = torch.float32,
133
136
  ):
134
-
135
137
  self.nfft = nfft
136
138
  self.norm = norm
137
139
  self.device = device
138
-
140
+ self.dtype = dtype
139
141
  gamma = 10 ** (
140
- -torch.abs(torch.tensor(alias_decay_db, device=self.device))
142
+ -torch.abs(torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype))
141
143
  / (self.nfft)
142
144
  / 20
143
145
  )
144
146
  self.alias_envelope = gamma ** torch.arange(
145
- 0, -self.nfft, -1, device=self.device
147
+ 0, -self.nfft, -1, device=self.device, dtype=self.dtype
146
148
  )
147
-
148
149
  fft = lambda x: torch.fft.rfft(x, n=self.nfft, dim=1, norm=self.norm)
149
150
  transform = lambda x: fft(torch.einsum("btm, t->btm", x, self.alias_envelope))
150
- super().__init__(transform=transform)
151
+ super().__init__(transform=transform, dtype=self.dtype)
151
152
 
152
153
 
153
154
  class iFFTAntiAlias(Transform):
@@ -163,6 +164,7 @@ class iFFTAntiAlias(Transform):
163
164
  - **norm** (str): The normalization mode. Default: "backward".
164
165
  - **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. Default: 0.0.
165
166
  - **device** (str): The device of the constructed tensors. Default: None.
167
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
166
168
 
167
169
  For details on the inverse FFT function, see `torch.fft.irfft documentation <https://pytorch.org/docs/stable/generated/torch.fft.irfft.html>`_.
168
170
  """
@@ -173,24 +175,23 @@ class iFFTAntiAlias(Transform):
173
175
  norm: str = "backward",
174
176
  alias_decay_db: float = 0.0,
175
177
  device: Optional[str] = None,
178
+ dtype: torch.dtype = torch.float32,
176
179
  ):
177
-
178
180
  self.nfft = nfft
179
181
  self.norm = norm
180
182
  self.device = device
181
-
183
+ self.dtype = dtype
182
184
  gamma = 10 ** (
183
- -torch.abs(torch.tensor(alias_decay_db, device=self.device))
185
+ -torch.abs(torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype))
184
186
  / (self.nfft)
185
187
  / 20
186
188
  )
187
189
  self.alias_envelope = gamma ** torch.arange(
188
- 0, -self.nfft, -1, device=self.device
190
+ 0, -self.nfft, -1, device=self.device, dtype=self.dtype
189
191
  )
190
-
191
192
  ifft = lambda x: torch.fft.irfft(x, n=self.nfft, dim=1, norm=self.norm)
192
193
  transform = lambda x: torch.einsum("btm, t->btm", ifft(x), self.alias_envelope)
193
- super().__init__(transform=transform)
194
+ super().__init__(transform=transform, dtype=self.dtype)
194
195
 
195
196
 
196
197
  # ============================= CORE ================================
@@ -216,6 +217,7 @@ class DSP(nn.Module):
216
217
  - **requires_grad** (bool, optional): Whether the parameters require gradients. Default: False.
217
218
  - **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
218
219
  - **device** (str): The device of the constructed tensor, if any. Default: None.
220
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
219
221
 
220
222
  **Attributes**:
221
223
  - **param** (nn.Parameter): The parameters of the DSP module.
@@ -234,6 +236,7 @@ class DSP(nn.Module):
234
236
  requires_grad: bool = False,
235
237
  alias_decay_db: float = 0.0,
236
238
  device: Optional[str] = None,
239
+ dtype: torch.dtype = torch.float32,
237
240
  ):
238
241
 
239
242
  super().__init__()
@@ -244,25 +247,31 @@ class DSP(nn.Module):
244
247
  self.new_value = 0 # flag indicating if new values have been assigned
245
248
  self.requires_grad = requires_grad
246
249
  self.device = device
250
+ self.dtype = dtype
247
251
  self.param = nn.Parameter(
248
- torch.empty(self.size, device=self.device), requires_grad=self.requires_grad
252
+ torch.empty(self.size, device=self.device, dtype=self.dtype), requires_grad=self.requires_grad
249
253
  )
250
254
  self.fft = lambda x: torch.fft.rfft(x, n=self.nfft, dim=0)
251
255
  self.ifft = lambda x: torch.fft.irfft(x, n=self.nfft, dim=0)
252
256
  # initialize time anti-aliasing envelope function
253
- self.alias_decay_db = torch.tensor(alias_decay_db, device=self.device)
257
+ self.alias_decay_db = torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype)
254
258
  self.init_param()
255
259
  self.get_gamma()
256
260
 
257
- def forward(self, x, **kwArguments):
261
+ @abc.abstractmethod
262
+ def forward(self, x, **kwArguments):
258
263
  r"""
259
264
  Forward method.
260
265
 
261
266
  Input is returned. Forward method is to be implemented by the child class.
262
267
 
263
268
  """
264
- Warning("Forward method not implemented. Input is returned")
269
+ warnings.warn(
270
+ "Forward method not implemented. Input is returned.",
271
+ UserWarning
272
+ )
265
273
  return x
274
+
266
275
 
267
276
  def init_param(self):
268
277
  r"""
@@ -334,6 +343,7 @@ class Gain(DSP):
334
343
  - **requires_grad** (bool): Whether the parameters requires gradients. Default: False.
335
344
  - **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
336
345
  - **device** (str): The device of the constructed tensors. Default: None.
346
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
337
347
 
338
348
  **Attributes**:
339
349
  - **param** (nn.Parameter): The parameters of the Gain module.
@@ -354,6 +364,7 @@ class Gain(DSP):
354
364
  requires_grad: bool = False,
355
365
  alias_decay_db: float = 0.0,
356
366
  device: Optional[str] = None,
367
+ dtype: torch.dtype = torch.float32,
357
368
  ):
358
369
  super().__init__(
359
370
  size=size,
@@ -362,6 +373,7 @@ class Gain(DSP):
362
373
  requires_grad=requires_grad,
363
374
  alias_decay_db=alias_decay_db,
364
375
  device=device,
376
+ dtype=dtype,
365
377
  )
366
378
  self.initialize_class()
367
379
 
@@ -461,6 +473,7 @@ class parallelGain(Gain):
461
473
  requires_grad: bool = False,
462
474
  alias_decay_db: float = 0.0,
463
475
  device: Optional[str] = None,
476
+ dtype: torch.dtype = torch.float32,
464
477
  ):
465
478
  super().__init__(
466
479
  size=size,
@@ -469,6 +482,7 @@ class parallelGain(Gain):
469
482
  requires_grad=requires_grad,
470
483
  alias_decay_db=alias_decay_db,
471
484
  device=device,
485
+ dtype=dtype,
472
486
  )
473
487
 
474
488
  def check_param_shape(self):
@@ -520,6 +534,7 @@ class Matrix(Gain):
520
534
  - **requires_grad** (bool, optional): Whether the matrix requires gradient computation. Default: False.
521
535
  - **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
522
536
  - **device** (str, optional): The device of the constructed tensors. Default: None.
537
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
523
538
 
524
539
  **Attributes**:
525
540
  - **param** (nn.Parameter): The parameters of the Matrix module.
@@ -543,6 +558,7 @@ class Matrix(Gain):
543
558
  requires_grad: bool = False,
544
559
  alias_decay_db: float = 0.0,
545
560
  device: Optional[str] = None,
561
+ dtype: torch.dtype = torch.float32,
546
562
  ):
547
563
  self.matrix_type = matrix_type
548
564
  self.iter = iter # iterations number for the rotation matrix
@@ -553,6 +569,7 @@ class Matrix(Gain):
553
569
  requires_grad=requires_grad,
554
570
  alias_decay_db=alias_decay_db,
555
571
  device=device,
572
+ dtype=dtype,
556
573
  )
557
574
 
558
575
  def matrix_gallery(self):
@@ -579,7 +596,7 @@ class Matrix(Gain):
579
596
  assert (
580
597
  N % 2 == 0
581
598
  ), "Matrix must have even dimensions to be Hadamard"
582
- self.map = lambda x: HadamardMatrix(self.size[0], device=self.device)(x)
599
+ self.map = lambda x: HadamardMatrix(self.size[0], device=self.device, dtype=self.dtype)(x)
583
600
  case "rotation":
584
601
  assert (
585
602
  N == self.size[1]
@@ -587,7 +604,7 @@ class Matrix(Gain):
587
604
  assert (
588
605
  N % 2 == 0
589
606
  ), "Matrix must have even dimensions to be a rotation matrix"
590
- self.map = lambda x: RotationMatrix(self.size[0], self.iter, device=self.device)([x[0][0]])
607
+ self.map = lambda x: RotationMatrix(self.size[0], self.iter, device=self.device, dtype=self.dtype)([x[0][0]])
591
608
 
592
609
  def initialize_class(self):
593
610
  r"""
@@ -619,6 +636,7 @@ class HouseholderMatrix(Gain):
619
636
  - **requires_grad** (bool, optional): If True, gradients will be computed for the parameters. Defaults to False.
620
637
  - **alias_decay_db** (float, optional): Alias decay in decibels. Defaults to 0.0.
621
638
  - **device** (optional): Device on which to perform computations. Defaults to None.
639
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
622
640
 
623
641
  **Attributes**:
624
642
  - **param** (nn.Parameter): The parameters `u`` used to construct the Householder matrix.
@@ -640,6 +658,7 @@ class HouseholderMatrix(Gain):
640
658
  requires_grad: bool = False,
641
659
  alias_decay_db: float = 0.0,
642
660
  device: Optional[str] = None,
661
+ dtype: torch.dtype = torch.float32,
643
662
  ):
644
663
  assert size[0] == size[1], "Matrix must be square"
645
664
  size = (size[0], 1)
@@ -651,6 +670,7 @@ class HouseholderMatrix(Gain):
651
670
  requires_grad=requires_grad,
652
671
  alias_decay_db=alias_decay_db,
653
672
  device=device,
673
+ dtype=dtype,
654
674
  )
655
675
 
656
676
  def forward(self, x, ext_param=None):
@@ -735,6 +755,7 @@ class Filter(DSP):
735
755
  - **requires_grad** (bool): Whether the filter parameters require gradients. Default: False.
736
756
  - **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
737
757
  - **device** (str): The device of the constructed tensors. Default: None.
758
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
738
759
 
739
760
  **Attributes**:
740
761
  - **param** (nn.Parameter): The parameters of the Filter module.
@@ -758,6 +779,7 @@ class Filter(DSP):
758
779
  requires_grad: bool = False,
759
780
  alias_decay_db: float = 0.0,
760
781
  device: Optional[str] = None,
782
+ dtype: torch.dtype = torch.float32,
761
783
  ):
762
784
  super().__init__(
763
785
  size=size,
@@ -766,6 +788,7 @@ class Filter(DSP):
766
788
  requires_grad=requires_grad,
767
789
  alias_decay_db=alias_decay_db,
768
790
  device=device,
791
+ dtype=dtype,
769
792
  )
770
793
  self.initialize_class()
771
794
 
@@ -886,6 +909,7 @@ class parallelFilter(Filter):
886
909
  requires_grad: bool = False,
887
910
  alias_decay_db: float = 0.0,
888
911
  device: Optional[str] = None,
912
+ dtype: torch.dtype = torch.float32,
889
913
  ):
890
914
  super().__init__(
891
915
  size=size,
@@ -894,6 +918,7 @@ class parallelFilter(Filter):
894
918
  requires_grad=requires_grad,
895
919
  alias_decay_db=alias_decay_db,
896
920
  device=device,
921
+ dtype=dtype,
897
922
  )
898
923
 
899
924
  def check_param_shape(self):
@@ -973,7 +998,8 @@ class ScatteringMatrix(Filter):
973
998
  - **requires_grad** (bool): Whether the filter parameters require gradients. Default: False.
974
999
  - **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
975
1000
  - **device** (str): The device of the constructed tensors. Default: None.
976
-
1001
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
1002
+
977
1003
  **Attributes**:
978
1004
  - **param** (nn.Parameter): The parameters of the Filter module.
979
1005
  - **map** (function): Mapping function to ensure orthogonality of :math:`\mathbf{U}_k`.
@@ -1002,6 +1028,7 @@ class ScatteringMatrix(Filter):
1002
1028
  requires_grad: bool = False,
1003
1029
  alias_decay_db: float = 0.0,
1004
1030
  device: Optional[str] = None,
1031
+ dtype: torch.dtype = torch.float32,
1005
1032
  ):
1006
1033
  self.sparsity = sparsity
1007
1034
  self.gain_per_sample = gain_per_sample
@@ -1017,6 +1044,7 @@ class ScatteringMatrix(Filter):
1017
1044
  requires_grad=requires_grad,
1018
1045
  alias_decay_db=alias_decay_db,
1019
1046
  device=device,
1047
+ dtype=dtype,
1020
1048
  )
1021
1049
 
1022
1050
  def get_freq_convolve(self):
@@ -1071,6 +1099,7 @@ class ScatteringMatrix(Filter):
1071
1099
  m_L=self.m_L,
1072
1100
  m_R=self.m_R,
1073
1101
  device=self.device,
1102
+ dtype=self.dtype,
1074
1103
  )
1075
1104
  self.check_param_shape()
1076
1105
  self.get_io()
@@ -1110,6 +1139,7 @@ class VelvetNoiseMatrix(Filter):
1110
1139
  - **requires_grad** (bool): Whether the filter parameters require gradients. Default: False.
1111
1140
  - **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
1112
1141
  - **device** (str): The device of the constructed tensors. Default: None.
1142
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
1113
1143
 
1114
1144
  **Attributes**:
1115
1145
  - **param** (nn.Parameter): The parameters of the Filter module.
@@ -1139,6 +1169,7 @@ class VelvetNoiseMatrix(Filter):
1139
1169
  m_R: torch.tensor = None,
1140
1170
  alias_decay_db: float = 0.0,
1141
1171
  device: Optional[str] = None,
1172
+ dtype: torch.dtype = torch.float32,
1142
1173
  ):
1143
1174
  self.sparsity = 1/density
1144
1175
  self.gain_per_sample = gain_per_sample
@@ -1155,8 +1186,9 @@ class VelvetNoiseMatrix(Filter):
1155
1186
  requires_grad=False,
1156
1187
  alias_decay_db=alias_decay_db,
1157
1188
  device=device,
1189
+ dtype=dtype,
1158
1190
  )
1159
- self.assign_value(torch.tensor(hadamard_matrix(self.size[-1]), device=self.device).unsqueeze(0).repeat(self.size[0], 1, 1))
1191
+ self.assign_value(torch.tensor(hadamard_matrix(self.size[-1]), device=self.device, dtype=self.dtype).unsqueeze(0).repeat(self.size[0], 1, 1))
1160
1192
 
1161
1193
  def get_freq_convolve(self):
1162
1194
  r"""
@@ -1210,6 +1242,7 @@ class VelvetNoiseMatrix(Filter):
1210
1242
  m_L=self.m_L,
1211
1243
  m_R=self.m_R,
1212
1244
  device=self.device,
1245
+ dtype=self.dtype,
1213
1246
  )
1214
1247
  self.check_param_shape()
1215
1248
  self.get_io()
@@ -1253,6 +1286,7 @@ class Biquad(Filter):
1253
1286
  - **requires_grad** (bool, optional): Whether the filter parameters require gradient computation. Default: True.
1254
1287
  - **alias_decay_db** (float): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
1255
1288
  - **device** (str, optional): The device of the constructed tensors. Default: None.
1289
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
1256
1290
 
1257
1291
  **Attributes**:
1258
1292
  - **param** (nn.Parameter): The parameters of the Filter module.
@@ -1279,17 +1313,19 @@ class Biquad(Filter):
1279
1313
  requires_grad: bool = False,
1280
1314
  alias_decay_db: float = 0.0,
1281
1315
  device: Optional[str] = None,
1316
+ dtype: torch.dtype = torch.float32,
1282
1317
  ):
1283
1318
  assert filter_type in ["lowpass", "highpass", "bandpass"], "Invalid filter type"
1284
1319
  self.n_sections = n_sections
1285
1320
  self.filter_type = filter_type
1286
1321
  self.fs = fs
1287
1322
  self.device = device
1323
+ self.dtype = dtype
1288
1324
  self.get_map()
1289
1325
  gamma = 10 ** (
1290
- -torch.abs(torch.tensor(alias_decay_db, device=self.device)) / (nfft) / 20
1326
+ -torch.abs(torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype)) / (nfft) / 20
1291
1327
  )
1292
- self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=self.device)
1328
+ self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=self.device, dtype=self.dtype)
1293
1329
  super().__init__(
1294
1330
  size=(n_sections, *self.get_size(), *size),
1295
1331
  nfft=nfft,
@@ -1297,6 +1333,7 @@ class Biquad(Filter):
1297
1333
  requires_grad=requires_grad,
1298
1334
  alias_decay_db=alias_decay_db,
1299
1335
  device=device,
1336
+ dtype=dtype,
1300
1337
  )
1301
1338
 
1302
1339
  def get_size(self):
@@ -1361,6 +1398,7 @@ class Biquad(Filter):
1361
1398
  gain=param[:, 1, :, :],
1362
1399
  fs=self.fs,
1363
1400
  device=self.device,
1401
+ dtype=self.dtype,
1364
1402
  )
1365
1403
  case "highpass":
1366
1404
  b, a = highpass_filter(
@@ -1368,6 +1406,7 @@ class Biquad(Filter):
1368
1406
  gain=param[:, 1, :, :],
1369
1407
  fs=self.fs,
1370
1408
  device=self.device,
1409
+ dtype=self.dtype,
1371
1410
  )
1372
1411
  case "bandpass":
1373
1412
  b, a = bandpass_filter(
@@ -1376,15 +1415,15 @@ class Biquad(Filter):
1376
1415
  gain=param[:, 2, :, :],
1377
1416
  fs=self.fs,
1378
1417
  device=self.device,
1418
+ dtype=self.dtype,
1379
1419
  )
1380
- b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
1381
- a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
1420
+ b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, b)
1421
+ a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, a)
1382
1422
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
1383
1423
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
1384
1424
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
1385
1425
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
1386
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
1387
- return H.to(H_type), B, A
1426
+ return H, B, A
1388
1427
 
1389
1428
  def get_map(self):
1390
1429
  r"""
@@ -1398,10 +1437,10 @@ class Biquad(Filter):
1398
1437
  (x[:, 0, :, :], 20 * torch.log10(torch.abs(x[:, 1, :, :]))),
1399
1438
  dim=1,
1400
1439
  ),
1401
- min=torch.tensor([0, -60], device=self.device)
1440
+ min=torch.tensor([0, -60], device=self.device, dtype=self.dtype)
1402
1441
  .view(-1, 1, 1)
1403
1442
  .expand_as(x),
1404
- max=torch.tensor([1, 60], device=self.device)
1443
+ max=torch.tensor([1, 60], device=self.device, dtype=self.dtype)
1405
1444
  .view(-1, 1, 1)
1406
1445
  .expand_as(x),
1407
1446
  )
@@ -1415,10 +1454,10 @@ class Biquad(Filter):
1415
1454
  ),
1416
1455
  dim=1,
1417
1456
  ),
1418
- min=torch.tensor([0 + torch.finfo(torch.float).eps, 0 + torch.finfo(torch.float).eps, -60], device=self.device)
1457
+ min=torch.tensor([0 + torch.finfo(self.dtype).eps, 0 + torch.finfo(self.dtype).eps, -60], device=self.device, dtype=self.dtype)
1419
1458
  .view(-1, 1, 1)
1420
1459
  .expand_as(x),
1421
- max=torch.tensor([1 - torch.finfo(torch.float).eps, 1 - torch.finfo(torch.float).eps, 60], device=self.device)
1460
+ max=torch.tensor([1 - torch.finfo(self.dtype).eps, 1 - torch.finfo(self.dtype).eps, 60], device=self.device, dtype=self.dtype)
1422
1461
  .view(-1, 1, 1)
1423
1462
  .expand_as(x),
1424
1463
  )
@@ -1491,6 +1530,7 @@ class parallelBiquad(Biquad):
1491
1530
  requires_grad: bool = False,
1492
1531
  alias_decay_db: float = 0.0,
1493
1532
  device: Optional[str] = None,
1533
+ dtype: torch.dtype = torch.float32,
1494
1534
  ):
1495
1535
  super().__init__(
1496
1536
  size=size,
@@ -1501,6 +1541,7 @@ class parallelBiquad(Biquad):
1501
1541
  requires_grad=requires_grad,
1502
1542
  alias_decay_db=alias_decay_db,
1503
1543
  device=device,
1544
+ dtype=dtype,
1504
1545
  )
1505
1546
 
1506
1547
  def check_param_shape(self):
@@ -1519,10 +1560,10 @@ class parallelBiquad(Biquad):
1519
1560
  torch.stack(
1520
1561
  (x[:, 0, :], 20 * torch.log10(torch.abs(x[:, -1, :]))), dim=1
1521
1562
  ),
1522
- min=torch.tensor([0, -60], device=self.device)
1563
+ min=torch.tensor([0, -60], device=self.device, dtype=self.dtype)
1523
1564
  .view(-1, 1)
1524
1565
  .expand_as(x),
1525
- max=torch.tensor([1, 60], device=self.device)
1566
+ max=torch.tensor([1, 60], device=self.device, dtype=self.dtype)
1526
1567
  .view(-1, 1)
1527
1568
  .expand_as(x),
1528
1569
  )
@@ -1536,10 +1577,10 @@ class parallelBiquad(Biquad):
1536
1577
  ),
1537
1578
  dim=1,
1538
1579
  ),
1539
- min=torch.tensor([0 + torch.finfo(torch.float).eps, 0 + torch.finfo(torch.float).eps, -60], device=self.device)
1580
+ min=torch.tensor([0 + torch.finfo(self.dtype).eps, 0 + torch.finfo(self.dtype).eps, -60], device=self.device, dtype=self.dtype)
1540
1581
  .view(-1, 1)
1541
1582
  .expand_as(x),
1542
- max=torch.tensor([1 - torch.finfo(torch.float).eps, 1 - torch.finfo(torch.float).eps, 60], device=self.device)
1583
+ max=torch.tensor([1 - torch.finfo(self.dtype).eps, 1 - torch.finfo(self.dtype).eps, 60], device=self.device, dtype=self.dtype)
1543
1584
  .view(-1, 1)
1544
1585
  .expand_as(x),
1545
1586
  )
@@ -1582,6 +1623,7 @@ class parallelBiquad(Biquad):
1582
1623
  gain=param[:, 1, :],
1583
1624
  fs=self.fs,
1584
1625
  device=self.device,
1626
+ dtype=self.dtype,
1585
1627
  )
1586
1628
  case "highpass":
1587
1629
  b, a = highpass_filter(
@@ -1589,6 +1631,7 @@ class parallelBiquad(Biquad):
1589
1631
  gain=param[:, 1, :],
1590
1632
  fs=self.fs,
1591
1633
  device=self.device,
1634
+ dtype=self.dtype,
1592
1635
  )
1593
1636
  case "bandpass":
1594
1637
  b, a = bandpass_filter(
@@ -1597,15 +1640,15 @@ class parallelBiquad(Biquad):
1597
1640
  gain=param[:, 2, :],
1598
1641
  fs=self.fs,
1599
1642
  device=self.device,
1643
+ dtype=self.dtype,
1600
1644
  )
1601
- b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
1602
- a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
1645
+ b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, b)
1646
+ a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, a)
1603
1647
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
1604
1648
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
1605
1649
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
1606
1650
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
1607
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
1608
- return H.to(H_type), B, A
1651
+ return H, B, A
1609
1652
 
1610
1653
  def get_freq_convolve(self):
1611
1654
  self.freq_convolve = lambda x, param: torch.einsum(
@@ -1673,7 +1716,8 @@ class SVF(Filter):
1673
1716
  - **requires_grad** (bool, optional): Whether the filter parameters require gradients. Default: False.
1674
1717
  - **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
1675
1718
  - **device** (str, optional): The device of the constructed tensors. Default: None.
1676
-
1719
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
1720
+
1677
1721
  **Attributes**:
1678
1722
  - **alias_envelope_dcy** (torch.Tensor): The anti time-aliasing decaying envelope.
1679
1723
  - **fft** (function): The FFT function. Calls the torch.fft.rfft function.
@@ -1700,6 +1744,7 @@ class SVF(Filter):
1700
1744
  requires_grad: bool = False,
1701
1745
  alias_decay_db: float = 0.0,
1702
1746
  device: Optional[str] = None,
1747
+ dtype: torch.dtype = torch.float32,
1703
1748
  ):
1704
1749
  self.fs = fs
1705
1750
  self.n_sections = n_sections
@@ -1715,9 +1760,9 @@ class SVF(Filter):
1715
1760
  ], "Invalid filter type"
1716
1761
  self.filter_type = filter_type
1717
1762
  gamma = 10 ** (
1718
- -torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20
1763
+ -torch.abs(torch.tensor(alias_decay_db, device=device, dtype=dtype)) / (nfft) / 20
1719
1764
  )
1720
- self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device)
1765
+ self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device, dtype=dtype)
1721
1766
  super().__init__(
1722
1767
  size=(5, self.n_sections, *size),
1723
1768
  nfft=nfft,
@@ -1725,6 +1770,7 @@ class SVF(Filter):
1725
1770
  requires_grad=requires_grad,
1726
1771
  alias_decay_db=alias_decay_db,
1727
1772
  device=device,
1773
+ dtype=dtype,
1728
1774
  )
1729
1775
 
1730
1776
  def check_param_shape(self):
@@ -1767,14 +1813,13 @@ class SVF(Filter):
1767
1813
  a[2] = (f**2) - 2 * R * f + 1
1768
1814
 
1769
1815
  # apply anti-aliasing
1770
- b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
1771
- a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
1816
+ b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, b)
1817
+ a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, a)
1772
1818
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
1773
1819
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
1774
1820
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
1775
1821
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
1776
- H_type = torch.complex128 if f.dtype == torch.float64 else torch.complex64
1777
- return H.to(H_type), B, A
1822
+ return H, B, A
1778
1823
 
1779
1824
  def param2freq(self, param):
1780
1825
  r"""
@@ -1801,8 +1846,8 @@ class SVF(Filter):
1801
1846
 
1802
1847
  """
1803
1848
  return torch.div(
1804
- torch.log(torch.ones(1, device=self.device) + torch.exp(param)),
1805
- torch.log(torch.tensor(2, device=self.device)),
1849
+ torch.log(torch.ones(1, device=self.device, dtype=self.dtype) + torch.exp(param)),
1850
+ torch.log(torch.tensor(2, device=self.device, dtype=self.dtype)),
1806
1851
  )
1807
1852
 
1808
1853
  def param2mix(self, param, R=None):
@@ -1887,8 +1932,8 @@ class SVF(Filter):
1887
1932
  )
1888
1933
  case None:
1889
1934
  # general SVF filter
1890
- bias = torch.ones((param.shape), device=self.device)
1891
- bias[1] = 2 * torch.ones((param.shape[1:]), device=self.device)
1935
+ bias = torch.ones((param.shape), device=self.device, dtype=self.dtype)
1936
+ bias[1] = 2 * torch.ones((param.shape[1:]), device=self.device, dtype=self.dtype)
1892
1937
  return param + bias
1893
1938
 
1894
1939
  def map_param2svf(self, param):
@@ -1899,7 +1944,7 @@ class SVF(Filter):
1899
1944
  r = self.param2R(param[1])
1900
1945
  if self.filter_type == "lowshelf" or self.filter_type == "highshelf":
1901
1946
  # R = r + torch.sqrt(torch.tensor(2))
1902
- R = torch.tensor(1, device=self.device)
1947
+ R = torch.tensor(1, device=self.device, dtype=self.dtype)
1903
1948
  if self.filter_type == "peaking":
1904
1949
  R = 1 / r # temporary fix for peaking filter
1905
1950
  m = self.param2mix(param[2:], r)
@@ -1945,6 +1990,7 @@ class parallelSVF(SVF):
1945
1990
  requires_grad: bool = False,
1946
1991
  alias_decay_db: float = 0.0,
1947
1992
  device: Optional[str] = None,
1993
+ dtype: torch.dtype = torch.float32,
1948
1994
  ):
1949
1995
  super().__init__(
1950
1996
  size=size,
@@ -1955,6 +2001,7 @@ class parallelSVF(SVF):
1955
2001
  requires_grad=requires_grad,
1956
2002
  alias_decay_db=alias_decay_db,
1957
2003
  device=device,
2004
+ dtype=dtype,
1958
2005
  )
1959
2006
 
1960
2007
  def check_param_shape(self):
@@ -1985,14 +2032,13 @@ class parallelSVF(SVF):
1985
2032
  a[2] = (f**2) - 2 * R * f + 1
1986
2033
 
1987
2034
  # apply anti-aliasing
1988
- b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
1989
- a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
2035
+ b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, b)
2036
+ a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, a)
1990
2037
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
1991
2038
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
1992
2039
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
1993
2040
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
1994
- H_type = torch.complex128 if f.dtype == torch.float64 else torch.complex64
1995
- return H.to(H_type), B, A
2041
+ return H, B, A
1996
2042
 
1997
2043
  def get_freq_convolve(self):
1998
2044
  self.freq_convolve = lambda x, param: torch.einsum(
@@ -2038,6 +2084,7 @@ class GEQ(Filter):
2038
2084
  - **requires_grad** (bool, optional): Whether the filter parameters require gradients. Default: False.
2039
2085
  - **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Default: 0.
2040
2086
  - **device** (str, optional): The device of the constructed tensors. Default: None.
2087
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
2041
2088
 
2042
2089
  **Attributes**:
2043
2090
  - **alias_envelope_dcy** (torch.Tensor): The anti time-aliasing decaying envelope.
@@ -2072,6 +2119,7 @@ class GEQ(Filter):
2072
2119
  requires_grad: bool = False,
2073
2120
  alias_decay_db: float = 0.0,
2074
2121
  device: Optional[str] = None,
2122
+ dtype: torch.dtype = torch.float32,
2075
2123
  ):
2076
2124
  self.octave_interval = octave_interval
2077
2125
  self.fs = fs
@@ -2080,9 +2128,9 @@ class GEQ(Filter):
2080
2128
  )
2081
2129
  self.n_gains = len(self.center_freq) + 3
2082
2130
  gamma = 10 ** (
2083
- -torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20
2131
+ -torch.abs(torch.tensor(alias_decay_db, device=device, dtype=dtype)) / (nfft) / 20
2084
2132
  )
2085
- self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device)
2133
+ self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device, dtype=dtype)
2086
2134
  super().__init__(
2087
2135
  size=(self.n_gains, *size),
2088
2136
  nfft=nfft,
@@ -2090,6 +2138,7 @@ class GEQ(Filter):
2090
2138
  requires_grad=requires_grad,
2091
2139
  alias_decay_db=alias_decay_db,
2092
2140
  device=device,
2141
+ dtype=dtype,
2093
2142
  )
2094
2143
 
2095
2144
  def init_param(self):
@@ -2124,14 +2173,13 @@ class GEQ(Filter):
2124
2173
  device=self.device,
2125
2174
  )
2126
2175
 
2127
- b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
2128
- a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
2176
+ b_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, a)
2177
+ a_aa = torch.einsum("p, pomn -> pomn", self.alias_envelope_dcy, b)
2129
2178
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
2130
2179
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
2131
2180
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
2132
2181
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
2133
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
2134
- return H.to(H_type), B, A
2182
+ return H, B, A
2135
2183
 
2136
2184
  def initialize_class(self):
2137
2185
  self.check_param_shape()
@@ -2178,6 +2226,7 @@ class parallelGEQ(GEQ):
2178
2226
  requires_grad: bool = False,
2179
2227
  alias_decay_db: float = 0.0,
2180
2228
  device: Optional[str] = None,
2229
+ dtype: torch.dtype = torch.float32,
2181
2230
  ):
2182
2231
  super().__init__(
2183
2232
  size=size,
@@ -2188,6 +2237,7 @@ class parallelGEQ(GEQ):
2188
2237
  requires_grad=requires_grad,
2189
2238
  alias_decay_db=alias_decay_db,
2190
2239
  device=device,
2240
+ dtype=dtype,
2191
2241
  )
2192
2242
 
2193
2243
  def check_param_shape(self):
@@ -2210,14 +2260,13 @@ class parallelGEQ(GEQ):
2210
2260
  device=self.device,
2211
2261
  )
2212
2262
 
2213
- b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
2214
- a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
2263
+ b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, a)
2264
+ a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, b)
2215
2265
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
2216
2266
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
2217
2267
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
2218
2268
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
2219
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
2220
- return H.to(H_type), B, A
2269
+ return H, B, A
2221
2270
 
2222
2271
  def get_freq_convolve(self):
2223
2272
  self.freq_convolve = lambda x, param: torch.einsum(
@@ -2248,6 +2297,7 @@ class PEQ(Filter):
2248
2297
  requires_grad: bool = False,
2249
2298
  alias_decay_db: float = 0.0,
2250
2299
  device: Optional[str] = None,
2300
+ dtype: torch.dtype = torch.float32,
2251
2301
  ):
2252
2302
  self.n_bands = n_bands
2253
2303
  self.design = design
@@ -2255,11 +2305,11 @@ class PEQ(Filter):
2255
2305
  self.f_min = f_min
2256
2306
  self.f_max = f_max
2257
2307
  gamma = 10 ** (
2258
- -torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20
2308
+ -torch.abs(torch.tensor(alias_decay_db, device=device, dtype=dtype)) / (nfft) / 20
2259
2309
  )
2260
- k = torch.arange(1, self.n_bands + 1, dtype=torch.float32)
2310
+ k = torch.arange(1, self.n_bands + 1, dtype=dtype)
2261
2311
  self.center_freq_bias = f_min * (f_max / f_min) ** ((k - 1) / (self.n_bands - 1))
2262
- self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device)
2312
+ self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device, dtype=dtype)
2263
2313
  super().__init__(
2264
2314
  size=(self.n_bands, 3, *size),
2265
2315
  nfft=nfft,
@@ -2267,6 +2317,7 @@ class PEQ(Filter):
2267
2317
  requires_grad=requires_grad,
2268
2318
  alias_decay_db=alias_decay_db,
2269
2319
  device=device,
2320
+ dtype=dtype,
2270
2321
  )
2271
2322
 
2272
2323
  def init_param(self):
@@ -2316,14 +2367,13 @@ class PEQ(Filter):
2316
2367
  G=G[1:-1],
2317
2368
  type='peaking',
2318
2369
  )
2319
- b_aa = torch.einsum("p, opmn -> opmn", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
2320
- a_aa = torch.einsum("p, opmn -> opmn", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
2370
+ b_aa = torch.einsum("p, opmn -> opmn", self.alias_envelope_dcy, b)
2371
+ a_aa = torch.einsum("p, opmn -> opmn", self.alias_envelope_dcy, a)
2321
2372
  B = torch.fft.rfft(b_aa, self.nfft, dim=1)
2322
2373
  A = torch.fft.rfft(a_aa, self.nfft, dim=1)
2323
2374
  H_temp = torch.prod(B, dim=0) / (torch.prod(A, dim=0))
2324
2375
  H = torch.where(torch.abs(torch.prod(A, dim=0)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
2325
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
2326
- return H.to(H_type), B, A
2376
+ return H, B, A
2327
2377
 
2328
2378
  def compute_biquad_coeff(self, f, R, G, type='peaking'):
2329
2379
  # f : freq, R : resonance, G : gain in dB
@@ -2432,6 +2482,7 @@ class parallelPEQ(PEQ):
2432
2482
  requires_grad: bool = False,
2433
2483
  alias_decay_db: float = 0.0,
2434
2484
  device: Optional[str] = None,
2485
+ dtype: torch.dtype = torch.float32,
2435
2486
  ):
2436
2487
  super().__init__(
2437
2488
  size=size,
@@ -2445,6 +2496,7 @@ class parallelPEQ(PEQ):
2445
2496
  requires_grad=requires_grad,
2446
2497
  alias_decay_db=alias_decay_db,
2447
2498
  device=device,
2499
+ dtype=dtype,
2448
2500
  )
2449
2501
 
2450
2502
  def init_param(self):
@@ -2487,14 +2539,13 @@ class parallelPEQ(PEQ):
2487
2539
  G=G[1:-1],
2488
2540
  type='peaking'
2489
2541
  )
2490
- b_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
2491
- a_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
2542
+ b_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy, b)
2543
+ a_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy, a)
2492
2544
  B = torch.fft.rfft(b_aa, self.nfft, dim=1)
2493
2545
  A = torch.fft.rfft(a_aa, self.nfft, dim=1)
2494
2546
  H_temp = torch.prod(B, dim=0) / (torch.prod(A, dim=0))
2495
2547
  H = torch.where(torch.abs(torch.prod(A, dim=0)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
2496
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
2497
- return H.to(H_type), B, A
2548
+ return H, B, A
2498
2549
 
2499
2550
 
2500
2551
  def map_eq(self, param):
@@ -2566,6 +2617,7 @@ class AccurateGEQ(Filter):
2566
2617
  - start_freq (float, optional): The starting frequency for the filter bands. Default: 31.25.
2567
2618
  - end_freq (float, optional): The ending frequency for the filter bands. Default: 16000.0.
2568
2619
  - device (str, optional): The device of the constructed tensors. Default: None.
2620
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
2569
2621
 
2570
2622
  **Attributes**:
2571
2623
  - fs (int): The sampling frequency.
@@ -2598,22 +2650,24 @@ class AccurateGEQ(Filter):
2598
2650
  alias_decay_db: float = 0.0,
2599
2651
  start_freq: float = 31.25,
2600
2652
  end_freq: float = 16000.0,
2601
- device=None
2653
+ device=None,
2654
+ dtype: torch.dtype = torch.float32,
2602
2655
  ):
2603
2656
  self.octave_interval = octave_interval
2604
2657
  self.fs = fs
2605
2658
  self.center_freq, self.shelving_crossover = eq_freqs(
2606
2659
  interval=self.octave_interval, start_freq=start_freq, end_freq=end_freq)
2607
2660
  self.n_gains = len(self.center_freq) + 2
2608
- gamma = 10 ** (-torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20)
2609
- self.alias_envelope_dcy = (gamma ** torch.arange(0, 3, 1, device=device))
2661
+ gamma = 10 ** (-torch.abs(torch.tensor(alias_decay_db, device=device, dtype=dtype)) / (nfft) / 20)
2662
+ self.alias_envelope_dcy = (gamma ** torch.arange(0, 3, 1, device=device, dtype=dtype))
2610
2663
  super().__init__(
2611
2664
  size=(self.n_gains, *size),
2612
2665
  nfft=nfft,
2613
2666
  map=map,
2614
2667
  requires_grad=False,
2615
2668
  alias_decay_db=alias_decay_db,
2616
- device=device
2669
+ device=device,
2670
+ dtype=dtype,
2617
2671
  )
2618
2672
 
2619
2673
  def init_param(self):
@@ -2646,14 +2700,13 @@ class AccurateGEQ(Filter):
2646
2700
  device=self.device
2647
2701
  )
2648
2702
 
2649
- b_aa = torch.einsum('p, pomn -> pomn', self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
2650
- a_aa = torch.einsum('p, pomn -> pomn', self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
2703
+ b_aa = torch.einsum('p, pomn -> pomn', self.alias_envelope_dcy, a)
2704
+ a_aa = torch.einsum('p, pomn -> pomn', self.alias_envelope_dcy, b)
2651
2705
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
2652
2706
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
2653
2707
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
2654
2708
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
2655
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
2656
- return H.to(H_type), B, A
2709
+ return H, B, A
2657
2710
 
2658
2711
  def initialize_class(self):
2659
2712
  self.check_param_shape()
@@ -2695,7 +2748,8 @@ class parallelAccurateGEQ(AccurateGEQ):
2695
2748
  alias_decay_db: float = 0.0,
2696
2749
  start_freq: float = 31.25,
2697
2750
  end_freq: float = 16000.0,
2698
- device=None
2751
+ device=None,
2752
+ dtype: torch.dtype = torch.float32,
2699
2753
  ):
2700
2754
  super().__init__(
2701
2755
  size=size,
@@ -2706,7 +2760,8 @@ class parallelAccurateGEQ(AccurateGEQ):
2706
2760
  alias_decay_db=alias_decay_db,
2707
2761
  start_freq=start_freq,
2708
2762
  end_freq=end_freq,
2709
- device=device
2763
+ device=device,
2764
+ dtype=dtype,
2710
2765
  )
2711
2766
 
2712
2767
  def check_param_shape(self):
@@ -2729,14 +2784,13 @@ class parallelAccurateGEQ(AccurateGEQ):
2729
2784
  device=self.device
2730
2785
  )
2731
2786
 
2732
- b_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
2733
- a_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
2787
+ b_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy, a)
2788
+ a_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy, b)
2734
2789
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
2735
2790
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
2736
2791
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
2737
2792
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
2738
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
2739
- return H.to(H_type), B, A
2793
+ return H, B, A
2740
2794
 
2741
2795
  def get_freq_convolve(self):
2742
2796
  self.freq_convolve = lambda x, param: torch.einsum(
@@ -2788,6 +2842,7 @@ class Delay(DSP):
2788
2842
  - **requires_grad** (bool, optional): Flag indicating whether the module parameters require gradients. Default: False.
2789
2843
  - **alias_decay_db** (float, optional): The decaying factor in dB for the time anti-aliasing envelope. The decay refers to the attenuation after nfft samples. Defaults to 0.
2790
2844
  - **device** (str, optional): The device of the constructed tensors. Default: None.
2845
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
2791
2846
 
2792
2847
  **Attributes**:
2793
2848
  - **alias_envelope_dcy** (torch.Tensor): The anti time-aliasing decaying envelope.
@@ -2815,6 +2870,7 @@ class Delay(DSP):
2815
2870
  requires_grad: bool = False,
2816
2871
  alias_decay_db: float = 0.0,
2817
2872
  device: Optional[str] = None,
2873
+ dtype: torch.dtype = torch.float32,
2818
2874
  ):
2819
2875
  self.fs = fs
2820
2876
  self.max_len = max_len
@@ -2826,6 +2882,7 @@ class Delay(DSP):
2826
2882
  requires_grad=requires_grad,
2827
2883
  alias_decay_db=alias_decay_db,
2828
2884
  device=device,
2885
+ dtype=dtype,
2829
2886
  )
2830
2887
  self.initialize_class()
2831
2888
 
@@ -2949,7 +3006,7 @@ class Delay(DSP):
2949
3006
  self.omega = (
2950
3007
  2
2951
3008
  * torch.pi
2952
- * torch.arange(0, self.nfft // 2 + 1, device=self.device)
3009
+ * torch.arange(0, self.nfft // 2 + 1, device=self.device, dtype=self.dtype)
2953
3010
  / self.nfft
2954
3011
  ).unsqueeze(1)
2955
3012
  self.get_freq_response()
@@ -2989,6 +3046,7 @@ class parallelDelay(Delay):
2989
3046
  requires_grad: bool = False,
2990
3047
  alias_decay_db: float = 0.0,
2991
3048
  device: Optional[str] = None,
3049
+ dtype: torch.dtype = torch.float32,
2992
3050
  ):
2993
3051
  super().__init__(
2994
3052
  size=size,
@@ -3000,6 +3058,7 @@ class parallelDelay(Delay):
3000
3058
  requires_grad=requires_grad,
3001
3059
  alias_decay_db=alias_decay_db,
3002
3060
  device=device,
3061
+ dtype=dtype,
3003
3062
  )
3004
3063
 
3005
3064
  def check_param_shape(self):