flamo 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flamo/auxiliary/config/config.py +3 -1
- flamo/auxiliary/eq.py +22 -18
- flamo/auxiliary/reverb.py +28 -17
- flamo/auxiliary/scattering.py +21 -21
- flamo/functional.py +74 -52
- flamo/optimize/dataset.py +7 -5
- flamo/optimize/surface.py +15 -12
- flamo/processor/dsp.py +158 -99
- flamo/processor/system.py +15 -10
- flamo/utils.py +2 -2
- {flamo-0.1.13.dist-info → flamo-0.2.0.dist-info}/METADATA +1 -1
- flamo-0.2.0.dist-info/RECORD +24 -0
- flamo-0.1.13.dist-info/RECORD +0 -24
- {flamo-0.1.13.dist-info → flamo-0.2.0.dist-info}/WHEEL +0 -0
- {flamo-0.1.13.dist-info → flamo-0.2.0.dist-info}/licenses/LICENSE +0 -0
flamo/auxiliary/config/config.py
CHANGED
|
@@ -23,6 +23,8 @@ class HomogeneousFDNConfig(BaseModel):
|
|
|
23
23
|
nfft: int = 96000
|
|
24
24
|
# device to run the model
|
|
25
25
|
device: str = 'cpu'
|
|
26
|
+
# data type
|
|
27
|
+
dtype: torch.dtype = torch.float32
|
|
26
28
|
# delays in samples
|
|
27
29
|
delays: Optional[List[int]] = None
|
|
28
30
|
# delay lengths range in ms
|
|
@@ -76,4 +78,4 @@ class HomogeneousFDNConfig(BaseModel):
|
|
|
76
78
|
), "CUDA is not available for training"
|
|
77
79
|
|
|
78
80
|
# forbid extra fields - adding this to help prevent errors in config file creation
|
|
79
|
-
model_config = ConfigDict(extra="forbid")
|
|
81
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
flamo/auxiliary/eq.py
CHANGED
|
@@ -5,7 +5,7 @@ from flamo.functional import db2mag, shelving_filter, peak_filter, probe_sos
|
|
|
5
5
|
from flamo.auxiliary.minimize import minimize_LBFGS
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def eq_freqs(interval: int = 1, start_freq: float = 31.25, end_freq: float = 16000.0):
|
|
8
|
+
def eq_freqs(interval: int = 1, start_freq: float = 31.25, end_freq: float = 16000.0, device: str = "cpu", dtype: torch.dtype = torch.float32):
|
|
9
9
|
r"""
|
|
10
10
|
Calculate the center frequencies and shelving crossover frequencies for an equalizer.
|
|
11
11
|
|
|
@@ -19,13 +19,13 @@ def eq_freqs(interval: int = 1, start_freq: float = 31.25, end_freq: float = 160
|
|
|
19
19
|
|
|
20
20
|
"""
|
|
21
21
|
center_freq = torch.tensor(
|
|
22
|
-
octave_bands(interval=interval, start_freq=start_freq, end_freq=end_freq)
|
|
22
|
+
octave_bands(interval=interval, start_freq=start_freq, end_freq=end_freq), device=device, dtype=dtype
|
|
23
23
|
)
|
|
24
24
|
shelving_crossover = torch.tensor(
|
|
25
25
|
[
|
|
26
26
|
center_freq[0] / np.power(2, 1 / interval / 2),
|
|
27
27
|
center_freq[-1] * np.power(2, 1 / interval / 2),
|
|
28
|
-
]
|
|
28
|
+
], device=device, dtype=dtype
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
return center_freq, shelving_crossover
|
|
@@ -61,6 +61,7 @@ def geq(
|
|
|
61
61
|
gain_db: torch.Tensor,
|
|
62
62
|
fs: int = 48000,
|
|
63
63
|
device: str = "cpu",
|
|
64
|
+
dtype: torch.dtype = torch.float32,
|
|
64
65
|
):
|
|
65
66
|
r"""
|
|
66
67
|
Computes the second-order sections coefficients of a graphic equalizer.
|
|
@@ -73,6 +74,7 @@ def geq(
|
|
|
73
74
|
- **gain_db** (torch.Tensor): Tensor containing the gain values in decibels for each frequency band.
|
|
74
75
|
- **fs** (int, optional): Sampling frequency. Default: 48000 Hz.
|
|
75
76
|
- **device** (str, optional): Device to use for constructing tensors. Default: cpu.
|
|
77
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
76
78
|
|
|
77
79
|
**Returns**:
|
|
78
80
|
- tuple: A tuple containing the numerator and denominator coefficients of the GEQ filter.
|
|
@@ -82,25 +84,25 @@ def geq(
|
|
|
82
84
|
assert (
|
|
83
85
|
len(gain_db) == num_bands
|
|
84
86
|
), "The number of gains must be equal to the number of frequencies."
|
|
85
|
-
sos = torch.zeros((6, num_bands), device=device)
|
|
87
|
+
sos = torch.zeros((6, num_bands), device=device, dtype=dtype)
|
|
86
88
|
|
|
87
89
|
for band in range(num_bands):
|
|
88
90
|
if band == 0:
|
|
89
|
-
b = torch.zeros(3, device=device)
|
|
91
|
+
b = torch.zeros(3, device=device, dtype=dtype)
|
|
90
92
|
b[0] = db2mag(gain_db[band])
|
|
91
|
-
a = torch.tensor([1, 0, 0], device=device)
|
|
93
|
+
a = torch.tensor([1, 0, 0], device=device, dtype=dtype)
|
|
92
94
|
elif band == 1:
|
|
93
95
|
b, a = shelving_filter(
|
|
94
|
-
shelving_freq[0], db2mag(gain_db[band]), "low", fs=fs, device=device
|
|
96
|
+
shelving_freq[0], db2mag(gain_db[band]), "low", fs=fs, device=device, dtype=dtype
|
|
95
97
|
)
|
|
96
98
|
elif band == num_bands - 1:
|
|
97
99
|
b, a = shelving_filter(
|
|
98
|
-
shelving_freq[1], db2mag(gain_db[band]), "high", fs=fs, device=device
|
|
100
|
+
shelving_freq[1], db2mag(gain_db[band]), "high", fs=fs, device=device, dtype=dtype
|
|
99
101
|
)
|
|
100
102
|
else:
|
|
101
103
|
Q = torch.sqrt(R) / (R - 1)
|
|
102
104
|
b, a = peak_filter(
|
|
103
|
-
center_freq[band - 2], db2mag(gain_db[band]), Q, fs=fs, device=device
|
|
105
|
+
center_freq[band - 2], db2mag(gain_db[band]), Q, fs=fs, device=device, dtype=dtype
|
|
104
106
|
)
|
|
105
107
|
|
|
106
108
|
sos_band = torch.hstack((b, a))
|
|
@@ -115,6 +117,7 @@ def accurate_geq(
|
|
|
115
117
|
shelving_crossover: torch.Tensor,
|
|
116
118
|
fs=48000,
|
|
117
119
|
device: str = "cpu",
|
|
120
|
+
dtype: torch.dtype = torch.float32,
|
|
118
121
|
):
|
|
119
122
|
r"""
|
|
120
123
|
Design a Graphic Equalizer (GEQ) filter.
|
|
@@ -125,6 +128,7 @@ def accurate_geq(
|
|
|
125
128
|
- shelving_crossover (torch.Tensor): Crossover frequencies for shelving filters.
|
|
126
129
|
- fs (int, optional): Sampling frequency. Default: 48000 Hz.
|
|
127
130
|
- device (str, optional): Device to use for constructing tensors. Default: 'cpu'.
|
|
131
|
+
- dtype (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
128
132
|
|
|
129
133
|
**Returns**:
|
|
130
134
|
- tuple: A tuple containing the numerator and denominator coefficients of the GEQ filter.
|
|
@@ -141,38 +145,38 @@ def accurate_geq(
|
|
|
141
145
|
|
|
142
146
|
nfft = 2**16
|
|
143
147
|
num_freq = len(center_freq) + len(shelving_crossover)
|
|
144
|
-
R = torch.tensor(2.7)
|
|
148
|
+
R = torch.tensor(2.7, dtype=dtype)
|
|
145
149
|
# Control frequencies are spaced logarithmically
|
|
146
150
|
num_control = 100
|
|
147
151
|
control_freq = torch.round(
|
|
148
|
-
torch.logspace(np.log10(1), np.log10(fs / 2.1), num_control + 1)
|
|
152
|
+
torch.logspace(np.log10(1), np.log10(fs / 2.1), num_control + 1, dtype=dtype)
|
|
149
153
|
)
|
|
150
154
|
# interpolate the target gain values at control frequencies
|
|
151
|
-
target_freq = torch.cat((torch.tensor([1]), center_freq, torch.tensor([fs / 2.1])))
|
|
155
|
+
target_freq = torch.cat((torch.tensor([1], dtype=dtype), center_freq, torch.tensor([fs / 2.1], dtype=dtype)))
|
|
152
156
|
# targetInterp = torch.tensor(np.interp(control_freq, target_freq, target_gain.squeeze()))
|
|
153
157
|
interp = RegularGridInterpolator([target_freq], target_gain)
|
|
154
158
|
targetInterp = interp([control_freq])
|
|
155
159
|
|
|
156
160
|
# Design prototype of the biquad sections
|
|
157
161
|
prototype_gain = 10 # dB
|
|
158
|
-
prototype_gain_array = torch.full((num_freq + 1, 1), prototype_gain)
|
|
162
|
+
prototype_gain_array = torch.full((num_freq + 1, 1), prototype_gain, dtype=dtype)
|
|
159
163
|
prototype_b, prototype_a = geq(
|
|
160
|
-
center_freq, shelving_crossover, R, prototype_gain_array, fs
|
|
164
|
+
center_freq, shelving_crossover, R, prototype_gain_array, fs, dtype=dtype
|
|
161
165
|
)
|
|
162
166
|
prototype_sos = torch.vstack((prototype_b, prototype_a))
|
|
163
|
-
G, _, _ = probe_sos(prototype_sos, control_freq, nfft, fs)
|
|
167
|
+
G, _, _ = probe_sos(prototype_sos, control_freq, nfft, fs, dtype=dtype)
|
|
164
168
|
G = G / prototype_gain # dB vs control frequencies
|
|
165
169
|
|
|
166
170
|
# Define the optimization bounds
|
|
167
171
|
upperBound = torch.tensor(
|
|
168
|
-
[torch.inf] + [2 * prototype_gain] * num_freq, device=device
|
|
172
|
+
[torch.inf] + [2 * prototype_gain] * num_freq, device=device, dtype=dtype
|
|
169
173
|
)
|
|
170
|
-
lowerBound = torch.tensor([-val for val in upperBound], device=device)
|
|
174
|
+
lowerBound = torch.tensor([-val for val in upperBound], device=device, dtype=dtype)
|
|
171
175
|
|
|
172
176
|
# Optimization
|
|
173
177
|
opt_gains = minimize_LBFGS(G, targetInterp, lowerBound, upperBound, num_freq)
|
|
174
178
|
|
|
175
179
|
# Generate the SOS coefficients
|
|
176
|
-
b, a = geq(center_freq, shelving_crossover, R, opt_gains, fs, device=device)
|
|
180
|
+
b, a = geq(center_freq, shelving_crossover, R, opt_gains, fs, device=device, dtype=dtype)
|
|
177
181
|
|
|
178
182
|
return b, a
|
flamo/auxiliary/reverb.py
CHANGED
|
@@ -104,11 +104,12 @@ class HomogeneousFDN:
|
|
|
104
104
|
def set_model(self, input_layer=None, output_layer=None):
|
|
105
105
|
# set the input and output layers of the FDN model
|
|
106
106
|
if input_layer is None:
|
|
107
|
-
input_layer = dsp.FFT(self.config_dict.nfft)
|
|
107
|
+
input_layer = dsp.FFT(self.config_dict.nfft, dtype=self.config_dict.dtype)
|
|
108
108
|
if output_layer is None:
|
|
109
109
|
output_layer = dsp.iFFTAntiAlias(
|
|
110
110
|
nfft=self.config_dict.nfft,
|
|
111
111
|
alias_decay_db=self.config_dict.alias_decay_db,
|
|
112
|
+
dtype=self.config_dict.dtype,
|
|
112
113
|
)
|
|
113
114
|
|
|
114
115
|
self.model = self.get_shell(input_layer, output_layer)
|
|
@@ -125,6 +126,7 @@ class HomogeneousFDN:
|
|
|
125
126
|
requires_grad=self.config_dict.input_gain_grad,
|
|
126
127
|
alias_decay_db=self.config_dict.alias_decay_db,
|
|
127
128
|
device=self.config_dict.device,
|
|
129
|
+
dtype=self.config_dict.dtype,
|
|
128
130
|
)
|
|
129
131
|
output_gain = dsp.Gain(
|
|
130
132
|
size=(1, self.N),
|
|
@@ -132,6 +134,7 @@ class HomogeneousFDN:
|
|
|
132
134
|
requires_grad=self.config_dict.output_gain_grad,
|
|
133
135
|
alias_decay_db=self.config_dict.alias_decay_db,
|
|
134
136
|
device=self.config_dict.device,
|
|
137
|
+
dtype=self.config_dict.dtype,
|
|
135
138
|
)
|
|
136
139
|
|
|
137
140
|
# RECURSION
|
|
@@ -144,6 +147,7 @@ class HomogeneousFDN:
|
|
|
144
147
|
requires_grad=self.config_dict.delays_grad,
|
|
145
148
|
alias_decay_db=self.config_dict.alias_decay_db,
|
|
146
149
|
device=self.config_dict.device,
|
|
150
|
+
dtype=self.config_dict.dtype,
|
|
147
151
|
)
|
|
148
152
|
# assign the required delay line lengths
|
|
149
153
|
delays.assign_value(delays.sample2s(delay_lines))
|
|
@@ -156,6 +160,7 @@ class HomogeneousFDN:
|
|
|
156
160
|
requires_grad=self.config_dict.mixing_matrix_grad,
|
|
157
161
|
alias_decay_db=self.config_dict.alias_decay_db,
|
|
158
162
|
device=self.config_dict.device,
|
|
163
|
+
dtype=self.config_dict.dtype,
|
|
159
164
|
)
|
|
160
165
|
|
|
161
166
|
# homogeneous attenuation
|
|
@@ -165,6 +170,7 @@ class HomogeneousFDN:
|
|
|
165
170
|
requires_grad=self.config_dict.attenuation_grad,
|
|
166
171
|
alias_decay_db=self.config_dict.alias_decay_db,
|
|
167
172
|
device=self.config_dict.device,
|
|
173
|
+
dtype=self.config_dict.dtype,
|
|
168
174
|
)
|
|
169
175
|
attenuation.map = map_gamma(delay_lines)
|
|
170
176
|
attenuation.assign_value(
|
|
@@ -328,7 +334,8 @@ class parallelFDNAccurateGEQ(dsp.parallelAccurateGEQ):
|
|
|
328
334
|
alias_decay_db: float = 0.0,
|
|
329
335
|
start_freq: float = 31.25,
|
|
330
336
|
end_freq: float = 16000.0,
|
|
331
|
-
device=None
|
|
337
|
+
device=None,
|
|
338
|
+
dtype=torch.float32
|
|
332
339
|
):
|
|
333
340
|
assert (delays is not None), "Delays must be provided"
|
|
334
341
|
self.delays = delays
|
|
@@ -342,7 +349,8 @@ class parallelFDNAccurateGEQ(dsp.parallelAccurateGEQ):
|
|
|
342
349
|
alias_decay_db=alias_decay_db,
|
|
343
350
|
start_freq=start_freq,
|
|
344
351
|
end_freq=end_freq,
|
|
345
|
-
device=device
|
|
352
|
+
device=device,
|
|
353
|
+
dtype=dtype
|
|
346
354
|
)
|
|
347
355
|
|
|
348
356
|
|
|
@@ -394,7 +402,8 @@ class parallelGFDNAccurateGEQ(parallelFDNAccurateGEQ):
|
|
|
394
402
|
alias_decay_db: float = 0.0,
|
|
395
403
|
start_freq: float = 31.25,
|
|
396
404
|
end_freq: float = 16000.0,
|
|
397
|
-
device=None
|
|
405
|
+
device=None,
|
|
406
|
+
dtype=torch.float32
|
|
398
407
|
):
|
|
399
408
|
assert (delays is not None), "Delays must be provided"
|
|
400
409
|
self.delays = delays
|
|
@@ -482,6 +491,7 @@ class parallelFDNGEQ(dsp.parallelGEQ):
|
|
|
482
491
|
requires_grad: bool = False,
|
|
483
492
|
alias_decay_db: float = 0.0,
|
|
484
493
|
device: Optional[str] = None,
|
|
494
|
+
dtype=torch.float32
|
|
485
495
|
):
|
|
486
496
|
assert (delays is not None), "Delays must be provided"
|
|
487
497
|
self.delays = delays
|
|
@@ -497,7 +507,8 @@ class parallelFDNGEQ(dsp.parallelGEQ):
|
|
|
497
507
|
map=map,
|
|
498
508
|
requires_grad=requires_grad,
|
|
499
509
|
alias_decay_db=alias_decay_db,
|
|
500
|
-
device=device
|
|
510
|
+
device=device,
|
|
511
|
+
dtype=dtype
|
|
501
512
|
)
|
|
502
513
|
|
|
503
514
|
def get_poly_coeff(self, param):
|
|
@@ -516,14 +527,13 @@ class parallelFDNGEQ(dsp.parallelGEQ):
|
|
|
516
527
|
fs=self.fs,
|
|
517
528
|
device=self.device
|
|
518
529
|
)
|
|
519
|
-
b_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy
|
|
520
|
-
a_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy
|
|
530
|
+
b_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy, b)
|
|
531
|
+
a_aa = torch.einsum('p, pon -> pon', self.alias_envelope_dcy, a)
|
|
521
532
|
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
522
533
|
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
523
534
|
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
524
535
|
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
525
|
-
|
|
526
|
-
return H.to(H_type), B, A
|
|
536
|
+
return H, B, A
|
|
527
537
|
|
|
528
538
|
def check_param_shape(self):
|
|
529
539
|
assert (
|
|
@@ -560,6 +570,7 @@ class parallelFDNPEQ(Filter):
|
|
|
560
570
|
requires_grad: bool = False,
|
|
561
571
|
alias_decay_db: float = 0.0,
|
|
562
572
|
device: Optional[str] = None,
|
|
573
|
+
dtype=torch.float32
|
|
563
574
|
):
|
|
564
575
|
self.delays = delays
|
|
565
576
|
self.is_twostage = is_twostage
|
|
@@ -576,12 +587,13 @@ class parallelFDNPEQ(Filter):
|
|
|
576
587
|
self.center_freq_bias = f_min * (f_max / f_min) ** ((k - 1) / (self.n_bands - 1))
|
|
577
588
|
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=device)
|
|
578
589
|
super().__init__(
|
|
579
|
-
size=(self.n_bands+1 if self.is_twostage else self.
|
|
590
|
+
size=(self.n_bands+1 if self.is_twostage else self.n_bands, 3, 1 if self.is_proportional else len(delays)),
|
|
580
591
|
nfft=nfft,
|
|
581
592
|
map=map,
|
|
582
593
|
requires_grad=requires_grad,
|
|
583
594
|
alias_decay_db=alias_decay_db,
|
|
584
595
|
device=device,
|
|
596
|
+
dtype=dtype
|
|
585
597
|
)
|
|
586
598
|
|
|
587
599
|
def get_poly_coeff(self, param):
|
|
@@ -644,16 +656,13 @@ class parallelFDNPEQ(Filter):
|
|
|
644
656
|
type='highshelf'
|
|
645
657
|
)
|
|
646
658
|
|
|
647
|
-
b_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy
|
|
648
|
-
a_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy
|
|
659
|
+
b_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy, b)
|
|
660
|
+
a_aa = torch.einsum("p, opn -> opn", self.alias_envelope_dcy, a)
|
|
649
661
|
B = torch.fft.rfft(b_aa, self.nfft, dim=1)
|
|
650
662
|
A = torch.fft.rfft(a_aa, self.nfft, dim=1)
|
|
651
663
|
H_temp = (torch.prod(B, dim=0) / (torch.prod(A, dim=0)))
|
|
652
|
-
# H_temp = (torch.prod(B, dim=0) / (torch.prod(A, dim=0)))
|
|
653
|
-
|
|
654
664
|
H = torch.where(torch.abs(torch.prod(A, dim=0)) != 0, H_temp, torch.finfo(H_temp.dtype).eps*torch.ones_like(H_temp))
|
|
655
|
-
|
|
656
|
-
return H.to(H_type), B, A
|
|
665
|
+
return H, B, A
|
|
657
666
|
|
|
658
667
|
def compute_biquad_coeff(self, f, R, G, type='peaking'):
|
|
659
668
|
# f : freq, R : resonance, G : gain in dB
|
|
@@ -805,6 +814,7 @@ class parallelFirstOrderShelving(dsp.parallelFilter):
|
|
|
805
814
|
delays: torch.Tensor = None,
|
|
806
815
|
alias_decay_db: float = 0.0,
|
|
807
816
|
device: str = None,
|
|
817
|
+
dtype: torch.dtype = torch.float32
|
|
808
818
|
):
|
|
809
819
|
size = (2,) # rt at DC and crossover frequency
|
|
810
820
|
assert (delays is not None), "Delays must be provided"
|
|
@@ -816,7 +826,8 @@ class parallelFirstOrderShelving(dsp.parallelFilter):
|
|
|
816
826
|
nfft=nfft,
|
|
817
827
|
map=map,
|
|
818
828
|
alias_decay_db=alias_decay_db,
|
|
819
|
-
device=device
|
|
829
|
+
device=device,
|
|
830
|
+
dtype=dtype
|
|
820
831
|
)
|
|
821
832
|
gamma = 10 ** (
|
|
822
833
|
-torch.abs(torch.tensor(alias_decay_db, device=device)) / (nfft) / 20
|
flamo/auxiliary/scattering.py
CHANGED
|
@@ -4,10 +4,6 @@ import numpy as np
|
|
|
4
4
|
from typing import Optional
|
|
5
5
|
from flamo.utils import to_complex
|
|
6
6
|
|
|
7
|
-
torch.random.manual_seed(0)
|
|
8
|
-
np.random.seed(0)
|
|
9
|
-
|
|
10
|
-
|
|
11
7
|
class ScatteringMapping(nn.Module):
|
|
12
8
|
r"""
|
|
13
9
|
Class mapping an orthogonal matrix to a paraunitary matrix using sparse scattering.
|
|
@@ -47,23 +43,26 @@ class ScatteringMapping(nn.Module):
|
|
|
47
43
|
m_L: Optional[torch.tensor] = None,
|
|
48
44
|
m_R: Optional[torch.tensor] = None,
|
|
49
45
|
device: str = "cpu",
|
|
46
|
+
dtype: torch.dtype = torch.float32
|
|
50
47
|
):
|
|
51
48
|
super(ScatteringMapping, self).__init__()
|
|
52
49
|
|
|
53
50
|
self.n_stages = n_stages
|
|
54
51
|
self.sparsity = sparsity
|
|
55
52
|
self.gain_per_sample = gain_per_sample
|
|
53
|
+
self.device = device
|
|
54
|
+
self.dtype = dtype
|
|
56
55
|
if m_L is None:
|
|
57
|
-
self.m_L = torch.zeros(N, device=device)
|
|
56
|
+
self.m_L = torch.zeros(N, device=device, dtype=self.dtype)
|
|
58
57
|
else:
|
|
59
58
|
self.m_L = m_L
|
|
60
59
|
if m_R is None:
|
|
61
|
-
self.m_R = torch.zeros(N, device=device)
|
|
60
|
+
self.m_R = torch.zeros(N, device=device, dtype=self.dtype)
|
|
62
61
|
else:
|
|
63
62
|
self.m_R = m_R
|
|
64
|
-
self.sparsity_vect = torch.ones((n_stages), device=device)
|
|
63
|
+
self.sparsity_vect = torch.ones((n_stages), device=device, dtype=self.dtype)
|
|
65
64
|
self.sparsity_vect[0] = sparsity
|
|
66
|
-
self.shifts = get_random_shifts(N, self.sparsity_vect, pulse_size)
|
|
65
|
+
self.shifts = get_random_shifts(N, self.sparsity_vect, pulse_size, dtype=self.dtype)
|
|
67
66
|
|
|
68
67
|
def forward(self, U):
|
|
69
68
|
r"""
|
|
@@ -81,7 +80,7 @@ class ScatteringMapping(nn.Module):
|
|
|
81
80
|
|
|
82
81
|
G = (
|
|
83
82
|
torch.diag(self.gain_per_sample ** self.shifts[k - 1, :])
|
|
84
|
-
.to(
|
|
83
|
+
.to(self.dtype)
|
|
85
84
|
.to(U.device)
|
|
86
85
|
)
|
|
87
86
|
R = torch.matmul(U[:, :, k], G)
|
|
@@ -103,6 +102,7 @@ def cascaded_paraunit_matrix(
|
|
|
103
102
|
pulse_size: int = 1,
|
|
104
103
|
m_L: Optional[torch.tensor] = None,
|
|
105
104
|
m_R: Optional[torch.tensor] = None,
|
|
105
|
+
dtype: torch.dtype = torch.float32,
|
|
106
106
|
):
|
|
107
107
|
r"""
|
|
108
108
|
Creates paraunitary matrix from input orthogonal matrix.
|
|
@@ -122,7 +122,7 @@ def cascaded_paraunit_matrix(
|
|
|
122
122
|
"""
|
|
123
123
|
|
|
124
124
|
K = n_stages + 1
|
|
125
|
-
sparsity_vect = torch.ones((n_stages), device=U.device)
|
|
125
|
+
sparsity_vect = torch.ones((n_stages), device=U.device, dtype=dtype)
|
|
126
126
|
sparsity_vect[0] = sparsity
|
|
127
127
|
# check that the input matrix is of correct shape
|
|
128
128
|
assert U.shape[0] == K, "The input matrix must have n_stages+1 stages"
|
|
@@ -133,14 +133,14 @@ def cascaded_paraunit_matrix(
|
|
|
133
133
|
N = V.shape[0]
|
|
134
134
|
|
|
135
135
|
if m_L is None:
|
|
136
|
-
m_L = torch.zeros(N, device=U.device)
|
|
136
|
+
m_L = torch.zeros(N, device=U.device, dtype=dtype)
|
|
137
137
|
if m_R is None:
|
|
138
|
-
m_R = torch.zeros(N, device=U.device)
|
|
138
|
+
m_R = torch.zeros(N, device=U.device, dtype=dtype)
|
|
139
139
|
|
|
140
|
-
shift_L = get_random_shifts(N, sparsity_vect, pulse_size)
|
|
140
|
+
shift_L = get_random_shifts(N, sparsity_vect, pulse_size, dtype=dtype)
|
|
141
141
|
for k in range(1, K):
|
|
142
142
|
|
|
143
|
-
G = torch.diag(gain_per_sample ** shift_L[k - 1, :]).to(
|
|
143
|
+
G = torch.diag(gain_per_sample ** shift_L[k - 1, :]).to(dtype)
|
|
144
144
|
R = torch.matmul(U[:, :, k], G)
|
|
145
145
|
|
|
146
146
|
V = shift_matrix(V, shift_L[k - 1, :], direction="left")
|
|
@@ -168,7 +168,7 @@ def poly_matrix_conv(A: torch.tensor, B: torch.tensor):
|
|
|
168
168
|
if szA[1] != szB[0]:
|
|
169
169
|
raise ValueError("Invalid matrix dimension.")
|
|
170
170
|
|
|
171
|
-
C = torch.zeros((szA[0], szB[1], szA[2] + szB[2] - 1), device=A.device)
|
|
171
|
+
C = torch.zeros((szA[0], szB[1], szA[2] + szB[2] - 1), device=A.device, dtype=A.dtype)
|
|
172
172
|
|
|
173
173
|
A = A.permute(2, 0, 1)
|
|
174
174
|
B = B.permute(2, 0, 1)
|
|
@@ -202,7 +202,7 @@ def shift_matrix(X: torch.tensor, shift: torch.tensor, direction: str = "left"):
|
|
|
202
202
|
required_space = order + shift.reshape(-1, 1)
|
|
203
203
|
additional_space = int((required_space.max() - X.shape[-1]) + 1)
|
|
204
204
|
X = torch.cat(
|
|
205
|
-
(X, torch.zeros((N, N, additional_space), device=shift.device)), dim=-1
|
|
205
|
+
(X, torch.zeros((N, N, additional_space), device=shift.device, dtype=X.dtype)), dim=-1
|
|
206
206
|
)
|
|
207
207
|
for i in range(N):
|
|
208
208
|
X[i, :, :] = torch.roll(X[i, :, :], int(shift[i].item()), dims=-1)
|
|
@@ -210,7 +210,7 @@ def shift_matrix(X: torch.tensor, shift: torch.tensor, direction: str = "left"):
|
|
|
210
210
|
required_space = order + shift.reshape(1, -1)
|
|
211
211
|
additional_space = int((required_space.max() - X.shape[-1]) + 1)
|
|
212
212
|
X = torch.cat(
|
|
213
|
-
(X, torch.zeros((N, N, additional_space), device=shift.device)), dim=-1
|
|
213
|
+
(X, torch.zeros((N, N, additional_space), device=shift.device, dtype=X.dtype)), dim=-1
|
|
214
214
|
)
|
|
215
215
|
for i in range(N):
|
|
216
216
|
X[:, i, :] = torch.roll(X[:, i, :], int(shift[i].item()), dims=-1)
|
|
@@ -228,14 +228,14 @@ def shift_mat_distribute(X: torch.tensor, sparsity: int, pulse_size: int):
|
|
|
228
228
|
return (rand_shift * pulse_size).int()
|
|
229
229
|
|
|
230
230
|
|
|
231
|
-
def get_random_shifts(N, sparsity_vect, pulse_size):
|
|
232
|
-
rand_shift = torch.zeros(sparsity_vect.shape[0], N, device=sparsity_vect.device)
|
|
231
|
+
def get_random_shifts(N, sparsity_vect, pulse_size, dtype=torch.float32):
|
|
232
|
+
rand_shift = torch.zeros(sparsity_vect.shape[0], N, device=sparsity_vect.device, dtype=dtype)
|
|
233
233
|
for k in range(sparsity_vect.shape[0]):
|
|
234
234
|
temp = torch.floor(
|
|
235
235
|
sparsity_vect[k]
|
|
236
236
|
* (
|
|
237
|
-
torch.arange(0, N, device=sparsity_vect.device)
|
|
238
|
-
+ torch.rand((N), device=sparsity_vect.device) * 0.99
|
|
237
|
+
torch.arange(0, N, device=sparsity_vect.device, dtype=dtype)
|
|
238
|
+
+ torch.rand((N), device=sparsity_vect.device, dtype=dtype) * 0.99
|
|
239
239
|
)
|
|
240
240
|
)
|
|
241
241
|
rand_shift[k, :] = (temp * pulse_size).int()
|