flamo 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flamo/functional.py CHANGED
@@ -56,19 +56,20 @@ def skew_matrix(X):
56
56
  return A - A.transpose(-1, -2)
57
57
 
58
58
 
59
- def get_frequency_samples(num: int, device: str | torch.device = None):
59
+ def get_frequency_samples(num: int, device: str | torch.device = None, dtype: torch.dtype = torch.float32):
60
60
  r"""
61
61
  Get frequency samples (in radians) sampled at linearly spaced points along the unit circle.
62
62
 
63
63
  **Arguments**
64
64
  - **num** (int): number of frequency samples
65
65
  - **device** (torch.device, str): The device of constructed tensors. Default: None.
66
+ - **dtype** (torch.dtype): The dtype of constructed tensors. Default: torch.float32.
66
67
 
67
68
  **Returns**
68
69
  - frequency samples in radians between [0, pi]
69
70
  """
70
- angle = torch.linspace(0, 1, num, device=device)
71
- abs = torch.ones(num, device=device)
71
+ angle = torch.linspace(0, 1, num, device=device, dtype=dtype)
72
+ abs = torch.ones(num, device=device, dtype=dtype)
72
73
  return torch.polar(abs, angle * np.pi)
73
74
 
74
75
 
@@ -77,17 +78,18 @@ class HadamardMatrix(nn.Module):
77
78
  Generate a Hadamard matrix of size N as a nn.Module.
78
79
  """
79
80
 
80
- def __init__(self, N, device: Optional[str] = None):
81
+ def __init__(self, N, device: Optional[str] = None, dtype: torch.dtype = torch.float32):
81
82
  super().__init__()
82
83
  self.N = N
83
84
  self.device = device
85
+ self.dtype = dtype
84
86
 
85
87
  def forward(self, x):
86
- U = torch.tensor([[1.0]], device=self.device)
88
+ U = torch.tensor([[1.0]], device=self.device, dtype=self.dtype)
87
89
  while U.shape[0] < self.N:
88
90
  U = torch.kron(
89
91
  U, torch.tensor([[1, 1], [1, -1]], dtype=U.dtype, device=U.device)
90
- ) / torch.sqrt(torch.tensor(2.0, device=U.device))
92
+ ) / torch.sqrt(torch.tensor(2.0, device=U.device, dtype=U.dtype))
91
93
  return U
92
94
 
93
95
 
@@ -103,6 +105,7 @@ class RotationMatrix(nn.Module):
103
105
  max_angle: float = torch.pi / 4,
104
106
  iter: Optional[int] = None,
105
107
  device: Optional[str] = None,
108
+ dtype: torch.dtype = torch.float32,
106
109
  ):
107
110
 
108
111
  super().__init__()
@@ -111,14 +114,15 @@ class RotationMatrix(nn.Module):
111
114
  self.max_angle = max_angle
112
115
  self.iter = iter
113
116
  self.device = device
117
+ self.dtype = dtype
114
118
 
115
119
  def create_submatrix(self, angles: torch.Tensor, iters: int = 1):
116
120
  """Create a submatrix for each group."""
117
- X = torch.zeros(2, 2, device=self.device)
121
+ X = torch.zeros(2, 2, device=self.device, dtype=self.dtype)
118
122
  angles[0] = torch.clamp(angles[0], self.min_angle, self.max_angle)
119
123
  X.fill_diagonal_(torch.cos(angles[0]))
120
- X[1, 0] = -torch.sin(torch.tensor(angles[0], device=self.device))
121
- X[0, 1] = torch.sin(torch.tensor(angles[0], device=self.device))
124
+ X[1, 0] = -torch.sin(angles[0])
125
+ X[0, 1] = torch.sin(angles[0])
122
126
 
123
127
  if iters is None:
124
128
  iters = torch.log2(torch.tensor(self.N)).int().item() - 1
@@ -166,6 +170,7 @@ def signal_gallery(
166
170
  rate: float = 1.0,
167
171
  reference: torch.Tensor = None,
168
172
  device: str | torch.device = None,
173
+ dtype: torch.dtype = torch.float32,
169
174
  ):
170
175
  r"""
