ml4gw 0.7.6__py3-none-any.whl → 0.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/distributions.py CHANGED
@@ -137,7 +137,7 @@ class PowerLaw(dist.TransformedDistribution):
137
137
  support = dist.constraints.nonnegative
138
138
 
139
139
  def __init__(
140
- self, minimum: float, maximum: float, index: int, validate_args=None
140
+ self, minimum: float, maximum: float, index: float, validate_args=None
141
141
  ):
142
142
  if index == 0:
143
143
  raise ValueError("Index of 0 is the same as Uniform")
@@ -5,6 +5,6 @@ from .scaler import ChannelWiseScaler
5
5
  from .snr_rescaler import SnrRescaler
6
6
  from .spectral import SpectralDensity
7
7
  from .spectrogram import MultiResolutionSpectrogram
8
- from .spline_interpolation import SplineInterpolate
8
+ from .spline_interpolation import SplineInterpolate1D, SplineInterpolate2D
9
9
  from .waveforms import WaveformProjector, WaveformSampler
10
10
  from .whitening import FixedWhiten, Whiten
@@ -8,7 +8,7 @@ from jaxtyping import Float, Int
8
8
  from torch import Tensor
9
9
 
10
10
  from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
11
- from .spline_interpolation import SplineInterpolate
11
+ from .spline_interpolation import SplineInterpolate1D
12
12
 
13
13
  """
14
14
  All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
@@ -260,18 +260,15 @@ class SingleQTransform(torch.nn.Module):
260
260
  )
261
261
  self.qtile_interpolators = torch.nn.ModuleList(
262
262
  [
263
- SplineInterpolate(
263
+ SplineInterpolate1D(
264
264
  kx=3,
265
265
  x_in=torch.arange(0, self.duration, self.duration / tiles),
266
- y_in=torch.arange(len(idx)),
267
266
  x_out=t_out,
268
- y_out=torch.arange(len(idx)),
269
267
  )
270
- for tiles, idx in zip(unique_ntiles, self.stack_idx)
268
+ for tiles in unique_ntiles
271
269
  ]
272
270
  )
273
271
 
274
- t_in = t_out
275
272
  f_in = self.freqs
276
273
  f_out = torch.logspace(
277
274
  math.log10(self.frange[0]),
@@ -279,13 +276,10 @@ class SingleQTransform(torch.nn.Module):
279
276
  self.spectrogram_shape[0],
280
277
  )
281
278
 
282
- self.interpolator = SplineInterpolate(
279
+ self.interpolator = SplineInterpolate1D(
283
280
  kx=3,
284
- ky=3,
285
- x_in=t_in,
286
- y_in=f_in,
287
- x_out=t_out,
288
- y_out=f_out,
281
+ x_in=f_in,
282
+ x_out=f_out,
289
283
  )
290
284
 
291
285
  def get_freqs(self) -> Float[Tensor, " nfreq"]:
@@ -379,14 +373,15 @@ class SingleQTransform(torch.nn.Module):
379
373
  ]
380
374
  time_interped = torch.cat(
381
375
  [
382
- interpolator(qtile)
383
- for qtile, interpolator in zip(
376
+ qtile_interpolator(qtile)
377
+ for qtile, qtile_interpolator in zip(
384
378
  qtiles, self.qtile_interpolators
385
379
  )
386
380
  ],
387
381
  dim=-2,
388
382
  )
389
- return self.interpolator(time_interped)
383
+ # Transpose because the final dimension gets interpolated
384
+ return self.interpolator(time_interped.mT).mT
390
385
  num_f_bins, num_t_bins = self.spectrogram_shape
391
386
  resampled = [
392
387
  F.interpolate(
@@ -1,131 +1,29 @@
1
1
  """
2
- Adaptation of code from https://github.com/dottormale/Qtransform
2
+ Adaptation of code from https://github.com/dottormale/Qtransform_torch/
3
3
  """
4
4
 
5
5
  from typing import Optional, Tuple
6
6
 
7
7
  import torch
8
- import torch.nn.functional as F
9
8
  from torch import Tensor
10
9
 
11
10
 
12
- class SplineInterpolate(torch.nn.Module):
11
+ class SplineInterpolateBase(torch.nn.Module):
13
12
  """
