ml4gw 0.5.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/constants.py CHANGED
@@ -4,42 +4,33 @@ Various constants, all in SI units.
4
4
 
5
5
  EulerGamma = 0.577215664901532860606512090082402431
6
6
 
7
+ # solar mass
7
8
  MSUN = 1.988409902147041637325262574352366540e30 # kg
8
- """Solar mass"""
9
9
 
10
+ # Geometrized nominal solar mass, m
10
11
  MRSUN = 1.476625038050124729627979840144936351e3
11
- """Geometrized nominal solar mass, m"""
12
12
 
13
+ # Newton's gravitational constant
13
14
  G = 6.67430e-11 # m^3 / kg / s^2
14
- """Newton's gravitational constant"""
15
15
 
16
+ # Speed of light
16
17
  C = 299792458.0 # m / s
17
- """Speed of light"""
18
18
 
19
- """Pi"""
19
+ # pi and 2pi
20
20
  PI = 3.141592653589793238462643383279502884
21
-
22
21
  TWO_PI = 6.283185307179586476925286766559005768
23
22
 
23
+ # G MSUN / C^3 in seconds
24
24
  gt = G * MSUN / (C**3.0)
25
- """
26
- G MSUN / C^3 in seconds
27
- """
28
25
 
26
+ # 1 solar mass in seconds. Same value as lal.MTSUN_SI
29
27
  MTSUN_SI = 4.925490947641266978197229498498379006e-6
30
- """1 solar mass in seconds. Same value as lal.MTSUN_SI"""
31
28
 
29
+ # Meters per Mpc.
32
30
  m_per_Mpc = 3.085677581491367278913937957796471611e22
33
- """
34
- Meters per Mpc.
35
- """
36
31
 
32
+ # 1 Mpc in seconds.
37
33
  MPC_SEC = m_per_Mpc / C
38
- """
39
- 1 Mpc in seconds.
40
- """
41
34
 
35
+ # Speed of light in vacuum (:math:`c`), in gigaparsecs per second
42
36
  clightGpc = C / 3.0856778570831e22
43
- """
44
- Speed of light in vacuum (:math:`c`), in gigaparsecs per second
45
- """
ml4gw/spectral.py CHANGED
@@ -441,7 +441,7 @@ def normalize_by_psd(
441
441
 
442
442
  # convert back to the time domain and normalize
443
443
  # TODO: what's this normalization factor?
444
- X = torch.fft.irfft(X_tilde, norm="forward", dim=-1)
444
+ X = torch.fft.irfft(X_tilde, n=X.shape[-1], norm="forward", dim=-1)
445
445
  X = X.float() / sample_rate**0.5
446
446
 
447
447
  # slice off corrupted data at edges of kernel
@@ -4,5 +4,6 @@ from .scaler import ChannelWiseScaler
4
4
  from .snr_rescaler import SnrRescaler
5
5
  from .spectral import SpectralDensity
6
6
  from .spectrogram import MultiResolutionSpectrogram
7
+ from .spline_interpolation import SplineInterpolate
7
8
  from .waveforms import WaveformProjector, WaveformSampler
8
9
  from .whitening import FixedWhiten, Whiten
@@ -1,11 +1,13 @@
1
1
  import math
2
- from typing import List, Optional, Tuple
2
+ import warnings
3
+ from typing import List, Tuple
3
4
 
4
5
  import torch
5
6
  import torch.nn.functional as F
6
7
  from jaxtyping import Float, Int
7
8
  from torch import Tensor
8
9
 
10
+ from ml4gw.transforms.spline_interpolation import SplineInterpolate
9
11
  from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
10
12
 
11
13
  """
@@ -38,7 +40,6 @@ class QTile(torch.nn.Module):
38
40
  mismatch:
39
41
  The maximum fractional mismatch between neighboring tiles
40
42
 
41
-
42
43
  """
43
44
 
44
45
  def __init__(
@@ -100,7 +101,9 @@ class QTile(torch.nn.Module):
100
101
  ).type(torch.long)
101
102
 
102
103
  def forward(
103
- self, fseries: FrequencySeries1to3d, norm: str = "median"
104
+ self,
105
+ fseries: FrequencySeries1to3d,
106
+ norm: str = "median",
104
107
  ) -> TimeSeries1to3d:
105
108
  """
106
109
  Compute the transform for this row
@@ -144,7 +147,7 @@ class QTile(torch.nn.Module):
144
147
  energy /= means
145
148
  else:
146
149
  raise ValueError("Invalid normalisation %r" % norm)
147
- return energy.type(torch.float32)
150
+ energy = energy.type(torch.float32)
148
151
  return energy
149
152
 
150
153
 
@@ -172,6 +175,19 @@ class SingleQTransform(torch.nn.Module):
172
175
  be chosen based on q, sample_rate, and duration
173
176
  mismatch:
174
177
  The maximum fractional mismatch between neighboring tiles
178
+ interpolation_method:
179
+ The method by which to interpolate each `QTile` to the specified
180
+ number of time and frequency bins. The acceptable values are
181
+ "bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
182
+ options will use PyTorch's built-in interpolation modes, while
183
+ "spline" will use the custom Torch-based implementation in
184
+ `ml4gw`, as PyTorch does not have spline-based intertpolation.
185
+ The "spline" mode is most similar to the results of GWpy's
186
+ Q-transform, which uses `scipy` to do spline interpolation.
187
+ However, it is also the slowest and most memory intensive due to
188
+ the matrix equation solving steps. Therefore, the default method
189
+ is "bicubic" as it produces the most similar results while
190
+ optimizing for computing performance.
175
191
  """
176
192
 
177
193
  def __init__(
@@ -182,6 +198,7 @@ class SingleQTransform(torch.nn.Module):
182
198
  q: float = 12,
183
199
  frange: List[float] = [0, torch.inf],
184
200
  mismatch: float = 0.2,
201
+ interpolation_method: str = "bicubic",
185
202
  ) -> None:
186
203
  super().__init__()
187
204
  self.q = q
@@ -190,20 +207,87 @@ class SingleQTransform(torch.nn.Module):
190
207
  self.duration = duration
191
208
  self.mismatch = mismatch
192
209
 
210
+ # If q is too large, the minimum of the frange computed
211
+ # below will be larger than the maximum
212
+ max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
213
+ if q >= max_q:
214
+ raise ValueError(
215
+ "The given q value is too large for the given duration and "
216
+ f"sample rate. The maximum allowable value is {max_q}"
217
+ )
218
+
219
+ if interpolation_method not in ["bilinear", "bicubic", "spline"]:
220
+ raise ValueError(
221
+ "Interpolation method must be either 'bilinear', 'bicubic', "
222
+ f"or 'spline'; got {interpolation_method}"
223
+ )
224
+ self.interpolation_method = interpolation_method
225
+
193
226
  qprime = self.q / 11 ** (1 / 2.0)
194
227
  if self.frange[0] <= 0: # set non-zero lower frequency
195
228
  self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
196
229
  if math.isinf(self.frange[1]): # set non-infinite upper frequency
197
230
  self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)
231
+
198
232
  self.freqs = self.get_freqs()
199
233
  self.qtile_transforms = torch.nn.ModuleList(
200
234
  [
201
- QTile(self.q, freq, self.duration, sample_rate, self.mismatch)
235
+ QTile(
236
+ q=self.q,
237
+ frequency=freq,
238
+ duration=self.duration,
239
+ sample_rate=sample_rate,
240
+ mismatch=self.mismatch,
241
+ )
202
242
  for freq in self.freqs
203
243
  ]
204
244
  )
