flamo 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,6 +23,8 @@ class HomogeneousFDNConfig(BaseModel):
23
23
  nfft: int = 96000
24
24
  # device to run the model
25
25
  device: str = 'cpu'
26
+ # data type
27
+ dtype: torch.dtype = torch.float32
26
28
  # delays in samples
27
29
  delays: Optional[List[int]] = None
28
30
  # delay lengths range in ms
@@ -76,4 +78,4 @@ class HomogeneousFDNConfig(BaseModel):
76
78
  ), "CUDA is not available for training"
77
79
 
78
80
  # forbid extra fields - adding this to help prevent errors in config file creation
79
- model_config = ConfigDict(extra="forbid")
81
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
flamo/auxiliary/eq.py CHANGED
@@ -5,7 +5,7 @@ from flamo.functional import db2mag, shelving_filter, peak_filter, probe_sos
5
5
  from flamo.auxiliary.minimize import minimize_LBFGS
6
6
 
7
7
 
8
- def eq_freqs(interval: int = 1, start_freq: float = 31.25, end_freq: float = 16000.0):
8
+ def eq_freqs(interval: int = 1, start_freq: float = 31.25, end_freq: float = 16000.0, device: str = "cpu", dtype: torch.dtype = torch.float32):
9
9
  r"""
10
10
  Calculate the center frequencies and shelving crossover frequencies for an equalizer.
11
11
 
@@ -19,13 +19,13 @@ def eq_freqs(interval: int = 1, start_freq: float = 31.25, end_freq: float = 160
19
19
 
20
20
  """
21
21
  center_freq = torch.tensor(
22
- octave_bands(interval=interval, start_freq=start_freq, end_freq=end_freq)
22
+ octave_bands(interval=interval, start_freq=start_freq, end_freq=end_freq), device=device, dtype=dtype
23
23
  )
24
24
  shelving_crossover = torch.tensor(
25
25
  [
26
26
  center_freq[0] / np.power(2, 1 / interval / 2),
27
27
  center_freq[-1] * np.power(2, 1 / interval / 2),
28
- ]
28
+ ], device=device, dtype=dtype
29
29
  )
30
30
 
31
31
  return center_freq, shelving_crossover
@@ -61,6 +61,7 @@ def geq(
61
61
  gain_db: torch.Tensor,
62
62
  fs: int = 48000,
63
63
  device: str = "cpu",
64
+ dtype: torch.dtype = torch.float32,
64
65
  ):
65
66
  r"""
66
67
  Computes the second-order sections coefficients of a graphic equalizer.
@@ -73,6 +74,7 @@ def geq(
73
74
  - **gain_db** (torch.Tensor): Tensor containing the gain values in decibels for each frequency band.
74
75
  - **fs** (int, optional): Sampling frequency. Default: 48000 Hz.
75
76
  - **device** (str, optional): Device to use for constructing tensors. Default: cpu.
77
+ - **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
76
78
 
77
79
  **Returns**:
78
80
  - tuple: A tuple containing the numerator and denominator coefficients of the GEQ filter.
@@ -82,25 +84,25 @@ def geq(
82
84
  assert (
83
85
  len(gain_db) == num_bands
84
86
  ), "The number of gains must be equal to the number of frequencies."
85
- sos = torch.zeros((6, num_bands), device=device)
87
+ sos = torch.zeros((6, num_bands), device=device, dtype=dtype)
86
88
 
87
89
  for band in range(num_bands):
88
90
  if band == 0:
89
- b = torch.zeros(3, device=device)
91
+ b = torch.zeros(3, device=device, dtype=dtype)
90
92
  b[0] = db2mag(gain_db[band])
91
- a = torch.tensor([1, 0, 0], device=device)
93
+ a = torch.tensor([1, 0, 0], device=device, dtype=dtype)
92
94
  elif band == 1:
93
95
  b, a = shelving_filter(
94
- shelving_freq[0], db2mag(gain_db[band]), "low", fs=fs, device=device
96
+ shelving_freq[0], db2mag(gain_db[band]), "low", fs=fs, device=device, dtype=dtype
95
97
  )
96
98
  elif band == num_bands - 1:
97
99
  b, a = shelving_filter(
98
- shelving_freq[1], db2mag(gain_db[band]), "high", fs=fs, device=device
100
+ shelving_freq[1], db2mag(gain_db[band]), "high", fs=fs, device=device, dtype=dtype
99
101
  )
100
102
  else:
101
103
  Q = torch.sqrt(R) / (R - 1)
102
104
  b, a = peak_filter(
103
- center_freq[band - 2], db2mag(gain_db[band]), Q, fs=fs, device=device
105
+ center_freq[band - 2], db2mag(gain_db[band]), Q, fs=fs, device=device, dtype=dtype
104
106
  )
105
107
 
106
108
  sos_band = torch.hstack((b, a))
@@ -115,6 +117,7 @@ def accurate_geq(
115
117
  shelving_crossover: torch.Tensor,
116
118
  fs=48000,
117
119
  device: str = "cpu",
120
+ dtype: torch.dtype = torch.float32,
118
121
  ):
119
122
  r"""