14
- Perform 1D or 2D spline interpolation based on De Boor's method.
15
- Supports batched, multi-channel inputs, so acceptable data
16
- shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
17
- ``(batch, height, width)``, ``(batch, channel, width)``, and
18
- ``(batch, channel, height, width)``.
19
-
20
- During initialization of this Module, both the desired input
21
- and output coordinate Tensors can be specified to allow
22
- pre-computation of the B-spline basis matrices, though the only
23
- mandatory argument is the coordinates of the data along the
24
- ``width`` dimension. If no argument is given for coordinates along
25
- the ``height`` dimension, it is assumed that 1D interpolation is
26
- desired.
27
-
28
- Unlike scipy's implementation of spline interpolation, the data
29
- to be interpolated is not passed until actually calling the
30
- object. This is useful for cases where the input and output
31
- coordinates are known in advance, but the data is not, so that
32
- the interpolator can be set up ahead of time.
33
-
34
- WARNING: compared to scipy's spline interpolation, this method
35
- produces edge artifacts when the output coordinates are near
36
- the boundaries of the input coordinates. Therefore, it is
37
- recommended to interpolate only to coordinates that are well
38
- within the input coordinate range. Unfortunately, the specific
39
- definition of "well within" changes based on the size of the
40
- data, so some testing may be required to get good results.
41
-
42
- Args:
43
- x_in:
44
- Coordinates of the width dimension of the data
45
- y_in:
46
- Coordinates of the height dimension of the data. If not
47
- specified, it is assumed the 1D interpolation is desired,
48
- and so the default value is a Tensor of length 1
49
- kx:
50
- Degree of spline interpolation along the width dimension.
51
- Default is cubic.
52
- ky:
53
- Degree of spline interpolation along the height dimension.
54
- Default is cubic.
55
- sx:
56
- Regularization factor to avoid singularities during matrix
57
- inversion for interpolation along the width dimension. Not
58
- to be confused with the ``s`` parameter in scipy's spline
59
- methods, which controls the number of knots.
60
- sy:
61
- Regularization factor to avoid singularities during matrix
62
- inversion for interpolation along the height dimension.
63
- x_out:
64
- Coordinates for the data to be interpolated to along the
65
- width dimension. If not specified during initialization,
66
- this must be specified during the object call.
67
- y_out:
68
- Coordinates for the data to be interpolated to along the
69
- height dimension. If not specified during initialization,
70
- this must be specified during the object call.
71
-
13
+ Base class for spline interpolation.
72
14
  """
73
15
 
74
- def __init__(
75
- self,
76
- x_in: Tensor,
77
- y_in: Tensor = None,
78
- kx: int = 3,
79
- ky: int = 3,
80
- sx: float = 0.001,
81
- sy: float = 0.001,
82
- x_out: Optional[Tensor] = None,
83
- y_out: Optional[Tensor] = None,
84
- ):
85
- super().__init__()
86
- if y_in is None:
87
- y_in = Tensor([1])
88
- self.kx = kx
89
- self.ky = ky
90
- self.sx = sx
91
- self.sy = sy
92
- self.register_buffer("x_in", x_in)
93
- self.register_buffer("y_in", y_in)
94
- self.register_buffer("x_out", x_out)
95
- self.register_buffer("y_out", y_out)
96
-
97
- tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
98
- self.register_buffer("tx", tx)
99
- self.register_buffer("Bx", Bx)
100
- self.register_buffer("BxT_Bx", BxT_Bx)
101
-
102
- ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
103
- self.register_buffer("ty", ty)
104
- self.register_buffer("By", By)
105
- self.register_buffer("ByT_By", ByT_By)
106
-
107
- if self.x_out is not None:
108
- Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
109
- self.register_buffer("Bx_out", Bx_out)
110
- if self.y_out is not None:
111
- By_out = self.bspline_basis_natural(y_out, ky, self.ty)
112
- self.register_buffer("By_out", By_out)
113
-
114
16
  def _compute_knots_and_basis_matrices(self, x, k, s):
115
- knots = self.generate_natural_knots(x, k)
17
+ knots = self.generate_fitpack_knots(x, k)
116
18
  basis_matrix = self.bspline_basis_natural(x, k, knots)
117
19
  identity = torch.eye(basis_matrix.shape[-1])
118
20
  B_T_B = basis_matrix.T @ basis_matrix + s * identity
119
21
  return knots, basis_matrix, B_T_B
120
22
 
121
- def generate_natural_knots(self, x: Tensor, k: int) -> Tensor:
23
+ def generate_fitpack_knots(self, x: Tensor, k: int) -> Tensor:
122
24
  """