205
245
  self.qtiles = None
206
246
 
247
+ if self.interpolation_method == "spline":
248
+ self._set_up_spline_interp()
249
+
250
+ def _set_up_spline_interp(self):
251
+ ntiles = [qtile.ntiles() for qtile in self.qtile_transforms]
252
+ # For efficiency, we'll stack all qtiles of the same length before
253
+ # interpolating, so we need to figure out which those are
254
+ unique_ntiles = sorted(list(set(ntiles)))
255
+ idx = torch.arange(len(ntiles))
256
+ self.stack_idx = [idx[Tensor(ntiles) == n] for n in unique_ntiles]
257
+
258
+ t_out = torch.arange(
259
+ 0, self.duration, self.duration / self.spectrogram_shape[1]
260
+ )
261
+ self.qtile_interpolators = torch.nn.ModuleList(
262
+ [
263
+ SplineInterpolate(
264
+ kx=3,
265
+ x_in=torch.arange(0, self.duration, self.duration / tiles),
266
+ y_in=torch.arange(len(idx)),
267
+ x_out=t_out,
268
+ y_out=torch.arange(len(idx)),
269
+ )
270
+ for tiles, idx in zip(unique_ntiles, self.stack_idx)
271
+ ]
272
+ )
273
+
274
+ t_in = t_out
275
+ f_in = self.freqs
276
+ f_out = torch.logspace(
277
+ math.log10(self.frange[0]),
278
+ math.log10(self.frange[-1]),
279
+ self.spectrogram_shape[0],
280
+ )
281
+
282
+ self.interpolator = SplineInterpolate(
283
+ kx=3,
284
+ ky=3,
285
+ x_in=t_in,
286
+ y_in=f_in,
287
+ x_out=t_out,
288
+ y_out=f_out,
289
+ )
290
+
207
291
  def get_freqs(self) -> Float[Tensor, " nfreq"]:
208
292
  """
209
293
  Calculate the frequencies that will be used in this transform.
@@ -220,7 +304,8 @@ class SingleQTransform(torch.nn.Module):
220
304
 
221
305
  freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
222
306
  freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
223
- freqs = (minf * freqs // fstepmin) * fstepmin
307
+ # Cast freqs to float64 to avoid off-by-ones from rounding
308
+ freqs = (minf * freqs.double() // fstepmin) * fstepmin
224
309
  return torch.unique(freqs)
225
310
 
226
311
  def get_max_energy(
@@ -268,7 +353,11 @@ class SingleQTransform(torch.nn.Module):
268
353
  if dimension == "batch":
269
354
  return torch.max(max_across_ft, dim=-1).values
270
355
 
271
- def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None:
356
+ def compute_qtiles(
357
+ self,
358
+ X: TimeSeries1to3d,
359
+ norm: str = "median",
360
+ ) -> None:
272
361
  """
273
362
  Take the FFT of the input timeseries and calculate the transform
274
363
  for each `QTile`
@@ -278,28 +367,40 @@ class SingleQTransform(torch.nn.Module):
278
367
  X[..., 1:] *= 2
279
368
  self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
280
369
 
281
- def interpolate(self, num_f_bins: int, num_t_bins: int) -> TimeSeries3d:
282
- """
283
- Interpolate each `QTile` to the specified number of time and
284
- frequency bins. Note that PyTorch does not have the same
285
- interpolation methods that GWpy uses, and so the interpolated
286
- spectrograms will be different even though the uninterpolated
287
- values match. The `bicubic` interpolation method is used as
288
- it seems to match GWpy most closely.
289
- """
370
+ def interpolate(self) -> TimeSeries3d:
290
371
  if self.qtiles is None:
291
372
  raise RuntimeError(
292
373
  "Q-tiles must first be computed with .compute_qtiles()"
293
374
  )
375
+ if self.interpolation_method == "spline":
376
+ qtiles = [
377
+ torch.stack([self.qtiles[i] for i in idx], dim=-2)
378
+ for idx in self.stack_idx
379
+ ]
380
+ time_interped = torch.cat(
381
+ [
382
+ interpolator(qtile)
383
+ for qtile, interpolator in zip(
384
+ qtiles, self.qtile_interpolators
385
+ )
386
+ ],
387
+ dim=-2,
388
+ )
389
+ return self.interpolator(time_interped)
390
+ num_f_bins, num_t_bins = self.spectrogram_shape
294
391
  resampled = [
295
392
  F.interpolate(
296
- qtile[None], (qtile.shape[-2], num_t_bins), mode="bicubic"
393
+ qtile[None],
394
+ (qtile.shape[-2], num_t_bins),
395
+ mode=self.interpolation_method,
297
396
  )
298
397
  for qtile in self.qtiles
299
398
  ]
300
399
  resampled = torch.stack(resampled, dim=-2)
301
400
  resampled = F.interpolate(
302
- resampled[0], (num_f_bins, num_t_bins), mode="bicubic"
401
+ resampled[0],
402
+ (num_f_bins, num_t_bins),
403
+ mode=self.interpolation_method,
303
404
  )
304
405
  return torch.squeeze(resampled)
305
406
 
@@ -307,7 +408,6 @@ class SingleQTransform(torch.nn.Module):
307
408
  self,
308
409
  X: TimeSeries1to3d,
309
410
  norm: str = "median",
310
- spectrogram_shape: Optional[Tuple[int, int]] = None,
311
411
  ):
312
412
  """
313
413
  Compute the Q-tiles and interpolate
@@ -321,24 +421,15 @@ class SingleQTransform(torch.nn.Module):
321
421
  three-dimensional, axes will be added during Q-tile
322
422
  computation.
323
423
  norm:
324
- The method of interpolation used by each QTile
325
- spectrogram_shape:
326
- The shape of the interpolated spectrogram, specified as
327
- `(num_f_bins, num_t_bins)`. Because the
328
- frequency spacing of the Q-tiles is in log-space, the frequency
329
- interpolation is log-spaced as well. If not given, the shape
330
- used to initialize the transform will be used.
424
+ The method of normalization used by each QTile
331
425
 
332
426
  Returns:
333
427
  The interpolated Q-transform for the batch of data. Output will
334
428
  have one more dimension than the input
335
429
  """
336
430
 
337
- if spectrogram_shape is None:
338
- spectrogram_shape = self.spectrogram_shape
339
- num_f_bins, num_t_bins = spectrogram_shape
340
431
  self.compute_qtiles(X, norm)
341
- return self.interpolate(num_f_bins, num_t_bins)
432
+ return self.interpolate()
342
433
 
343
434
 
344
435
  class QScan(torch.nn.Module):
@@ -376,14 +467,22 @@ class QScan(torch.nn.Module):
376
467
  spectrogram_shape: Tuple[int, int],
377
468
  qrange: List[float] = [4, 64],
378
469
  frange: List[float] = [0, torch.inf],
470
+ interpolation_method="bicubic",
379
471
  mismatch: float = 0.2,
380
472
  ) -> None:
381
473
  super().__init__()
382
474
  self.qrange = qrange
383
475
  self.mismatch = mismatch
384
- self.qs = self.get_qs()
385
476
  self.frange = frange
386
477
  self.spectrogram_shape = spectrogram_shape
478
+ max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
479
+ self.qs = self.get_qs()
480
+ if self.qs[-1] >= max_q:
481
+ warnings.warn(
482
+ "Some Q values exceed the maximum allowable Q value of "
483
+ f"{max_q}. The list of Q values to be tested in this "
484
+ "scan will be truncated to avoid those values."
485
+ )
387
486
 