120
123
  Design a Graphic Equalizer (GEQ) filter.
@@ -125,6 +128,7 @@ def accurate_geq(
125
128
  - shelving_crossover (torch.Tensor): Crossover frequencies for shelving filters.
126
129
  - fs (int, optional): Sampling frequency. Default: 48000 Hz.
127
130
  - device (str, optional): Device to use for constructing tensors. Default: 'cpu'.
131
+ - dtype (torch.dtype, optional): Data type for tensors. Default: torch.float32.
128
132
 
129
133
  **Returns**:
130
134
  - tuple: A tuple containing the numerator and denominator coefficients of the GEQ filter.
@@ -141,38 +145,38 @@ def accurate_geq(
141
145
 
142
146
  nfft = 2**16
143
147
  num_freq = len(center_freq) + len(shelving_crossover)
144
- R = torch.tensor(2.7)
148
+ R = torch.tensor(2.7, dtype=dtype)
145
149
  # Control frequencies are spaced logarithmically
146
150
  num_control = 100
147
151
  control_freq = torch.round(
148
- torch.logspace(np.log10(1), np.log10(fs / 2.1), num_control + 1)
152
+ torch.logspace(np.log10(1), np.log10(fs / 2.1), num_control + 1, dtype=dtype)
149
153
  )
150
154
  # interpolate the target gain values at control frequencies
151
- target_freq = torch.cat((torch.tensor([1]), center_freq, torch.tensor([fs / 2.1])))
155
+ target_freq = torch.cat((torch.tensor([1], dtype=dtype), center_freq, torch.tensor([fs / 2.1], dtype=dtype)))
152
156
  # targetInterp = torch.tensor(np.interp(control_freq, target_freq, target_gain.squeeze()))
153
157
  interp = RegularGridInterpolator([target_freq], target_gain)
154
158
  targetInterp = interp([control_freq])
155
159
 
156
160
  # Design prototype of the biquad sections
157
161
  prototype_gain = 10 # dB
158
- prototype_gain_array = torch.full((num_freq + 1, 1), prototype_gain)
162
+ prototype_gain_array = torch.full((num_freq + 1, 1), prototype_gain, dtype=dtype)
159
163
  prototype_b, prototype_a = geq(
160
- center_freq, shelving_crossover, R, prototype_gain_array, fs
164
+ center_freq, shelving_crossover, R, prototype_gain_array, fs, dtype=dtype
161
165
  )
162
166
  prototype_sos = torch.vstack((prototype_b, prototype_a))
163
- G, _, _ = probe_sos(prototype_sos, control_freq, nfft, fs)
167
+ G, _, _ = probe_sos(prototype_sos, control_freq, nfft, fs, dtype=dtype)
164
168
  G = G / prototype_gain # dB vs control frequencies
165
169
 
166
170
  # Define the optimization bounds
167
171
  upperBound = torch.tensor(
168
- [torch.inf] + [2 * prototype_gain] * num_freq, device=device
172
+ [torch.inf] + [2 * prototype_gain] * num_freq, device=device, dtype=dtype
169
173
  )
170
- lowerBound = torch.tensor([-val for val in upperBound], device=device)
174
+ lowerBound = torch.tensor([-val for val in upperBound], device=device, dtype=dtype)
171
175
 
172
176
  # Optimization
173
177
  opt_gains = minimize_LBFGS(G, targetInterp, lowerBound, upperBound, num_freq)
174
178
 
175
179
  # Generate the SOS coefficients
176
- b, a = geq(center_freq, shelving_crossover, R, opt_gains, fs, device=device)
180
+ b, a = geq(center_freq, shelving_crossover, R, opt_gains, fs, device=device, dtype=dtype)
177
181
 
178
182
  return b, a
flamo/auxiliary/reverb.py CHANGED
@@ -104,11 +104,12 @@ class HomogeneousFDN:
104
104
  def set_model(self, input_layer=None, output_layer=None):
105
105
  # set the input and output layers of the FDN model
106
106
  if input_layer is None:
107
- input_layer = dsp.FFT(self.config_dict.nfft)
107
+ input_layer = dsp.FFT(self.config_dict.nfft, dtype=self.config_dict.dtype)
108
108
  if output_layer is None:
109
109
  output_layer = dsp.iFFTAntiAlias(
110
110
  nfft=self.config_dict.nfft,
111
111
  alias_decay_db=self.config_dict.alias_decay_db,
112
+ dtype=self.config_dict.dtype,
112
113
  )
113
114
 
114
115
  self.model = self.get_shell(input_layer, output_layer)
@@ -125,6 +126,7 @@ class HomogeneousFDN:
125
126
  requires_grad=self.config_dict.input_gain_grad,
126
127
  alias_decay_db=self.config_dict.alias_decay_db,
127
128
  device=self.config_dict.device,
129
+ dtype=self.config_dict.dtype,
128
130
  )
129
131
  output_gain = dsp.Gain(
130
132
  size=(1, self.N),
@@ -132,6 +134,7 @@ class HomogeneousFDN:
132
134
  requires_grad=self.config_dict.output_gain_grad,
133
135
  alias_decay_db=self.config_dict.alias_decay_db,
134
136
  device=self.config_dict.device,
137
+ dtype=self.config_dict.dtype,
135
138
  )
136
139
 
137
140
  # RECURSION
@@ -144,6 +147,7 @@ class HomogeneousFDN:
144
147
  requires_grad=self.config_dict.delays_grad,
145
148
  alias_decay_db=self.config_dict.alias_decay_db,
146
149
  device=self.config_dict.device,
150
+ dtype=self.config_dict.dtype,
147
151
  )
