ml4gw 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ml4gw/augmentations.py +5 -0
  2. ml4gw/dataloading/__init__.py +5 -0
  3. ml4gw/dataloading/chunked_dataset.py +2 -4
  4. ml4gw/dataloading/hdf5_dataset.py +12 -10
  5. ml4gw/dataloading/in_memory_dataset.py +12 -12
  6. ml4gw/distributions.py +3 -3
  7. ml4gw/gw.py +18 -21
  8. ml4gw/nn/__init__.py +6 -0
  9. ml4gw/nn/autoencoder/base.py +5 -9
  10. ml4gw/nn/autoencoder/convolutional.py +7 -10
  11. ml4gw/nn/autoencoder/skip_connection.py +3 -5
  12. ml4gw/nn/norm.py +4 -4
  13. ml4gw/nn/resnet/resnet_1d.py +12 -13
  14. ml4gw/nn/resnet/resnet_2d.py +13 -14
  15. ml4gw/nn/streaming/online_average.py +3 -5
  16. ml4gw/nn/streaming/snapshotter.py +10 -14
  17. ml4gw/spectral.py +20 -23
  18. ml4gw/transforms/__init__.py +7 -1
  19. ml4gw/transforms/decimator.py +183 -0
  20. ml4gw/transforms/iirfilter.py +3 -5
  21. ml4gw/transforms/pearson.py +3 -4
  22. ml4gw/transforms/qtransform.py +20 -26
  23. ml4gw/transforms/scaler.py +3 -5
  24. ml4gw/transforms/snr_rescaler.py +7 -11
  25. ml4gw/transforms/spectral.py +6 -13
  26. ml4gw/transforms/spectrogram.py +6 -3
  27. ml4gw/transforms/spline_interpolation.py +312 -143
  28. ml4gw/transforms/transform.py +4 -6
  29. ml4gw/transforms/waveforms.py +8 -15
  30. ml4gw/transforms/whitening.py +11 -16
  31. ml4gw/types.py +8 -5
  32. ml4gw/utils/interferometer.py +20 -3
  33. ml4gw/utils/slicing.py +26 -30
  34. ml4gw/waveforms/__init__.py +6 -0
  35. ml4gw/waveforms/cbc/phenom_p.py +7 -9
  36. ml4gw/waveforms/conversion.py +2 -4
  37. ml4gw/waveforms/generator.py +3 -3
  38. {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/METADATA +33 -12
  39. ml4gw-0.7.8.dist-info/RECORD +57 -0
  40. {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +2 -1
  41. ml4gw-0.7.8.dist-info/top_level.txt +1 -0
  42. ml4gw-0.7.6.dist-info/RECORD +0 -55
  43. {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,4 @@
1
1
  import warnings
2
- from typing import Dict, List
3
2
 
4
3
  import torch
5
4
  import torch.nn.functional as F
@@ -104,7 +103,7 @@ class MultiResolutionSpectrogram(torch.nn.Module):
104
103
  self.register_buffer("freq_idxs", freq_idxs)
105
104
  self.register_buffer("time_idxs", time_idxs)
106
105
 
107
- def _check_and_format_kwargs(self, kwargs: Dict[str, List]) -> List:
106
+ def _check_and_format_kwargs(self, kwargs: dict[str, list]) -> list:
108
107
  lengths = sorted(len(v) for v in kwargs.values())
109
108
  lengths = list(set(lengths))
110
109
 
@@ -127,7 +126,10 @@ class MultiResolutionSpectrogram(torch.nn.Module):
127
126
  size = lengths[1]
128
127
  kwargs = {k: v * int(size / len(v)) for k, v in kwargs.items()}
129
128
 
130
- return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
129
+ return [
130
+ dict(zip(kwargs, col, strict=True))
131
+ for col in zip(*kwargs.values(), strict=True)
132
+ ]
131
133
 
132
134
  def forward(
133
135
  self, X: TimeSeries3d
@@ -161,6 +163,7 @@ class MultiResolutionSpectrogram(torch.nn.Module):
161
163
  self.right_pad,
162
164
  self.top_pad,
163
165
  self.bottom_pad,
166
+ strict=True,
164
167
  ):
165
168
  padded_specs.append(F.pad(spec, (left, right, top, bottom)))
166
169
 
@@ -1,131 +1,27 @@
1
1
  """
2
- Adaptation of code from https://github.com/dottormale/Qtransform
2
+ Adaptation of code from https://github.com/dottormale/Qtransform_torch/
3
3
  """
4
4
 
5
- from typing import Optional, Tuple
6
-
7
5
  import torch
8
- import torch.nn.functional as F
9
6
  from torch import Tensor
10
7
 
11
8
 
12
- class SplineInterpolate(torch.nn.Module):
9
+ class SplineInterpolateBase(torch.nn.Module):
13
10
  """
14
- Perform 1D or 2D spline interpolation based on De Boor's method.
15
- Supports batched, multi-channel inputs, so acceptable data
16
- shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
17
- ``(batch, height, width)``, ``(batch, channel, width)``, and
18
- ``(batch, channel, height, width)``.
19
-
20
- During initialization of this Module, both the desired input
21
- and output coordinate Tensors can be specified to allow
22
- pre-computation of the B-spline basis matrices, though the only
23
- mandatory argument is the coordinates of the data along the
24
- ``width`` dimension. If no argument is given for coordinates along
25
- the ``height`` dimension, it is assumed that 1D interpolation is
26
- desired.
27
-
28
- Unlike scipy's implementation of spline interpolation, the data
29
- to be interpolated is not passed until actually calling the
30
- object. This is useful for cases where the input and output
31
- coordinates are known in advance, but the data is not, so that
32
- the interpolator can be set up ahead of time.
33
-
34
- WARNING: compared to scipy's spline interpolation, this method
35
- produces edge artifacts when the output coordinates are near
36
- the boundaries of the input coordinates. Therefore, it is
37
- recommended to interpolate only to coordinates that are well
38
- within the input coordinate range. Unfortunately, the specific
39
- definition of "well within" changes based on the size of the
40
- data, so some testing may be required to get good results.
41
-
42
- Args:
43
- x_in:
44
- Coordinates of the width dimension of the data
45
- y_in:
46
- Coordinates of the height dimension of the data. If not
47
- specified, it is assumed the 1D interpolation is desired,
48
- and so the default value is a Tensor of length 1
49
- kx:
50
- Degree of spline interpolation along the width dimension.
51
- Default is cubic.
52
- ky:
53
- Degree of spline interpolation along the height dimension.
54
- Default is cubic.
55
- sx:
56
- Regularization factor to avoid singularities during matrix
57
- inversion for interpolation along the width dimension. Not
58
- to be confused with the ``s`` parameter in scipy's spline
59
- methods, which controls the number of knots.
60
- sy:
61
- Regularization factor to avoid singularities during matrix
62
- inversion for interpolation along the height dimension.
63
- x_out:
64
- Coordinates for the data to be interpolated to along the
65
- width dimension. If not specified during initialization,
66
- this must be specified during the object call.
67
- y_out:
68
- Coordinates for the data to be interpolated to along the
69
- height dimension. If not specified during initialization,
70
- this must be specified during the object call.
71
-
11
+ Base class for spline interpolation.
72
12
  """
73
13
 
74
- def __init__(
75
- self,
76
- x_in: Tensor,
77
- y_in: Tensor = None,
78
- kx: int = 3,
79
- ky: int = 3,
80
- sx: float = 0.001,
81
- sy: float = 0.001,
82
- x_out: Optional[Tensor] = None,
83
- y_out: Optional[Tensor] = None,
84
- ):
85
- super().__init__()
86
- if y_in is None:
87
- y_in = Tensor([1])
88
- self.kx = kx
89
- self.ky = ky
90
- self.sx = sx
91
- self.sy = sy
92
- self.register_buffer("x_in", x_in)
93
- self.register_buffer("y_in", y_in)
94
- self.register_buffer("x_out", x_out)
95
- self.register_buffer("y_out", y_out)
96
-
97
- tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
98
- self.register_buffer("tx", tx)
99
- self.register_buffer("Bx", Bx)
100
- self.register_buffer("BxT_Bx", BxT_Bx)
101
-
102
- ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
103
- self.register_buffer("ty", ty)
104
- self.register_buffer("By", By)
105
- self.register_buffer("ByT_By", ByT_By)
106
-
107
- if self.x_out is not None:
108
- Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
109
- self.register_buffer("Bx_out", Bx_out)
110
- if self.y_out is not None:
111
- By_out = self.bspline_basis_natural(y_out, ky, self.ty)
112
- self.register_buffer("By_out", By_out)
113
-
114
14
  def _compute_knots_and_basis_matrices(self, x, k, s):
115
- knots = self.generate_natural_knots(x, k)
15
+ knots = self.generate_fitpack_knots(x, k)
116
16
  basis_matrix = self.bspline_basis_natural(x, k, knots)
117
17
  identity = torch.eye(basis_matrix.shape[-1])
118
18
  B_T_B = basis_matrix.T @ basis_matrix + s * identity
119
19
  return knots, basis_matrix, B_T_B
120
20
 
121
- def generate_natural_knots(self, x: Tensor, k: int) -> Tensor:
21
+ def generate_fitpack_knots(self, x: Tensor, k: int) -> Tensor:
122
22
  """
123
- Generates a natural knot sequence for B-spline interpolation.
124
- Natural knot sequence means that 2*k knots are added to the beginning
125
- and end of datapoints as replicas of first and last datapoint
126
- respectively in order to enforce natural boundary conditions,
127
- i.e. second derivative = 0.
128
- The other n nodes are placed in correspondece of the data points.
23
+ Generates a knot sequence for B-spline interpolation
24
+ in the same way as the FITPACK algorithm used by SciPy.
129
25
 
130
26
  Args:
131
27
  x: Tensor of data point positions.
@@ -134,7 +30,17 @@ class SplineInterpolate(torch.nn.Module):
134
30
  Returns:
135
31
  Tensor of knot positions.
136
32
  """
137
- return F.pad(x[None], (k, k), mode="replicate")[0]
33
+ num_knots = x.shape[-1] + k + 1
34
+ knots = torch.zeros(num_knots, dtype=x.dtype)
35
+ knots[: k + 1] = x[0]
36
+ knots[-(k + 1) :] = x[-1]
37
+
38
+ # Interior knots are the rolling average of the data points
39
+ # excluding the first and last points
40
+ windows = x[1:-1].unfold(dimension=-1, size=k, step=1)
41
+ knots[k + 1 : -k - 1] = windows.mean(dim=-1)
42
+
43
+ return knots
138
44
 
139
45
  def compute_L_R(
140
46
  self,
@@ -142,7 +48,7 @@ class SplineInterpolate(torch.nn.Module):
142
48
  t: Tensor,
143
49
  d: int,
144
50
  m: int,
145
- ) -> Tuple[Tensor, Tensor]:
51
+ ) -> tuple[Tensor, Tensor]:
146
52
  """
147
53
  Compute the L and R values for B-spline basis functions.
148
54
  L and R are respectively the first and second coefficient multiplying
@@ -233,9 +139,6 @@ class SplineInterpolate(torch.nn.Module):
233
139
  Returns:
234
140
  Tensor containing the kth-order B-spline basis functions
235
141
  """
236
-
237
- if len(x) == 1:
238
- return torch.eye(1)
239
142
  n = x.shape[0]
240
143
  m = t.shape[0] - k - 1
241
144
 
@@ -255,6 +158,271 @@ class SplineInterpolate(torch.nn.Module):
255
158
 
256
159
  return b[:, :, -1]
257
160
 
161
+
162
+ class SplineInterpolate1D(SplineInterpolateBase):
163
+ """
164
+ Perform 1D spline interpolation based on De Boor's method.
165
+ It is allowed to have two spatial dimensions, but the second
166
+ dimension cannot be interpolated along. To interpolate along both
167
+ dimensions, use :class:`SplineInterpolate2D`.
168
+
169
+ Supports batched, multi-channel inputs, so acceptable data
170
+ shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
171
+ ``(batch, height, width)``, ``(batch, channel, width)``, and
172
+ ``(batch, channel, height, width)``.
173
+
174
+ During initialization of this Module, both the desired input
175
+ and output coordinate Tensors can be specified to allow
176
+ pre-computation of the B-spline basis matrices, though the only
177
+ mandatory argument is the coordinates of the data along the
178
+ ``width`` dimension.
179
+
180
+ Unlike scipy's implementation of spline interpolation, the data
181
+ to be interpolated is not passed until actually calling the
182
+ object. This is useful for cases where the input and output
183
+ coordinates are known in advance, but the data is not, so that
184
+ the interpolator can be set up ahead of time.
185
+
186
+ Args:
187
+ x_in:
188
+ Coordinates of the width dimension of the data
189
+ kx:
190
+ Degree of spline interpolation along the width dimension.
191
+ Default is cubic.
192
+ sx:
193
+ Regularization factor to avoid singularities during matrix
194
+ inversion for interpolation along the width dimension. Not
195
+ to be confused with the ``s`` parameter in scipy's spline
196
+ methods, which controls the number of knots.
197
+ x_out:
198
+ Coordinates for the data to be interpolated to along the
199
+ width dimension. If not specified during initialization,
200
+ this must be specified during the object call.
201
+
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ x_in: Tensor,
207
+ kx: int = 3,
208
+ sx: float = 0.0,
209
+ x_out: Tensor | None = None,
210
+ ):
211
+ super().__init__()
212
+
213
+ if len(x_in) < kx + 2:
214
+ raise ValueError(
215
+ "Input x-coordinates must have at least kx + 2 points."
216
+ )
217
+
218
+ # Ensure that coordinates are floats
219
+ x_in = x_in.float()
220
+ x_out = x_out.float() if x_out is not None else None
221
+
222
+ self.kx = kx
223
+ self.sx = sx
224
+ self.register_buffer("x_in", x_in)
225
+ self.register_buffer("x_out", x_out)
226
+
227
+ tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
228
+ self.register_buffer("tx", tx)
229
+ self.register_buffer("Bx", Bx)
230
+ self.register_buffer("BxT_Bx", BxT_Bx)
231
+
232
+ if self.x_out is not None:
233
+ x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
234
+ Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
235
+ self.register_buffer("Bx_out", Bx_out)
236
+
237
+ def spline_fit_natural(self, Z):
238
+ # Adding batch/channel dimension handling
239
+ # Bx @ Z
240
+ BxT_Z = torch.einsum("ij,bchj->bchi", self.Bx.T, Z)
241
+ # (BxT @ Bx)^-1 @ (BxT @ Z) = Bx^-1 @ Z
242
+ C = torch.linalg.solve(self.BxT_Bx, BxT_Z.unsqueeze(-1))
243
+ return C.squeeze(-1)
244
+
245
+ def evaluate_spline(self, C: Tensor):
246
+ """
247
+ Evaluate a bivariate spline on a grid of x and y points.
248
+
249
+ Args:
250
+ C: Coefficient tensor of shape (batch_size, mx, my).
251
+
252
+ Returns:
253
+ Z_interp: Interpolated values at the grid points.
254
+ """
255
+ # Perform matrix multiplication using einsum to get Z_interp
256
+ return torch.einsum("ij,bchj->bchi", self.Bx_out, C)
257
+
258
+ def _validate_inputs(self, Z, x_out):
259
+ if x_out is None and self.x_out is None:
260
+ raise ValueError(
261
+ "Output x-coordinates were not specified in either object "
262
+ "creation or in forward call"
263
+ )
264
+
265
+ dims = len(Z.shape)
266
+ if dims > 4:
267
+ raise ValueError("Input data has more than 4 dimensions")
268
+
269
+ if Z.shape[-1] != len(self.x_in):
270
+ raise ValueError(
271
+ "The spatial dimensions of the data tensor do not match "
272
+ "the given input dimensions. "
273
+ f"Expected {len(self.x_in)}, but got {Z.shape[-1]}"
274
+ )
275
+
276
+ # Expand Z to have a batch, channel, and height dimension if needed
277
+ while len(Z.shape) < 4:
278
+ Z = Z.unsqueeze(0)
279
+
280
+ return Z
281
+
282
+ def forward(
283
+ self,
284
+ Z: Tensor,
285
+ x_out: Tensor | None = None,
286
+ ) -> Tensor:
287
+ """
288
+ Compute the interpolated data
289
+
290
+ Args:
291
+ Z:
292
+ Tensor of data to be interpolated. Must be between 2 and 4
293
+ dimensions. The shape of the tensor must agree with the
294
+ input coordinates given on initialization.
295
+ x_out:
296
+ Coordinates to interpolate the data to along the width
297
+ dimension. Overrides any value that was set during
298
+ initialization.
299
+
300
+ Returns:
301
+ A 4D tensor with shape ``(batch, channel, height, width)``.
302
+ Depending on the input data shape, many of these dimensions
303
+ may have length 1.
304
+ """
305
+
306
+ Z = self._validate_inputs(Z, x_out)
307
+
308
+ if x_out is not None:
309
+ x_out = x_out.float()
310
+ x_clamped = torch.clamp(
311
+ x_out, self.tx[self.kx], self.tx[-self.kx - 1]
312
+ )
313
+ self.Bx_out = self.bspline_basis_natural(
314
+ x_clamped, self.kx, self.tx
315
+ )
316
+
317
+ coef = self.spline_fit_natural(Z)
318
+ Z_interp = self.evaluate_spline(coef)
319
+ return Z_interp
320
+
321
+
322
+ class SplineInterpolate2D(SplineInterpolateBase):
323
+ """
324
+ Perform 2D spline interpolation based on De Boor's method.
325
+ Supports batched, multi-channel inputs, so acceptable data
326
+ shapes are ``(height, width)``, ``(batch, height, width)``,
327
+ and ``(batch, channel, height, width)``.
328
+
329
+ During initialization of this Module, both the desired input
330
+ and output coordinate Tensors can be specified to allow
331
+ pre-computation of the B-spline basis matrices, though the only
332
+ mandatory arguments are the input coordinates.
333
+
334
+ Unlike scipy's implementation of spline interpolation, the data
335
+ to be interpolated is not passed until actually calling the
336
+ object. This is useful for cases where the input and output
337
+ coordinates are known in advance, but the data is not, so that
338
+ the interpolator can be set up ahead of time.
339
+
340
+ Args:
341
+ x_in:
342
+ Coordinates of the width dimension of the data
343
+ y_in:
344
+ Coordinates of the height dimension of the data.
345
+ kx:
346
+ Degree of spline interpolation along the width dimension.
347
+ Default is cubic.
348
+ ky:
349
+ Degree of spline interpolation along the height dimension.
350
+ Default is cubic.
351
+ sx:
352
+ Regularization factor to avoid singularities during matrix
353
+ inversion for interpolation along the width dimension. Not
354
+ to be confused with the ``s`` parameter in scipy's spline
355
+ methods, which controls the number of knots.
356
+ sy:
357
+ Regularization factor to avoid singularities during matrix
358
+ inversion for interpolation along the height dimension.
359
+ x_out:
360
+ Coordinates for the data to be interpolated to along the
361
+ width dimension. If not specified during initialization,
362
+ this must be specified during the object call.
363
+ y_out:
364
+ Coordinates for the data to be interpolated to along the
365
+ height dimension. If not specified during initialization,
366
+ this must be specified during the object call.
367
+
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ x_in: Tensor,
373
+ y_in: Tensor,
374
+ kx: int = 3,
375
+ ky: int = 3,
376
+ sx: float = 0.0,
377
+ sy: float = 0.0,
378
+ x_out: Tensor | None = None,
379
+ y_out: Tensor | None = None,
380
+ ):
381
+ super().__init__()
382
+
383
+ if len(x_in) < kx + 2:
384
+ raise ValueError(
385
+ "Input x-coordinates must have at least kx + 2 points."
386
+ )
387
+ if len(y_in) < ky + 2:
388
+ raise ValueError(
389
+ "Input y-coordinates must have at least ky + 2 points."
390
+ )
391
+
392
+ # Ensure that coordinates are floats
393
+ x_in = x_in.float()
394
+ y_in = y_in.float()
395
+ x_out = x_out.float() if x_out is not None else None
396
+ y_out = y_out.float() if y_out is not None else None
397
+
398
+ self.kx = kx
399
+ self.ky = ky
400
+ self.sx = sx
401
+ self.sy = sy
402
+ self.register_buffer("x_in", x_in)
403
+ self.register_buffer("y_in", y_in)
404
+ self.register_buffer("x_out", x_out)
405
+ self.register_buffer("y_out", y_out)
406
+
407
+ tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
408
+ self.register_buffer("tx", tx)
409
+ self.register_buffer("Bx", Bx)
410
+ self.register_buffer("BxT_Bx", BxT_Bx)
411
+
412
+ ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
413
+ self.register_buffer("ty", ty)
414
+ self.register_buffer("By", By)
415
+ self.register_buffer("ByT_By", ByT_By)
416
+
417
+ if self.x_out is not None:
418
+ x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
419
+ Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
420
+ self.register_buffer("Bx_out", Bx_out)
421
+ if self.y_out is not None:
422
+ y_clamped = torch.clamp(y_out, ty[ky], ty[-ky - 1])
423
+ By_out = self.bspline_basis_natural(y_clamped, ky, self.ty)
424
+ self.register_buffer("By_out", By_out)
425
+
258
426
  def bivariate_spline_fit_natural(self, Z):
259
427
  # Adding batch/channel dimension handling
260
428
  # ByT @ Z @ BxW
@@ -285,29 +453,16 @@ class SplineInterpolate(torch.nn.Module):
285
453
  )