388
487
  # Deliberately doing something different from GWpy here.
389
488
  # Their final frange is the intersection of the frange
@@ -397,9 +496,11 @@ class QScan(torch.nn.Module):
397
496
  spectrogram_shape=spectrogram_shape,
398
497
  q=q,
399
498
  frange=self.frange.copy(),
499
+ interpolation_method=interpolation_method,
400
500
  mismatch=self.mismatch,
401
501
  )
402
502
  for q in self.qs
503
+ if q < max_q
403
504
  ]
404
505
  )
405
506
 
@@ -415,6 +516,7 @@ class QScan(torch.nn.Module):
415
516
  self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
416
517
  for i in range(nplanes)
417
518
  ]
519
+
418
520
  return qs
419
521
 
420
522
  def forward(
@@ -422,7 +524,6 @@ class QScan(torch.nn.Module):
422
524
  X: TimeSeries1to3d,
423
525
  fsearch_range: List[float] = None,
424
526
  norm: str = "median",
425
- spectrogram_shape: Optional[Tuple[int, int]] = None,
426
527
  ):
427
528
  """
428
529
  Compute the set of QTiles for each Q transform and determine which
@@ -442,12 +543,6 @@ class QScan(torch.nn.Module):
442
543
  for the maximum energy
443
544
  norm:
444
545
  The method of interpolation used by each QTile
445
- spectrogram_shape:
446
- The shape of the interpolated spectrogram, specified as
447
- `(num_f_bins, num_t_bins)`. Because the
448
- frequency spacing of the Q-tiles is in log-space, the frequency
449
- interpolation is log-spaced as well. If not given, the shape
450
- used to initialize the transform will be used.
451
546
 
452
547
  Returns:
453
548
  An interpolated Q-transform for the batch of data. Output will
@@ -463,7 +558,4 @@ class QScan(torch.nn.Module):
463
558
  ]
464
559
  )
465
560
  )
466
- if spectrogram_shape is None:
467
- spectrogram_shape = self.spectrogram_shape
468
- num_f_bins, num_t_bins = spectrogram_shape
469
- return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
561
+ return self.q_transforms[idx].interpolate()
@@ -36,7 +36,9 @@ class ChannelWiseScaler(FittableTransform):
36
36
  self.register_buffer("mean", mean)
37
37
  self.register_buffer("std", std)
38
38
 
39
- def fit(self, X: Float[Tensor, "... time"]) -> None:
39
+ def fit(
40
+ self, X: Float[Tensor, "... time"], std_reg: Optional[float] = 0.0
41
+ ) -> None:
40
42
  """Fit the scaling parameters to a timeseries
41
43
 
42
44
  Computes the channel-wise mean and standard deviation
@@ -59,7 +61,7 @@ class ChannelWiseScaler(FittableTransform):
59
61
  "Can't fit channel wise mean and standard deviation "
60
62
  "from tensor of shape {}".format(X.shape)
61
63
  )
62
-
64
+ std += std_reg * torch.ones_like(std)
63
65
  super().build(mean=mean, std=std)
64
66
 
