ml4gw 0.7.7__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ml4gw/augmentations.py +5 -0
  2. ml4gw/dataloading/__init__.py +5 -0
  3. ml4gw/dataloading/chunked_dataset.py +2 -4
  4. ml4gw/dataloading/hdf5_dataset.py +12 -10
  5. ml4gw/dataloading/in_memory_dataset.py +12 -12
  6. ml4gw/distributions.py +2 -2
  7. ml4gw/gw.py +18 -21
  8. ml4gw/nn/__init__.py +6 -0
  9. ml4gw/nn/autoencoder/base.py +5 -9
  10. ml4gw/nn/autoencoder/convolutional.py +7 -10
  11. ml4gw/nn/autoencoder/skip_connection.py +3 -5
  12. ml4gw/nn/norm.py +4 -4
  13. ml4gw/nn/resnet/resnet_1d.py +12 -13
  14. ml4gw/nn/resnet/resnet_2d.py +13 -14
  15. ml4gw/nn/streaming/online_average.py +3 -5
  16. ml4gw/nn/streaming/snapshotter.py +10 -14
  17. ml4gw/spectral.py +20 -23
  18. ml4gw/transforms/__init__.py +6 -0
  19. ml4gw/transforms/decimator.py +183 -0
  20. ml4gw/transforms/iirfilter.py +3 -5
  21. ml4gw/transforms/pearson.py +3 -4
  22. ml4gw/transforms/qtransform.py +10 -11
  23. ml4gw/transforms/scaler.py +3 -5
  24. ml4gw/transforms/snr_rescaler.py +7 -11
  25. ml4gw/transforms/spectral.py +6 -13
  26. ml4gw/transforms/spectrogram.py +6 -3
  27. ml4gw/transforms/spline_interpolation.py +7 -9
  28. ml4gw/transforms/transform.py +4 -6
  29. ml4gw/transforms/waveforms.py +8 -15
  30. ml4gw/transforms/whitening.py +11 -16
  31. ml4gw/types.py +8 -5
  32. ml4gw/utils/interferometer.py +20 -3
  33. ml4gw/utils/slicing.py +26 -30
  34. ml4gw/waveforms/__init__.py +6 -0
  35. ml4gw/waveforms/cbc/phenom_p.py +7 -9
  36. ml4gw/waveforms/conversion.py +2 -4
  37. ml4gw/waveforms/generator.py +3 -3
  38. {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/METADATA +28 -8
  39. ml4gw-0.7.8.dist-info/RECORD +57 -0
  40. ml4gw-0.7.7.dist-info/RECORD +0 -56
  41. {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +0 -0
  42. {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
  43. {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/top_level.txt +0 -0
ml4gw/spectral.py CHANGED
@@ -1,4 +1,7 @@
1
1
  """
2
+ This module provides functions for calculation of spectral densities
3
+ and for whitening.
4
+
2
5
  Several implementation details are derived from the scipy csd and welch
3
6
  implementations. For more info, see
4
7
 
@@ -9,8 +12,6 @@ and
9
12
  https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
10
13
  """
11
14
 
12
- from typing import Optional, Union
13
-
14
15
  import torch
15
16
  from jaxtyping import Float
16
17
  from torch import Tensor
@@ -36,13 +37,11 @@ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
36
37
  return torch.quantile(x, q=0.5, axis=axis) / bias
37
38
 
38
39
 
39
- def _validate_shapes(
40
- x: Tensor, nperseg: int, y: Optional[Tensor] = None
41
- ) -> None:
40
+ def _validate_shapes(x: Tensor, nperseg: int, y: Tensor | None = None) -> None:
42
41
  if x.shape[-1] < nperseg:
43
42
  raise ValueError(
44
- "Number of samples {} in input x is insufficient "
45
- "for number of fft samples {}".format(x.shape[-1], nperseg)
43
+ f"Number of samples {x.shape[-1]} in input x is insufficient "
44
+ f"for number of fft samples {nperseg}"
46
45
  )
47
46
  elif x.ndim > 3:
48
47
  raise ValueError(
@@ -59,30 +58,30 @@ def _validate_shapes(
59
58
  if x.shape[-1] != y.shape[-1]:
60
59
  raise ValueError(
61
60
  "Time dimensions of x and y tensors must "
62
- "be the same, found {} and {}".format(x.shape[-1], y.shape[-1])
61
+ f"be the same, found {x.shape[-1]} and {y.shape[-1]}"
63
62
  )
64
63
  elif x.ndim == 1 and not y.ndim == 1:
65
64
  raise ValueError(
66
65
  "Can't compute cross spectral density of "
67
- "1D tensor x with {}D tensor y".format(y.ndim)
66
+ f"1D tensor x with {y.ndim}D tensor y"
68
67
  )
69
68
  elif x.ndim > 1 and y.ndim == x.ndim:
70
69
  if not y.shape == x.shape:
71
70
  raise ValueError(
72
71
  "If x and y tensors have the same number "
73
72
  "of dimensions, shapes must fully match. "
74
- "Found shapes {} and {}".format(x.shape, y.shape)
73
+ f"Found shapes {x.shape} and {y.shape}"
75
74
  )
76
75
  elif x.ndim > 1 and y.ndim != (x.ndim - 1):
77
76
  raise ValueError(
78
77
  "Can't compute cross spectral density of "
79
- "tensors with shapes {} and {}".format(x.shape, y.shape)
78
+ f"tensors with shapes {x.shape} and {y.shape}"
80
79
  )
81
80
  elif x.ndim > 2 and y.shape[0] != x.shape[0]:
82
81
  raise ValueError(
83
82
  "If x is a 3D tensor and y is a 2D tensor, "
84
83
  "0th batch dimensions must match, but found "
85
- "values {} and {}".format(x.shape[0], y.shape[0])
84
+ f"values {x.shape[0]} and {y.shape[0]}"
86
85
  )
87
86
 
88
87
 
@@ -93,7 +92,7 @@ def fast_spectral_density(
93
92
  window: Float[Tensor, " {nperseg//2+1}"],
94
93
  scale: float,
95
94
  average: str = "median",
96
- y: Optional[TimeSeries1to3d] = None,
95
+ y: TimeSeries1to3d | None = None,
97
96
  ) -> FrequencySeries1to3d:
98
97
  """
99
98
  Compute the power spectral density of a multichannel
@@ -340,10 +339,10 @@ def spectral_density(
340
339
 
341
340
  def truncate_inverse_power_spectrum(
342
341
  psd: PSDTensor,
343
- fduration: Union[Float[Tensor, " time"], float],
342
+ fduration: Float[Tensor, " time"] | float,
344
343
  sample_rate: float,
345
- highpass: Optional[float] = None,
346
- lowpass: Optional[float] = None,
344
+ highpass: float | None = None,
345
+ lowpass: float | None = None,
347
346
  ) -> PSDTensor:
348
347
  """
349
348
  Truncate the length of the time domain response
@@ -460,10 +459,10 @@ def normalize_by_psd(
460
459
  def whiten(
461
460
  X: WaveformTensor,
462
461
  psd: PSDTensor,
463
- fduration: Union[Float[Tensor, " time"], float],
462
+ fduration: Float[Tensor, " time"] | float,
464
463
  sample_rate: float,
465
- highpass: Optional[float] = None,
466
- lowpass: Optional[float] = None,
464
+ highpass: float | None = None,
465
+ lowpass: float | None = None,
467
466
  ) -> WaveformTensor:
468
467
  """
469
468
  Whiten a batch of timeseries using the specified
@@ -522,10 +521,8 @@ def whiten(
522
521
  N = X.size(-1)
523
522
  if N <= (2 * pad):
524
523
  raise ValueError(
525
- (
526
- "Not enough timeseries samples {} for number of "
527
- "padded samples {}"
528
- ).format(N, 2 * pad)
524
+ f"Not enough timeseries samples {N} for number of "
525
+ f"padded samples {2 * pad}"
529
526
  )
530
527
 
531
528
  # normalize the number of expected dimensions in the PSD
@@ -1,3 +1,9 @@
1
+ """
2
+ This module contains a variety of data transformation classes,
3
+ including objects to calculate spectral densities, whiten data,
4
+ and compute Q-transforms.
5
+ """
6
+
1
7
  from .iirfilter import IIRFilter
2
8
  from .pearson import ShiftedPearsonCorrelation
3
9
  from .qtransform import QScan, SingleQTransform
@@ -0,0 +1,183 @@
1
+ import torch
2
+
3
+
4
+ class Decimator(torch.nn.Module):
5
+ r"""
6
+ Downsample (decimate) a timeseries according to a user-defined schedule.
7
+
8
+ .. note::
9
+
10
+ This is a naive decimator that does not use any IIR/FIR filtering
11
+ and selects every M-th sample according to the schedule.
12
+
13
+ The schedule specifies which segments of the input to keep and at what
14
+ sampling rate. Each row of the schedule has the form:
15
+
16
+ `[start_time, end_time, target_sample_rate]`
17
+
18
+ Args:
19
+ sample_rate (int):
20
+ Sampling rate (Hz) of the input timeseries.
21
+ schedule (torch.Tensor):
22
+ Tensor of shape `(N, 3)` defining start time, end time,
23
+ and target sample rate for each segment.
24
+
25
+ Shape:
26
+ - Input: `(B, C, T)` where
27
+ - B = batch size
28
+ - C = channels
29
+ - T = number of timesteps
30
+ (must equal schedule duration × sample_rate)
31
+ - Output:
32
+ - If ``split=False`` → `(B, C, T')` where `T'` is total
33
+ number of decimated samples across all segments.
34
+ - If ``split=True`` → list of tensors, one per segment.
35
+
36
+ Returns:
37
+ torch.Tensor or List[torch.Tensor]:
38
+ The decimated timeseries, or list of decimated segments if
39
+ ``split=True``.
40
+
41
+ Example:
42
+ .. code-block:: python
43
+
44
+ >>> import torch
45
+ >>> from ml4gw.transforms.decimator import Decimator
46
+
47
+ >>> sample_rate = 2048
48
+ >>> X_duration = 60
49
+
50
+ >>> schedule = torch.tensor(
51
+ ... [[0, 40, 256], [40, 58, 512], [58, 60, 2048]],
52
+ ... dtype=torch.int,
53
+ ... )
54
+
55
+ >>> decimator = Decimator(sample_rate=sample_rate,
56
+ ... schedule=schedule)
57
+ >>> X = torch.randn(1, 1, sample_rate * X_duration)
58
+ >>> X_dec = decimator(X)
59
+ >>> X_seg = decimator(X, split=True)
60
+
61
+ >>> print("Original shape:", X.shape)
62
+ Original shape: torch.Size([1, 1, 122880])
63
+ >>> print("Decimated shape:", X_dec.shape)
64
+ Decimated shape: torch.Size([1, 1, 23552])
65
+ >>> for i, seg in enumerate(X_seg):
66
+ ... print(f"Segment {i} shape:", seg.shape)
67
+ Segment 0 shape: torch.Size([1, 1, 10240])
68
+ Segment 1 shape: torch.Size([1, 1, 9216])
69
+ Segment 2 shape: torch.Size([1, 1, 4096])
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ sample_rate: int = None,
75
+ schedule: torch.Tensor = None,
76
+ ) -> None:
77
+ super().__init__()
78
+ self.sample_rate = sample_rate
79
+ self.schedule = schedule
80
+
81
+ self._validate_inputs()
82
+ idx = self.build_variable_indices()
83
+ self.register_buffer("idx", idx)
84
+
85
+ self.expected_len = int(
86
+ (self.schedule[:, 1][-1] - self.schedule[:, 0][0])
87
+ * self.sample_rate
88
+ )
89
+
90
+ def _validate_inputs(self) -> None:
91
+ r"""
92
+ Validate the schedule and sample_rate.
93
+ """
94
+ if self.schedule.ndim != 2 or self.schedule.shape[1] != 3:
95
+ raise ValueError(
96
+ f"Schedule must be of shape (N, 3), got {self.schedule.shape}"
97
+ )
98
+
99
+ if not torch.all(self.schedule[:, 1] > self.schedule[:, 0]):
100
+ raise ValueError(
101
+ "Each schedule segment must have end_time > start_time"
102
+ )
103
+
104
+ if torch.any(self.sample_rate % self.schedule[:, 2].long() != 0):
105
+ raise ValueError(
106
+ f"Sample rate {self.sample_rate} must be divisible by all "
107
+ f"target rates {self.schedule[:, 2].tolist()}"
108
+ )
109
+
110
+ def build_variable_indices(self) -> torch.Tensor:
111
+ r"""
112
+ Compute the time indices to keep based on the schedule.
113
+
114
+ Returns:
115
+ torch.Tensor:
116
+ 1D tensor of indices used to decimate the input.
117
+ """
118
+ idx = torch.tensor([], dtype=torch.long)
119
+
120
+ for s in self.schedule:
121
+ if idx.size(0) == 0:
122
+ start = int(s[0] * self.sample_rate)
123
+ else:
124
+ start = int(idx[-1]) + int(idx[-1] - idx[-2])
125
+ stop = int(start + (s[1] - s[0]) * self.sample_rate)
126
+ step = int(self.sample_rate // s[2])
127
+ new_idx = torch.arange(start, stop, step, dtype=torch.long)
128
+ idx = torch.cat((idx, new_idx))
129
+ return idx
130
+
131
+ def split_by_schedule(self, X: torch.Tensor) -> tuple[torch.Tensor, ...]:
132
+ r"""
133
+ Split a decimated timeseries into segments according to the schedule.
134
+
135
+ Args:
136
+ X (torch.Tensor):
137
+ Decimated input of shape `(B, C, T')`.
138
+
139
+ Returns:
140
+ tuple of torch.Tensor:
141
+ Each segment has shape :math:`(B, C, T_i)`
142
+ where :math:`T_i` is the length implied by
143
+ the corresponding schedule row.
144
+ """
145
+ split_sizes = (
146
+ ((self.schedule[:, 1] - self.schedule[:, 0]) * self.schedule[:, 2])
147
+ .long()
148
+ .tolist()
149
+ )
150
+
151
+ return torch.split(X, split_sizes, dim=-1)
152
+
153
+ def forward(
154
+ self,
155
+ X: torch.Tensor,
156
+ split: bool = False,
157
+ ) -> torch.Tensor | list[torch.Tensor]:
158
+ r"""
159
+ Apply decimation to the input timeseries.
160
+
161
+ Args:
162
+ X (torch.Tensor):
163
+ Input tensor of shape `(B, C, T)`, where `T` must equal
164
+ schedule duration × sample_rate.
165
+ split (bool, optional):
166
+ If True, return a list of segments instead of a single
167
+ concatenated tensor. Default: False.
168
+
169
+ Returns:
170
+ torch.Tensor or List[torch.Tensor]:
171
+ Decimated timeseries, or list of decimated segments.
172
+ """
173
+ if X.shape[-1] != self.expected_len:
174
+ raise ValueError(
175
+ f"X length {X.shape[-1]} does not match "
176
+ f"expected schedule duration {self.expected_len}"
177
+ )
178
+
179
+ X_dec = X.index_select(dim=-1, index=self.idx)
180
+
181
+ if split:
182
+ X_dec = self.split_by_schedule(X_dec)
183
+ return X_dec
@@ -1,5 +1,3 @@
1
- from typing import Union
2
-
3
1
  import torch
4
2
  from scipy.signal import iirfilter
5
3
  from torchaudio.functional import filtfilt
@@ -55,9 +53,9 @@ class IIRFilter(torch.nn.Module):
55
53
  def __init__(
56
54
  self,
57
55
  N: int,
58
- Wn: Union[float, torch.Tensor],
59
- rs: Union[None, float, torch.Tensor] = None,
60
- rp: Union[None, float, torch.Tensor] = None,
56
+ Wn: float | torch.Tensor,
57
+ rs: None | float | torch.Tensor = None,
58
+ rp: None | float | torch.Tensor = None,
61
59
  btype="band",
62
60
  analog=False,
63
61
  ftype="butter",
@@ -52,15 +52,14 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
52
52
  raise ValueError(
53
53
  "y may not have more dimensions that x for "
54
54
  "ShiftedPearsonCorrelation, but found shapes "
55
- "{} and {}".format(y.shape, x.shape)
55
+ f"{y.shape} and {x.shape}"
56
56
  )
57
57
  for dim in range(y.ndim):
58
58
  if y.size(-dim - 1) != x.size(-dim - 1):
59
59
  raise ValueError(
60
60
  "x and y expected to have same size along "
61
- "last dimensions, but found shapes {} and {}".format(
62
- x.shape, y.shape
63
- )
61
+ f"last dimensions, but found shapes {x.shape} and "
62
+ f"{y.shape}"
64
63
  )
65
64
 
66
65
  def forward(
@@ -1,6 +1,5 @@
1
1
  import math
2
2
  import warnings
3
- from typing import List, Tuple
4
3
 
5
4
  import torch
6
5
  import torch.nn.functional as F
@@ -146,7 +145,7 @@ class QTile(torch.nn.Module):
146
145
  means = torch.mean(energy, dim=-1, keepdim=True)
147
146
  energy /= means
148
147
  else:
149
- raise ValueError("Invalid normalisation %r" % norm)
148
+ raise ValueError(f"Invalid normalisation {norm}")
150
149
  energy = energy.type(torch.float32)
151
150
  return energy
152
151
 
@@ -194,9 +193,9 @@ class SingleQTransform(torch.nn.Module):
194
193
  self,
195
194
  duration: float,
196
195
  sample_rate: float,
197
- spectrogram_shape: Tuple[int, int],
196
+ spectrogram_shape: tuple[int, int],
198
197
  q: float = 12,
199
- frange: List[float] = None,
198
+ frange: list[float] = None,
200
199
  mismatch: float = 0.2,
201
200
  interpolation_method: str = "bicubic",
202
201
  ) -> None:
@@ -303,7 +302,7 @@ class SingleQTransform(torch.nn.Module):
303
302
  return torch.unique(freqs)
304
303
 
305
304
  def get_max_energy(
306
- self, fsearch_range: List[float] = None, dimension: str = "both"
305
+ self, fsearch_range: list[float] = None, dimension: str = "both"
307
306
  ):
308
307
  """
309
308
  Gets the maximum energy value among the QTiles. The maximum can
@@ -375,7 +374,7 @@ class SingleQTransform(torch.nn.Module):
375
374
  [
376
375
  qtile_interpolator(qtile)
377
376
  for qtile, qtile_interpolator in zip(
378
- qtiles, self.qtile_interpolators
377
+ qtiles, self.qtile_interpolators, strict=True
379
378
  )
380
379
  ],
381
380
  dim=-2,
@@ -459,9 +458,9 @@ class QScan(torch.nn.Module):
459
458
  self,
460
459
  duration: float,
461
460
  sample_rate: float,
462
- spectrogram_shape: Tuple[int, int],
463
- qrange: List[float] = None,
464
- frange: List[float] = None,
461
+ spectrogram_shape: tuple[int, int],
462
+ qrange: list[float] = None,
463
+ frange: list[float] = None,
465
464
  interpolation_method="bicubic",
466
465
  mismatch: float = 0.2,
467
466
  ) -> None:
@@ -500,7 +499,7 @@ class QScan(torch.nn.Module):
500
499
  ]
501
500
  )
502
501
 
503
- def get_qs(self) -> List[float]:
502
+ def get_qs(self) -> list[float]:
504
503
  """
505
504
  Determine the values of Q to try for the set of Q-transforms
506
505
  """
@@ -518,7 +517,7 @@ class QScan(torch.nn.Module):
518
517
  def forward(
519
518
  self,
520
519
  X: TimeSeries1to3d,
521
- fsearch_range: List[float] = None,
520
+ fsearch_range: list[float] = None,
522
521
  norm: str = "median",
523
522
  ):
524
523
  """
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  import torch
4
2
  from jaxtyping import Float
5
3
  from torch import Tensor
@@ -24,7 +22,7 @@ class ChannelWiseScaler(FittableTransform):
24
22
  to be 1D (single channel).
25
23
  """
26
24
 
27
- def __init__(self, num_channels: Optional[int] = None) -> None:
25
+ def __init__(self, num_channels: int | None = None) -> None:
28
26
  super().__init__()
29
27
 
30
28
  shape = (num_channels or 1,)
@@ -37,7 +35,7 @@ class ChannelWiseScaler(FittableTransform):
37
35
  self.register_buffer("std", std)
38
36
 
39
37
  def fit(
40
- self, X: Float[Tensor, "... time"], std_reg: Optional[float] = 0.0
38
+ self, X: Float[Tensor, "... time"], std_reg: float | None = 0.0
41
39
  ) -> None:
42
40
  """Fit the scaling parameters to a timeseries
43
41
 
@@ -59,7 +57,7 @@ class ChannelWiseScaler(FittableTransform):
59
57
  else:
60
58
  raise ValueError(
61
59
  "Can't fit channel wise mean and standard deviation "
62
- "from tensor of shape {}".format(X.shape)
60
+ f"from tensor of shape {X.shape}"
63
61
  )
64
62
  std += std_reg * torch.ones_like(std)
65
63
  super().build(mean=mean, std=std)
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  import torch
4
2
 
5
3
  from ..gw import compute_network_snr
@@ -13,8 +11,8 @@ class SnrRescaler(FittableSpectralTransform):
13
11
  num_channels: int,
14
12
  sample_rate: float,
15
13
  waveform_duration: float,
16
- highpass: Optional[float] = None,
17
- lowpass: Optional[float] = None,
14
+ highpass: float | None = None,
15
+ lowpass: float | None = None,
18
16
  dtype: torch.dtype = torch.float32,
19
17
  ) -> None:
20
18
  super().__init__()
@@ -45,15 +43,13 @@ class SnrRescaler(FittableSpectralTransform):
45
43
  def fit(
46
44
  self,
47
45
  *background: TimeSeries2d,
48
- fftlength: Optional[float] = None,
49
- overlap: Optional[float] = None,
46
+ fftlength: float | None = None,
47
+ overlap: float | None = None,
50
48
  ):
51
49
  if len(background) != self.num_channels:
52
50
  raise ValueError(
53
- "Expected to fit whitening transform on {} background "
54
- "timeseries, but was passed {}".format(
55
- self.num_channels, len(background)
56
- )
51
+ f"Expected to fit whitening transform on {self.num_channels} "
52
+ f"background timeseries, but was passed {len(background)}"
57
53
  )
58
54
 
59
55
  num_freqs = self.background.size(1)
@@ -69,7 +65,7 @@ class SnrRescaler(FittableSpectralTransform):
69
65
  def forward(
70
66
  self,
71
67
  responses: WaveformTensor,
72
- target_snrs: Optional[BatchTensor] = None,
68
+ target_snrs: BatchTensor | None = None,
73
69
  ):
74
70
  snrs = compute_network_snr(
75
71
  responses,
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  import torch
4
2
  from jaxtyping import Float
5
3
  from torch import Tensor
@@ -52,20 +50,17 @@ class SpectralDensity(torch.nn.Module):
52
50
  self,
53
51
  sample_rate: float,
54
52
  fftlength: float,
55
- overlap: Optional[float] = None,
53
+ overlap: float | None = None,
56
54
  average: str = "mean",
57
- window: Optional[
58
- Float[Tensor, " {int(fftlength*sample_rate)}"]
59
- ] = None,
55
+ window: Float[Tensor, " {int(fftlength*sample_rate)}"] | None = None,
60
56
  fast: bool = False,
61
57
  ) -> None:
62
58
  if overlap is None:
63
59
  overlap = fftlength / 2
64
60
  elif overlap >= fftlength:
65
61
  raise ValueError(
66
- "Can't have overlap {} longer than fftlength {}".format(
67
- overlap, fftlength
68
- )
62
+ f"Can't have overlap {overlap} longer than fftlength "
63
+ f"{fftlength}"
69
64
  )
70
65
 
71
66
  super().__init__()
@@ -80,9 +75,7 @@ class SpectralDensity(torch.nn.Module):
80
75
 
81
76
  if window.size(0) != self.nperseg:
82
77
  raise ValueError(
83
- "Window must have length {} got {}".format(
84
- self.nperseg, window.size(0)
85
- )
78
+ f"Window must have length {self.nperseg} got {window.size(0)}"
86
79
  )
87
80
  self.register_buffer("window", window)
88
81
 
@@ -99,7 +92,7 @@ class SpectralDensity(torch.nn.Module):
99
92
  self.fast = fast
100
93
 
101
94
  def forward(
102
- self, x: TimeSeries1to3d, y: Optional[TimeSeries1to3d] = None
95
+ self, x: TimeSeries1to3d, y: TimeSeries1to3d | None = None
103
96
  ) -> FrequencySeries1to3d:
104
97
  if self.fast:
105
98
  return fast_spectral_density(
@@ -1,5 +1,4 @@
1
1
  import warnings
2
- from typing import Dict, List
3
2
 
4
3
  import torch
5
4
  import torch.nn.functional as F
@@ -104,7 +103,7 @@ class MultiResolutionSpectrogram(torch.nn.Module):
104
103
  self.register_buffer("freq_idxs", freq_idxs)
105
104
  self.register_buffer("time_idxs", time_idxs)
106
105
 
107
- def _check_and_format_kwargs(self, kwargs: Dict[str, List]) -> List:
106
+ def _check_and_format_kwargs(self, kwargs: dict[str, list]) -> list:
108
107
  lengths = sorted(len(v) for v in kwargs.values())
109
108
  lengths = list(set(lengths))
110
109
 
@@ -127,7 +126,10 @@ class MultiResolutionSpectrogram(torch.nn.Module):
127
126
  size = lengths[1]
128
127
  kwargs = {k: v * int(size / len(v)) for k, v in kwargs.items()}
129
128
 
130
- return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
129
+ return [
130
+ dict(zip(kwargs, col, strict=True))
131
+ for col in zip(*kwargs.values(), strict=True)
132
+ ]
131
133
 
132
134
  def forward(
133
135
  self, X: TimeSeries3d
@@ -161,6 +163,7 @@ class MultiResolutionSpectrogram(torch.nn.Module):
161
163
  self.right_pad,
162
164
  self.top_pad,
163
165
  self.bottom_pad,
166
+ strict=True,
164
167
  ):
165
168
  padded_specs.append(F.pad(spec, (left, right, top, bottom)))
166
169
 
@@ -2,8 +2,6 @@
2
2
  Adaptation of code from https://github.com/dottormale/Qtransform_torch/
3
3
  """
4
4
 
5
- from typing import Optional, Tuple
6
-
7
5
  import torch
8
6
  from torch import Tensor
9
7
 
@@ -50,7 +48,7 @@ class SplineInterpolateBase(torch.nn.Module):
50
48
  t: Tensor,
51
49
  d: int,
52
50
  m: int,
53
- ) -> Tuple[Tensor, Tensor]:
51
+ ) -> tuple[Tensor, Tensor]:
54
52
  """
55
53
  Compute the L and R values for B-spline basis functions.
56
54
  L and R are respectively the first and second coefficient multiplying
@@ -208,7 +206,7 @@ class SplineInterpolate1D(SplineInterpolateBase):
208
206
  x_in: Tensor,
209
207
  kx: int = 3,
210
208
  sx: float = 0.0,
211
- x_out: Optional[Tensor] = None,
209
+ x_out: Tensor | None = None,
212
210
  ):
213
211
  super().__init__()
214
212
 
@@ -284,7 +282,7 @@ class SplineInterpolate1D(SplineInterpolateBase):
284
282
  def forward(
285
283
  self,
286
284
  Z: Tensor,
287
- x_out: Optional[Tensor] = None,
285
+ x_out: Tensor | None = None,
288
286
  ) -> Tensor:
289
287
  """
290
288
  Compute the interpolated data
@@ -377,8 +375,8 @@ class SplineInterpolate2D(SplineInterpolateBase):
377
375
  ky: int = 3,
378
376
  sx: float = 0.0,
379
377
  sy: float = 0.0,
380
- x_out: Optional[Tensor] = None,
381
- y_out: Optional[Tensor] = None,
378
+ x_out: Tensor | None = None,
379
+ y_out: Tensor | None = None,
382
380
  ):
383
381
  super().__init__()
384
382
 
@@ -483,8 +481,8 @@ class SplineInterpolate2D(SplineInterpolateBase):
483
481
  def forward(
484
482
  self,
485
483
  Z: Tensor,
486
- x_out: Optional[Tensor] = None,
487
- y_out: Optional[Tensor] = None,
484
+ x_out: Tensor | None = None,
485
+ y_out: Tensor | None = None,
488
486
  ) -> Tensor:
489
487
  """
490
488
  Compute the interpolated data
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  import torch
4
2
 
5
3
  from ..spectral import spectral_density
@@ -20,8 +18,8 @@ class FittableTransform(torch.nn.Module):
20
18
  def _check_built(self):
21
19
  if not self.built:
22
20
  raise ValueError(
23
- "Must fit parameters of {} transform to data "
24
- "before calling forward step".format(self.__class__.__name__)
21
+ f"Must fit parameters of {self.__class__.__name__} transform "
22
+ "to data before calling forward step"
25
23
  )
26
24
 
27
25
  def __call__(self, *args, **kwargs):
@@ -47,8 +45,8 @@ class FittableSpectralTransform(FittableTransform):
47
45
  x: TimeSeries1to3d,
48
46
  sample_rate: float,
49
47
  num_freqs: int,
50
- fftlength: Optional[float] = None,
51
- overlap: Optional[float] = None,
48
+ fftlength: float | None = None,
49
+ overlap: float | None = None,
52
50
  ) -> FrequencySeries1to3d:
53
51
  # if we specified an FFT length, convert
54
52
  # the (assumed) time-domain data to the