123
- Generates a natural knot sequence for B-spline interpolation.
124
- Natural knot sequence means that 2*k knots are added to the beginning
125
- and end of datapoints as replicas of first and last datapoint
126
- respectively in order to enforce natural boundary conditions,
127
- i.e. second derivative = 0.
128
- The other n nodes are placed in correspondece of the data points.
25
+ Generates a knot sequence for B-spline interpolation
26
+ in the same way as the FITPACK algorithm used by SciPy.
129
27
 
130
28
  Args:
131
29
  x: Tensor of data point positions.
@@ -134,7 +32,17 @@ class SplineInterpolate(torch.nn.Module):
134
32
  Returns:
135
33
  Tensor of knot positions.
136
34
  """
137
- return F.pad(x[None], (k, k), mode="replicate")[0]
35
+ num_knots = x.shape[-1] + k + 1
36
+ knots = torch.zeros(num_knots, dtype=x.dtype)
37
+ knots[: k + 1] = x[0]
38
+ knots[-(k + 1) :] = x[-1]
39
+
40
+ # Interior knots are the rolling average of the data points
41
+ # excluding the first and last points
42
+ windows = x[1:-1].unfold(dimension=-1, size=k, step=1)
43
+ knots[k + 1 : -k - 1] = windows.mean(dim=-1)
44
+
45
+ return knots
138
46
 
139
47
  def compute_L_R(
140
48
  self,
@@ -233,9 +141,6 @@ class SplineInterpolate(torch.nn.Module):
233
141
  Returns:
234
142
  Tensor containing the kth-order B-spline basis functions
235
143
  """
236
-
237
- if len(x) == 1:
238
- return torch.eye(1)
239
144
  n = x.shape[0]
240
145
  m = t.shape[0] - k - 1
241
146
 
@@ -255,6 +160,271 @@ class SplineInterpolate(torch.nn.Module):
255
160
 
256
161
  return b[:, :, -1]
257
162
 