65
67
  def forward(
@@ -0,0 +1,370 @@
1
+ """
2
+ Adaptation of code from https://github.com/dottormale/Qtransform
3
+ """
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+
12
+ class SplineInterpolate(torch.nn.Module):
13
+ """
14
+ Perform 1D or 2D spline interpolation based on De Boor's method.
15
+ Supports batched, multi-channel inputs, so acceptable data
16
+ shapes are `(width)`, `(height, width)`, `(batch, width)`,
17
+ `(batch, height, width)`, `(batch, channel, width)`, and
18
+ `(batch, channel, height, width)`.
19
+
20
+ During initialization of this Module, both the desired input
21
+ and output coordinate Tensors can be specified to allow
22
+ pre-computation of the B-spline basis matrices, though the only
23
+ mandatory argument is the coordinates of the data along the
24
+ `width` dimension. If no argument is given for coordinates along
25
+ the `height` dimension, it is assumed that 1D interpolation is
26
+ desired.
27
+
28
+ Unlike scipy's implementation of spline interpolation, the data
29
+ to be interpolated is not passed until actually calling the
30
+ object. This is useful for cases where the input and output
31
+ coordinates are known in advance, but the data is not, so that
32
+ the interpolator can be set up ahead of time.
33
+
34
+ WARNING: compared to scipy's spline interpolation, this method
35
+ produces edge artifacts when the output coordinates are near
36
+ the boundaries of the input coordinates. Therefore, it is
37
+ recommended to interpolate only to coordinates that are well
38
+ within the input coordinate range. Unfortunately, the specific
39
+ definition of "well within" changes based on the size of the
40
+ data, so some testing may be required to get good results.
41
+
42
+ Args:
43
+ x_in:
44
+ Coordinates of the width dimension of the data
45
+ y_in:
46
+ Coordinates of the height dimension of the data. If not
47
+ specified, it is assumed the 1D interpolation is desired,
48
+ and so the default value is a Tensor of length 1
49
+ kx:
50
+ Degree of spline interpolation along the width dimension.
51
+ Default is cubic.
52
+ ky:
53
+ Degree of spline interpolation along the height dimension.
54
+ Default is cubic.
55
+ sx:
56
+ Regularization factor to avoid singularities during matrix
57
+ inversion for interpolation along the width dimension. Not
58
+ to be confused with the `s` parameter in scipy's spline
59
+ methods, which controls the number of knots.
60
+ sy:
61
+ Regularization factor to avoid singularities during matrix
62
+ inversion for interpolation along the height dimension.
63
+ x_out:
64
+ Coordinates for the data to be interpolated to along the
65
+ width dimension. If not specified during initialization,
66
+ this must be specified during the object call.
67
+ y_out:
68
+ Coordinates for the data to be interpolated to along the
69
+ height dimension. If not specified during initialization,
70
+ this must be specified during the object call.
71
+
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ x_in: Tensor,
77
+ y_in: Tensor = Tensor([1]),
78
+ kx: int = 3,
79
+ ky: int = 3,
80
+ sx: float = 0.001,
81
+ sy: float = 0.001,
82
+ x_out: Optional[Tensor] = None,
83
+ y_out: Optional[Tensor] = None,
84
+ ):
85
+ super().__init__()
86
+ self.kx = kx
87
+ self.ky = ky
88
+ self.sx = sx
89
+ self.sy = sy
90
+ self.register_buffer("x_in", x_in)
91
+ self.register_buffer("y_in", y_in)
92
+ self.register_buffer("x_out", x_out)
93
+ self.register_buffer("y_out", y_out)
94
+
95
+ tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
96
+ self.register_buffer("tx", tx)
97
+ self.register_buffer("Bx", Bx)
98
+ self.register_buffer("BxT_Bx", BxT_Bx)
99
+
100
+ ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
101
+ self.register_buffer("ty", ty)
102
+ self.register_buffer("By", By)
103
+ self.register_buffer("ByT_By", ByT_By)
104
+
105
+ if self.x_out is not None:
106
+ Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
107
+ self.register_buffer("Bx_out", Bx_out)
108
+ if self.y_out is not None:
109
+ By_out = self.bspline_basis_natural(y_out, ky, self.ty)
110
+ self.register_buffer("By_out", By_out)
111
+
112
+ def _compute_knots_and_basis_matrices(self, x, k, s):
113
+ knots = self.generate_natural_knots(x, k)
114
+ basis_matrix = self.bspline_basis_natural(x, k, knots)
115
+ identity = torch.eye(basis_matrix.shape[-1])
116
+ B_T_B = basis_matrix.T @ basis_matrix + s * identity
117
+ return knots, basis_matrix, B_T_B
118
+
119
+ def generate_natural_knots(self, x: Tensor, k: int) -> Tensor:
120
+ """
121
+ Generates a natural knot sequence for B-spline interpolation.
122
+ Natural knot sequence means that 2*k knots are added to the beginning
123
+ and end of datapoints as replicas of first and last datapoint
124
+ respectively in order to enforce natural boundary conditions,
125
+ i.e. second derivative = 0.
126
+ The other n nodes are placed in correspondece of the data points.
127
+
128
+ Args:
129
+ x: Tensor of data point positions.
130
+ k: Degree of the spline.
131
+
132
+ Returns:
133
+ Tensor of knot positions.
134
+ """
135
+ return F.pad(x[None], (k, k), mode="replicate")[0]
136
+
137
+ def compute_L_R(
138
+ self,
139
+ x: Tensor,
140
+ t: Tensor,
141
+ d: int,
142
+ m: int,
143
+ ) -> Tuple[Tensor, Tensor]:
144
+
145
+ """
146
+ Compute the L and R values for B-spline basis functions.
147
+ L and R are respectively the first and second coefficient multiplying
148
+ B_{i,p-1}(x) and B_{i+1,p-1}(x) in De Boor's recursive formula for
149
+ Bspline basis funciton computation
150
+ See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for details
151
+
152
+ Args:
153
+ x:
154
+ Tensor of data point positions.
155
+ t:
156
+ Tensor of knot positions.
157
+ d:
158
+ Current degree of the basis function.
159
+ m:
160
+ Number of intervals (n - k - 1, where n is the number of knots
161
+ and k is the degree).
162
+
163
+ Returns:
164
+ L: Tensor containing left values for the B-spline basis functions.
165
+ R: Tensor containing right values for the B-spline basis functions.
166
+ """
167
+ left_num = x.unsqueeze(1) - t[:m].unsqueeze(0)
168
+ left_den = t[d : m + d] - t[:m]
169
+ L = left_num / left_den.unsqueeze(0)
170
+ L = torch.nan_to_num_(L, nan=0.0, posinf=0.0, neginf=0.0)
171
+
172
+ right_num = t[d + 1 : m + d + 1] - x.unsqueeze(1)
173
+ right_den = t[d + 1 : m + d + 1] - t[1 : m + 1]
174
+ R = right_num / right_den.unsqueeze(0)
175
+ R = torch.nan_to_num_(R, nan=0.0, posinf=0.0, neginf=0.0)
176
+
177
+ return L, R
178
+
179
+ def zeroth_order(
180
+ self,
181
+ x: Tensor,
182
+ k: int,
183
+ t: Tensor,
184
+ n: int,
185
+ m: int,
186
+ ) -> Tensor:
187
+
188
+ """
189
+ Compute the zeroth-order B-spline basis functions
190
+ according to de Boors recursive formula.
191
+ See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for reference
192
+
193
+ Args:
194
+ x:
195
+ Tensor of data point positions.
196
+ k:
197
+ Degree of the spline.
198
+ t:
199
+ Tensor of knot positions.
200
+ n:
201
+ Number of data points.
202
+ m:
203
+ Number of intervals (n - k - 1, where n is the number of knots
204
+ and k is the degree).
205
+
206
+ Returns:
207
+ b: Tensor containing the zeroth-order B-spline basis functions.
208
+ """
209
+ b = torch.zeros((n, m, k + 1))
210
+
211
+ mask_lower = t[: m + 1].unsqueeze(0)[:, :-1] <= x.unsqueeze(1)
212
+ mask_upper = x.unsqueeze(1) < t[: m + 1].unsqueeze(0)[:, 1:]
213
+
214
+ b[:, :, 0] = mask_lower & mask_upper
215
+ b[:, 0, 0] = torch.where(x < t[1], torch.ones_like(x), b[:, 0, 0])
216
+ b[:, -1, 0] = torch.where(x >= t[-2], torch.ones_like(x), b[:, -1, 0])
217
+ return b
218
+
219
+ def bspline_basis_natural(
220
+ self,
221
+ x: Tensor,
222
+ k: int,
223
+ t: Tensor,
224
+ ) -> Tensor:
225
+ """
226
+ Compute bspline basis function using de Boor's recursive formula
227
+ (See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for reference)
228
+ Args:
229
+ x: Tensor of data point positions.
230
+ k: Degree of the spline.
231
+ t: Tensor of knot positions.
232
+
233
+ Returns:
234
+ Tensor containing the kth-order B-spline basis functions
235
+ """
236
+
237
+ if len(x) == 1:
238
+ return torch.eye(1)
239
+ n = x.shape[0]
240
+ m = t.shape[0] - k - 1
241
+
242
+ # calculate zeroth order basis funciton
243
+ b = self.zeroth_order(x, k, t, n, m)
244
+
245
+ zeros_tensor = torch.zeros(b.shape[0], 1)
246
+ # recursive de Boors formula for bspline basis functions
247
+ for d in range(1, k + 1):
248
+ L, R = self.compute_L_R(x, t, d, m)
249
+ left = L * b[:, :, d - 1]
250
+
251
+ temp_b = torch.cat([b[:, 1:, d - 1], zeros_tensor], dim=1)
252
+
253
+ right = R * temp_b
254
+ b[:, :, d] = left + right
255
+
256
+ return b[:, :, -1]
257
+
258
+ def bivariate_spline_fit_natural(self, Z):
259
+
260
+ if len(Z.shape) == 3:
261
+ Z_Bx = torch.matmul(Z, self.Bx)
262
+ # ((BxT @ Bx)^-1 @ (Z @ Bx)T)T = Z @ BxT^-1
263
+ return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT
264
+
265
+ # Adding batch/channel dimension handling
266
+ # ByT @ Z @ Bx
267
+ ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx)
268
+ # (ByT @ By)^-1 @ (ByT @ Z @ Bx) = By^-1 @ Z @ Bx
269
+ E = torch.linalg.solve(self.ByT_By, ByT_Z_Bx)
270
+ # ((BxT @ Bx)^-1 @ (By^-1 @ Z @ Bx)T)T = By^-1 @ Z @ BxT^-1
271
+ return torch.linalg.solve(self.BxT_Bx, E.mT).mT
272
+
273
+ def evaluate_bivariate_spline(self, C: Tensor):
274
+ """
275
+ Evaluate a bivariate spline on a grid of x and y points.
276
+
277
+ Args:
278
+ C: Coefficient tensor of shape (batch_size, mx, my).
279
+
280
+ Returns:
281
+ Z_interp: Interpolated values at the grid points.
282
+ """
283
+ # Perform matrix multiplication using einsum to get Z_interp
284
+ if len(C.shape) == 3:
285
+ return torch.matmul(C, self.Bx_out.mT)
286
+ return torch.einsum("ik,bckm,mj->bcij", self.By_out, C, self.Bx_out.mT)
287
+
288
+ def _validate_inputs(self, Z, x_out, y_out):
289
+ if x_out is None and self.x_out is None:
290
+ raise ValueError(
291
+ "Output x-coordinates were not specified in either object "
292
+ "creation or in forward call"
293
+ )
294
+
295
+ if y_out is None and self.y_out is None:
296
+ y_out = self.y_in
297
+
298
+ dims = len(Z.shape)
299
+ if dims > 4:
300
+ raise ValueError("Input data has more than 4 dimensions")
301
+
302
+ if len(self.y_in) > 1 and dims == 1:
303
+ raise ValueError(
304
+ "An input y-coordinate array with length greater than 1 "
305
+ "was given, but the input data is 1-dimensional. Expected "
306
+ "input data to be at least 2-dimensional"
307
+ )
308
+
309
+ # Expand Z to have 4 dimensions
310
+ # There are 6 valid input shapes: (w), (b, w), (b, c, w),
311
+ # (h, w), (b, h, w), and (b, c, h, w).
312
+
313
+ # If the input y coordinate array has length 1,
314
+ # assume the first dimension(s) are batch dimensions
315
+ # and that no height dimension is included in Z
316
+ idx = -2 if len(self.y_in) == 1 else -3
317
+ while len(Z.shape) < 4:
318
+ Z = Z.unsqueeze(idx)
319
+
320
+ if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
321
+ raise ValueError(
322
+ "The spatial dimensions of the data tensor do not match "
323
+ "the given input dimensions. "
324
+ f"Expected [{len(self.y_in)}, {len(self.x_in)}], but got "
325
+ f"[{Z.shape[-2]}, {Z.shape[-1]}]"
326
+ )
327
+
328
+ return Z, y_out
329
+
330
+ def forward(
331
+ self,
332
+ Z: Tensor,
333
+ x_out: Optional[Tensor] = None,
334
+ y_out: Optional[Tensor] = None,
335
+ ) -> Tensor:
336
+ """
337
+ Compute the interpolated data
338
+
339
+ Args:
340
+ Z:
341
+ Tensor of data to be interpolated. Must be between 1 and 4
342
+ dimensions. The shape of the tensor must agree with the
343
+ input coordinates given on initialization. If `y_in` was
344
+ not specified during initialization, it is assumed that
345
+ Z does not have a height dimension.
346
+ x_out:
347
+ Coordinates to interpolate the data to along the width
348
+ dimension. Overrides any value that was set during
349
+ initialization.
350
+ y_out:
351
+ Coordinates to interpolate the data to along the height
352
+ dimension. Overrides any value that was set during
353
+ initialization.
354
+
355
+ Returns:
356
+ A 4D tensor with shape `(batch, channel, height, width)`.
357
+ Depending on the input data shape, many of these dimensions
358
+ may have length 1.
359
+ """
360
+
361
+ Z, y_out = self._validate_inputs(Z, x_out, y_out)
362
+
363
+ if x_out is not None:
364
+ self.Bx_out = self.bspline_basis_natural(x_out, self.kx, self.tx)
365
+ if y_out is not None:
366
+ self.By_out = self.bspline_basis_natural(y_out, self.ky, self.ty)
367
+
368
+ coef = self.bivariate_spline_fit_natural(Z)
369
+ Z_interp = self.evaluate_bivariate_spline(coef)
370
+ return Z_interp
@@ -1,5 +1,2 @@
1
- from .phenom_d import IMRPhenomD
2
- from .phenom_p import IMRPhenomPv2
3
- from .ringdown import Ringdown
4
- from .sine_gaussian import SineGaussian
5
- from .taylorf2 import TaylorF2
1
+ from .adhoc import *
2
+ from .cbc import *
@@ -0,0 +1,2 @@
1
+ from .ringdown import Ringdown
2
+ from .sine_gaussian import SineGaussian
@@ -0,0 +1,3 @@
1
+ from .phenom_d import IMRPhenomD
2
+ from .phenom_p import IMRPhenomPv2
3
+ from .taylorf2 import TaylorF2
@@ -1,4 +1,4 @@
1
- from typing import Dict, Tuple
1
+ from typing import Dict, Optional, Tuple
2
2
 
3
3
  import torch
4
4
  from jaxtyping import Float
@@ -6,6 +6,7 @@ from torch import Tensor
6
6
 
7
7
  from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
8
8
  from ml4gw.types import BatchTensor, FrequencySeries1d
9
+ from ml4gw.waveforms.conversion import rotate_y, rotate_z
9
10
 
10
11
  from .phenom_d import IMRPhenomD
11
12
 
@@ -25,11 +26,11 @@ class IMRPhenomPv2(IMRPhenomD):
25
26
  s2x: BatchTensor,
26
27
  s2y: BatchTensor,
27
28
  s2z: BatchTensor,
28
- dist_mpc: BatchTensor,
29
- tc: BatchTensor,
30
- phiRef: BatchTensor,
31
- incl: BatchTensor,
29
+ distance: BatchTensor,
30
+ phic: BatchTensor,
31
+ inclination: BatchTensor,
32
32
  f_ref: float,
33
+ tc: Optional[BatchTensor] = None,
33
34
  ):
34
35
  """
35
36
  IMRPhenomPv2 waveform
@@ -53,13 +54,13 @@ class IMRPhenomPv2(IMRPhenomD):
53
54
  Spin component y of the second BH.
54
55
  s2z :
55
56
  Spin component z of the second BH.
56
- dist_mpc :
57
+ distance :
57
58
  Luminosity distance in Mpc.
58
59
  tc :
59
60
  Coalescence time.
60
- phiRef :
61
+ phic :
61
62
  Reference phase.
62
- incl :
63
+ inclination :
63
64
  Inclination angle.
64
65
  f_ref :
65
66
  Reference frequency in Hz.
@@ -71,6 +72,9 @@ class IMRPhenomPv2(IMRPhenomD):
71
72
  Note: m1 must be larger than m2.
72
73
  """
73
74
 
75
+ if tc is None:
76
+ tc = torch.zeros_like(chirp_mass)
77
+
74
78
  m2 = chirp_mass * (1.0 + mass_ratio) ** 0.2 / mass_ratio**0.6
75
79
  m1 = m2 * mass_ratio
76
80
 
@@ -89,7 +93,7 @@ class IMRPhenomPv2(IMRPhenomD):
89
93
  phi_aligned,
90
94
  zeta_polariz,
91
95
  ) = self.convert_spins(
92
- m1, m2, f_ref, phiRef, incl, s1x, s1y, s1z, s2x, s2y, s2z
96
+ m1, m2, f_ref, phic, inclination, s1x, s1y, s1z, s2x, s2y, s2z
93
97
  )
94
98
 
95
99
  phic = 2 * phi_aligned
@@ -152,7 +156,7 @@ class IMRPhenomPv2(IMRPhenomD):
152
156
  phic,
153
157
  M,
154
158
  xi,
155
- dist_mpc,
159
+ distance,
156
160
  )
157
161
 
158
162
  hp, hc = self.PhenomPCoreTwistUp(
@@ -309,7 +313,7 @@ class IMRPhenomPv2(IMRPhenomD):
309
313
  phic,
310
314
  M,
311
315
  xi,
312
- dist_mpc,
316
+ distance,
313
317
  ):
314
318
  """
315
319
  m1, m2: in solar masses
@@ -324,10 +328,10 @@ class IMRPhenomPv2(IMRPhenomD):
324
328
  phase, _ = self.phenom_d_phase(Mf, m1, m2, eta, eta2, chi1, chi2, xi)
325
329
  phase = (phase.mT - (phic + PI / 4.0)).mT
326
330
  Amp = self.phenom_d_amp(
327
- Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, dist_mpc
331
+ Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, distance
328
332
  )[0]
329
333
  Amp0 = self.get_Amp0(Mf, eta)
330
- dist_s = dist_mpc * MPC_SEC
334
+ dist_s = distance * MPC_SEC
331
335
  Amp = ((Amp0 * Amp).mT * (M_s**2.0) / dist_s).mT
332
336
  # phase -= 2. * phic; # line 1316 ???
333
337
  hPhenom = Amp * (torch.exp(-1j * phase))
@@ -391,16 +395,6 @@ class IMRPhenomPv2(IMRPhenomD):
391
395
 
392
396
  return interpolated.reshape(original_shape)
393
397
 
394
- def ROTATEZ(self, angle: BatchTensor, x, y, z):
395
- tmp_x = x * torch.cos(angle) - y * torch.sin(angle)
396
- tmp_y = x * torch.sin(angle) + y * torch.cos(angle)
397
- return tmp_x, tmp_y, z
398
-
399
- def ROTATEY(self, angle, x, y, z):
400
- tmp_x = x * torch.cos(angle) + z * torch.sin(angle)
401
- tmp_z = -x * torch.sin(angle) + z * torch.cos(angle)
402
- return tmp_x, y, tmp_z
403
-
404
398
  def L2PNR(
405
399
  self,
406
400
  v: BatchTensor,
@@ -425,8 +419,8 @@ class IMRPhenomPv2(IMRPhenomD):
425
419
  m1: BatchTensor,
426
420
  m2: BatchTensor,
427
421
  f_ref: float,
428
- phiRef: BatchTensor,
429
- incl: BatchTensor,
422
+ phic: BatchTensor,
423
+ inclination: BatchTensor,
430
424
  s1x: BatchTensor,
431
425
  s1y: BatchTensor,
432
426
  s1z: BatchTensor,
@@ -486,32 +480,32 @@ class IMRPhenomPv2(IMRPhenomD):
486
480
  # First we determine kappa
487
481
  # in the source frame, the components of N are given in
488
482
  # Eq (35c) of T1500606-v6
489
- Nx_sf = torch.sin(incl) * torch.cos(PI / 2.0 - phiRef)
490
- Ny_sf = torch.sin(incl) * torch.sin(PI / 2.0 - phiRef)
491
- Nz_sf = torch.cos(incl)
483
+ Nx_sf = torch.sin(inclination) * torch.cos(PI / 2.0 - phic)
484
+ Ny_sf = torch.sin(inclination) * torch.sin(PI / 2.0 - phic)
485
+ Nz_sf = torch.cos(inclination)
492
486
 
493
487
  tmp_x = Nx_sf
494
488
  tmp_y = Ny_sf
495
489
  tmp_z = Nz_sf
496
490
 
497
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
498
- tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
491
+ tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z)
492
+ tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
499
493
 
500
494
  kappa = -torch.arctan2(tmp_y, tmp_x)
501
495
 
502
496
  # Then we determine alpha0, by rotating LN
503
497
  tmp_x, tmp_y, tmp_z = 0, 0, 1
504
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
505
- tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
506
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
498
+ tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z)
499
+ tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
500
+ tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z)
507
501
 