148
152
  # assign the required delay line lengths
149
153
  delays.assign_value(delays.sample2s(delay_lines))
@@ -156,6 +160,7 @@ class HomogeneousFDN:
156
160
  requires_grad=self.config_dict.mixing_matrix_grad,
157
161
  alias_decay_db=self.config_dict.alias_decay_db,
158
162
  device=self.config_dict.device,
163
+ dtype=self.config_dict.dtype,
159
164
  )
160
165
 
161
166
  # homogeneous attenuation
@@ -165,6 +170,7 @@ class HomogeneousFDN:
165
170
  requires_grad=self.config_dict.attenuation_grad,
166
171
  alias_decay_db=self.config_dict.alias_decay_db,
167
172
  device=self.config_dict.device,
173
+ dtype=self.config_dict.dtype,
168
174
  )
169
175
  attenuation.map = map_gamma(delay_lines)
170
176
  attenuation.assign_value(
@@ -328,7 +334,8 @@ class parallelFDNAccurateGEQ(dsp.parallelAccurateGEQ):
328
334
  alias_decay_db: float = 0.0,
329
335
  start_freq: float = 31.25,
330
336
  end_freq: float = 16000.0,
331
- device=None
337
+ device=None,
338
+ dtype=torch.float32
332
339
  ):
333
340
  assert (delays is not None), "Delays must be provided"
334
341
  self.delays = delays
@@ -342,7 +349,8 @@ class parallelFDNAccurateGEQ(dsp.parallelAccurateGEQ):
342
349
  alias_decay_db=alias_decay_db,
343
350
  start_freq=start_freq,
344
351
  end_freq=end_freq,
345
- device=device
352
+ device=device,
353
+ dtype=dtype
346
354
  )
347
355
 
348
356
 