286
454
 
287
455
  if y_out is None and self.y_out is None:
288
- y_out = self.y_in
456
+ raise ValueError(
457
+ "Output y-coordinates were not specified in either object "
458
+ "creation or in forward call"
459
+ )
289
460
 
290
461
  dims = len(Z.shape)
291
462
  if dims > 4:
292
463
  raise ValueError("Input data has more than 4 dimensions")
293
-
294
- if len(self.y_in) > 1 and dims == 1:
295
- raise ValueError(
296
- "An input y-coordinate array with length greater than 1 "
297
- "was given, but the input data is 1-dimensional. Expected "
298
- "input data to be at least 2-dimensional"
299
- )
300
-
301
- # Expand Z to have 4 dimensions
302
- # There are 6 valid input shapes: (w), (b, w), (b, c, w),
303
- # (h, w), (b, h, w), and (b, c, h, w).
304
-
305
- # If the input y coordinate array has length 1,
306
- # assume the first dimension(s) are batch dimensions
307
- # and that no height dimension is included in Z
308
- idx = -2 if len(self.y_in) == 1 else -3
309
- while len(Z.shape) < 4:
310
- Z = Z.unsqueeze(idx)
464
+ if dims < 2:
465
+ raise ValueError("Input data has fewer than 2 dimensions")
311
466
 