508
502
  alpha0 = torch.arctan2(tmp_y, tmp_x)
509
503
 
510
504
  # Finally we determine thetaJ, by rotating N
511
505
  tmp_x, tmp_y, tmp_z = Nx_sf, Ny_sf, Nz_sf
512
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
513
- tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
514
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
506
+ tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z)
507
+ tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
508
+ tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z)
515
509
  Nx_Jf, Nz_Jf = tmp_x, tmp_z
516
510
  thetaJN = torch.arccos(Nz_Jf)
517
511
 
@@ -528,13 +522,13 @@ class IMRPhenomPv2(IMRPhenomD):
528
522
  # Both triads differ from each other by a rotation around N by an angle
529
523
  # \zeta and we need to rotate the polarizations accordingly by 2\zeta
530
524
 
531
- Xx_sf = -torch.cos(incl) * torch.sin(phiRef)
532
- Xy_sf = -torch.cos(incl) * torch.cos(phiRef)
533
- Xz_sf = torch.sin(incl)
525
+ Xx_sf = -torch.cos(inclination) * torch.sin(phic)
526
+ Xy_sf = -torch.cos(inclination) * torch.cos(phic)
527
+ Xz_sf = torch.sin(inclination)
534
528
  tmp_x, tmp_y, tmp_z = Xx_sf, Xy_sf, Xz_sf