163
+
164
+ class SplineInterpolate1D(SplineInterpolateBase):
165
+ """
166
+ Perform 1D spline interpolation based on De Boor's method.
167
+ It is allowed to have two spatial dimensions, but the second
168
+ dimension cannot be interpolated along. To interpolate along both
169
+ dimensions, use :class:`SplineInterpolate2D`.
170
+
171
+ Supports batched, multi-channel inputs, so acceptable data
172
+ shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
173
+ ``(batch, height, width)``, ``(batch, channel, width)``, and
174
+ ``(batch, channel, height, width)``.
175
+
176
+ During initialization of this Module, both the desired input
177
+ and output coordinate Tensors can be specified to allow
178
+ pre-computation of the B-spline basis matrices, though the only
179
+ mandatory argument is the coordinates of the data along the
180
+ ``width`` dimension.
181
+
182
+ Unlike scipy's implementation of spline interpolation, the data
183
+ to be interpolated is not passed until actually calling the
184
+ object. This is useful for cases where the input and output
185
+ coordinates are known in advance, but the data is not, so that
186
+ the interpolator can be set up ahead of time.
187
+
188
+ Args:
189
+ x_in:
190
+ Coordinates of the width dimension of the data
191
+ kx:
192
+ Degree of spline interpolation along the width dimension.
193
+ Default is cubic.
194
+ sx:
195
+ Regularization factor to avoid singularities during matrix
196
+ inversion for interpolation along the width dimension. Not
197
+ to be confused with the ``s`` parameter in scipy's spline
198
+ methods, which controls the number of knots.
199
+ x_out:
200
+ Coordinates for the data to be interpolated to along the
201
+ width dimension. If not specified during initialization,
202
+ this must be specified during the object call.
203
+
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ x_in: Tensor,
209
+ kx: int = 3,
210
+ sx: float = 0.0,
211
+ x_out: Optional[Tensor] = None,
212
+ ):
213
+ super().__init__()
214
+
215
+ if len(x_in) < kx + 2:
216
+ raise ValueError(
217
+ "Input x-coordinates must have at least kx + 2 points."
218
+ )
219
+
220
+ # Ensure that coordinates are floats
221
+ x_in = x_in.float()
222
+ x_out = x_out.float() if x_out is not None else None
223
+
224
+ self.kx = kx
225
+ self.sx = sx
226
+ self.register_buffer("x_in", x_in)
227
+ self.register_buffer("x_out", x_out)
228
+
229
+ tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
230
+ self.register_buffer("tx", tx)
231
+ self.register_buffer("Bx", Bx)
232
+ self.register_buffer("BxT_Bx", BxT_Bx)
233
+
234
+ if self.x_out is not None:
235
+ x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
236
+ Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
237
+ self.register_buffer("Bx_out", Bx_out)
238
+
239
+ def spline_fit_natural(self, Z):
240
+ # Adding batch/channel dimension handling
241
+ # Bx @ Z
242
+ BxT_Z = torch.einsum("ij,bchj->bchi", self.Bx.T, Z)
243
+ # (BxT @ Bx)^-1 @ (BxT @ Z) = Bx^-1 @ Z
244
+ C = torch.linalg.solve(self.BxT_Bx, BxT_Z.unsqueeze(-1))
245
+ return C.squeeze(-1)
246
+
247
+ def evaluate_spline(self, C: Tensor):
248
+ """
249
+ Evaluate a bivariate spline on a grid of x and y points.
250
+
251
+ Args:
252
+ C: Coefficient tensor of shape (batch_size, mx, my).
253
+
254
+ Returns:
255
+ Z_interp: Interpolated values at the grid points.
256
+ """
257
+ # Perform matrix multiplication using einsum to get Z_interp
258
+ return torch.einsum("ij,bchj->bchi", self.Bx_out, C)
259
+
260
+ def _validate_inputs(self, Z, x_out):
261
+ if x_out is None and self.x_out is None:
262
+ raise ValueError(
263
+ "Output x-coordinates were not specified in either object "
264
+ "creation or in forward call"
265
+ )
266
+
267
+ dims = len(Z.shape)
268
+ if dims > 4:
269
+ raise ValueError("Input data has more than 4 dimensions")
270
+
271
+ if Z.shape[-1] != len(self.x_in):
272
+ raise ValueError(
273
+ "The spatial dimensions of the data tensor do not match "
274
+ "the given input dimensions. "
275
+ f"Expected {len(self.x_in)}, but got {Z.shape[-1]}"
276
+ )
277
+
278
+ # Expand Z to have a batch, channel, and height dimension if needed
279
+ while len(Z.shape) < 4:
280
+ Z = Z.unsqueeze(0)
281
+
282
+ return Z
283
+
284
+ def forward(
285
+ self,
286
+ Z: Tensor,
287
+ x_out: Optional[Tensor] = None,
288
+ ) -> Tensor:
289
+ """
290
+ Compute the interpolated data
291
+
292
+ Args:
293
+ Z:
294
+ Tensor of data to be interpolated. Must be between 2 and 4
295
+ dimensions. The shape of the tensor must agree with the
296
+ input coordinates given on initialization.
297
+ x_out:
298
+ Coordinates to interpolate the data to along the width
299
+ dimension. Overrides any value that was set during
300
+ initialization.
301
+
302
+ Returns:
303
+ A 4D tensor with shape ``(batch, channel, height, width)``.
304
+ Depending on the input data shape, many of these dimensions
305
+ may have length 1.
306
+ """
307
+
308
+ Z = self._validate_inputs(Z, x_out)
309
+
310
+ if x_out is not None:
311
+ x_out = x_out.float()
312
+ x_clamped = torch.clamp(
313
+ x_out, self.tx[self.kx], self.tx[-self.kx - 1]
314
+ )
315
+ self.Bx_out = self.bspline_basis_natural(
316
+ x_clamped, self.kx, self.tx
317
+ )
318
+
319
+ coef = self.spline_fit_natural(Z)
320
+ Z_interp = self.evaluate_spline(coef)
321
+ return Z_interp
322
+
323
+
324
+ class SplineInterpolate2D(SplineInterpolateBase):
325
+ """
326
+ Perform 2D spline interpolation based on De Boor's method.
327
+ Supports batched, multi-channel inputs, so acceptable data
328
+ shapes are ``(height, width)``, ``(batch, height, width)``,
329
+ and ``(batch, channel, height, width)``.
330
+
331
+ During initialization of this Module, both the desired input
332
+ and output coordinate Tensors can be specified to allow
333
+ pre-computation of the B-spline basis matrices, though the only
334
+ mandatory arguments are the input coordinates.
335
+
336
+ Unlike scipy's implementation of spline interpolation, the data
337
+ to be interpolated is not passed until actually calling the
338
+ object. This is useful for cases where the input and output
339
+ coordinates are known in advance, but the data is not, so that
340
+ the interpolator can be set up ahead of time.
341
+
342
+ Args:
343
+ x_in:
344
+ Coordinates of the width dimension of the data
345
+ y_in:
346
+ Coordinates of the height dimension of the data.
347
+ kx:
348
+ Degree of spline interpolation along the width dimension.
349
+ Default is cubic.
350
+ ky:
351
+ Degree of spline interpolation along the height dimension.
352
+ Default is cubic.
353
+ sx:
354
+ Regularization factor to avoid singularities during matrix
355
+ inversion for interpolation along the width dimension. Not
356
+ to be confused with the ``s`` parameter in scipy's spline
357
+ methods, which controls the number of knots.
358
+ sy:
359
+ Regularization factor to avoid singularities during matrix
360
+ inversion for interpolation along the height dimension.
361
+ x_out:
362
+ Coordinates for the data to be interpolated to along the
363
+ width dimension. If not specified during initialization,
364
+ this must be specified during the object call.
365
+ y_out:
366
+ Coordinates for the data to be interpolated to along the
367
+ height dimension. If not specified during initialization,
368
+ this must be specified during the object call.
369
+
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ x_in: Tensor,
375
+ y_in: Tensor,
376
+ kx: int = 3,
377
+ ky: int = 3,
378
+ sx: float = 0.0,
379
+ sy: float = 0.0,
380
+ x_out: Optional[Tensor] = None,
381
+ y_out: Optional[Tensor] = None,
382
+ ):
383
+ super().__init__()
384
+
385
+ if len(x_in) < kx + 2:
386
+ raise ValueError(
387
+ "Input x-coordinates must have at least kx + 2 points."
388
+ )
389
+ if len(y_in) < ky + 2:
390
+ raise ValueError(
391
+ "Input y-coordinates must have at least ky + 2 points."
392
+ )
393
+
394
+ # Ensure that coordinates are floats
395
+ x_in = x_in.float()
396
+ y_in = y_in.float()
397
+ x_out = x_out.float() if x_out is not None else None
398
+ y_out = y_out.float() if y_out is not None else None
399
+
400
+ self.kx = kx
401
+ self.ky = ky
402
+ self.sx = sx
403
+ self.sy = sy
404
+ self.register_buffer("x_in", x_in)
405
+ self.register_buffer("y_in", y_in)
406
+ self.register_buffer("x_out", x_out)
407
+ self.register_buffer("y_out", y_out)
408
+
409
+ tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
410
+ self.register_buffer("tx", tx)
411
+ self.register_buffer("Bx", Bx)
412
+ self.register_buffer("BxT_Bx", BxT_Bx)
413
+
414
+ ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
415
+ self.register_buffer("ty", ty)
416
+ self.register_buffer("By", By)
417
+ self.register_buffer("ByT_By", ByT_By)
418
+
419
+ if self.x_out is not None:
420
+ x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
421
+ Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
422
+ self.register_buffer("Bx_out", Bx_out)
423
+ if self.y_out is not None:
424
+ y_clamped = torch.clamp(y_out, ty[ky], ty[-ky - 1])
425
+ By_out = self.bspline_basis_natural(y_clamped, ky, self.ty)
426
+ self.register_buffer("By_out", By_out)
427
+
258
428
  def bivariate_spline_fit_natural(self, Z):