312
467
  if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
313
468
  raise ValueError(
@@ -317,24 +472,26 @@ class SplineInterpolate(torch.nn.Module):
317
472
  f"[{Z.shape[-2]}, {Z.shape[-1]}]"
318
473
  )
319
474
 
475
+ # Expand Z to have a batch and channel dimension if needed
476
+ while len(Z.shape) < 4:
477
+ Z = Z.unsqueeze(0)
478
+
320
479
  return Z, y_out
321
480
 
322
481
  def forward(
323
482
  self,
324
483
  Z: Tensor,
325
- x_out: Optional[Tensor] = None,
326
- y_out: Optional[Tensor] = None,
484
+ x_out: Tensor | None = None,
485
+ y_out: Tensor | None = None,
327
486
  ) -> Tensor:
328
487
  """
329
488
  Compute the interpolated data
330
489
 
331
490
  Args:
332
491
  Z:
333
- Tensor of data to be interpolated. Must be between 1 and 4
492
+ Tensor of data to be interpolated. Must be between 2 and 4
334
493
  dimensions. The shape of the tensor must agree with the
335
- input coordinates given on initialization. If ``y_in`` was
336
- not specified during initialization, it is assumed that
337
- Z does not have a height dimension.
494
+ input coordinates given on initialization.
338
495
  x_out:
339
496
  Coordinates to interpolate the data to along the width
340
497
  dimension. Overrides any value that was set during
@@ -353,9 +510,21 @@ class SplineInterpolate(torch.nn.Module):
353
510
  Z, y_out = self._validate_inputs(Z, x_out, y_out)
354
511
 
355
512
  if x_out is not None:
356
- self.Bx_out = self.bspline_basis_natural(x_out, self.kx, self.tx)
513
+ x_out = x_out.float()
514
+ x_clamped = torch.clamp(
515
+ x_out, self.tx[self.kx], self.tx[-self.kx - 1]
516
+ )
517
+ self.Bx_out = self.bspline_basis_natural(
518
+ x_clamped, self.kx, self.tx
519
+ )
357
520
  if y_out is not None:
358
- self.By_out = self.bspline_basis_natural(y_out, self.ky, self.ty)
521
+ y_out = y_out.float()
522
+ y_clamped = torch.clamp(
523
+ y_out, self.ty[self.ky], self.ty[-self.ky - 1]
524
+ )
525
+ self.By_out = self.bspline_basis_natural(
526
+ y_clamped, self.ky, self.ty
527
+ )
359
528
 
360
529
  coef = self.bivariate_spline_fit_natural(Z)
361
530
  Z_interp = self.evaluate_bivariate_spline(coef)
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  import torch
4
2
 
5
3
  from ..spectral import spectral_density
@@ -20,8 +18,8 @@ class FittableTransform(torch.nn.Module):
20
18
  def _check_built(self):
21
19
  if not self.built:
22
20
  raise ValueError(
23
- "Must fit parameters of {} transform to data "
24
- "before calling forward step".format(self.__class__.__name__)
21
+ f"Must fit parameters of {self.__class__.__name__} transform "
22
+ "to data before calling forward step"
25
23
  )
26
24
 
27
25
  def __call__(self, *args, **kwargs):
@@ -47,8 +45,8 @@ class FittableSpectralTransform(FittableTransform):
47
45
  x: TimeSeries1to3d,
48
46
  sample_rate: float,
49
47
  num_freqs: int,
50
- fftlength: Optional[float] = None,
51
- overlap: Optional[float] = None,
48
+ fftlength: float | None = None,
49
+ overlap: float | None = None,
52
50
  ) -> FrequencySeries1to3d:
53
51
  # if we specified an FFT length, convert
54
52
  # the (assumed) time-domain data to the
@@ -1,5 +1,3 @@
1
- from typing import List, Optional
2
-
3
1
  import torch
4
2
  from jaxtyping import Float
5
3
  from torch import Tensor
@@ -13,7 +11,7 @@ from ..types import BatchTensor
13
11
  class WaveformSampler(torch.nn.Module):
14
12
  def __init__(
15
13
  self,
16
- parameters: Optional[Float[Tensor, "batch num_params"]] = None,
14
+ parameters: Float[Tensor, "batch num_params"] | None = None,
17
15
  **polarizations: Float[Tensor, "batch time"],
18
16
  ):
19
17
  super().__init__()
@@ -24,10 +22,8 @@ class WaveformSampler(torch.nn.Module):
24
22
  for polarization, tensor in polarizations.items():
25
23
  if num_waveforms is not None and len(tensor) != num_waveforms:
26
24
  raise ValueError(
27
- "Polarization {} has {} waveforms "
28
- "associated with it, expected {}".format(
29
- polarization, len(tensor), num_waveforms
30
- )
25
+ f"Polarization {polarization} has {len(tensor)} waveforms "
26
+ f"associated with it, expected {num_waveforms}"
31
27
  )
32
28
  elif num_waveforms is None:
33
29
  num_waveforms = tensor.shape[0]
@@ -36,10 +32,8 @@ class WaveformSampler(torch.nn.Module):
36
32
 
37
33
  if parameters is not None and len(parameters) != num_waveforms:
38
34
  raise ValueError(
39
- "Waveform parameters has {} waveforms "
40
- "associated with it, expected {}".format(
41
- len(parameters), num_waveforms
42
- )
35
+ f"Waveform parameters has {len(parameters)} waveforms "
36
+ f"associated with it, expected {num_waveforms}"
43
37
  )
44
38
  self.num_waveforms = num_waveforms
45
39
  self.parameters = parameters
@@ -48,9 +42,8 @@ class WaveformSampler(torch.nn.Module):
48
42
  # TODO: should we allow sampling with replacement?
49
43
  if N > self.num_waveforms:
50
44
  raise ValueError(
51
- "Requested {} waveforms, but only {} are available".format(
52
- N, self.num_waveforms
53
- )
45
+ f"Requested {N} waveforms, but only {self.num_waveforms} are "
46
+ "available"
54
47
  )
55
48
  # TODO: do we still really want this behavior here when a
56
49
  # user can do this without instantiating a WaveformSampler?
@@ -67,7 +60,7 @@ class WaveformSampler(torch.nn.Module):
67
60
 
68
61
 
69
62
  class WaveformProjector(torch.nn.Module):
70
- def __init__(self, ifos: List[str], sample_rate: float):
63
+ def __init__(self, ifos: list[str], sample_rate: float):
71
64
  super().__init__()
72
65
  tensors, vertices = gw.get_ifo_geometry(*ifos)
73
66
  self.sample_rate = sample_rate