535
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
536
- tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
537
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
529
+ tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z)
530
+ tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
531
+ tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z)
538
532
 
539
533
  # Now the tmp_a are the components of X in the J frame
540
534
  # We need the polar angle of that vector in the P,Q basis of Arun et al
@@ -0,0 +1,187 @@
1
+ import torch
2
+
3
+ from ml4gw.constants import MTSUN_SI, PI
4
+ from ml4gw.types import BatchTensor
5
+
6
+
7
+ def rotate_z(angle: BatchTensor, x, y, z):
8
+ x_tmp = x * torch.cos(angle) - y * torch.sin(angle)
9
+ y_tmp = x * torch.sin(angle) + y * torch.cos(angle)
10
+ return x_tmp, y_tmp, z
11
+
12
+
13
+ def rotate_y(angle, x, y, z):
14
+ x_tmp = x * torch.cos(angle) + z * torch.sin(angle)
15
+ z_tmp = -x * torch.sin(angle) + z * torch.cos(angle)
16
+ return x_tmp, y, z_tmp
17
+
18
+
19
+ def XLALSimInspiralLN(
20
+ total_mass: BatchTensor, eta: BatchTensor, v: BatchTensor
21
+ ):
22
+ """
23
+ See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L2173 # noqa
24
+ """
25
+ return total_mass**2 * eta / v
26
+
27
+
28
+ def XLALSimInspiralL_2PN(eta: BatchTensor):
29
+ """
30
+ See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L2181 # noqa
31
+ """
32
+ return 1.5 + eta / 6.0
33
+
34
+
35
+ def bilby_spins_to_lalsim(
36
+ theta_jn: BatchTensor,
37
+ phi_jl: BatchTensor,
38
+ tilt_1: BatchTensor,
39
+ tilt_2: BatchTensor,
40
+ phi_12: BatchTensor,
41
+ a_1: BatchTensor,
42
+ a_2: BatchTensor,
43
+ mass_1: BatchTensor,
44
+ mass_2: BatchTensor,
45
+ f_ref: float,
46
+ phi_ref: BatchTensor,
47
+ ):
48
+ """
49
+ Converts between bilby spin and lalsimulation spin conventions.
50
+
51
+ See https://github.com/bilby-dev/bilby/blob/cccdf891e82d46319e69dbfdf48c4970b4e9a727/bilby/gw/conversion.py#L105 # noqa
52
+ and https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L3594 # noqa
53
+
54
+ Args:
55
+ theta_jn: BatchTensor,
56
+ phi_jl: BatchTensor,
57
+ tilt_1: BatchTensor,
58
+ tilt_2: BatchTensor,
59
+ phi_12: BatchTensor,
60
+ a_1: BatchTensor,
61
+ a_2: BatchTensor,
62
+ mass_1: BatchTensor,
63
+ mass_2: BatchTensor,
64
+ f_ref: float,
65
+ phi_ref: BatchTensor,
66
+ """
67
+
68
+ # check if f_ref is valid
69
+ if f_ref <= 0.0:
70
+ raise ValueError(
71
+ "f_ref <= 0 is invalid. "
72
+ "Please pass in the starting GW frequency instead."
73
+ )
74
+
75
+ # starting frame: LNhat is along the z-axis and the unit
76
+ # spin vectors are defined from the angles relative to LNhat.
77
+ # Note that we put s1hat in the x-z plane, and phi12
78
+ # sets the azimuthal angle of s2hat measured from the x-axis.
79
+ lnh_x = 0
80
+ lnh_y = 0
81
+ lnh_z = 1
82
+ # Spins are given wrt to L,
83
+ # but still we cannot fill the spin as we do not know
84
+ # what will be the relative orientation of L and N.
85
+ # Note that these spin components are NOT wrt to binary
86
+ # separation vector, but wrt to binary separation vector
87
+ # at phiref=0.
88
+
89
+ s1hatx = torch.sin(tilt_1) * torch.cos(phi_ref)
90
+ s1haty = torch.sin(tilt_1) * torch.sin(phi_ref)
91
+ s1hatz = torch.cos(tilt_1)
92
+ s2hatx = torch.sin(tilt_2) * torch.cos(phi_12 + phi_ref)
93
+ s2haty = torch.sin(tilt_2) * torch.sin(phi_12 + phi_ref)
94
+ s2hatz = torch.cos(tilt_2)
95
+
96
+ total_mass = mass_1 + mass_2
97
+
98
+ eta = mass_1 * mass_2 / (mass_1 + mass_2) / (mass_1 + mass_2)
99
+
100
+ # v parameter at reference point
101
+ v0 = ((mass_1 + mass_2) * MTSUN_SI * PI * f_ref) ** (1 / 3)
102
+
103
+ # Define S1, S2, J with proper magnitudes */
104
+
105
+ l_mag = XLALSimInspiralLN(total_mass, eta, v0) * (
106
+ 1.0 + v0 * v0 * XLALSimInspiralL_2PN(eta)
107
+ )
108
+ s1x = mass_1 * mass_1 * a_1 * s1hatx
109
+ s1y = mass_1 * mass_1 * a_1 * s1haty
110
+ s1z = mass_1 * mass_1 * a_1 * s1hatz
111
+ s2x = mass_2 * mass_2 * a_2 * s2hatx
112
+ s2y = mass_2 * mass_2 * a_2 * s2haty
113
+ s2z = mass_2 * mass_2 * a_2 * s2hatz
114
+ Jx = s1x + s2x
115
+ Jy = s1y + s2y
116
+ Jz = l_mag + s1z + s2z
117
+
118
+ # Normalize J to Jhat, find its angles in starting frame */
119
+ Jnorm = torch.sqrt(Jx * Jx + Jy * Jy + Jz * Jz)
120
+ Jhatx = Jx / Jnorm
121
+ Jhaty = Jy / Jnorm
122
+ Jhatz = Jz / Jnorm
123
+ theta0 = torch.acos(Jhatz)
124
+ phi0 = torch.atan2(Jhaty, Jhatx)
125
+
126
+ # Rotation 1: Rotate about z-axis by -phi0 to put Jhat in x-z plane
127
+ s1hatx, s1haty, s1hatz = rotate_z(-phi0, s1hatx, s1haty, s1hatz)
128
+ s2hatx, s2haty, s2hatz = rotate_z(-phi0, s2hatx, s2haty, s2hatz)
129
+
130
+ # Rotation 2: Rotate about new y-axis by -theta0
131
+ # to put Jhat along z-axis
132
+
133
+ lnh_x, lnh_y, lnh_z = rotate_y(-theta0, lnh_x, lnh_y, lnh_z)
134
+ s1hatx, s1haty, s1hatz = rotate_y(-theta0, s1hatx, s1haty, s1hatz)
135
+ s2hatx, s2haty, s2hatz = rotate_y(-theta0, s2hatx, s2haty, s2hatz)
136
+
137
+ # Rotation 3: Rotate about new z-axis by phiJL to put L at desired
138
+ # azimuth about J. Note that is currently in x-z plane towards -x
139
+ # (i.e. azimuth=pi). Hence we rotate about z by phiJL - LAL_PI
140
+ lnh_x, lnh_y, lnh_z = rotate_z(phi_jl - PI, lnh_x, lnh_y, lnh_z)
141
+ s1hatx, s1haty, s1hatz = rotate_z(phi_jl - PI, s1hatx, s1haty, s1hatz)
142
+ s2hatx, s2haty, s2hatz = rotate_z(phi_jl - PI, s2hatx, s2haty, s2hatz)
143
+
144
+ # The cosinus of the angle between L and N is the scalar
145
+ # product of the two vectors.
146
+ # We do not need to perform additional rotation to compute it.
147
+ Nx = 0.0
148
+ Ny = torch.sin(theta_jn)
149
+ Nz = torch.cos(theta_jn)
150
+ incl = torch.acos(Nx * lnh_x + Ny * lnh_y + Nz * lnh_z)
151
+
152
+ # Rotation 4-5: Now J is along z and N in y-z plane, inclined from J
153
+ # by thetaJN and with >ve component along y.
154
+ # Now we bring L into the z axis to get spin components.
155
+ thetalj = torch.acos(lnh_z)
156
+ phil = torch.atan2(lnh_y, lnh_x)
157
+
158
+ s1hatx, s1haty, s1hatz = rotate_z(-phil, s1hatx, s1haty, s1hatz)
159
+ s2hatx, s2haty, s2hatz = rotate_z(-phil, s2hatx, s2haty, s2hatz)
160
+ Nx, Ny, Nz = rotate_z(-phil, Nx, Ny, Nz)
161
+
162
+ s1hatx, s1haty, s1hatz = rotate_y(-thetalj, s1hatx, s1haty, s1hatz)
163
+ s2hatx, s2haty, s2hatz = rotate_y(-thetalj, s2hatx, s2haty, s2hatz)
164
+ Nx, Ny, Nz = rotate_y(-thetalj, Nx, Ny, Nz)
165
+
166
+ # Rotation 6: Now L is along z and we have to bring N
167
+ # in the y-z plane with >ve y components.
168
+
169
+ phiN = torch.atan2(Ny, Nx)
170
+ # Note the extra -phiRef here:
171
+ # output spins must be given wrt to two body separations
172
+ # which are rigidly rotated with spins
173
+ s1hatx, s1haty, s1hatz = rotate_z(
174
+ PI / 2.0 - phiN - phi_ref, s1hatx, s1haty, s1hatz
175
+ )
176
+ s2hatx, s2haty, s2hatz = rotate_z(
177
+ PI / 2.0 - phiN - phi_ref, s2hatx, s2haty, s2hatz
178
+ )
179
+
180
+ s1x = s1hatx * a_1
181
+ s1y = s1haty * a_1
182
+ s1z = s1hatz * a_1
183
+ s2x = s2hatx * a_2
184
+ s2y = s2haty * a_2
185
+ s2z = s2hatz * a_2
186
+
187
+ return incl, s1x, s1y, s1z, s2x, s2y, s2z
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ml4gw
3
- Version: 0.5.1
3
+ Version: 0.6.0
4
4
  Summary: Tools for training torch models on gravitational wave data