171
176
  Generate a tensor containing a signal based on the specified signal type.
@@ -187,6 +192,7 @@ def signal_gallery(
187
192
  - **fs** (int, optional): The sampling frequency of the signals. Defaults to 48000.
188
193
  - **reference** (torch.Tensor, optional): A reference signal to use. Defaults to None.
189
194
  - **device** (torch.device, optional): The device of constructed tensors. Defaults to None.
195
+ - **dtype** (torch.dtype, optional): The dtype of constructed tensors. Defaults to torch.float32.
190
196
 
191
197
  **Returns**:
192
198
  - torch.Tensor: A tensor of shape (batch_size, n_samples, n) containing the generated signals.
@@ -207,7 +213,7 @@ def signal_gallery(
207
213
  raise ValueError(f"Signal type {signal_type} not recognized.")
208
214
  match signal_type:
209
215
  case "impulse":
210
- x = torch.zeros(batch_size, n_samples, n)
216
+ x = torch.zeros(batch_size, n_samples, n, dtype=dtype)
211
217
  x[:, 0, :] = 1
212
218
  return x.to(device)
213
219
  case "sine":
@@ -218,7 +224,7 @@ def signal_gallery(
218
224
  * np.pi
219
225
  * rate
220
226
  / fs
221
- * torch.linspace(0, n_samples / fs, n_samples)
227
+ * torch.linspace(0, n_samples / fs, n_samples, dtype=dtype)
222
228
  )
223
229
  .unsqueeze(-1)
224
230
  .expand(batch_size, n_samples, n)
@@ -226,44 +232,45 @@ def signal_gallery(
226
232
  )
227
233
  else:
228
234
  return torch.sin(
229
- torch.linspace(0, 2 * np.pi, n_samples)
235
+ torch.linspace(0, 2 * np.pi, n_samples, dtype=dtype)
230
236
  .unsqueeze(-1)
231
237
  .expand(batch_size, n_samples, n)
232
238
  ).to(device)
233
239
  case "sweep":
234
- t = torch.linspace(0, n_samples / fs - 1 / fs, n_samples)
240
+ t = torch.linspace(0, n_samples / fs - 1 / fs, n_samples, dtype=dtype)
235
241
  x = torch.tensor(
236
242
  scipy.signal.chirp(t, f0=20, f1=20000, t1=t[-1], method="linear"),
237
243
  device=device,
244
+ dtype=dtype,
238
245
  ).unsqueeze(-1)
239
246
  return x.expand(batch_size, n_samples, n)
240
247
  case "wgn":
241
- return torch.randn((batch_size, n_samples, n), device=device)
248
+ return torch.randn((batch_size, n_samples, n), device=device, dtype=dtype)
242
249
  case "exp":
243
250
  return (
244
- torch.exp(-rate * torch.arange(n_samples) / fs)
251
+ torch.exp(-rate * torch.arange(n_samples, dtype=dtype) / fs)
245
252
  .unsqueeze(-1)
246
253
  .expand(batch_size, n_samples, n)
247
254
  .to(device)
248
255
  )
249
256
  case "velvet":
250
- x = torch.empty((batch_size, n_samples, n), device=device)
257
+ x = torch.empty((batch_size, n_samples, n), device=device, dtype=dtype)
251
258
  for i_batch in range(batch_size):
252
259
  for i_ch in range(n):
253
- x[i_batch, :, i_ch] = gen_velvet_noise(n_samples, fs, rate, device)
260
+ x[i_batch, :, i_ch] = gen_velvet_noise(n_samples, fs, rate, device, dtype)
254
261
  return x
255
262
  case "reference":
256
263
  if isinstance(reference, torch.Tensor):
257
264
  return reference.expand(batch_size, n_samples, n).to(device)
258
265
  else:
259
- return torch.tensor(reference, device=device).expand(
266
+ return torch.tensor(reference, device=device, dtype=dtype).expand(
260
267
  batch_size, n_samples, n
261
268
  )
262
269
  case "noise":
263
- return torch.randn((batch_size, n_samples, n), device=device)
270
+ return torch.randn((batch_size, n_samples, n), device=device, dtype=dtype)
264
271
 
265
272
 
266
- def gen_velvet_noise(n_samples: int, fs: int, density: float, device: str | torch.device = None) -> torch.Tensor:
273
+ def gen_velvet_noise(n_samples: int, fs: int, density: float, device: str | torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor:
267
274
  r"""
268
275
  Generate a velvet noise sequence.
269
276
  **Arguments**:
@@ -271,15 +278,16 @@ def gen_velvet_noise(n_samples: int, fs: int, density: float, device: str | torc
271
278
  - **fs** (int): The sampling frequency of the signal in Hz.
272
279
  - **density** (float): The density of impulses in impulses per second.
273
280
  - **device** (str | torch.device): The device of constructed tensors.
281
+ - **dtype** (torch.dtype): The dtype of constructed tensors.
274
282
  **Returns**:
275
283
  - torch.Tensor: A tensor of shape (n_samples,) containing the velvet noise sequence.
276
284
  """
277
285
  Td = fs / density # average distance between impulses
278
286
  num_impulses = n_samples / Td # expected number of impulses
279
287
  floor_impulses = math.floor(num_impulses)
280
- grid = torch.arange(floor_impulses) * Td
288
+ grid = torch.arange(floor_impulses, dtype=dtype) * Td
281
289
 
282
- jitter_factors = torch.rand(floor_impulses)
290
+ jitter_factors = torch.rand(floor_impulses, dtype=dtype)
283
291
  impulse_indices = torch.ceil(grid + jitter_factors * (Td - 1)).long()
284
292
 
285
293
  # first impulse is at position 0 and all indices are within bounds
@@ -290,7 +298,7 @@ def gen_velvet_noise(n_samples: int, fs: int, density: float, device: str | torc
290
298
  signs = 2 * torch.randint(0, 2, (floor_impulses,)) - 1
291
299
 
292
300
  # Construct sparse signal
293
- sequence = torch.zeros(n_samples, device=device)
301
+ sequence = torch.zeros(n_samples, device=device, dtype=dtype)
294
302
  sequence[impulse_indices] = signs.float()
295
303
 
296
304
  return sequence
@@ -370,6 +378,7 @@ def lowpass_filter(
370
378
  gain: float = 0.0,
371
379
  fs: int = 48000,
372
380
  device: str | torch.device = None,
381
+ dtype: torch.dtype = torch.float32,
373
382
  ) -> tuple:
374
383
  r"""
375
384
  Lowpass filter coefficients. It uses the `RBJ cookbook formulas <https://webaudio.github.io/Audio-EQ-Cookbook/Audio-EQ-Cookbook.txt>`_ to map
@@ -402,12 +411,12 @@ def lowpass_filter(
402
411
  """
403
412
 
404
413
  omegaC = hertz2rad(fc, fs).to(device=device)
405
- two = torch.tensor(2, device=device)
414
+ two = torch.tensor(2, device=device, dtype=dtype)
406
415
  alpha = torch.sin(omegaC) / 2 * torch.sqrt(two)
407
416
  cosOC = torch.cos(omegaC)
408
417
 
409
- a = torch.ones(3, *omegaC.shape, device=device)
410
- b = torch.ones(3, *omegaC.shape, device=device)
418
+ a = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
419
+ b = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
411
420
 
412
421
  b[0] = (1 - cosOC) / 2
413
422
  b[1] = 1 - cosOC
@@ -424,6 +433,7 @@ def highpass_filter(
424
433
  gain: float = 0.0,
425
434
  fs: int = 48000,
426
435
  device: str | torch.device = None,
436
+ dtype: torch.dtype = torch.float32,
427
437
  ) -> tuple:
428
438
  r"""
429
439
  Highpass filter coefficients. It uses the `RBJ cookbook formulas <https://webaudio.github.io/Audio-EQ-Cookbook/Audio-EQ-Cookbook.txt>`_ to map
@@ -455,12 +465,12 @@ def highpass_filter(
455
465
  """
456
466
 
457
467
  omegaC = hertz2rad(fc, fs)
458
- two = torch.tensor(2, device=device)
468
+ two = torch.tensor(2, device=device, dtype=dtype)
459
469
  alpha = torch.sin(omegaC) / 2 * torch.sqrt(two)
460
470
  cosOC = torch.cos(omegaC)
461
471
 
462
- a = torch.ones(3, *omegaC.shape, device=device)
463
- b = torch.ones(3, *omegaC.shape, device=device)
472
+ a = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
473
+ b = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
464
474
 
465
475
  b[0] = (1 + cosOC) / 2
466
476
  b[1] = -(1 + cosOC)
@@ -478,6 +488,7 @@ def bandpass_filter(
478
488
  gain: float = 0.0,
479
489
  fs: int = 48000,
480
490
  device: str | torch.device = None,
491
+ dtype: torch.dtype = torch.float32,
481
492
  ) -> tuple:
482
493
  r"""
483
494
  Bandpass filter coefficients. It uses the `RBJ cookbook formulas <https://webaudio.github.io/Audio-EQ-Cookbook/Audio-EQ-Cookbook.txt>`_ to map
@@ -521,15 +532,15 @@ def bandpass_filter(
521
532
 
522
533
  omegaC = (hertz2rad(fc1, fs) + hertz2rad(fc2, fs)) / 2
523
534
  BW = torch.log2(fc2 / fc1)
524
- two = torch.tensor(2, device=device)
535
+ two = torch.tensor(2, device=device, dtype=dtype)
525
536
  alpha = torch.sin(omegaC) * torch.sinh(
526
537
  torch.log(two) / two * BW * (omegaC / torch.sin(omegaC))
527
538
  )
528
539
 
529
540
  cosOC = torch.cos(omegaC)
530
541
 
531
- a = torch.ones(3, *omegaC.shape, device=device)
532
- b = torch.ones(3, *omegaC.shape, device=device)
542
+ a = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
543
+ b = torch.ones(3, *omegaC.shape, device=device, dtype=dtype)
533
544
 
534
545
  b[0] = alpha
535
546
  b[1] = 0
@@ -547,7 +558,8 @@ def shelving_filter(
547
558
  type: str = "low",
548
559
  fs: int = 48000,
549
560
  device: torch.device | str = None,
550
- ):
561
+ dtype: torch.dtype = torch.float32,
562
+ ) -> tuple:
551
563
  r"""
552
564
  Shelving filter coefficients.
553
565
  Maps the cutoff frequencies and gain to the :math:`\mathbf{b}` and :math:`\mathbf{a}` biquad coefficients.
@@ -582,8 +594,8 @@ def shelving_filter(
582
594
  - **b** (torch.Tensor): The numerator coefficients of the filter transfer function.
583
595
  - **a** (torch.Tensor): The denominator coefficients of the filter transfer function.
584
596
  """
585
- b = torch.ones(3, device=device)
586
- a = torch.ones(3, device=device)
597
+ b = torch.ones(3, device=device, dtype=dtype)
598
+ a = torch.ones(3, device=device, dtype=dtype)
587
599
 
588
600
  omegaC = hertz2rad(fc, fs)
589
601
  t = torch.tan(omegaC / 2)
@@ -591,7 +603,7 @@ def shelving_filter(
591
603
  g2 = gain**0.5
592
604
  g4 = gain**0.25
593
605
 
594
- two = torch.tensor(2, device=device)
606
+ two = torch.tensor(2, device=device, dtype=dtype)
595
607
  b[0] = g2 * t2 + torch.sqrt(two) * t * g4 + 1
596
608
  b[1] = 2 * g2 * t2 - 2
597
609
  b[2] = g2 * t2 - torch.sqrt(two) * t * g4 + 1
@@ -616,6 +628,7 @@ def peak_filter(
616
628
  Q: torch.Tensor,
617
629
  fs: int = 48000,
618
630
  device: str | torch.device = None,
631
+ dtype: torch.dtype = torch.float32,
619
632
  ) -> tuple:
620
633
  r"""
621
634
  Peak filter coefficients.
@@ -644,8 +657,8 @@ def peak_filter(
644
657
  - **b** (torch.Tensor): The numerator coefficients of the filter transfer function.
645
658
  - **a** (torch.Tensor): The denominator coefficients of the filter transfer function
646
659
  """
647
- b = torch.ones(3, device=device)
648
- a = torch.ones(3, device=device)
660
+ b = torch.ones(3, device=device, dtype=dtype)
661
+ a = torch.ones(3, device=device, dtype=dtype)
649
662
 
650
663
  omegaC = hertz2rad(fc, fs)
651
664
  bandWidth = omegaC / Q
@@ -668,6 +681,7 @@ def prop_shelving_filter(
668
681
  type: str = "low",
669
682
  fs: int = 48000,
670
683
  device="cpu",
684
+ dtype: torch.dtype = torch.float32,
671
685
  ):
672
686
  r"""
673
687
  Proportional first order Shelving filter coefficients.
@@ -700,6 +714,7 @@ def prop_shelving_filter(
700
714
  - **type** (str, optional): The type of shelving filter. Can be 'low' or 'high'. Default: 'low'.
701
715
  - **fs** (int, optional): The sampling frequency of the signal in Hz.
702
716
  - **device** (torch.device | str, optional): The device of constructed tensors. Default: None.
717
+ - **dtype** (torch.dtype, optional): The dtype of constructed tensors. Default: torch.float32.
703
718
 
704
719
  **Returns**:
705
720
  - **b** (torch.Tensor): The numerator coefficients of the filter transfer function.
@@ -712,7 +727,7 @@ def prop_shelving_filter(
712
727
  t = torch.tan(torch.pi * fc / fs)
713
728
  k = 10 ** (gain / 20)
714
729
 
715
- a = torch.zeros((2, *fc.shape), device=device)
730
+ a = torch.zeros((2, *fc.shape), device=device, dtype=dtype)
716
731
  b = torch.zeros_like(a)
717
732
 
718
733
  if type == "low":
@@ -736,6 +751,7 @@ def prop_peak_filter(
736
751
  gain: torch.Tensor,
737
752
  fs: int = 48000,
738
753
  device="cpu",
754
+ dtype: torch.dtype = torch.float32,
739
755
  ):
740
756
  r"""
741
757
  Proportional Peak (Presence) filter coefficients.
@@ -761,6 +777,7 @@ def prop_peak_filter(
761
777
  - **gain** (torch.Tensor): The gain in dB of the filter.
762
778
  - **fs** (int, optional): The sampling frequency of the signal in Hz.
763
779
  - **device** (torch.device | str, optional): The device of constructed tensors. Default: None.
780
+ - **dtype** (torch.dtype, optional): The dtype of constructed tensors. Default: torch.float32.
764
781
 
765
782
  **Returns**:
766
783
  - **b** (torch.Tensor): The numerator coefficients of the filter transfer function.
@@ -774,7 +791,7 @@ def prop_peak_filter(
774
791
  c = torch.cos(2 * np.pi * fc / fs)
775
792
  k = 10 ** (gain / 20)
776
793
 
777
- a = torch.zeros((3, *fc.shape), device=device)
794
+ a = torch.zeros((3, *fc.shape), device=device, dtype=dtype)
778
795
  b = torch.zeros_like(a)
779
796
 
780
797
  b[0] = 1 + torch.sqrt(k) * t
@@ -815,6 +832,7 @@ def svf(
815
832
  filter_type: str = None,
816
833
  fs: int = 48000,
817
834
  device: str | torch.device = None,
835
+ dtype: torch.dtype = torch.float32,
818
836
  ):
819
837
  r"""
820
838
  Implements a State Variable Filter (SVF) with various filter types.
@@ -827,6 +845,7 @@ def svf(
827
845
  - **filter_type** (str, optional): The type of filter to be applied. Can be one of "lowpass", "highpass", "bandpass", "lowshelf", "highshelf", "peaking", "notch", or None. Default: None.
828
846
  - **fs** (int, optional): The sampling frequency. Default: 48000.
829
847
  - **device** (torch.device, optional): The device of constructed tensors. Default: None.
848
+ - **dtype** (torch.dtype, optional): The dtype of constructed tensors. Default: torch.float32.
830
849
 
831
850
  **Returns**:
832
851
  Tuple[torch.Tensor, torch.Tensor]: The numerator and denominator coefficients of the filter transfer function.
@@ -897,8 +916,8 @@ def svf(
897
916
  case None:
898
917
  print("No filter type specified. Using the given mixing coefficents.")
899
918
 
900
- b = torch.zeros((3, *f.shape), device=device)
901
- a = torch.zeros((3, *f.shape), device=device)
919
+ b = torch.zeros((3, *f.shape), device=device, dtype=dtype)
920
+ a = torch.zeros((3, *f.shape), device=device, dtype=dtype)
902
921
 
903
922
  b[0] = (f**2) * m[..., 0] + f * m[..., 1] + m[..., 2]
904
923
  b[1] = 2 * (f**2) * m[..., 0] - 2 * m[..., 2]
@@ -917,6 +936,7 @@ def probe_sos(
917
936
  nfft: int,
918
937
  fs: int,
919
938
  device: str | torch.device = None,
939
+ dtype: torch.dtype = torch.float32,
920
940
  ):
921
941
  r"""
922
942
  Probe the frequency / magnitude response of a cascaded SOS filter at the points
@@ -928,6 +948,7 @@ def probe_sos(
928
948
  - **nfft** (int): Length of the FFT used for frequency analysis.
929
949
  - **fs** (float): Sampling frequency in Hz.
930
950
  - **device** (torch.device, optional): The device of constructed tensors. Default: None.
951
+ - **dtype** (torch.dtype, optional): The dtype of constructed tensors. Default: torch.float32.
931
952
 
932
953
  **Returns**:
933
954
  tuple: A tuple containing the following:
@@ -938,8 +959,8 @@ def probe_sos(
938
959
  n_freqs = sos.shape[-1]
939
960
 
940
961
  H = torch.zeros((nfft // 2 + 1, n_freqs), dtype=torch.cdouble, device=device)
941
- W = torch.zeros((nfft // 2 + 1, n_freqs), device=device)
942
- G = torch.zeros((len(control_freqs), n_freqs), device=device)
962
+ W = torch.zeros((nfft // 2 + 1, n_freqs), device=device, dtype=dtype)
963
+ G = torch.zeros((len(control_freqs), n_freqs), device=device, dtype=dtype)
943
964
 
944
965
  for band in range(n_freqs):
945
966
  sos[:, band] = sos[:, band] / sos[3, band]
@@ -1003,7 +1024,7 @@ def find_onset(rir: torch.Tensor):
1003
1024
 
1004
1025
 
1005
1026
  def WGN_reverb(
1006
- matrix_size: tuple = (1, 1), t60: float = 1.0, samplerate: int = 48000, device=None
1027
+ matrix_size: tuple = (1, 1), t60: float = 1.0, samplerate: int = 48000, device=None, dtype: torch.dtype = torch.float32
1007
1028
  ) -> torch.Tensor:
1008
1029
  r"""
1009
1030
  Generates White-Gaussian-Noise-reverb impulse responses.
@@ -1013,6 +1034,7 @@ def WGN_reverb(
1013
1034
  - **t60** (float, optional): Reverberation time. Defaults to 1.0.
1014
1035
  - **samplerate** (int, optional): Sampling frequency. Defaults to 48000.
1015
1036
  - **nfft** (int, optional): Number of frequency bins. Defaults to 2**11.
1037
+ - **dtype** (torch.dtype, optional): The dtype of constructed tensors. Defaults to torch.float32.
1016
1038
 
1017
1039
  **Returns**:
1018
1040
  torch.Tensor: Matrix of WGN-reverb impulse responses.
@@ -1020,10 +1042,10 @@ def WGN_reverb(
1020
1042
  # Number of samples
1021
1043
  n_samples = int(1.5 * t60 * samplerate)
1022
1044
  # White Guassian Noise
1023
- noise = torch.randn(n_samples, *matrix_size, device=device)
1045
+ noise = torch.randn(n_samples, *matrix_size, device=device, dtype=dtype)
1024
1046
  # Decay
1025
- dr = t60 / torch.log(torch.tensor(1000, dtype=torch.float32, device=device))
1026
- decay = torch.exp(-1 / dr * torch.linspace(0, t60, n_samples))
1047
+ dr = t60 / torch.log(torch.tensor(1000, dtype=dtype, device=device))
1048
+ decay = torch.exp(-1 / dr * torch.linspace(0, t60, n_samples, dtype=dtype))
1027
1049
  decay = decay.view(-1, *(1,) * (len(matrix_size))).expand(-1, *matrix_size)
1028
1050
  # Decaying WGN
1029
1051
  IRs = torch.mul(noise, decay)
@@ -1031,11 +1053,11 @@ def WGN_reverb(
1031
1053
  TFs = torch.fft.rfft(input=IRs, n=n_samples, dim=0)
1032
1054
 
1033
1055
  # Generate bandpass filter
1034
- fc_left = torch.tensor([20], dtype=torch.float32, device=device)
1035
- fc_right = torch.tensor([20000], dtype=torch.float32, device=device)
1036
- g = torch.tensor([1], dtype=torch.float32, device=device)
1056
+ fc_left = torch.tensor([20], dtype=dtype, device=device)
1057
+ fc_right = torch.tensor([20000], dtype=dtype, device=device)
1058
+ g = torch.tensor([1], dtype=dtype, device=device)
1037
1059
  b, a = bandpass_filter(
1038
- fc1=fc_left, fc2=fc_right, gain=g, fs=samplerate, device=device
1060
+ fc1=fc_left, fc2=fc_right, gain=g, fs=samplerate, device=device, dtype=dtype
1039
1061
  )
1040
1062
  sos = torch.cat((b.reshape(1, 3), a.reshape(1, 3)), dim=1)
1041
1063
  bp_H = sosfreqz(sos=sos, nfft=n_samples).squeeze()
flamo/optimize/dataset.py CHANGED
@@ -29,9 +29,10 @@ class Dataset(torch.utils.data.Dataset):
29
29
  target: torch.Tensor = torch.randn(1, 1),
30
30
  expand: int = 1,
31
31
  device: str = "cpu",
32
+ dtype: torch.dtype = torch.float32,
32
33
  ):
33
- self.input = input.to(device)
34
- self.target = target.to(device)
34
+ self.input = input.to(device).to(dtype)
35
+ self.target = target.to(device).to(dtype)
35
36
  self.expand = expand
36
37
  self.device = device
37
38
  self.input = self.input.expand(tuple([expand] + [d for d in input.shape[1:]]))
@@ -69,11 +70,12 @@ class DatasetColorless(Dataset):
69
70
  target_shape: tuple,
70
71
  expand: int = 1000,
71
72
  device: str = "cpu",
73
+ dtype: torch.dtype = torch.float32,
72
74
  ):
73
- input = torch.zeros(input_shape, device=device)
75
+ input = torch.zeros(input_shape, device=device, dtype=dtype)
74
76
  input[:, 0, :] = 1
75
- target = torch.ones(target_shape, device=device)
76
- super().__init__(input=input, target=target, expand=expand, device=device)
77
+ target = torch.ones(target_shape, device=device, dtype=dtype)
78
+ super().__init__(input=input, target=target, expand=expand, device=device, dtype=dtype)
77
79
 
78
80
  def __getitem__(self, index):
79
81
  return self.input[index], self.target[index]
flamo/optimize/surface.py CHANGED
@@ -58,7 +58,7 @@ class LossProfile:
58
58
 
59
59
  """
60
60
 
61
- def __init__(self, net: Shell, loss_config: LossConfig, device: str = "cpu"):
61
+ def __init__(self, net: Shell, loss_config: LossConfig, device: str = "cpu", dtype: torch.dtype = torch.float32):
62
62
 
63
63
  super().__init__()
64
64
  self.net = net
@@ -68,6 +68,7 @@ class LossProfile:
68
68
  self.n_runs = loss_config.n_runs
69
69
  self.output_dir = loss_config.output_dir
70
70
  self.device = device
71
+ self.dtype = dtype
71
72
  self.register_steps()
72
73
 
73
74
  def compute_loss(self, input: torch.Tensor, target: torch.Tensor):
@@ -101,9 +102,9 @@ class LossProfile:
101
102
  if type(self.param_config.lower_bound) == list:
102
103
  # interpolate between the lower and upper bound
103
104
  new_value = (1 - steps[i_step]) * torch.tensor(
104
- self.param_config.lower_bound, device=self.device
105
+ self.param_config.lower_bound, device=self.device, dtype=self.dtype
105
106
  ) + steps[i_step] * torch.tensor(
106
- self.param_config.upper_bound, device=self.device
107
+ self.param_config.upper_bound, device=self.device, dtype=self.dtype
107
108
  )
108
109
  else:
109
110
  new_value = steps[i_step]
@@ -230,12 +231,14 @@ class LossProfile:
230
231
  upper_bound = param_upper_bound
231
232
 
232
233
  if scale == "linear":
233
- steps = torch.linspace(lower_bound, upper_bound, n_steps)
234
+ steps = torch.linspace(lower_bound, upper_bound, n_steps, device=self.device, dtype=self.dtype)
234
235
  elif scale == "log":
235
236
  steps = torch.logspace(
236
- torch.log10(torch.tensor(lower_bound, device=self.device)),
237
- torch.log10(torch.tensor(upper_bound, device=self.device)),
237
+ torch.log10(torch.tensor(lower_bound, device=self.device, dtype=self.dtype)),
238
+ torch.log10(torch.tensor(upper_bound, device=self.device, dtype=self.dtype)),
238
239
  n_steps,
240
+ device=self.device,
241
+ dtype=self.dtype,
239
242
  )
240
243
  else:
241
244
  raise ValueError("Scale must be either 'linear' or 'log'")
@@ -336,9 +339,9 @@ class LossSurface(LossProfile):
336
339
  - **steps** (dict): Dictionary of steps between the lower and upper bound of the parameters.
337
340
  """
338
341
 
339
- def __init__(self, net: Shell, loss_config: LossConfig, device: str = "cpu"):
342
+ def __init__(self, net: Shell, loss_config: LossConfig, device: str = "cpu", dtype: torch.dtype = torch.float32):
340
343
 
341
- super().__init__(net, loss_config, device)
344
+ super().__init__(net, loss_config, device, dtype)
342
345
 
343
346
  assert (
344
347
  len(loss_config.param_config) == 2
@@ -403,9 +406,9 @@ class LossSurface(LossProfile):
403
406
  if type(self.param_config[0].lower_bound) == list:
404
407
  # interpolate between the lower and upper bound
405
408
  new_value = (1 - steps_0[i_step_0]) * torch.tensor(
406
- self.param_config[0].lower_bound, device=self.device
409
+ self.param_config[0].lower_bound, device=self.device, dtype=self.dtype
407
410
  ) + steps_0[i_step_0] * torch.tensor(
408
- self.param_config[0].upper_bound, device=self.device
411
+ self.param_config[0].upper_bound, device=self.device, dtype=self.dtype
409
412
  )
410
413
  else:
411
414
  new_value = steps_0[i_step_0]
@@ -419,9 +422,9 @@ class LossSurface(LossProfile):
419
422
  if type(self.param_config[1].lower_bound) == list:
420
423
  # interpolate between the lower and upper bound
421
424
  new_value = (1 - steps_1[i_step_1]) * torch.tensor(
422
- self.param_config[1].lower_bound, device=self.device
425
+ self.param_config[1].lower_bound, device=self.device, dtype=self.dtype
423
426
  ) + steps_1[i_step_1] * torch.tensor(
424
- self.param_config[1].upper_bound, device=self.device
427
+ self.param_config[1].upper_bound, device=self.device, dtype=self.dtype
425
428
  )
426
429
  else:
427
430
  new_value = steps_1[i_step_1]