259
429
  # Adding batch/channel dimension handling
260
430
  # ByT @ Z @ BxW
@@ -285,29 +455,16 @@ class SplineInterpolate(torch.nn.Module):
285
455
  )
286
456
 
287
457
  if y_out is None and self.y_out is None:
288
- y_out = self.y_in
458
+ raise ValueError(
459
+ "Output y-coordinates were not specified in either object "
460
+ "creation or in forward call"
461
+ )
289
462
 
290
463
  dims = len(Z.shape)
291
464
  if dims > 4:
292
465
  raise ValueError("Input data has more than 4 dimensions")
293
-
294
- if len(self.y_in) > 1 and dims == 1:
295
- raise ValueError(
296
- "An input y-coordinate array with length greater than 1 "
297
- "was given, but the input data is 1-dimensional. Expected "
298
- "input data to be at least 2-dimensional"
299
- )
300
-
301
- # Expand Z to have 4 dimensions
302
- # There are 6 valid input shapes: (w), (b, w), (b, c, w),
303
- # (h, w), (b, h, w), and (b, c, h, w).
304
-
305
- # If the input y coordinate array has length 1,
306
- # assume the first dimension(s) are batch dimensions
307
- # and that no height dimension is included in Z
308
- idx = -2 if len(self.y_in) == 1 else -3
309
- while len(Z.shape) < 4:
310
- Z = Z.unsqueeze(idx)
466
+ if dims < 2:
467
+ raise ValueError("Input data has fewer than 2 dimensions")
311
468
 