5
5
  Author: Alec Gunny
6
6
  Author-email: alec.gunny@ligo.org
@@ -1,6 +1,6 @@
1
1
  ml4gw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  ml4gw/augmentations.py,sha256=pZH9tjEpXV0AIqvHHDkpUE-BorG02beOz2pmSipw2EY,1232
3
- ml4gw/constants.py,sha256=W9beA9RDRdIug1I2H7VLPEPv_DFsQWWoYRmzxv7FWgM,891
3
+ ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
4
4
  ml4gw/dataloading/__init__.py,sha256=EHBBqU7y2-Np5iQ_xyufxamUEM1pPEquqFo7oaJnaJE,149
5
5
  ml4gw/dataloading/chunked_dataset.py,sha256=FpDc4gFxt-PMyXs5qSWLuTGXMTuS1B-hH8gUOCOGxZk,5260
6
6
  ml4gw/dataloading/hdf5_dataset.py,sha256=UB1Eog8l7m4M78Owst7oYQZICb0DRJer9WVLVn4hl_I,6645
@@ -20,28 +20,32 @@ ml4gw/nn/resnet/resnet_2d.py,sha256=aK4I0FOZk62JxnYFz0t1O0s5s7J7yRNYSM1flRypvVc,
20
20
  ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
21
21
  ml4gw/nn/streaming/online_average.py,sha256=aI8hkT7I3thXkda9tsXxYrzump9swelSXPdSTwPlJWY,4719