@@ -394,7 +402,8 @@ class parallelGFDNAccurateGEQ(parallelFDNAccurateGEQ):
394
402
  alias_decay_db: float = 0.0,
395
403
  start_freq: float = 31.25,
396
404
  end_freq: float = 16000.0,
397
- device=None
405
+ device=None,
406
+ dtype=torch.float32
398
407
  ):
399
408
  assert (delays is not None), "Delays must be provided"
400
409
  self.delays = delays
@@ -482,6 +491,7 @@ class parallelFDNGEQ(dsp.parallelGEQ):
482
491
  requires_grad: bool = False,
483
492
  alias_decay_db: float = 0.0,
484
493
  device: Optional[str] = None,
494
+ dtype=torch.float32
485
495
  ):
486
496
  assert (delays is not None), "Delays must be provided"
487
497
  self.delays = delays
@@ -497,7 +507,8 @@ class parallelFDNGEQ(dsp.parallelGEQ):
497
507
  map=map,
498
508
  requires_grad=requires_grad,
499
509
  alias_decay_db=alias_decay_db,
500
- device=device
510
+ device=device,
511
+ dtype=dtype
501
512
  )
502
513
 
503
514
  def get_poly_coeff(self, param):
@@ -516,14 +527,13 @@ class parallelFDNGEQ(dsp.parallelGEQ):
516
527
  fs=self.fs,
517
528
  device=self.device
518
529
  )
519
- b_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
520
- a_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
530
+ b_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy, b)
531
+ a_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy, a)
521
532
  B = torch.fft.rfft(b_aa, self.nfft, dim=0)
522
533
  A = torch.fft.rfft(a_aa, self.nfft, dim=0)
523
534
  H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
524
535
  H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
525
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
526
- return H.to(H_type), B, A
536
+ return H, B, A
527
537
 
528
538
  def check_param_shape(self):
529
539
  assert (
@@ -560,6 +570,7 @@ class parallelFDNPEQ(Filter):
560
570
  requires_grad: bool = False,
561
571
  alias_decay_db: float = 0.0,
562
572
  device: Optional[str] = None,
573
+ dtype=torch.float32
563
574
  ):
564
575
  self.delays = delays
565
576
  self.is_twostage = is_twostage
@@ -576,12 +587,13 @@ class parallelFDNPEQ(Filter):
576
587
  self.center_freq_bias = f_min * (f_max / f_min) ** ((k - 1) / (self.n_bands - 1))
577
588
  self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device)
578
589
  super().__init__(
579
- size=(self.n_bands+1 if self.is_twostage else self.n_band, 3, 1 if self.is_proportional else len(delays)),
590
+ size=(self.n_bands+1 if self.is_twostage else self.n_bands, 3, 1 if self.is_proportional else len(delays)),
580
591
  nfft=nfft,
581
592
  map=map,
582
593
  requires_grad=requires_grad,
583
594
  alias_decay_db=alias_decay_db,
584
595
  device=device,
596
+ dtype=dtype
585
597
  )
586
598
 
587
599
  def get_poly_coeff(self, param):
@@ -644,16 +656,13 @@ class parallelFDNPEQ(Filter):
644
656
  type='highshelf'
645
657
  )
646
658
 
647
- b_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy.to(torch.double), b.to(torch.double))
648
- a_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy.to(torch.double), a.to(torch.double))
659
+ b_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy, b)
660
+ a_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy, a)
649
661
  B = torch.fft.rfft(b_aa, self.nfft, dim=1)
650
662
  A = torch.fft.rfft(a_aa, self.nfft, dim=1)
651
663
  H_temp = (torch.prod(B, dim=0) / (torch.prod(A, dim=0)))
