ml4gw 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ml4gw/augmentations.py +5 -0
  2. ml4gw/dataloading/__init__.py +5 -0
  3. ml4gw/dataloading/chunked_dataset.py +2 -4
  4. ml4gw/dataloading/hdf5_dataset.py +12 -10
  5. ml4gw/dataloading/in_memory_dataset.py +12 -12
  6. ml4gw/distributions.py +3 -3
  7. ml4gw/gw.py +18 -21
  8. ml4gw/nn/__init__.py +6 -0
  9. ml4gw/nn/autoencoder/base.py +5 -9
  10. ml4gw/nn/autoencoder/convolutional.py +7 -10
  11. ml4gw/nn/autoencoder/skip_connection.py +3 -5
  12. ml4gw/nn/norm.py +4 -4
  13. ml4gw/nn/resnet/resnet_1d.py +12 -13
  14. ml4gw/nn/resnet/resnet_2d.py +13 -14
  15. ml4gw/nn/streaming/online_average.py +3 -5
  16. ml4gw/nn/streaming/snapshotter.py +10 -14
  17. ml4gw/spectral.py +20 -23
  18. ml4gw/transforms/__init__.py +7 -1
  19. ml4gw/transforms/decimator.py +183 -0
  20. ml4gw/transforms/iirfilter.py +3 -5
  21. ml4gw/transforms/pearson.py +3 -4
  22. ml4gw/transforms/qtransform.py +20 -26
  23. ml4gw/transforms/scaler.py +3 -5
  24. ml4gw/transforms/snr_rescaler.py +7 -11
  25. ml4gw/transforms/spectral.py +6 -13
  26. ml4gw/transforms/spectrogram.py +6 -3
  27. ml4gw/transforms/spline_interpolation.py +312 -143
  28. ml4gw/transforms/transform.py +4 -6
  29. ml4gw/transforms/waveforms.py +8 -15
  30. ml4gw/transforms/whitening.py +11 -16
  31. ml4gw/types.py +8 -5
  32. ml4gw/utils/interferometer.py +20 -3
  33. ml4gw/utils/slicing.py +26 -30
  34. ml4gw/waveforms/__init__.py +6 -0
  35. ml4gw/waveforms/cbc/phenom_p.py +7 -9
  36. ml4gw/waveforms/conversion.py +2 -4
  37. ml4gw/waveforms/generator.py +3 -3
  38. {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/METADATA +33 -12
  39. ml4gw-0.7.8.dist-info/RECORD +57 -0
  40. {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +2 -1
  41. ml4gw-0.7.8.dist-info/top_level.txt +1 -0
  42. ml4gw-0.7.6.dist-info/RECORD +0 -55
  43. {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Optional, Sequence, Tuple
1
+ from collections.abc import Sequence
2
2
 
3
3
  import torch
4
4
  from jaxtyping import Float
@@ -58,15 +58,13 @@ class Snapshotter(torch.nn.Module):
58
58
  snapshot_size: int,
59
59
  stride_size: int,
60
60
  batch_size: int,
61
- channels_per_snapshot: Optional[Sequence[int]] = None,
61
+ channels_per_snapshot: Sequence[int] | None = None,
62
62
  ) -> None:
63
63
  super().__init__()
64
64
  if stride_size >= snapshot_size:
65
65
  raise ValueError(
66
- "Snapshotter can't accommodate stride {} "
67
- "which is greater than snapshot size {}".format(
68
- stride_size, snapshot_size
69
- )
66
+ f"Snapshotter can't accommodate stride {stride_size} "
67
+ f"which is greater than snapshot size {snapshot_size}"
70
68
  )
71
69
 
72
70
  self.snapshot_size = snapshot_size
@@ -77,9 +75,8 @@ class Snapshotter(torch.nn.Module):
77
75
  if channels_per_snapshot is not None:
78
76
  if sum(channels_per_snapshot) != num_channels:
79
77
  raise ValueError(
80
- "Can't break {} channels into {}".format(
81
- num_channels, channels_per_snapshot
82
- )
78
+ f"Can't break {num_channels} channels into "
79
+ f"{channels_per_snapshot}"
83
80
  )
84
81
  self.channels_per_snapshot = channels_per_snapshot
85
82
  self.num_channels = num_channels
@@ -90,8 +87,8 @@ class Snapshotter(torch.nn.Module):
90
87
  def forward(
91
88
  self,
92
89
  update: Float[Tensor, "channel time1"],
93
- snapshot: Optional[Float[Tensor, "channel time2"]] = None,
94
- ) -> Tuple[Tensor, ...]:
90
+ snapshot: Float[Tensor, "channel time2"] | None = None,
91
+ ) -> tuple[Tensor, ...]:
95
92
  if snapshot is None:
96
93
  snapshot = self.get_initial_state()
97
94
 
@@ -108,9 +105,8 @@ class Snapshotter(torch.nn.Module):
108
105
  if self.channels_per_snapshot is not None:
109
106
  if snapshots.size(1) != self.num_channels:
110
107
  raise ValueError(
111
- "Expected {} channels, found {}".format(
112
- self.num_channels, snapshots.size(1)
113
- )
108
+ f"Expected {self.num_channels} channels, found "
109
+ f"{snapshots.size(1)}"
114
110
  )
115
111
  snapshots = torch.split(
116
112
  snapshots, self.channels_per_snapshot, dim=1
ml4gw/spectral.py CHANGED
@@ -1,4 +1,7 @@
1
1
  """
2
+ This module provides functions for calculation of spectral densities
3
+ and for whitening.
4
+
2
5
  Several implementation details are derived from the scipy csd and welch
3
6
  implementations. For more info, see
4
7
 
@@ -9,8 +12,6 @@ and
9
12
  https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
10
13
  """
11
14
 
12
- from typing import Optional, Union
13
-
14
15
  import torch
15
16
  from jaxtyping import Float
16
17
  from torch import Tensor
@@ -36,13 +37,11 @@ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
36
37
  return torch.quantile(x, q=0.5, axis=axis) / bias
37
38
 
38
39
 
39
- def _validate_shapes(
40
- x: Tensor, nperseg: int, y: Optional[Tensor] = None
41
- ) -> None:
40
+ def _validate_shapes(x: Tensor, nperseg: int, y: Tensor | None = None) -> None:
42
41
  if x.shape[-1] < nperseg:
43
42
  raise ValueError(
44
- "Number of samples {} in input x is insufficient "
45
- "for number of fft samples {}".format(x.shape[-1], nperseg)
43
+ f"Number of samples {x.shape[-1]} in input x is insufficient "
44
+ f"for number of fft samples {nperseg}"
46
45
  )
47
46
  elif x.ndim > 3:
48
47
  raise ValueError(
@@ -59,30 +58,30 @@ def _validate_shapes(
59
58
  if x.shape[-1] != y.shape[-1]:
60
59
  raise ValueError(
61
60
  "Time dimensions of x and y tensors must "
62
- "be the same, found {} and {}".format(x.shape[-1], y.shape[-1])
61
+ f"be the same, found {x.shape[-1]} and {y.shape[-1]}"
63
62
  )
64
63
  elif x.ndim == 1 and not y.ndim == 1:
65
64
  raise ValueError(
66
65
  "Can't compute cross spectral density of "
67
- "1D tensor x with {}D tensor y".format(y.ndim)
66
+ f"1D tensor x with {y.ndim}D tensor y"
68
67
  )
69
68
  elif x.ndim > 1 and y.ndim == x.ndim:
70
69
  if not y.shape == x.shape:
71
70
  raise ValueError(
72
71
  "If x and y tensors have the same number "
73
72
  "of dimensions, shapes must fully match. "
74
- "Found shapes {} and {}".format(x.shape, y.shape)
73
+ f"Found shapes {x.shape} and {y.shape}"
75
74
  )
76
75
  elif x.ndim > 1 and y.ndim != (x.ndim - 1):
77
76
  raise ValueError(
78
77
  "Can't compute cross spectral density of "
79
- "tensors with shapes {} and {}".format(x.shape, y.shape)
78
+ f"tensors with shapes {x.shape} and {y.shape}"
80
79
  )
81
80
  elif x.ndim > 2 and y.shape[0] != x.shape[0]:
82
81
  raise ValueError(
83
82
  "If x is a 3D tensor and y is a 2D tensor, "
84
83
  "0th batch dimensions must match, but found "
85
- "values {} and {}".format(x.shape[0], y.shape[0])
84
+ f"values {x.shape[0]} and {y.shape[0]}"
86
85
  )
87
86
 
88
87
 
@@ -93,7 +92,7 @@ def fast_spectral_density(
93
92
  window: Float[Tensor, " {nperseg//2+1}"],
94
93
  scale: float,
95
94
  average: str = "median",
96
- y: Optional[TimeSeries1to3d] = None,
95
+ y: TimeSeries1to3d | None = None,
97
96
  ) -> FrequencySeries1to3d:
98
97
  """
99
98
  Compute the power spectral density of a multichannel
@@ -340,10 +339,10 @@ def spectral_density(
340
339
 
341
340
  def truncate_inverse_power_spectrum(
342
341
  psd: PSDTensor,
343
- fduration: Union[Float[Tensor, " time"], float],
342
+ fduration: Float[Tensor, " time"] | float,
344
343
  sample_rate: float,
345
- highpass: Optional[float] = None,
346
- lowpass: Optional[float] = None,
344
+ highpass: float | None = None,
345
+ lowpass: float | None = None,
347
346
  ) -> PSDTensor:
348
347
  """
349
348
  Truncate the length of the time domain response
@@ -460,10 +459,10 @@ def normalize_by_psd(
460
459
  def whiten(
461
460
  X: WaveformTensor,
462
461
  psd: PSDTensor,
463
- fduration: Union[Float[Tensor, " time"], float],
462
+ fduration: Float[Tensor, " time"] | float,
464
463
  sample_rate: float,
465
- highpass: Optional[float] = None,
466
- lowpass: Optional[float] = None,
464
+ highpass: float | None = None,
465
+ lowpass: float | None = None,
467
466
  ) -> WaveformTensor:
468
467
  """
469
468
  Whiten a batch of timeseries using the specified
@@ -522,10 +521,8 @@ def whiten(
522
521
  N = X.size(-1)
523
522
  if N <= (2 * pad):
524
523
  raise ValueError(
525
- (
526
- "Not enough timeseries samples {} for number of "
527
- "padded samples {}"
528
- ).format(N, 2 * pad)
524
+ f"Not enough timeseries samples {N} for number of "
525
+ f"padded samples {2 * pad}"
529
526
  )
530
527
 
531
528
  # normalize the number of expected dimensions in the PSD
@@ -1,3 +1,9 @@
1
+ """
2
+ This module contains a variety of data transformation classes,
3
+ including objects to calculate spectral densities, whiten data,
4
+ and compute Q-transforms.
5
+ """
6
+
1
7
  from .iirfilter import IIRFilter
2
8
  from .pearson import ShiftedPearsonCorrelation
3
9
  from .qtransform import QScan, SingleQTransform
@@ -5,6 +11,6 @@ from .scaler import ChannelWiseScaler
5
11
  from .snr_rescaler import SnrRescaler
6
12
  from .spectral import SpectralDensity
7
13
  from .spectrogram import MultiResolutionSpectrogram
8
- from .spline_interpolation import SplineInterpolate
14
+ from .spline_interpolation import SplineInterpolate1D, SplineInterpolate2D
9
15
  from .waveforms import WaveformProjector, WaveformSampler
10
16
  from .whitening import FixedWhiten, Whiten
@@ -0,0 +1,183 @@
1
+ import torch
2
+
3
+
4
+ class Decimator(torch.nn.Module):
5
+ r"""
6
+ Downsample (decimate) a timeseries according to a user-defined schedule.
7
+
8
+ .. note::
9
+
10
+ This is a naive decimator that does not use any IIR/FIR filtering
11
+ and selects every M-th sample according to the schedule.
12
+
13
+ The schedule specifies which segments of the input to keep and at what
14
+ sampling rate. Each row of the schedule has the form:
15
+
16
+ `[start_time, end_time, target_sample_rate]`
17
+
18
+ Args:
19
+ sample_rate (int):
20
+ Sampling rate (Hz) of the input timeseries.
21
+ schedule (torch.Tensor):
22
+ Tensor of shape `(N, 3)` defining start time, end time,
23
+ and target sample rate for each segment.
24
+
25
+ Shape:
26
+ - Input: `(B, C, T)` where
27
+ - B = batch size
28
+ - C = channels
29
+ - T = number of timesteps
30
+ (must equal schedule duration × sample_rate)
31
+ - Output:
32
+ - If ``split=False`` → `(B, C, T')` where `T'` is total
33
+ number of decimated samples across all segments.
34
+ - If ``split=True`` → list of tensors, one per segment.
35
+
36
+ Returns:
37
+ torch.Tensor or List[torch.Tensor]:
38
+ The decimated timeseries, or list of decimated segments if
39
+ ``split=True``.
40
+
41
+ Example:
42
+ .. code-block:: python
43
+
44
+ >>> import torch
45
+ >>> from ml4gw.transforms.decimator import Decimator
46
+
47
+ >>> sample_rate = 2048
48
+ >>> X_duration = 60
49
+
50
+ >>> schedule = torch.tensor(
51
+ ... [[0, 40, 256], [40, 58, 512], [58, 60, 2048]],
52
+ ... dtype=torch.int,
53
+ ... )
54
+
55
+ >>> decimator = Decimator(sample_rate=sample_rate,
56
+ ... schedule=schedule)
57
+ >>> X = torch.randn(1, 1, sample_rate * X_duration)
58
+ >>> X_dec = decimator(X)
59
+ >>> X_seg = decimator(X, split=True)
60
+
61
+ >>> print("Original shape:", X.shape)
62
+ Original shape: torch.Size([1, 1, 122880])
63
+ >>> print("Decimated shape:", X_dec.shape)
64
+ Decimated shape: torch.Size([1, 1, 23552])
65
+ >>> for i, seg in enumerate(X_seg):
66
+ ... print(f"Segment {i} shape:", seg.shape)
67
+ Segment 0 shape: torch.Size([1, 1, 10240])
68
+ Segment 1 shape: torch.Size([1, 1, 9216])
69
+ Segment 2 shape: torch.Size([1, 1, 4096])
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ sample_rate: int = None,
75
+ schedule: torch.Tensor = None,
76
+ ) -> None:
77
+ super().__init__()
78
+ self.sample_rate = sample_rate
79
+ self.schedule = schedule
80
+
81
+ self._validate_inputs()
82
+ idx = self.build_variable_indices()
83
+ self.register_buffer("idx", idx)
84
+
85
+ self.expected_len = int(
86
+ (self.schedule[:, 1][-1] - self.schedule[:, 0][0])
87
+ * self.sample_rate
88
+ )
89
+
90
+ def _validate_inputs(self) -> None:
91
+ r"""
92
+ Validate the schedule and sample_rate.
93
+ """
94
+ if self.schedule.ndim != 2 or self.schedule.shape[1] != 3:
95
+ raise ValueError(
96
+ f"Schedule must be of shape (N, 3), got {self.schedule.shape}"
97
+ )
98
+
99
+ if not torch.all(self.schedule[:, 1] > self.schedule[:, 0]):
100
+ raise ValueError(
101
+ "Each schedule segment must have end_time > start_time"
102
+ )
103
+
104
+ if torch.any(self.sample_rate % self.schedule[:, 2].long() != 0):
105
+ raise ValueError(
106
+ f"Sample rate {self.sample_rate} must be divisible by all "
107
+ f"target rates {self.schedule[:, 2].tolist()}"
108
+ )
109
+
110
+ def build_variable_indices(self) -> torch.Tensor:
111
+ r"""
112
+ Compute the time indices to keep based on the schedule.
113
+
114
+ Returns:
115
+ torch.Tensor:
116
+ 1D tensor of indices used to decimate the input.
117
+ """
118
+ idx = torch.tensor([], dtype=torch.long)
119
+
120
+ for s in self.schedule:
121
+ if idx.size(0) == 0:
122
+ start = int(s[0] * self.sample_rate)
123
+ else:
124
+ start = int(idx[-1]) + int(idx[-1] - idx[-2])
125
+ stop = int(start + (s[1] - s[0]) * self.sample_rate)
126
+ step = int(self.sample_rate // s[2])
127
+ new_idx = torch.arange(start, stop, step, dtype=torch.long)
128
+ idx = torch.cat((idx, new_idx))
129
+ return idx
130
+
131
+ def split_by_schedule(self, X: torch.Tensor) -> tuple[torch.Tensor, ...]:
132
+ r"""
133
+ Split a decimated timeseries into segments according to the schedule.
134
+
135
+ Args:
136
+ X (torch.Tensor):
137
+ Decimated input of shape `(B, C, T')`.
138
+
139
+ Returns:
140
+ tuple of torch.Tensor:
141
+ Each segment has shape :math:`(B, C, T_i)`
142
+ where :math:`T_i` is the length implied by
143
+ the corresponding schedule row.
144
+ """
145
+ split_sizes = (
146
+ ((self.schedule[:, 1] - self.schedule[:, 0]) * self.schedule[:, 2])
147
+ .long()
148
+ .tolist()
149
+ )
150
+
151
+ return torch.split(X, split_sizes, dim=-1)
152
+
153
+ def forward(
154
+ self,
155
+ X: torch.Tensor,
156
+ split: bool = False,
157
+ ) -> torch.Tensor | list[torch.Tensor]:
158
+ r"""
159
+ Apply decimation to the input timeseries.
160
+
161
+ Args:
162
+ X (torch.Tensor):
163
+ Input tensor of shape `(B, C, T)`, where `T` must equal
164
+ schedule duration × sample_rate.
165
+ split (bool, optional):
166
+ If True, return a list of segments instead of a single
167
+ concatenated tensor. Default: False.
168
+
169
+ Returns:
170
+ torch.Tensor or List[torch.Tensor]:
171
+ Decimated timeseries, or list of decimated segments.
172
+ """
173
+ if X.shape[-1] != self.expected_len:
174
+ raise ValueError(
175
+ f"X length {X.shape[-1]} does not match "
176
+ f"expected schedule duration {self.expected_len}"
177
+ )
178
+
179
+ X_dec = X.index_select(dim=-1, index=self.idx)
180
+
181
+ if split:
182
+ X_dec = self.split_by_schedule(X_dec)
183
+ return X_dec
@@ -1,5 +1,3 @@
1
- from typing import Union
2
-
3
1
  import torch
4
2
  from scipy.signal import iirfilter
5
3
  from torchaudio.functional import filtfilt
@@ -55,9 +53,9 @@ class IIRFilter(torch.nn.Module):
55
53
  def __init__(
56
54
  self,
57
55
  N: int,
58
- Wn: Union[float, torch.Tensor],
59
- rs: Union[None, float, torch.Tensor] = None,
60
- rp: Union[None, float, torch.Tensor] = None,
56
+ Wn: float | torch.Tensor,
57
+ rs: None | float | torch.Tensor = None,
58
+ rp: None | float | torch.Tensor = None,
61
59
  btype="band",
62
60
  analog=False,
63
61
  ftype="butter",
@@ -52,15 +52,14 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
52
52
  raise ValueError(
53
53
  "y may not have more dimensions that x for "
54
54
  "ShiftedPearsonCorrelation, but found shapes "
55
- "{} and {}".format(y.shape, x.shape)
55
+ f"{y.shape} and {x.shape}"
56
56
  )
57
57
  for dim in range(y.ndim):
58
58
  if y.size(-dim - 1) != x.size(-dim - 1):
59
59
  raise ValueError(
60
60
  "x and y expected to have same size along "
61
- "last dimensions, but found shapes {} and {}".format(
62
- x.shape, y.shape
63
- )
61
+ f"last dimensions, but found shapes {x.shape} and "
62
+ f"{y.shape}"
64
63
  )
65
64
 
66
65
  def forward(
@@ -1,6 +1,5 @@
1
1
  import math
2
2
  import warnings
3
- from typing import List, Tuple
4
3
 
5
4
  import torch
6
5
  import torch.nn.functional as F
@@ -8,7 +7,7 @@ from jaxtyping import Float, Int
8
7
  from torch import Tensor
9
8
 
10
9
  from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
11
- from .spline_interpolation import SplineInterpolate
10
+ from .spline_interpolation import SplineInterpolate1D
12
11
 
13
12
  """
14
13
  All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
@@ -146,7 +145,7 @@ class QTile(torch.nn.Module):
146
145
  means = torch.mean(energy, dim=-1, keepdim=True)
147
146
  energy /= means
148
147
  else:
149
- raise ValueError("Invalid normalisation %r" % norm)
148
+ raise ValueError(f"Invalid normalisation {norm}")
150
149
  energy = energy.type(torch.float32)
151
150
  return energy
152
151
 
@@ -194,9 +193,9 @@ class SingleQTransform(torch.nn.Module):
194
193
  self,
195
194
  duration: float,
196
195
  sample_rate: float,
197
- spectrogram_shape: Tuple[int, int],
196
+ spectrogram_shape: tuple[int, int],
198
197
  q: float = 12,
199
- frange: List[float] = None,
198
+ frange: list[float] = None,
200
199
  mismatch: float = 0.2,
201
200
  interpolation_method: str = "bicubic",
202
201
  ) -> None:
@@ -260,18 +259,15 @@ class SingleQTransform(torch.nn.Module):
260
259
  )
261
260
  self.qtile_interpolators = torch.nn.ModuleList(
262
261
  [
263
- SplineInterpolate(
262
+ SplineInterpolate1D(
264
263
  kx=3,
265
264
  x_in=torch.arange(0, self.duration, self.duration / tiles),
266
- y_in=torch.arange(len(idx)),
267
265
  x_out=t_out,
268
- y_out=torch.arange(len(idx)),
269
266
  )
270
- for tiles, idx in zip(unique_ntiles, self.stack_idx)
267
+ for tiles in unique_ntiles
271
268
  ]
272
269
  )
273
270
 
274
- t_in = t_out
275
271
  f_in = self.freqs
276
272
  f_out = torch.logspace(
277
273
  math.log10(self.frange[0]),
@@ -279,13 +275,10 @@ class SingleQTransform(torch.nn.Module):
279
275
  self.spectrogram_shape[0],
280
276
  )
281
277
 
282
- self.interpolator = SplineInterpolate(
278
+ self.interpolator = SplineInterpolate1D(
283
279
  kx=3,
284
- ky=3,
285
- x_in=t_in,
286
- y_in=f_in,
287
- x_out=t_out,
288
- y_out=f_out,
280
+ x_in=f_in,
281
+ x_out=f_out,
289
282
  )
290
283
 
291
284
  def get_freqs(self) -> Float[Tensor, " nfreq"]:
@@ -309,7 +302,7 @@ class SingleQTransform(torch.nn.Module):
309
302
  return torch.unique(freqs)
310
303
 
311
304
  def get_max_energy(
312
- self, fsearch_range: List[float] = None, dimension: str = "both"
305
+ self, fsearch_range: list[float] = None, dimension: str = "both"
313
306
  ):
314
307
  """
315
308
  Gets the maximum energy value among the QTiles. The maximum can
@@ -379,14 +372,15 @@ class SingleQTransform(torch.nn.Module):
379
372
  ]
380
373
  time_interped = torch.cat(
381
374
  [
382
- interpolator(qtile)
383
- for qtile, interpolator in zip(
384
- qtiles, self.qtile_interpolators
375
+ qtile_interpolator(qtile)
376
+ for qtile, qtile_interpolator in zip(
377
+ qtiles, self.qtile_interpolators, strict=True
385
378
  )
386
379
  ],
387
380
  dim=-2,
388
381
  )
389
- return self.interpolator(time_interped)
382
+ # Transpose because the final dimension gets interpolated
383
+ return self.interpolator(time_interped.mT).mT
390
384
  num_f_bins, num_t_bins = self.spectrogram_shape
391
385
  resampled = [
392
386
  F.interpolate(
@@ -464,9 +458,9 @@ class QScan(torch.nn.Module):
464
458
  self,
465
459
  duration: float,
466
460
  sample_rate: float,
467
- spectrogram_shape: Tuple[int, int],
468
- qrange: List[float] = None,
469
- frange: List[float] = None,
461
+ spectrogram_shape: tuple[int, int],
462
+ qrange: list[float] = None,
463
+ frange: list[float] = None,
470
464
  interpolation_method="bicubic",
471
465
  mismatch: float = 0.2,
472
466
  ) -> None:
@@ -505,7 +499,7 @@ class QScan(torch.nn.Module):
505
499
  ]
506
500
  )
507
501
 
508
- def get_qs(self) -> List[float]:
502
+ def get_qs(self) -> list[float]:
509
503
  """
510
504
  Determine the values of Q to try for the set of Q-transforms
511
505
  """
@@ -523,7 +517,7 @@ class QScan(torch.nn.Module):
523
517
  def forward(
524
518
  self,
525
519
  X: TimeSeries1to3d,
526
- fsearch_range: List[float] = None,
520
+ fsearch_range: list[float] = None,
527
521
  norm: str = "median",
528
522
  ):
529
523
  """
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  import torch
4
2
  from jaxtyping import Float
5
3
  from torch import Tensor
@@ -24,7 +22,7 @@ class ChannelWiseScaler(FittableTransform):
24
22
  to be 1D (single channel).
25
23
  """
26
24
 
27
- def __init__(self, num_channels: Optional[int] = None) -> None:
25
+ def __init__(self, num_channels: int | None = None) -> None:
28
26
  super().__init__()
29
27
 
30
28
  shape = (num_channels or 1,)
@@ -37,7 +35,7 @@ class ChannelWiseScaler(FittableTransform):
37
35
  self.register_buffer("std", std)
38
36
 
39
37
  def fit(
40
- self, X: Float[Tensor, "... time"], std_reg: Optional[float] = 0.0
38
+ self, X: Float[Tensor, "... time"], std_reg: float | None = 0.0
41
39
  ) -> None:
42
40
  """Fit the scaling parameters to a timeseries
43
41
 
@@ -59,7 +57,7 @@ class ChannelWiseScaler(FittableTransform):
59
57
  else:
60
58
  raise ValueError(
61
59
  "Can't fit channel wise mean and standard deviation "
62
- "from tensor of shape {}".format(X.shape)
60
+ f"from tensor of shape {X.shape}"
63
61
  )
64
62
  std += std_reg * torch.ones_like(std)
65
63
  super().build(mean=mean, std=std)
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  import torch
4
2
 
5
3
  from ..gw import compute_network_snr
@@ -13,8 +11,8 @@ class SnrRescaler(FittableSpectralTransform):
13
11
  num_channels: int,
14
12
  sample_rate: float,
15
13
  waveform_duration: float,
16
- highpass: Optional[float] = None,
17
- lowpass: Optional[float] = None,
14
+ highpass: float | None = None,
15
+ lowpass: float | None = None,
18
16
  dtype: torch.dtype = torch.float32,
19
17
  ) -> None:
20
18
  super().__init__()
@@ -45,15 +43,13 @@ class SnrRescaler(FittableSpectralTransform):
45
43
  def fit(
46
44
  self,
47
45
  *background: TimeSeries2d,
48
- fftlength: Optional[float] = None,
49
- overlap: Optional[float] = None,
46
+ fftlength: float | None = None,
47
+ overlap: float | None = None,
50
48
  ):
51
49
  if len(background) != self.num_channels:
52
50
  raise ValueError(
53
- "Expected to fit whitening transform on {} background "
54
- "timeseries, but was passed {}".format(
55
- self.num_channels, len(background)
56
- )
51
+ f"Expected to fit whitening transform on {self.num_channels} "
52
+ f"background timeseries, but was passed {len(background)}"
57
53
  )
58
54
 
59
55
  num_freqs = self.background.size(1)
@@ -69,7 +65,7 @@ class SnrRescaler(FittableSpectralTransform):
69
65
  def forward(
70
66
  self,
71
67
  responses: WaveformTensor,
72
- target_snrs: Optional[BatchTensor] = None,
68
+ target_snrs: BatchTensor | None = None,
73
69
  ):
74
70
  snrs = compute_network_snr(
75
71
  responses,
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  import torch
4
2
  from jaxtyping import Float
5
3
  from torch import Tensor
@@ -52,20 +50,17 @@ class SpectralDensity(torch.nn.Module):
52
50
  self,
53
51
  sample_rate: float,
54
52
  fftlength: float,
55
- overlap: Optional[float] = None,
53
+ overlap: float | None = None,
56
54
  average: str = "mean",
57
- window: Optional[
58
- Float[Tensor, " {int(fftlength*sample_rate)}"]
59
- ] = None,
55
+ window: Float[Tensor, " {int(fftlength*sample_rate)}"] | None = None,
60
56
  fast: bool = False,
61
57
  ) -> None:
62
58
  if overlap is None:
63
59
  overlap = fftlength / 2
64
60
  elif overlap >= fftlength:
65
61
  raise ValueError(
66
- "Can't have overlap {} longer than fftlength {}".format(
67
- overlap, fftlength
68
- )
62
+ f"Can't have overlap {overlap} longer than fftlength "
63
+ f"{fftlength}"
69
64
  )
70
65
 
71
66
  super().__init__()
@@ -80,9 +75,7 @@ class SpectralDensity(torch.nn.Module):
80
75
 
81
76
  if window.size(0) != self.nperseg:
82
77
  raise ValueError(
83
- "Window must have length {} got {}".format(
84
- self.nperseg, window.size(0)
85
- )
78
+ f"Window must have length {self.nperseg} got {window.size(0)}"
86
79
  )
87
80
  self.register_buffer("window", window)
88
81
 
@@ -99,7 +92,7 @@ class SpectralDensity(torch.nn.Module):
99
92
  self.fast = fast
100
93
 
101
94
  def forward(
102
- self, x: TimeSeries1to3d, y: Optional[TimeSeries1to3d] = None
95
+ self, x: TimeSeries1to3d, y: TimeSeries1to3d | None = None
103
96
  ) -> FrequencySeries1to3d:
104
97
  if self.fast:
105
98
  return fast_spectral_density(