22
22
  ml4gw/nn/streaming/snapshotter.py,sha256=B9qtbHxnPszAHQ5WQppWJLRuMnnYIxGk7MRUlgja7Is,4476
23
- ml4gw/spectral.py,sha256=Mt3-yz4a83z0X7M1sVp00_vB947w-9OjU0iNdEkbQcU,19145
24
- ml4gw/transforms/__init__.py,sha256=24pdP_hIg1wfrtZxxRBPhcEXsCbvVKtNKp7JL8SEogE,362
23
+ ml4gw/spectral.py,sha256=0UPgbqGay-xP-3uJ7orZCb9fSO4eVbu6JTjzZJOFqj4,19160
24
+ ml4gw/transforms/__init__.py,sha256=-DLdjD4usIi0ttSw61ZV7HieCTgHz1vTwfAlRgzbuDw,414
25
25
  ml4gw/transforms/pearson.py,sha256=Ep3mMsY15AF55taRaWNjpHRTvtr1StShUDfqk0dN-qo,3235
26
- ml4gw/transforms/qtransform.py,sha256=umBSpykfmPftjfyMqbniiP2mTh62q4hoYPA55qneJ4o,17702
27
- ml4gw/transforms/scaler.py,sha256=fLZo-m6_yFY3UDoLEaS_YgCnYggxlcKstXcM7749TiU,2433
26
+ ml4gw/transforms/qtransform.py,sha256=TWQsBeKhRoqJdkc4cPt58pKozgb_6-jZivn8u0AzQyQ,20695
27
+ ml4gw/transforms/scaler.py,sha256=souOt-hOO4M6dqPNXOspfmeU2V9622yGoIMNvju5JZI,2524
28
28
  ml4gw/transforms/snr_rescaler.py,sha256=3XXCTaXc2dzzpXRZx7iqRwImvYtRSJLM5fHdBGfpoUs,2351
29
29
  ml4gw/transforms/spectral.py,sha256=gTHUeC0gGYbzgBZHb_FxC_4zdhl5H-XCiLg1hrvKB70,4393
30
30
  ml4gw/transforms/spectrogram.py,sha256=HS3Rf5iB7JjhlSESRDdFGUwCtIBdvUaJUDulkB4Lmos,6162
31
+ ml4gw/transforms/spline_interpolation.py,sha256=GkyAVLrtZODIIDLkBdAngO9jqEHRzvEFTdxjNM7U1Bc,13526
31
32
  ml4gw/transforms/transform.py,sha256=BuzTbPFxp18OEGP9Tu9jBGtvqy3len1cqvqg5X37DiY,2512
32
33
  ml4gw/transforms/waveforms.py,sha256=LkYCvxPqYhHa2yYZTvPE6j0E4HFy16b5ndCRQb7WfcA,3196
33
34
  ml4gw/transforms/whitening.py,sha256=Aw_ogq93CYCATiHWBqSZ-qsUtaHAMA3k009ZRtQTtHA,9596
34
35
  ml4gw/types.py,sha256=CcctqDcNajR7khGT6BD-WYsfRKpiP0udoSAB0k1qcFw,863
35
36
  ml4gw/utils/interferometer.py,sha256=lRS0N3SwUTknhYXX57VACJ99jK1P9M19oUWN_i_nQN0,1814
36
37
  ml4gw/utils/slicing.py,sha256=ilRz_5sJzwmd5VyBlrj81tvyC3uCnXYjd0TO2fzFMr8,13563
37
- ml4gw/waveforms/__init__.py,sha256=dnxfRGX_B3zQPB3_3srLyjZXRxTn4miZqYIRe7PYyrU,170
38
+ ml4gw/waveforms/__init__.py,sha256=QVUzBx_y8A9_AsRuTJruPvL9mqGnBt11Iw1MOYjXyE4,40
39
+ ml4gw/waveforms/adhoc/__init__.py,sha256=XVwP4t8TMUj87WY3yMGRTkXsv7_lVr1w8p8iKBW8iKE,71
40
+ ml4gw/waveforms/adhoc/ringdown.py,sha256=m8IBQTxKBBGFqBtWGEO4KG3DEYR8TTnNyGVdVLaMKa8,3316
41
+ ml4gw/waveforms/adhoc/sine_gaussian.py,sha256=-MtrI7ydwBTk4K0O4tdkC8-w5OifQszdnWN9__I4XzY,3569
42
+ ml4gw/waveforms/cbc/__init__.py,sha256=hGbPsFNAIveYJnff8qKY8RWeBPFtZoYcnGHxraPWtWI,99
43
+ ml4gw/waveforms/cbc/phenom_d.py,sha256=vA60SjOvWSIcsU83-KEw2hnU3ATo4eW8A2mMmuMXo7Y,46941
44
+ ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9n25he9U,53447
45
+ ml4gw/waveforms/cbc/phenom_p.py,sha256=Y8L2r3UPkJeQqJNwknWBmcG_nO2Z_aXJ_DfWc_lzJhg,26720
46
+ ml4gw/waveforms/cbc/taylorf2.py,sha256=ySYLGTT_c3k4NzPDsQ9v822kzvU6TwYpELJEWlCDGQE,10428
47
+ ml4gw/waveforms/conversion.py,sha256=F5fsNeqf6KHY66opDIj8fN9bwUcwrt9f7PCaxLAi9Jk,6367
38
48
  ml4gw/waveforms/generator.py,sha256=dO6RQ96EC87p2q0tEkxA62XkkJc1xARFO1SKcGvyDhM,1272
39
- ml4gw/waveforms/phenom_d.py,sha256=vA60SjOvWSIcsU83-KEw2hnU3ATo4eW8A2mMmuMXo7Y,46941
40
- ml4gw/waveforms/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9n25he9U,53447
41
- ml4gw/waveforms/phenom_p.py,sha256=VybpPlc2_yMGywnPz5B79QAygAj-WAeHZTPiZHets28,26951
42
- ml4gw/waveforms/ringdown.py,sha256=m8IBQTxKBBGFqBtWGEO4KG3DEYR8TTnNyGVdVLaMKa8,3316
43
- ml4gw/waveforms/sine_gaussian.py,sha256=-MtrI7ydwBTk4K0O4tdkC8-w5OifQszdnWN9__I4XzY,3569
44
- ml4gw/waveforms/taylorf2.py,sha256=ySYLGTT_c3k4NzPDsQ9v822kzvU6TwYpELJEWlCDGQE,10428
45
- ml4gw-0.5.1.dist-info/METADATA,sha256=P2uoQtMX_K5SSwAzTY5tyNvWYszxaDADTS54iDOQYKw,5785
46
- ml4gw-0.5.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
47
- ml4gw-0.5.1.dist-info/RECORD,,
49
+ ml4gw-0.6.0.dist-info/METADATA,sha256=6bwcfu6ojmrxgtMnFVViy9FanSmMXjhnN33yAzViFzo,5785
50
+ ml4gw-0.6.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
51
+ ml4gw-0.6.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.0
2
+ Generator: poetry-core 1.9.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
File without changes
File without changes
File without changes