652
- # H_temp = (torch.prod(B, dim=0) / (torch.prod(A, dim=0)))
653
-
654
664
  H = torch.where(torch.abs(torch.prod(A, dim=0)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
655
- H_type = torch.complex128 if param.dtype == torch.float64 else torch.complex64
656
- return H.to(H_type), B, A
665
+ return H, B, A
657
666
 
658
667
  def compute_biquad_coeff(self, f, R, G, type='peaking'):
659
668
  # f : freq, R : resonance, G : gain in dB
@@ -805,6 +814,7 @@ class parallelFirstOrderShelving(dsp.parallelFilter):
805
814
  delays: torch.Tensor = None,
806
815
  alias_decay_db: float = 0.0,
807
816
  device: str = None,
817
+ dtype: torch.dtype = torch.float32
808
818
  ):
809
819
  size = (2,) # rt at DC and crossover frequency
810
820
  assert (delays is not None), "Delays must be provided"
@@ -816,7 +826,8 @@ class parallelFirstOrderShelving(dsp.parallelFilter):
816
826
  nfft=nfft,
817
827
  map=map,
818
828
  alias_decay_db=alias_decay_db,
819
- device=device
829
+ device=device,
830
+ dtype=dtype
820
831
  )
821
832
  gamma = 10 ** (
822
833
  -torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20
@@ -4,10 +4,6 @@ import numpy as np
4
4
  from typing import Optional
5
5
  from flamo.utils import to_complex
6
6
 
7
- torch.random.manual_seed(0)
8
- np.random.seed(0)
9
-
10
-
11
7
  class ScatteringMapping(nn.Module):
12
8
  r"""
13
9
  Class mapping an orthogonal matrix to a paraunitary matrix using sparse scattering.
@@ -47,23 +43,26 @@ class ScatteringMapping(nn.Module):
47
43
  m_L: Optional[torch.tensor] = None,
48
44
  m_R: Optional[torch.tensor] = None,
49
45
  device: str = "cpu",
46
+ dtype: torch.dtype = torch.float32
50
47
  ):
51
48
  super(ScatteringMapping, self).__init__()
52
49
 
53
50
  self.n_stages = n_stages
54
51
  self.sparsity = sparsity
55
52
  self.gain_per_sample = gain_per_sample
53
+ self.device = device
54
+ self.dtype = dtype
56
55
  if m_L is None:
57
- self.m_L = torch.zeros(N, device=device)
56
+ self.m_L = torch.zeros(N, device=device, dtype=self.dtype)
58
57
  else:
59
58
  self.m_L = m_L
60
59
  if m_R is None:
61
- self.m_R = torch.zeros(N, device=device)
60
+ self.m_R = torch.zeros(N, device=device, dtype=self.dtype)
62
61
  else:
63
62
  self.m_R = m_R
64
- self.sparsity_vect = torch.ones((n_stages), device=device)
63
+ self.sparsity_vect = torch.ones((n_stages), device=device, dtype=self.dtype)
65
64
  self.sparsity_vect[0] = sparsity
66
- self.shifts = get_random_shifts(N, self.sparsity_vect, pulse_size)
65
+ self.shifts = get_random_shifts(N, self.sparsity_vect, pulse_size, dtype=self.dtype)
67
66
 
68
67
  def forward(self, U):
69
68
  r"""
@@ -81,7 +80,7 @@ class ScatteringMapping(nn.Module):
81
80
 
82
81
  G = (
83
82
  torch.diag(self.gain_per_sample ** self.shifts[k - 1, :])
84
- .to(torch.float32)
83
+ .to(self.dtype)
85
84
  .to(U.device)
86
85
  )
87
86
  R = torch.matmul(U[:, :, k], G)
@@ -103,6 +102,7 @@ def cascaded_paraunit_matrix(
103
102
  pulse_size: int = 1,
104
103
  m_L: Optional[torch.tensor] = None,
105
104
  m_R: Optional[torch.tensor] = None,
105
+ dtype: torch.dtype = torch.float32,
106
106
  ):
107
107
  r"""
108
108
  Creates paraunitary matrix from input orthogonal matrix.
@@ -122,7 +122,7 @@ def cascaded_paraunit_matrix(
122
122
  """
123
123
 
124
124
  K = n_stages + 1
125
- sparsity_vect = torch.ones((n_stages), device=U.device)
125
+ sparsity_vect = torch.ones((n_stages), device=U.device, dtype=dtype)
126
126
  sparsity_vect[0] = sparsity
127
127
  # check that the input matrix is of correct shape
128
128
  assert U.shape[0] == K, "The input matrix must have n_stages+1 stages"
@@ -133,14 +133,14 @@ def cascaded_paraunit_matrix(
133
133
  N = V.shape[0]
134
134
 
135
135
  if m_L is None:
136
- m_L = torch.zeros(N, device=U.device)
136
+ m_L = torch.zeros(N, device=U.device, dtype=dtype)
137
137
  if m_R is None:
138
- m_R = torch.zeros(N, device=U.device)
138
+ m_R = torch.zeros(N, device=U.device, dtype=dtype)
139
139
 
140
- shift_L = get_random_shifts(N, sparsity_vect, pulse_size)
140
+ shift_L = get_random_shifts(N, sparsity_vect, pulse_size, dtype=dtype)
141
141
  for k in range(1, K):
142
142
 
143
- G = torch.diag(gain_per_sample ** shift_L[k - 1, :]).to(torch.float32)
143
+ G = torch.diag(gain_per_sample ** shift_L[k - 1, :]).to(dtype)
144
144
  R = torch.matmul(U[:, :, k], G)
145
145
 
146
146
  V = shift_matrix(V, shift_L[k - 1, :], direction="left")
@@ -168,7 +168,7 @@ def poly_matrix_conv(A: torch.tensor, B: torch.tensor):
168
168
  if szA[1] != szB[0]:
169
169
  raise ValueError("Invalid matrix dimension.")
170
170
 
171
- C = torch.zeros((szA[0], szB[1], szA[2] + szB[2] - 1), device=A.device)
171
+ C = torch.zeros((szA[0], szB[1], szA[2] + szB[2] - 1), device=A.device, dtype=A.dtype)
172
172
 
173
173
  A = A.permute(2, 0, 1)
174
174
  B = B.permute(2, 0, 1)
@@ -202,7 +202,7 @@ def shift_matrix(X: torch.tensor, shift: torch.tensor, direction: str = "left"):
202
202
  required_space = order + shift.reshape(-1, 1)
203
203
  additional_space = int((required_space.max() - X.shape[-1]) + 1)
204
204
  X = torch.cat(
205
- (X, torch.zeros((N, N, additional_space), device=shift.device)), dim=-1
205
+ (X, torch.zeros((N, N, additional_space), device=shift.device, dtype=X.dtype)), dim=-1
206
206
  )
207
207
  for i in range(N):
208
208
  X[i, :, :] = torch.roll(X[i, :, :], int(shift[i].item()), dims=-1)
@@ -210,7 +210,7 @@ def shift_matrix(X: torch.tensor, shift: torch.tensor, direction: str = "left"):
210
210
  required_space = order + shift.reshape(1, -1)
211
211
  additional_space = int((required_space.max() - X.shape[-1]) + 1)
212
212
  X = torch.cat(
213
- (X, torch.zeros((N, N, additional_space), device=shift.device)), dim=-1
213
+ (X, torch.zeros((N, N, additional_space), device=shift.device, dtype=X.dtype)), dim=-1
214
214
  )
215
215
  for i in range(N):
216
216
  X[:, i, :] = torch.roll(X[:, i, :], int(shift[i].item()), dims=-1)
@@ -228,14 +228,14 @@ def shift_mat_distribute(X: torch.tensor, sparsity: int, pulse_size: int):
228
228
  return (rand_shift * pulse_size).int()
229
229
 
230
230
 
231
- def get_random_shifts(N, sparsity_vect, pulse_size):
232
- rand_shift = torch.zeros(sparsity_vect.shape[0], N, device=sparsity_vect.device)
231
+ def get_random_shifts(N, sparsity_vect, pulse_size, dtype=torch.float32):
232
+ rand_shift = torch.zeros(sparsity_vect.shape[0], N, device=sparsity_vect.device, dtype=dtype)
233
233
  for k in range(sparsity_vect.shape[0]):
234
234
  temp = torch.floor(
235
235
  sparsity_vect[k]
236
236
  * (
237
- torch.arange(0, N, device=sparsity_vect.device)
238
- + torch.rand((N), device=sparsity_vect.device) * 0.99
237
+ torch.arange(0, N, device=sparsity_vect.device, dtype=dtype)
238
+ + torch.rand((N), device=sparsity_vect.device, dtype=dtype) * 0.99
239
239
  )
240
240
  )
241
241
  rand_shift[k, :] = (temp * pulse_size).int()