ml4gw 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

Files changed (44) hide show
  1. ml4gw/augmentations.py +8 -2
  2. ml4gw/constants.py +10 -19
  3. ml4gw/dataloading/chunked_dataset.py +4 -2
  4. ml4gw/dataloading/hdf5_dataset.py +1 -1
  5. ml4gw/dataloading/in_memory_dataset.py +8 -4
  6. ml4gw/distributions.py +5 -3
  7. ml4gw/gw.py +21 -27
  8. ml4gw/nn/autoencoder/base.py +11 -6
  9. ml4gw/nn/autoencoder/convolutional.py +7 -4
  10. ml4gw/nn/autoencoder/skip_connection.py +7 -6
  11. ml4gw/nn/autoencoder/utils.py +2 -1
  12. ml4gw/nn/norm.py +5 -1
  13. ml4gw/nn/streaming/online_average.py +7 -5
  14. ml4gw/nn/streaming/snapshotter.py +7 -5
  15. ml4gw/spectral.py +41 -37
  16. ml4gw/transforms/__init__.py +1 -0
  17. ml4gw/transforms/pearson.py +7 -3
  18. ml4gw/transforms/qtransform.py +151 -53
  19. ml4gw/transforms/scaler.py +9 -3
  20. ml4gw/transforms/snr_rescaler.py +6 -5
  21. ml4gw/transforms/spectral.py +9 -2
  22. ml4gw/transforms/spectrogram.py +7 -1
  23. ml4gw/transforms/spline_interpolation.py +370 -0
  24. ml4gw/transforms/transform.py +4 -3
  25. ml4gw/transforms/waveforms.py +10 -7
  26. ml4gw/transforms/whitening.py +12 -4
  27. ml4gw/types.py +25 -10
  28. ml4gw/utils/interferometer.py +1 -1
  29. ml4gw/utils/slicing.py +24 -16
  30. ml4gw/waveforms/__init__.py +2 -5
  31. ml4gw/waveforms/adhoc/__init__.py +2 -0
  32. ml4gw/waveforms/{ringdown.py → adhoc/ringdown.py} +8 -9
  33. ml4gw/waveforms/{sine_gaussian.py → adhoc/sine_gaussian.py} +6 -6
  34. ml4gw/waveforms/cbc/__init__.py +3 -0
  35. ml4gw/waveforms/{phenom_d.py → cbc/phenom_d.py} +20 -18
  36. ml4gw/waveforms/{phenom_p.py → cbc/phenom_p.py} +106 -95
  37. ml4gw/waveforms/{taylorf2.py → cbc/taylorf2.py} +33 -27
  38. ml4gw/waveforms/conversion.py +187 -0
  39. ml4gw/waveforms/generator.py +9 -5
  40. {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/METADATA +4 -3
  41. ml4gw-0.6.0.dist-info/RECORD +51 -0
  42. {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/WHEEL +1 -1
  43. ml4gw-0.5.0.dist-info/RECORD +0 -47
  44. /ml4gw/waveforms/{phenom_d_data.py → cbc/phenom_d_data.py} +0 -0
@@ -1,8 +1,11 @@
1
1
  from typing import Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.spectral import fast_spectral_density, spectral_density
8
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
6
9
 
7
10
 
8
11
  class SpectralDensity(torch.nn.Module):
@@ -51,7 +54,9 @@ class SpectralDensity(torch.nn.Module):
51
54
  fftlength: float,
52
55
  overlap: Optional[float] = None,
53
56
  average: str = "mean",
54
- window: Optional[torch.Tensor] = None,
57
+ window: Optional[
58
+ Float[Tensor, " {int(fftlength*sample_rate)}"]
59
+ ] = None,
55
60
  fast: bool = False,
56
61
  ) -> None:
57
62
  if overlap is None:
@@ -93,7 +98,9 @@ class SpectralDensity(torch.nn.Module):
93
98
  self.average = average
94
99
  self.fast = fast
95
100
 
96
- def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
101
+ def forward(
102
+ self, x: TimeSeries1to3d, y: Optional[TimeSeries1to3d] = None
103
+ ) -> FrequencySeries1to3d:
97
104
  if self.fast:
98
105
  return fast_spectral_density(
99
106
  x,
@@ -3,8 +3,12 @@ from typing import Dict, List
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
6
+ from jaxtyping import Float
7
+ from torch import Tensor
6
8
  from torchaudio.transforms import Spectrogram
7
9
 
10
+ from ml4gw.types import TimeSeries3d
11
+
8
12
 
9
13
  class MultiResolutionSpectrogram(torch.nn.Module):
10
14
  """
@@ -122,7 +126,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
122
126
 
123
127
  return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
124
128
 
125
- def forward(self, X: torch.Tensor) -> torch.Tensor:
129
+ def forward(
130
+ self, X: TimeSeries3d
131
+ ) -> Float[Tensor, "batch channel frequency time"]:
126
132
  """
127
133
  Calculate spectrograms of the input tensor and
128
134
  combine them into a single spectrogram
@@ -0,0 +1,370 @@
1
+ """
2
+ Adaptation of code from https://github.com/dottormale/Qtransform
3
+ """
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+
12
+ class SplineInterpolate(torch.nn.Module):
13
+ """
14
+ Perform 1D or 2D spline interpolation based on De Boor's method.
15
+ Supports batched, multi-channel inputs, so acceptable data
16
+ shapes are `(width)`, `(height, width)`, `(batch, width)`,
17
+ `(batch, height, width)`, `(batch, channel, width)`, and
18
+ `(batch, channel, height, width)`.
19
+
20
+ During initialization of this Module, both the desired input
21
+ and output coordinate Tensors can be specified to allow
22
+ pre-computation of the B-spline basis matrices, though the only
23
+ mandatory argument is the coordinates of the data along the
24
+ `width` dimension. If no argument is given for coordinates along
25
+ the `height` dimension, it is assumed that 1D interpolation is
26
+ desired.
27
+
28
+ Unlike scipy's implementation of spline interpolation, the data
29
+ to be interpolated is not passed until actually calling the
30
+ object. This is useful for cases where the input and output
31
+ coordinates are known in advance, but the data is not, so that
32
+ the interpolator can be set up ahead of time.
33
+
34
+ WARNING: compared to scipy's spline interpolation, this method
35
+ produces edge artifacts when the output coordinates are near
36
+ the boundaries of the input coordinates. Therefore, it is
37
+ recommended to interpolate only to coordinates that are well
38
+ within the input coordinate range. Unfortunately, the specific
39
+ definition of "well within" changes based on the size of the
40
+ data, so some testing may be required to get good results.
41
+
42
+ Args:
43
+ x_in:
44
+ Coordinates of the width dimension of the data
45
+ y_in:
46
+ Coordinates of the height dimension of the data. If not
47
+ specified, it is assumed the 1D interpolation is desired,
48
+ and so the default value is a Tensor of length 1
49
+ kx:
50
+ Degree of spline interpolation along the width dimension.
51
+ Default is cubic.
52
+ ky:
53
+ Degree of spline interpolation along the height dimension.
54
+ Default is cubic.
55
+ sx:
56
+ Regularization factor to avoid singularities during matrix
57
+ inversion for interpolation along the width dimension. Not
58
+ to be confused with the `s` parameter in scipy's spline
59
+ methods, which controls the number of knots.
60
+ sy:
61
+ Regularization factor to avoid singularities during matrix
62
+ inversion for interpolation along the height dimension.
63
+ x_out:
64
+ Coordinates for the data to be interpolated to along the
65
+ width dimension. If not specified during initialization,
66
+ this must be specified during the object call.
67
+ y_out:
68
+ Coordinates for the data to be interpolated to along the
69
+ height dimension. If not specified during initialization,
70
+ this must be specified during the object call.
71
+
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ x_in: Tensor,
77
+ y_in: Tensor = Tensor([1]),
78
+ kx: int = 3,
79
+ ky: int = 3,
80
+ sx: float = 0.001,
81
+ sy: float = 0.001,
82
+ x_out: Optional[Tensor] = None,
83
+ y_out: Optional[Tensor] = None,
84
+ ):
85
+ super().__init__()
86
+ self.kx = kx
87
+ self.ky = ky
88
+ self.sx = sx
89
+ self.sy = sy
90
+ self.register_buffer("x_in", x_in)
91
+ self.register_buffer("y_in", y_in)
92
+ self.register_buffer("x_out", x_out)
93
+ self.register_buffer("y_out", y_out)
94
+
95
+ tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
96
+ self.register_buffer("tx", tx)
97
+ self.register_buffer("Bx", Bx)
98
+ self.register_buffer("BxT_Bx", BxT_Bx)
99
+
100
+ ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
101
+ self.register_buffer("ty", ty)
102
+ self.register_buffer("By", By)
103
+ self.register_buffer("ByT_By", ByT_By)
104
+
105
+ if self.x_out is not None:
106
+ Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
107
+ self.register_buffer("Bx_out", Bx_out)
108
+ if self.y_out is not None:
109
+ By_out = self.bspline_basis_natural(y_out, ky, self.ty)
110
+ self.register_buffer("By_out", By_out)
111
+
112
+ def _compute_knots_and_basis_matrices(self, x, k, s):
113
+ knots = self.generate_natural_knots(x, k)
114
+ basis_matrix = self.bspline_basis_natural(x, k, knots)
115
+ identity = torch.eye(basis_matrix.shape[-1])
116
+ B_T_B = basis_matrix.T @ basis_matrix + s * identity
117
+ return knots, basis_matrix, B_T_B
118
+
119
+ def generate_natural_knots(self, x: Tensor, k: int) -> Tensor:
120
+ """
121
+ Generates a natural knot sequence for B-spline interpolation.
122
+ Natural knot sequence means that 2*k knots are added to the beginning
123
+ and end of datapoints as replicas of first and last datapoint
124
+ respectively in order to enforce natural boundary conditions,
125
+ i.e. second derivative = 0.
126
+ The other n nodes are placed in correspondece of the data points.
127
+
128
+ Args:
129
+ x: Tensor of data point positions.
130
+ k: Degree of the spline.
131
+
132
+ Returns:
133
+ Tensor of knot positions.
134
+ """
135
+ return F.pad(x[None], (k, k), mode="replicate")[0]
136
+
137
+ def compute_L_R(
138
+ self,
139
+ x: Tensor,
140
+ t: Tensor,
141
+ d: int,
142
+ m: int,
143
+ ) -> Tuple[Tensor, Tensor]:
144
+
145
+ """
146
+ Compute the L and R values for B-spline basis functions.
147
+ L and R are respectively the first and second coefficient multiplying
148
+ B_{i,p-1}(x) and B_{i+1,p-1}(x) in De Boor's recursive formula for
149
+ Bspline basis funciton computation
150
+ See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for details
151
+
152
+ Args:
153
+ x:
154
+ Tensor of data point positions.
155
+ t:
156
+ Tensor of knot positions.
157
+ d:
158
+ Current degree of the basis function.
159
+ m:
160
+ Number of intervals (n - k - 1, where n is the number of knots
161
+ and k is the degree).
162
+
163
+ Returns:
164
+ L: Tensor containing left values for the B-spline basis functions.
165
+ R: Tensor containing right values for the B-spline basis functions.
166
+ """
167
+ left_num = x.unsqueeze(1) - t[:m].unsqueeze(0)
168
+ left_den = t[d : m + d] - t[:m]
169
+ L = left_num / left_den.unsqueeze(0)
170
+ L = torch.nan_to_num_(L, nan=0.0, posinf=0.0, neginf=0.0)
171
+
172
+ right_num = t[d + 1 : m + d + 1] - x.unsqueeze(1)
173
+ right_den = t[d + 1 : m + d + 1] - t[1 : m + 1]
174
+ R = right_num / right_den.unsqueeze(0)
175
+ R = torch.nan_to_num_(R, nan=0.0, posinf=0.0, neginf=0.0)
176
+
177
+ return L, R
178
+
179
+ def zeroth_order(
180
+ self,
181
+ x: Tensor,
182
+ k: int,
183
+ t: Tensor,
184
+ n: int,
185
+ m: int,
186
+ ) -> Tensor:
187
+
188
+ """
189
+ Compute the zeroth-order B-spline basis functions
190
+ according to de Boors recursive formula.
191
+ See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for reference
192
+
193
+ Args:
194
+ x:
195
+ Tensor of data point positions.
196
+ k:
197
+ Degree of the spline.
198
+ t:
199
+ Tensor of knot positions.
200
+ n:
201
+ Number of data points.
202
+ m:
203
+ Number of intervals (n - k - 1, where n is the number of knots
204
+ and k is the degree).
205
+
206
+ Returns:
207
+ b: Tensor containing the zeroth-order B-spline basis functions.
208
+ """
209
+ b = torch.zeros((n, m, k + 1))
210
+
211
+ mask_lower = t[: m + 1].unsqueeze(0)[:, :-1] <= x.unsqueeze(1)
212
+ mask_upper = x.unsqueeze(1) < t[: m + 1].unsqueeze(0)[:, 1:]
213
+
214
+ b[:, :, 0] = mask_lower & mask_upper
215
+ b[:, 0, 0] = torch.where(x < t[1], torch.ones_like(x), b[:, 0, 0])
216
+ b[:, -1, 0] = torch.where(x >= t[-2], torch.ones_like(x), b[:, -1, 0])
217
+ return b
218
+
219
+ def bspline_basis_natural(
220
+ self,
221
+ x: Tensor,
222
+ k: int,
223
+ t: Tensor,
224
+ ) -> Tensor:
225
+ """
226
+ Compute bspline basis function using de Boor's recursive formula
227
+ (See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for reference)
228
+ Args:
229
+ x: Tensor of data point positions.
230
+ k: Degree of the spline.
231
+ t: Tensor of knot positions.
232
+
233
+ Returns:
234
+ Tensor containing the kth-order B-spline basis functions
235
+ """
236
+
237
+ if len(x) == 1:
238
+ return torch.eye(1)
239
+ n = x.shape[0]
240
+ m = t.shape[0] - k - 1
241
+
242
+ # calculate zeroth order basis funciton
243
+ b = self.zeroth_order(x, k, t, n, m)
244
+
245
+ zeros_tensor = torch.zeros(b.shape[0], 1)
246
+ # recursive de Boors formula for bspline basis functions
247
+ for d in range(1, k + 1):
248
+ L, R = self.compute_L_R(x, t, d, m)
249
+ left = L * b[:, :, d - 1]
250
+
251
+ temp_b = torch.cat([b[:, 1:, d - 1], zeros_tensor], dim=1)
252
+
253
+ right = R * temp_b
254
+ b[:, :, d] = left + right
255
+
256
+ return b[:, :, -1]
257
+
258
+ def bivariate_spline_fit_natural(self, Z):
259
+
260
+ if len(Z.shape) == 3:
261
+ Z_Bx = torch.matmul(Z, self.Bx)
262
+ # ((BxT @ Bx)^-1 @ (Z @ Bx)T)T = Z @ BxT^-1
263
+ return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT
264
+
265
+ # Adding batch/channel dimension handling
266
+ # ByT @ Z @ Bx
267
+ ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx)
268
+ # (ByT @ By)^-1 @ (ByT @ Z @ Bx) = By^-1 @ Z @ Bx
269
+ E = torch.linalg.solve(self.ByT_By, ByT_Z_Bx)
270
+ # ((BxT @ Bx)^-1 @ (By^-1 @ Z @ Bx)T)T = By^-1 @ Z @ BxT^-1
271
+ return torch.linalg.solve(self.BxT_Bx, E.mT).mT
272
+
273
+ def evaluate_bivariate_spline(self, C: Tensor):
274
+ """
275
+ Evaluate a bivariate spline on a grid of x and y points.
276
+
277
+ Args:
278
+ C: Coefficient tensor of shape (batch_size, mx, my).
279
+
280
+ Returns:
281
+ Z_interp: Interpolated values at the grid points.
282
+ """
283
+ # Perform matrix multiplication using einsum to get Z_interp
284
+ if len(C.shape) == 3:
285
+ return torch.matmul(C, self.Bx_out.mT)
286
+ return torch.einsum("ik,bckm,mj->bcij", self.By_out, C, self.Bx_out.mT)
287
+
288
+ def _validate_inputs(self, Z, x_out, y_out):
289
+ if x_out is None and self.x_out is None:
290
+ raise ValueError(
291
+ "Output x-coordinates were not specified in either object "
292
+ "creation or in forward call"
293
+ )
294
+
295
+ if y_out is None and self.y_out is None:
296
+ y_out = self.y_in
297
+
298
+ dims = len(Z.shape)
299
+ if dims > 4:
300
+ raise ValueError("Input data has more than 4 dimensions")
301
+
302
+ if len(self.y_in) > 1 and dims == 1:
303
+ raise ValueError(
304
+ "An input y-coordinate array with length greater than 1 "
305
+ "was given, but the input data is 1-dimensional. Expected "
306
+ "input data to be at least 2-dimensional"
307
+ )
308
+
309
+ # Expand Z to have 4 dimensions
310
+ # There are 6 valid input shapes: (w), (b, w), (b, c, w),
311
+ # (h, w), (b, h, w), and (b, c, h, w).
312
+
313
+ # If the input y coordinate array has length 1,
314
+ # assume the first dimension(s) are batch dimensions
315
+ # and that no height dimension is included in Z
316
+ idx = -2 if len(self.y_in) == 1 else -3
317
+ while len(Z.shape) < 4:
318
+ Z = Z.unsqueeze(idx)
319
+
320
+ if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
321
+ raise ValueError(
322
+ "The spatial dimensions of the data tensor do not match "
323
+ "the given input dimensions. "
324
+ f"Expected [{len(self.y_in)}, {len(self.x_in)}], but got "
325
+ f"[{Z.shape[-2]}, {Z.shape[-1]}]"
326
+ )
327
+
328
+ return Z, y_out
329
+
330
+ def forward(
331
+ self,
332
+ Z: Tensor,
333
+ x_out: Optional[Tensor] = None,
334
+ y_out: Optional[Tensor] = None,
335
+ ) -> Tensor:
336
+ """
337
+ Compute the interpolated data
338
+
339
+ Args:
340
+ Z:
341
+ Tensor of data to be interpolated. Must be between 1 and 4
342
+ dimensions. The shape of the tensor must agree with the
343
+ input coordinates given on initialization. If `y_in` was
344
+ not specified during initialization, it is assumed that
345
+ Z does not have a height dimension.
346
+ x_out:
347
+ Coordinates to interpolate the data to along the width
348
+ dimension. Overrides any value that was set during
349
+ initialization.
350
+ y_out:
351
+ Coordinates to interpolate the data to along the height
352
+ dimension. Overrides any value that was set during
353
+ initialization.
354
+
355
+ Returns:
356
+ A 4D tensor with shape `(batch, channel, height, width)`.
357
+ Depending on the input data shape, many of these dimensions
358
+ may have length 1.
359
+ """
360
+
361
+ Z, y_out = self._validate_inputs(Z, x_out, y_out)
362
+
363
+ if x_out is not None:
364
+ self.Bx_out = self.bspline_basis_natural(x_out, self.kx, self.tx)
365
+ if y_out is not None:
366
+ self.By_out = self.bspline_basis_natural(y_out, self.ky, self.ty)
367
+
368
+ coef = self.bivariate_spline_fit_natural(Z)
369
+ Z_interp = self.evaluate_bivariate_spline(coef)
370
+ return Z_interp
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from ml4gw.spectral import spectral_density
6
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
6
7
 
7
8
 
8
9
  class FittableTransform(torch.nn.Module):
@@ -43,12 +44,12 @@ class FittableTransform(torch.nn.Module):
43
44
  class FittableSpectralTransform(FittableTransform):
44
45
  def normalize_psd(
45
46
  self,
46
- x,
47
+ x: TimeSeries1to3d,
47
48
  sample_rate: float,
48
49
  num_freqs: int,
49
50
  fftlength: Optional[float] = None,
50
51
  overlap: Optional[float] = None,
51
- ):
52
+ ) -> FrequencySeries1to3d:
52
53
  # if we specified an FFT length, convert
53
54
  # the (assumed) time-domain data to the
54
55
  # frequency domain
@@ -68,7 +69,7 @@ class FittableSpectralTransform(FittableTransform):
68
69
  scale=scale,
69
70
  )
70
71
 
71
- # add two dummy dimensions in case we need to inerpolate
72
+ # add two dummy dimensions in case we need to interpolate
72
73
  # the frequency dimension, since `interpolate` expects
73
74
  # a (batch, channel, spatial) formatted tensor as input
74
75
  x = x.view(1, 1, -1)
@@ -1,8 +1,11 @@
1
1
  from typing import List, Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw import gw
8
+ from ml4gw.types import BatchTensor
6
9
 
7
10
 
8
11
  # TODO: should these live in ml4gw.waveforms submodule?
@@ -10,8 +13,8 @@ from ml4gw import gw
10
13
  class WaveformSampler(torch.nn.Module):
11
14
  def __init__(
12
15
  self,
13
- parameters: Optional[torch.Tensor] = None,
14
- **polarizations: torch.Tensor,
16
+ parameters: Optional[Float[Tensor, "batch num_params"]] = None,
17
+ **polarizations: Float[Tensor, "batch time"],
15
18
  ):
16
19
  super().__init__()
17
20
  # make sure we have the same number of waveforms
@@ -29,7 +32,7 @@ class WaveformSampler(torch.nn.Module):
29
32
  elif num_waveforms is None:
30
33
  num_waveforms = tensor.shape[0]
31
34
 
32
- self.polarizations[polarization] = torch.Tensor(tensor)
35
+ self.polarizations[polarization] = Tensor(tensor)
33
36
 
34
37
  if parameters is not None and len(parameters) != num_waveforms:
35
38
  raise ValueError(
@@ -73,10 +76,10 @@ class WaveformProjector(torch.nn.Module):
73
76
 
74
77
  def forward(
75
78
  self,
76
- dec: gw.ScalarTensor,
77
- psi: gw.ScalarTensor,
78
- phi: gw.ScalarTensor,
79
- **polarizations,
79
+ dec: BatchTensor,
80
+ psi: BatchTensor,
81
+ phi: BatchTensor,
82
+ **polarizations: Float[Tensor, "batch time"],
80
83
  ):
81
84
  ifo_responses = gw.compute_observed_strain(
82
85
  dec,
@@ -1,9 +1,15 @@
1
- from typing import Optional
1
+ from typing import Optional, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from ml4gw import spectral
6
6
  from ml4gw.transforms.transform import FittableSpectralTransform
7
+ from ml4gw.types import (
8
+ FrequencySeries1d,
9
+ FrequencySeries1to3d,
10
+ TimeSeries1d,
11
+ TimeSeries3d,
12
+ )
7
13
 
8
14
 
9
15
  class Whiten(torch.nn.Module):
@@ -58,7 +64,9 @@ class Whiten(torch.nn.Module):
58
64
  window = torch.hann_window(size, dtype=torch.float64)
59
65
  self.register_buffer("window", window)
60
66
 
61
- def forward(self, X: torch.Tensor, psd: torch.Tensor) -> torch.Tensor:
67
+ def forward(
68
+ self, X: TimeSeries3d, psd: FrequencySeries1to3d
69
+ ) -> TimeSeries3d:
62
70
  """
63
71
  Whiten a batch of multichannel timeseries by a
64
72
  background power spectral density.
@@ -142,7 +150,7 @@ class FixedWhiten(FittableSpectralTransform):
142
150
  def fit(
143
151
  self,
144
152
  fduration: float,
145
- *background: torch.Tensor,
153
+ *background: Union[TimeSeries1d, FrequencySeries1d],
146
154
  fftlength: Optional[float] = None,
147
155
  highpass: Optional[float] = None,
148
156
  overlap: Optional[float] = None
@@ -224,7 +232,7 @@ class FixedWhiten(FittableSpectralTransform):
224
232
  fduration = torch.Tensor([fduration])
225
233
  self.build(psd=psd, fduration=fduration)
226
234
 
227
- def forward(self, X: torch.Tensor) -> torch.Tensor:
235
+ def forward(self, X: TimeSeries3d) -> TimeSeries3d:
228
236
  """
229
237
  Whiten the input timeseries tensor using the
230
238
  PSD fit by the `.fit` method, which must be
ml4gw/types.py CHANGED
@@ -1,10 +1,25 @@
1
- from torchtyping import TensorType
2
-
3
- WaveformTensor = TensorType["batch", "num_ifos", "time"]
4
- PSDTensor = TensorType["num_ifos", "frequency"]
5
- ScalarTensor = TensorType["batch"]
6
- VectorGeometry = TensorType["batch", "space"]
7
- TensorGeometry = TensorType["batch", "space", "space"]
8
- NetworkVertices = TensorType["num_ifos", 3]
9
- NetworkDetectorTensors = TensorType["num_ifos", 3, 3]
10
- TimeSeriesTensor = TensorType["num_channels", "time"]
1
+ from typing import Union
2
+
3
+ from jaxtyping import Float
4
+ from torch import Tensor
5
+
6
+ WaveformTensor = Float[Tensor, "batch num_ifos time"]
7
+ PSDTensor = Float[Tensor, "num_ifos frequency"]
8
+ BatchTensor = Float[Tensor, "batch"]
9
+ VectorGeometry = Float[Tensor, "batch space"]
10
+ TensorGeometry = Float[Tensor, "batch space space"]
11
+ NetworkVertices = Float[Tensor, "num_ifos 3"]
12
+ NetworkDetectorTensors = Float[Tensor, "num_ifos 3 3"]
13
+
14
+
15
+ TimeSeries1d = Float[Tensor, "time"]
16
+ TimeSeries2d = Float[TimeSeries1d, "channel"]
17
+ TimeSeries3d = Float[TimeSeries2d, "batch"]
18
+ TimeSeries1to3d = Union[TimeSeries1d, TimeSeries2d, TimeSeries3d]
19
+
20
+ FrequencySeries1d = Float[Tensor, "frequency"]
21
+ FrequencySeries2d = Float[FrequencySeries1d, "channel"]
22
+ FrequencySeries3d = Float[FrequencySeries2d, "batch"]
23
+ FrequencySeries1to3d = Union[
24
+ FrequencySeries1d, FrequencySeries2d, FrequencySeries3d
25
+ ]
@@ -4,7 +4,7 @@ import torch
4
4
  # based on values from
5
5
  # https://lscsoft.docs.ligo.org/lalsuite/lal/_l_a_l_detectors_8h_source.html
6
6
  class InterferometerGeometry:
7
- def __init__(self, name: str):
7
+ def __init__(self, name: str) -> None:
8
8
  if name == "H1":
9
9
  self.x_arm = torch.Tensor(
10
10
  (-0.22389266154, +0.79983062746, +0.55690487831)
ml4gw/utils/slicing.py CHANGED
@@ -1,25 +1,30 @@
1
1
  from typing import Optional, Union
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float, Int64
5
+ from torch import Tensor
4
6
  from torch.nn.functional import unfold
5
- from torchtyping import TensorType
6
7
 
7
- # need to define these for flake8 compatibility
8
- batch = time = channel = None # noqa
8
+ from ml4gw.types import (
9
+ TimeSeries1d,
10
+ TimeSeries1to3d,
11
+ TimeSeries2d,
12
+ TimeSeries3d,
13
+ )
9
14
 
10
- TimeSeriesTensor = Union[TensorType["time"], TensorType["channel", "time"]]
11
-
12
- BatchTimeSeriesTensor = Union[
13
- TensorType["batch", "time"], TensorType["batch", "channel", "time"]
14
- ]
15
+ BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]
15
16
 
16
17
 
17
18
  def unfold_windows(
18
- x: torch.Tensor,
19
+ x: TimeSeries1to3d,
19
20
  window_size: int,
20
21
  stride: int,
21
22
  drop_last: bool = True,
22
- ):
23
+ ) -> Union[
24
+ Float[TimeSeries1d, " window"],
25
+ Float[TimeSeries2d, " window"],
26
+ Float[TimeSeries3d, " window"],
27
+ ]:
23
28
  """Unfold a timeseries into windows
24
29
 
25
30
  Args:
@@ -83,8 +88,8 @@ def unfold_windows(
83
88
 
84
89
 
85
90
  def slice_kernels(
86
- x: Union[TimeSeriesTensor, TensorType["batch", "channel", "time"]],
87
- idx: TensorType[..., torch.int64],
91
+ x: TimeSeries1to3d,
92
+ idx: Int64[Tensor, "..."],
88
93
  kernel_size: int,
89
94
  ) -> BatchTimeSeriesTensor:
90
95
  """Slice kernels from single or multichannel timeseries
@@ -96,7 +101,8 @@ def slice_kernels(
96
101
  one more dimension than `x`.
97
102
 
98
103
  Args:
99
- x: The timeseries tensor to slice kernels from
104
+ x:
105
+ The timeseries tensor to slice kernels from
100
106
  idx:
101
107
  The indices in `x` of the first sample of each
102
108
  kernel. If `x` is 1D, `idx` must be 1D as well.
@@ -114,6 +120,7 @@ def slice_kernels(
114
120
  coincidentally among the channels.
115
121
  kernel_size:
116
122
  The length of the kernels to slice from the timeseries
123
+
117
124
  Returns:
118
125
  A tensor of shape `(batch_size, kernel_size)` if `x` is
119
126
  1D and `(batch_size, num_channels, kernel_size)` if `x`
@@ -225,7 +232,7 @@ def slice_kernels(
225
232
 
226
233
 
227
234
  def sample_kernels(
228
- X: TimeSeriesTensor,
235
+ X: TimeSeries1to3d,
229
236
  kernel_size: int,
230
237
  N: Optional[int] = None,
231
238
  max_center_offset: Optional[int] = None,
@@ -245,8 +252,9 @@ def sample_kernels(
245
252
  either be `None` or be equal to `len(X)`.
246
253
 
247
254
  Args:
248
- X: The timeseries tensor from which to sample kernels
249
- kernel_size: The size of the kernels to sample
255
+ X:
256
+ The timeseries tensor from which to sample kernels
257
+ kernel_size: The size of the kernels to sample
250
258
  N:
251
259
  The number of kernels to sample. Can be left as
252
260
  `None` if `X` is 3D, otherwise must be specified