312
469
  if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
313
470
  raise ValueError(
@@ -317,6 +474,10 @@ class SplineInterpolate(torch.nn.Module):
317
474
  f"[{Z.shape[-2]}, {Z.shape[-1]}]"
318
475
  )
319
476
 
477
+ # Expand Z to have a batch and channel dimension if needed
478
+ while len(Z.shape) < 4:
479
+ Z = Z.unsqueeze(0)
480
+
320
481
  return Z, y_out
321
482
 
322
483
  def forward(
@@ -330,11 +491,9 @@ class SplineInterpolate(torch.nn.Module):
330
491
 
331
492
  Args:
332
493
  Z:
333
- Tensor of data to be interpolated. Must be between 1 and 4
494
+ Tensor of data to be interpolated. Must be between 2 and 4
334
495
  dimensions. The shape of the tensor must agree with the
335
- input coordinates given on initialization. If ``y_in`` was
336
- not specified during initialization, it is assumed that
337
- Z does not have a height dimension.
496
+ input coordinates given on initialization.
338
497
  x_out:
339
498
  Coordinates to interpolate the data to along the width
340
499
  dimension. Overrides any value that was set during
@@ -353,9 +512,21 @@ class SplineInterpolate(torch.nn.Module):
353
512
  Z, y_out = self._validate_inputs(Z, x_out, y_out)
354
513
 
355
514
  if x_out is not None:
356
- self.Bx_out = self.bspline_basis_natural(x_out, self.kx, self.tx)
515
+ x_out = x_out.float()
516
+ x_clamped = torch.clamp(
517
+ x_out, self.tx[self.kx], self.tx[-self.kx - 1]
518
+ )
519
+ self.Bx_out = self.bspline_basis_natural(
520
+ x_clamped, self.kx, self.tx
521
+ )
357
522
  if y_out is not None:
358
- self.By_out = self.bspline_basis_natural(y_out, self.ky, self.ty)
523
+ y_out = y_out.float()
524
+ y_clamped = torch.clamp(
525
+ y_out, self.ty[self.ky], self.ty[-self.ky - 1]
526
+ )
527
+ self.By_out = self.bspline_basis_natural(
528
+ y_clamped, self.ky, self.ty
529
+ )
359
530
 
360
531
  coef = self.bivariate_spline_fit_natural(Z)
361
532
  Z_interp = self.evaluate_bivariate_spline(coef)
@@ -1,9 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ml4gw
3
- Version: 0.7.6
3
+ Version: 0.7.7
4
4
  Summary: Tools for training torch models on gravitational wave data
5
5
  Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>
6
- License-File: LICENSE
7
6
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
8
7
  Classifier: Programming Language :: Python :: 3.9
9
8
  Classifier: Programming Language :: Python :: 3.10
@@ -11,12 +10,14 @@ Classifier: Programming Language :: Python :: 3.11
11
10
  Classifier: Programming Language :: Python :: 3.12
12
11
  Classifier: Programming Language :: Python :: 3.13
13
12
  Requires-Python: <3.13,>=3.9
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
14
15
  Requires-Dist: jaxtyping<0.3,>=0.2
16
+ Requires-Dist: torch~=2.0
17
+ Requires-Dist: torchaudio~=2.0
15
18
  Requires-Dist: numpy<2.0.0
16
19
  Requires-Dist: scipy<1.15,>=1.9.0
17
- Requires-Dist: torchaudio~=2.0
18
- Requires-Dist: torch~=2.0
19
- Description-Content-Type: text/markdown
20
+ Dynamic: license-file
20
21
 
21
22
  # ML4GW
22
23
  ![PyPI - Version](https://img.shields.io/pypi/v/ml4gw)
@@ -1,7 +1,7 @@
1
1
  ml4gw/__init__.py,sha256=81quoggCuIypZjZs3bbf1Ty70KHdva5RGEJxi0oC57E,25
2
2
  ml4gw/augmentations.py,sha256=4tSWO-I4Eg2QJWzdcLFg9QcOLlvRjNHvnjLCZS8K-Wc,1270
3
3
  ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
4
- ml4gw/distributions.py,sha256=T6H1r5IMWyO38Uyb-BpmYx0AcokWN_ZJHGo-G_20m6w,12830
4
+ ml4gw/distributions.py,sha256=B3heOf_x-71UJBExkmaiqBfjTHiGDds_bYo62wydddU,12832
5
5
  ml4gw/gw.py,sha256=bJ-GCZxanqrhbm373h9muOSZpam7wM-dJBZroy_pVNQ,20291
6
6
  ml4gw/spectral.py,sha256=Mx_zRjZ9tD7N-wknv35oA3fk2X0rDJxQdQzRyuCFryw,19982
7
7
  ml4gw/types.py,sha256=CcctqDcNajR7khGT6BD-WYsfRKpiP0udoSAB0k1qcFw,863
@@ -22,15 +22,15 @@ ml4gw/nn/resnet/resnet_2d.py,sha256=MAbXtkSrP4aWGtY-QC8ox3-y5jDHJrzRPL5ryQ4RBvM,
22
22
  ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
23
23
  ml4gw/nn/streaming/online_average.py,sha256=YSFUHhwNfQjUJbzQCqaCVApSueswzYB4yel981Omiqw,4718
24
24
  ml4gw/nn/streaming/snapshotter.py,sha256=vEQLFi-fEH-o7TO9SmYXy5whxFxXQBDeOQOFhSnofSg,4503
25
- ml4gw/transforms/__init__.py,sha256=OaTQJD4GFkDkcxt0DIwt2AzeEcv9t21ciKXxQnqDiuI,447
25
+ ml4gw/transforms/__init__.py,sha256=sREcTsJYPOHBk_QEvhlwvLeoR61Th8LVGaBicWngmv4,470
26
26
  ml4gw/transforms/iirfilter.py,sha256=HcdsjcSaSi2xe65ojxnaqeSdbYvSQVFIkHKon3nW238,3194
27
27
  ml4gw/transforms/pearson.py,sha256=sFyHD6IdskbRS8V1fY0Kt9N8R2_EhnuL6UjFa6fnmTU,3244
28
- ml4gw/transforms/qtransform.py,sha256=dXE3Genxgg3UdQ5dM-FfcvbX--UGpr0hjX9sO5tpM7k,20754
28
+ ml4gw/transforms/qtransform.py,sha256=2GtPyY5DPhoUMDs68juLIzmxbelfsS4CpI0HyTPe3Oo,20636
29
29
  ml4gw/transforms/scaler.py,sha256=BKn4RQ_TNArdwPI_j5nAe7H2jOH_-MrZPsNByE-8Pl8,2518
30
30
  ml4gw/transforms/snr_rescaler.py,sha256=lfuwdwMY117gB-emmn0_22gsK_A9xnkHJv2-76HFWc4,2728
31
31
  ml4gw/transforms/spectral.py,sha256=ebAuPSdQqha6J3MMzxqJqR31XPKUDrSz3iJaHM3orpk,4449
32
32
  ml4gw/transforms/spectrogram.py,sha256=NIyTD8kZRe8rjMUTy1_-wpFyvAswzTfYwD4TJJcPqgs,6369
33
- ml4gw/transforms/spline_interpolation.py,sha256=iz6CkRzAYFSMjRTLFJAetE5FAI6WmrpfKzMPK4sueNQ,13320
33
+ ml4gw/transforms/spline_interpolation.py,sha256=_D_P2jNIOL8-XiSSjsGClxuVwthO-CxcoqNwvrBWQpk,18668
34
34
  ml4gw/transforms/transform.py,sha256=lpHQbM4PhdijvNBsZigPX-mS04aiVVq5q3HMfxvpFg0,2506
35
35
  ml4gw/transforms/waveforms.py,sha256=koWOuHuUpQWmTT1yawSWa_MOuLfDBuugy91KIyuklOo,3189
36
36
  ml4gw/transforms/whitening.py,sha256=UyFustRhu3zv0ynJBvvxekWA-YOMwEIOYDNpoD5r_qQ,10400
@@ -49,7 +49,8 @@ ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9
49
49
  ml4gw/waveforms/cbc/phenom_p.py,sha256=tOUBoYfr0ub6OGRjDQbquGoW8AnThiGjJvbHhyGnAnk,27680
50
50
  ml4gw/waveforms/cbc/taylorf2.py,sha256=emWbl3vjsCzBOooHOVO7pPlPcj05r4up6InlMkO5m_E,10422
51
51
  ml4gw/waveforms/cbc/utils.py,sha256=LT1ky10_6ZrbwTcxIrWP1O75GUEuU5q2ZE2yYDhadQE,3037
52
- ml4gw-0.7.6.dist-info/METADATA,sha256=dI3qI2Kk4p-XP_hPs7QWPfgjzRGQMcIva-ST6mBdA0A,3380
53
- ml4gw-0.7.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
54
- ml4gw-0.7.6.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
55
- ml4gw-0.7.6.dist-info/RECORD,,
52
+ ml4gw-0.7.7.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
53
+ ml4gw-0.7.7.dist-info/METADATA,sha256=C3IsV7-AGBWfid4FvE1xuGL0HFX_wzRii0Wzb4aSOiM,3402
54
+ ml4gw-0.7.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
+ ml4gw-0.7.7.dist-info/top_level.txt,sha256=JnWLyPXJ3_WUcjr6fRV0ZTXj8FR0x4vBzjkg-1bl2tw,6
56
+ ml4gw-0.7.7.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ ml4gw