ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -0,0 +1,796 @@
1
+ from typing import Dict, Tuple
2
+
3
+ import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
6
+
7
+ from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
8
+ from ml4gw.types import BatchTensor, FrequencySeries1d
9
+
10
+ from .phenom_d import IMRPhenomD
11
+
12
+
13
+ class IMRPhenomPv2(IMRPhenomD):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(
18
+ self,
19
+ fs: FrequencySeries1d,
20
+ chirp_mass: BatchTensor,
21
+ mass_ratio: BatchTensor,
22
+ s1x: BatchTensor,
23
+ s1y: BatchTensor,
24
+ s1z: BatchTensor,
25
+ s2x: BatchTensor,
26
+ s2y: BatchTensor,
27
+ s2z: BatchTensor,
28
+ dist_mpc: BatchTensor,
29
+ tc: BatchTensor,
30
+ phiRef: BatchTensor,
31
+ incl: BatchTensor,
32
+ f_ref: float,
33
+ ):
34
+ """
35
+ IMRPhenomPv2 waveform
36
+
37
+ Args:
38
+ fs :
39
+ Frequency series in Hz.
40
+ chirp_mass :
41
+ Chirp mass in solar masses.
42
+ mass_ratio :
43
+ Mass ratio m1/m2.
44
+ s1x :
45
+ Spin component x of the first BH.
46
+ s1y :
47
+ Spin component y of the first BH.
48
+ s1z :
49
+ Spin component z of the first BH.
50
+ s2x :
51
+ Spin component x of the second BH.
52
+ s2y :
53
+ Spin component y of the second BH.
54
+ s2z :
55
+ Spin component z of the second BH.
56
+ dist_mpc :
57
+ Luminosity distance in Mpc.
58
+ tc :
59
+ Coalescence time.
60
+ phiRef :
61
+ Reference phase.
62
+ incl :
63
+ Inclination angle.
64
+ f_ref :
65
+ Reference frequency in Hz.
66
+
67
+ Returns:
68
+ hc, hp: Tuple[torch.Tensor, torch.Tensor]
69
+ Cross and plus polarizations
70
+
71
+ Note: m1 must be larger than m2.
72
+ """
73
+
74
+ m2 = chirp_mass * (1.0 + mass_ratio) ** 0.2 / mass_ratio**0.6
75
+ m1 = m2 * mass_ratio
76
+
77
+ # # flip m1 m2. For some reason LAL uses this convention for PhenomPv2
78
+ m1, m2 = m2, m1
79
+ s1x, s2x = s2x, s1x
80
+ s1y, s2y = s2y, s1y
81
+ s1z, s2z = s2z, s1z
82
+
83
+ (
84
+ chi1_l,
85
+ chi2_l,
86
+ chip,
87
+ thetaJN,
88
+ alpha0,
89
+ phi_aligned,
90
+ zeta_polariz,
91
+ ) = self.convert_spins(
92
+ m1, m2, f_ref, phiRef, incl, s1x, s1y, s1z, s2x, s2y, s2z
93
+ )
94
+
95
+ phic = 2 * phi_aligned
96
+ q = m2 / m1 # q>=1
97
+ M = m1 + m2
98
+ chi_eff = (m1 * chi1_l + m2 * chi2_l) / M
99
+ chil = (1.0 + q) / q * chi_eff
100
+ eta = m1 * m2 / (M * M)
101
+ eta2 = eta * eta
102
+ Seta = torch.sqrt(1.0 - 4.0 * eta)
103
+ chi = self.chiPN(Seta, eta, chi2_l, chi1_l)
104
+ chi22 = chi2_l * chi2_l
105
+ chi12 = chi1_l * chi1_l
106
+ xi = -1.0 + chi
107
+ m_sec = M * MTSUN_SI
108
+ piM = PI * m_sec
109
+
110
+ omega_ref = piM * f_ref
111
+ logomega_ref = torch.log(omega_ref)
112
+ omega_ref_cbrt = (piM * f_ref) ** (1 / 3) # == v0
113
+ omega_ref_cbrt2 = omega_ref_cbrt * omega_ref_cbrt
114
+
115
+ angcoeffs = self.ComputeNNLOanglecoeffs(q, chil, chip)
116
+
117
+ alphaNNLOoffset = (
118
+ angcoeffs["alphacoeff1"] / omega_ref
119
+ + angcoeffs["alphacoeff2"] / omega_ref_cbrt2
120
+ + angcoeffs["alphacoeff3"] / omega_ref_cbrt
121
+ + angcoeffs["alphacoeff4"] * logomega_ref
122
+ + angcoeffs["alphacoeff5"] * omega_ref_cbrt
123
+ )
124
+
125
+ epsilonNNLOoffset = (
126
+ angcoeffs["epsiloncoeff1"] / omega_ref
127
+ + angcoeffs["epsiloncoeff2"] / omega_ref_cbrt2
128
+ + angcoeffs["epsiloncoeff3"] / omega_ref_cbrt
129
+ + angcoeffs["epsiloncoeff4"] * logomega_ref
130
+ + angcoeffs["epsiloncoeff5"] * omega_ref_cbrt
131
+ )
132
+
133
+ Y2m2 = self.SpinWeightedY(thetaJN, 0, -2, 2, -2)
134
+ Y2m1 = self.SpinWeightedY(thetaJN, 0, -2, 2, -1)
135
+ Y20 = self.SpinWeightedY(thetaJN, 0, -2, 2, -0)
136
+ Y21 = self.SpinWeightedY(thetaJN, 0, -2, 2, 1)
137
+ Y22 = self.SpinWeightedY(thetaJN, 0, -2, 2, 2)
138
+ Y2 = torch.stack((Y2m2, Y2m1, Y20, Y21, Y22))
139
+
140
+ hPhenomDs, diffRDphase = self.PhenomPOneFrequency(
141
+ fs,
142
+ m2,
143
+ m1,
144
+ eta,
145
+ eta2,
146
+ Seta,
147
+ chi2_l,
148
+ chi1_l,
149
+ chi12,
150
+ chi22,
151
+ chip,
152
+ phic,
153
+ M,
154
+ xi,
155
+ dist_mpc,
156
+ )
157
+
158
+ hp, hc = self.PhenomPCoreTwistUp(
159
+ fs,
160
+ hPhenomDs,
161
+ eta,
162
+ chi1_l,
163
+ chi2_l,
164
+ chip,
165
+ M,
166
+ angcoeffs,
167
+ Y2,
168
+ alphaNNLOoffset - alpha0,
169
+ epsilonNNLOoffset,
170
+ )
171
+ t0 = (diffRDphase.unsqueeze(1)) / (2 * PI)
172
+ phase_corr = torch.cos(2 * PI * fs * (t0)) - 1j * torch.sin(
173
+ 2 * PI * fs * (t0)
174
+ )
175
+ M_s = (m1 + m2) * MTSUN_SI
176
+ phase_corr_tc = torch.exp(
177
+ -1j * fs * M_s.unsqueeze(1) * tc.unsqueeze(1)
178
+ )
179
+ hp *= phase_corr * phase_corr_tc
180
+ hc *= phase_corr * phase_corr_tc
181
+
182
+ c2z = torch.cos(2 * zeta_polariz).unsqueeze(1)
183
+ s2z = torch.sin(2 * zeta_polariz).unsqueeze(1)
184
+ hplus = c2z * hp + s2z * hc
185
+ hcross = c2z * hc - s2z * hp
186
+ return hcross, hplus
187
+
188
+ def PhenomPCoreTwistUp(
189
+ self,
190
+ fHz: FrequencySeries1d,
191
+ hPhenom: BatchTensor,
192
+ eta: BatchTensor,
193
+ chi1_l: BatchTensor,
194
+ chi2_l: BatchTensor,
195
+ chip: BatchTensor,
196
+ M: BatchTensor,
197
+ angcoeffs: Dict[str, BatchTensor],
198
+ Y2m: BatchTensor,
199
+ alphaoffset: BatchTensor,
200
+ epsilonoffset: BatchTensor,
201
+ ) -> Tuple[BatchTensor, BatchTensor]:
202
+ assert angcoeffs is not None
203
+ assert Y2m is not None
204
+ f = fHz * MTSUN_SI * M.unsqueeze(1) # Frequency in geometric units
205
+ q = (1.0 + torch.sqrt(1.0 - 4.0 * eta) - 2.0 * eta) / (2.0 * eta)
206
+ m1 = 1.0 / (1.0 + q) # Mass of the smaller BH for unit total mass M=1.
207
+ m2 = q / (1.0 + q) # Mass of the larger BH for unit total mass M=1.
208
+ Sperp = chip * (
209
+ m2 * m2
210
+ ) # Dimensionfull spin component in the orbital plane.
211
+ # S_perp = S_2_perp chi_eff = m1 * chi1_l + m2 * chi2_l
212
+ # effective spin for M=1
213
+
214
+ SL = chi1_l * m1 * m1 + chi2_l * m2 * m2 # Dimensionfull aligned spin.
215
+
216
+ omega = PI * f
217
+ logomega = torch.log(omega)
218
+ omega_cbrt = (omega) ** (1 / 3)
219
+ omega_cbrt2 = omega_cbrt * omega_cbrt
220
+ alpha = (
221
+ (
222
+ angcoeffs["alphacoeff1"] / omega.mT
223
+ + angcoeffs["alphacoeff2"] / omega_cbrt2.mT
224
+ + angcoeffs["alphacoeff3"] / omega_cbrt.mT
225
+ + angcoeffs["alphacoeff4"] * logomega.mT
226
+ + angcoeffs["alphacoeff5"] * omega_cbrt.mT
227
+ )
228
+ - alphaoffset
229
+ ).mT
230
+
231
+ epsilon = (
232
+ (
233
+ angcoeffs["epsiloncoeff1"] / omega.mT
234
+ + angcoeffs["epsiloncoeff2"] / omega_cbrt2.mT
235
+ + angcoeffs["epsiloncoeff3"] / omega_cbrt.mT
236
+ + angcoeffs["epsiloncoeff4"] * logomega.mT
237
+ + angcoeffs["epsiloncoeff5"] * omega_cbrt.mT
238
+ )
239
+ - epsilonoffset
240
+ ).mT
241
+
242
+ cBetah, sBetah = self.WignerdCoefficients(
243
+ omega_cbrt.mT, SL, eta, Sperp
244
+ )
245
+
246
+ cBetah2 = cBetah * cBetah
247
+ cBetah3 = cBetah2 * cBetah
248
+ cBetah4 = cBetah3 * cBetah
249
+ sBetah2 = sBetah * sBetah
250
+ sBetah3 = sBetah2 * sBetah
251
+ sBetah4 = sBetah3 * sBetah
252
+
253
+ hp_sum = 0
254
+ hc_sum = 0
255
+
256
+ cexp_i_alpha = torch.exp(1j * alpha)
257
+ cexp_2i_alpha = cexp_i_alpha * cexp_i_alpha
258
+ cexp_mi_alpha = 1.0 / cexp_i_alpha
259
+ cexp_m2i_alpha = cexp_mi_alpha * cexp_mi_alpha
260
+ T2m = (
261
+ cexp_2i_alpha.mT * cBetah4.mT * Y2m[0]
262
+ - cexp_i_alpha.mT * 2 * cBetah3.mT * sBetah.mT * Y2m[1]
263
+ + 1
264
+ * torch.sqrt(torch.tensor(6))
265
+ * sBetah2.mT
266
+ * cBetah2.mT
267
+ * Y2m[2]
268
+ - cexp_mi_alpha.mT * 2 * cBetah.mT * sBetah3.mT * Y2m[3]
269
+ + cexp_m2i_alpha.mT * sBetah4.mT * Y2m[4]
270
+ ).mT
271
+ Tm2m = (
272
+ cexp_m2i_alpha.mT * sBetah4.mT * torch.conj(Y2m[0])
273
+ + cexp_mi_alpha.mT
274
+ * 2
275
+ * cBetah.mT
276
+ * sBetah3.mT
277
+ * torch.conj(Y2m[1])
278
+ + 1
279
+ * torch.sqrt(torch.tensor(6))
280
+ * sBetah2.mT
281
+ * cBetah2.mT
282
+ * torch.conj(Y2m[2])
283
+ + cexp_i_alpha.mT * 2 * cBetah3.mT * sBetah.mT * torch.conj(Y2m[3])
284
+ + cexp_2i_alpha.mT * cBetah4.mT * torch.conj(Y2m[4])
285
+ ).mT
286
+ hp_sum = T2m + Tm2m
287
+ hc_sum = 1j * (T2m - Tm2m)
288
+
289
+ eps_phase_hP = torch.exp(-2j * epsilon) * hPhenom / 2.0
290
+
291
+ hp = eps_phase_hP * hp_sum
292
+ hc = eps_phase_hP * hc_sum
293
+
294
+ return hp, hc
295
+
296
+ def PhenomPOneFrequency(
297
+ self,
298
+ fs,
299
+ m1,
300
+ m2,
301
+ eta,
302
+ eta2,
303
+ Seta,
304
+ chi1,
305
+ chi2,
306
+ chi12,
307
+ chi22,
308
+ chip,
309
+ phic,
310
+ M,
311
+ xi,
312
+ dist_mpc,
313
+ ):
314
+ """
315
+ m1, m2: in solar masses
316
+ phic: Orbital phase at peak of the underlying non precessing model
317
+ M: Total mass (Solar masses)
318
+ """
319
+
320
+ M_s = M * MTSUN_SI
321
+ Mf = torch.outer(M_s, fs)
322
+ fRD, _ = self.phP_get_fRD_fdamp(m1, m2, chi1, chi2, chip)
323
+
324
+ phase, _ = self.phenom_d_phase(Mf, m1, m2, eta, eta2, chi1, chi2, xi)
325
+ phase = (phase.mT - (phic + PI / 4.0)).mT
326
+ Amp = self.phenom_d_amp(
327
+ Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, dist_mpc
328
+ )[0]
329
+ Amp0 = self.get_Amp0(Mf, eta)
330
+ dist_s = dist_mpc * MPC_SEC
331
+ Amp = ((Amp0 * Amp).mT * (M_s**2.0) / dist_s).mT
332
+ # phase -= 2. * phic; # line 1316 ???
333
+ hPhenom = Amp * (torch.exp(-1j * phase))
334
+
335
+ fRDs = torch.outer(
336
+ fRD, torch.linspace(0.5, 1.5, 101, device=fRD.device)
337
+ )
338
+ delta_fRds = torch.median(torch.diff(fRDs, axis=1), axis=1)[0]
339
+ MfRDs = torch.zeros_like(fRDs)
340
+ for i in range(fRD.shape[0]):
341
+ MfRDs[i, :] = torch.outer(M_s, fRDs[i, :])[i, :]
342
+ RD_phase = self.phenom_d_phase(
343
+ MfRDs, m1, m2, eta, eta2, chi1, chi2, xi
344
+ )[0]
345
+ diff = torch.diff(RD_phase, axis=1)
346
+ diffRDphase = (diff[:, 1:] + diff[:, :-1]) / (
347
+ 2 * delta_fRds.unsqueeze(1)
348
+ )
349
+ diffRDphase = -diffRDphase[:, 50]
350
+ # MfRD = torch.outer(M_s, fRD)
351
+ # Dphase = torch.diag(
352
+ # -self.phenom_d_phase(
353
+ # MfRD, m1, m2, eta, eta2, chi1, chi2, xi)[1] * M_s
354
+ # ).view(-1, 1)
355
+ return hPhenom, diffRDphase
356
+
357
+ # Utility functions
358
+
359
+ def interpolate(
360
+ self,
361
+ x: Float[Tensor, " new_series"],
362
+ xp: Float[Tensor, " series"],
363
+ fp: Float[Tensor, " series"],
364
+ ) -> Float[Tensor, " new_series"]:
365
+ """One-dimensional linear interpolation for monotonically
366
+ increasing sample points.
367
+
368
+ Returns the one-dimensional piecewise linear interpolant to a function
369
+ with given data points :math:`(xp, fp)`, evaluated at :math:`x`
370
+
371
+ Args:
372
+ x: the :math:`x`-coordinates at which to evaluate the interpolated
373
+ values.
374
+ xp: the :math:`x`-coordinates of data points, must be increasing.
375
+ fp: the :math:`y`-coordinates of data points, same length as `xp`.
376
+
377
+ Returns:
378
+ the interpolated values, same size as `x`.
379
+ """
380
+ original_shape = x.shape
381
+ x = x.flatten()
382
+ xp = xp.flatten()
383
+ fp = fp.flatten()
384
+
385
+ m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) # slope
386
+ b = fp[:-1] - (m * xp[:-1])
387
+
388
+ indices = torch.searchsorted(xp, x, right=False) - 1
389
+
390
+ interpolated = m[indices] * x + b[indices]
391
+
392
+ return interpolated.reshape(original_shape)
393
+
394
+ def ROTATEZ(self, angle: BatchTensor, x, y, z):
395
+ tmp_x = x * torch.cos(angle) - y * torch.sin(angle)
396
+ tmp_y = x * torch.sin(angle) + y * torch.cos(angle)
397
+ return tmp_x, tmp_y, z
398
+
399
+ def ROTATEY(self, angle, x, y, z):
400
+ tmp_x = x * torch.cos(angle) + z * torch.sin(angle)
401
+ tmp_z = -x * torch.sin(angle) + z * torch.cos(angle)
402
+ return tmp_x, y, tmp_z
403
+
404
+ def L2PNR(
405
+ self,
406
+ v: BatchTensor,
407
+ eta: BatchTensor,
408
+ ) -> BatchTensor:
409
+ eta2 = eta**2
410
+ x = v**2
411
+ x2 = x**2
412
+ tmp = (
413
+ eta
414
+ * (
415
+ 1.0
416
+ + (1.5 + eta / 6.0) * x
417
+ + (3.375 - (19.0 * eta) / 8.0 - eta2 / 24.0) * x2
418
+ )
419
+ ) / x**0.5
420
+
421
+ return tmp
422
+
423
+ def convert_spins(
424
+ self,
425
+ m1: BatchTensor,
426
+ m2: BatchTensor,
427
+ f_ref: float,
428
+ phiRef: BatchTensor,
429
+ incl: BatchTensor,
430
+ s1x: BatchTensor,
431
+ s1y: BatchTensor,
432
+ s1z: BatchTensor,
433
+ s2x: BatchTensor,
434
+ s2y: BatchTensor,
435
+ s2z: BatchTensor,
436
+ ) -> Tuple[
437
+ BatchTensor,
438
+ BatchTensor,
439
+ BatchTensor,
440
+ BatchTensor,
441
+ BatchTensor,
442
+ BatchTensor,
443
+ BatchTensor,
444
+ ]:
445
+ M = m1 + m2
446
+ m1_2 = m1 * m1
447
+ m2_2 = m2 * m2
448
+ eta = m1 * m2 / (M * M) # Symmetric mass-ratio
449
+
450
+ # From the components in the source frame, we can easily determine
451
+ # chi1_l, chi2_l, chip and phi_aligned, which we need to return.
452
+ # We also compute the spherical angles of J,
453
+ # which we need to transform to the J frame
454
+
455
+ # Aligned spins
456
+ chi1_l = s1z # Dimensionless aligned spin on BH 1
457
+ chi2_l = s2z # Dimensionless aligned spin on BH 2
458
+
459
+ # Magnitude of the spin projections in the orbital plane
460
+ S1_perp = m1_2 * torch.sqrt(s1x**2 + s1y**2)
461
+ S2_perp = m2_2 * torch.sqrt(s2x**2 + s2y**2)
462
+
463
+ A1 = 2 + (3 * m2) / (2 * m1)
464
+ A2 = 2 + (3 * m1) / (2 * m2)
465
+ ASp1 = A1 * S1_perp
466
+ ASp2 = A2 * S2_perp
467
+ num = torch.maximum(ASp1, ASp2)
468
+ den = A2 * m2_2 # warning: this assumes m2 > m1
469
+ chip = num / den
470
+
471
+ m_sec = M * MTSUN_SI
472
+ piM = PI * m_sec
473
+ v_ref = (piM * f_ref) ** (1 / 3)
474
+ L0 = M * M * self.L2PNR(v_ref, eta)
475
+ J0x_sf = m1_2 * s1x + m2_2 * s2x
476
+ J0y_sf = m1_2 * s1y + m2_2 * s2y
477
+ J0z_sf = L0 + m1_2 * s1z + m2_2 * s2z
478
+ J0 = torch.sqrt(J0x_sf * J0x_sf + J0y_sf * J0y_sf + J0z_sf * J0z_sf)
479
+
480
+ thetaJ_sf = torch.arccos(J0z_sf / J0)
481
+
482
+ phiJ_sf = torch.arctan2(J0y_sf, J0x_sf)
483
+
484
+ phi_aligned = -phiJ_sf
485
+
486
+ # First we determine kappa
487
+ # in the source frame, the components of N are given in
488
+ # Eq (35c) of T1500606-v6
489
+ Nx_sf = torch.sin(incl) * torch.cos(PI / 2.0 - phiRef)
490
+ Ny_sf = torch.sin(incl) * torch.sin(PI / 2.0 - phiRef)
491
+ Nz_sf = torch.cos(incl)
492
+
493
+ tmp_x = Nx_sf
494
+ tmp_y = Ny_sf
495
+ tmp_z = Nz_sf
496
+
497
+ tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
498
+ tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
499
+
500
+ kappa = -torch.arctan2(tmp_y, tmp_x)
501
+
502
+ # Then we determine alpha0, by rotating LN
503
+ tmp_x, tmp_y, tmp_z = 0, 0, 1
504
+ tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
505
+ tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
506
+ tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
507
+
508
+ alpha0 = torch.arctan2(tmp_y, tmp_x)
509
+
510
+ # Finally we determine thetaJ, by rotating N
511
+ tmp_x, tmp_y, tmp_z = Nx_sf, Ny_sf, Nz_sf
512
+ tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
513
+ tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
514
+ tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
515
+ Nx_Jf, Nz_Jf = tmp_x, tmp_z
516
+ thetaJN = torch.arccos(Nz_Jf)
517
+
518
+ # Finally, we need to redefine the polarizations:
519
+ # PhenomP's polarizations are defined following Arun et al
520
+ # (arXiv:0810.5336)
521
+ # i.e. projecting the metric onto the P,Q,N triad defined with
522
+ # P=NxJ/|NxJ| (see (2.6) in there).
523
+ # By contrast, the triad X,Y,N used in LAL
524
+ # ("waveframe" in the nomenclature of T1500606-v6)
525
+ # is defined in e.g. eq (35) of this document
526
+ # (via its components in the source frame;
527
+ # note we use the default Omega=Pi/2).
528
+ # Both triads differ from each other by a rotation around N by an angle
529
+ # \zeta and we need to rotate the polarizations accordingly by 2\zeta
530
+
531
+ Xx_sf = -torch.cos(incl) * torch.sin(phiRef)
532
+ Xy_sf = -torch.cos(incl) * torch.cos(phiRef)
533
+ Xz_sf = torch.sin(incl)
534
+ tmp_x, tmp_y, tmp_z = Xx_sf, Xy_sf, Xz_sf
535
+ tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
536
+ tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
537
+ tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
538
+
539
+ # Now the tmp_a are the components of X in the J frame
540
+ # We need the polar angle of that vector in the P,Q basis of Arun et al
541
+ # P = NxJ/|NxJ| and since we put N in the (pos x)z half plane of the J
542
+ # frame
543
+ PArunx_Jf = 0.0
544
+ PAruny_Jf = -1.0
545
+ PArunz_Jf = 0.0
546
+
547
+ # Q = NxP
548
+ QArunx_Jf = Nz_Jf
549
+ QAruny_Jf = 0.0
550
+ QArunz_Jf = -Nx_Jf
551
+
552
+ # Calculate the dot products XdotPArun and XdotQArun
553
+ XdotPArun = tmp_x * PArunx_Jf + tmp_y * PAruny_Jf + tmp_z * PArunz_Jf
554
+ XdotQArun = tmp_x * QArunx_Jf + tmp_y * QAruny_Jf + tmp_z * QArunz_Jf
555
+
556
+ zeta_polariz = torch.arctan2(XdotQArun, XdotPArun)
557
+ return chi1_l, chi2_l, chip, thetaJN, alpha0, phi_aligned, zeta_polariz
558
+
559
+ # TODO: add input and output types
560
+ def SpinWeightedY(self, theta, phi, s, l, m): # noqa: E741
561
+ "copied from SphericalHarmonics.c in LAL"
562
+ if s == -2:
563
+ if l == 2: # noqa: E741
564
+ if m == -2:
565
+ fac = (
566
+ torch.sqrt(torch.tensor(5.0 / (64.0 * PI)))
567
+ * (1.0 - torch.cos(theta))
568
+ * (1.0 - torch.cos(theta))
569
+ )
570
+ elif m == -1:
571
+ fac = (
572
+ torch.sqrt(torch.tensor(5.0 / (16.0 * PI)))
573
+ * torch.sin(theta)
574
+ * (1.0 - torch.cos(theta))
575
+ )
576
+ elif m == 0:
577
+ fac = (
578
+ torch.sqrt(torch.tensor(15.0 / (32.0 * PI)))
579
+ * torch.sin(theta)
580
+ * torch.sin(theta)
581
+ )
582
+ elif m == 1:
583
+ fac = (
584
+ torch.sqrt(torch.tensor(5.0 / (16.0 * PI)))
585
+ * torch.sin(theta)
586
+ * (1.0 + torch.cos(theta))
587
+ )
588
+ elif m == 2:
589
+ fac = (
590
+ torch.sqrt(torch.tensor(5.0 / (64.0 * PI)))
591
+ * (1.0 + torch.cos(theta))
592
+ * (1.0 + torch.cos(theta))
593
+ )
594
+ else:
595
+ raise ValueError(
596
+ f"Invalid mode s={s}, l={l}, m={m} - require |m| <= l"
597
+ )
598
+ return fac * torch.complex(
599
+ torch.cos(torch.tensor(m * phi)),
600
+ torch.sin(torch.tensor(m * phi)),
601
+ )
602
+
603
+ def WignerdCoefficients(
604
+ self,
605
+ v: BatchTensor,
606
+ SL: BatchTensor,
607
+ eta: BatchTensor,
608
+ Sp: BatchTensor,
609
+ ) -> Tuple[BatchTensor, BatchTensor]:
610
+ # We define the shorthand s := Sp / (L + SL)
611
+ L = self.L2PNR(v, eta)
612
+ s = (Sp / (L + SL)).mT
613
+ s2 = s**2
614
+ cos_beta = 1.0 / (1.0 + s2) ** 0.5
615
+ cos_beta_half = ((1.0 + cos_beta) / 2.0) ** 0.5 # cos(beta/2)
616
+ sin_beta_half = ((1.0 - cos_beta) / 2.0) ** 0.5 # sin(beta/2)
617
+
618
+ return cos_beta_half, sin_beta_half
619
+
620
+ def ComputeNNLOanglecoeffs(
621
+ self,
622
+ q: BatchTensor,
623
+ chil: BatchTensor,
624
+ chip: BatchTensor,
625
+ ) -> Dict[str, BatchTensor]:
626
+ m2 = q / (1.0 + q)
627
+ m1 = 1.0 / (1.0 + q)
628
+ dm = m1 - m2
629
+ mtot = 1.0
630
+ eta = m1 * m2 # mtot = 1
631
+ eta2 = eta * eta
632
+ eta3 = eta2 * eta
633
+ eta4 = eta3 * eta
634
+ mtot2 = mtot * mtot
635
+ mtot4 = mtot2 * mtot2
636
+ mtot6 = mtot4 * mtot2
637
+ mtot8 = mtot6 * mtot2
638
+ chil2 = chil * chil
639
+ chip2 = chip * chip
640
+ chip4 = chip2 * chip2
641
+ dm2 = dm * dm
642
+ dm3 = dm2 * dm
643
+ m2_2 = m2 * m2
644
+ m2_3 = m2_2 * m2
645
+ m2_4 = m2_3 * m2
646
+ m2_5 = m2_4 * m2
647
+ m2_6 = m2_5 * m2
648
+ m2_7 = m2_6 * m2
649
+ m2_8 = m2_7 * m2
650
+
651
+ angcoeffs = {}
652
+ angcoeffs["alphacoeff1"] = -0.18229166666666666 - (5 * dm) / (
653
+ 64.0 * m2
654
+ )
655
+
656
+ angcoeffs["alphacoeff2"] = (-15 * dm * m2 * chil) / (
657
+ 128.0 * mtot2 * eta
658
+ ) - (35 * m2_2 * chil) / (128.0 * mtot2 * eta)
659
+
660
+ angcoeffs["alphacoeff3"] = (
661
+ -1.7952473958333333
662
+ - (4555 * dm) / (7168.0 * m2)
663
+ - (15 * chip2 * dm * m2_3) / (128.0 * mtot4 * eta2)
664
+ - (35 * chip2 * m2_4) / (128.0 * mtot4 * eta2)
665
+ - (515 * eta) / 384.0
666
+ - (15 * dm2 * eta) / (256.0 * m2_2)
667
+ - (175 * dm * eta) / (256.0 * m2)
668
+ )
669
+
670
+ angcoeffs["alphacoeff4"] = (
671
+ -(35 * PI) / 48.0
672
+ - (5 * dm * PI) / (16.0 * m2)
673
+ + (5 * dm2 * chil) / (16.0 * mtot2)
674
+ + (5 * dm * m2 * chil) / (3.0 * mtot2)
675
+ + (2545 * m2_2 * chil) / (1152.0 * mtot2)
676
+ - (5 * chip2 * dm * m2_5 * chil) / (128.0 * mtot6 * eta3)
677
+ - (35 * chip2 * m2_6 * chil) / (384.0 * mtot6 * eta3)
678
+ + (2035 * dm * m2 * chil) / (21504.0 * mtot2 * eta)
679
+ + (2995 * m2_2 * chil) / (9216.0 * mtot2 * eta)
680
+ )
681
+
682
+ angcoeffs["alphacoeff5"] = (
683
+ 4.318908476114694
684
+ + (27895885 * dm) / (2.1676032e7 * m2)
685
+ - (15 * chip4 * dm * m2_7) / (512.0 * mtot8 * eta4)
686
+ - (35 * chip4 * m2_8) / (512.0 * mtot8 * eta4)
687
+ - (485 * chip2 * dm * m2_3) / (14336.0 * mtot4 * eta2)
688
+ + (475 * chip2 * m2_4) / (6144.0 * mtot4 * eta2)
689
+ + (15 * chip2 * dm2 * m2_2) / (256.0 * mtot4 * eta)
690
+ + (145 * chip2 * dm * m2_3) / (512.0 * mtot4 * eta)
691
+ + (575 * chip2 * m2_4) / (1536.0 * mtot4 * eta)
692
+ + (39695 * eta) / 86016.0
693
+ + (1615 * dm2 * eta) / (28672.0 * m2_2)
694
+ - (265 * dm * eta) / (14336.0 * m2)
695
+ + (955 * eta2) / 576.0
696
+ + (15 * dm3 * eta2) / (1024.0 * m2_3)
697
+ + (35 * dm2 * eta2) / (256.0 * m2_2)
698
+ + (2725 * dm * eta2) / (3072.0 * m2)
699
+ - (15 * dm * m2 * PI * chil) / (16.0 * mtot2 * eta)
700
+ - (35 * m2_2 * PI * chil) / (16.0 * mtot2 * eta)
701
+ + (15 * chip2 * dm * m2_7 * chil2) / (128.0 * mtot8 * eta4)
702
+ + (35 * chip2 * m2_8 * chil2) / (128.0 * mtot8 * eta4)
703
+ + (375 * dm2 * m2_2 * chil2) / (256.0 * mtot4 * eta)
704
+ + (1815 * dm * m2_3 * chil2) / (256.0 * mtot4 * eta)
705
+ + (1645 * m2_4 * chil2) / (192.0 * mtot4 * eta)
706
+ )
707
+
708
+ angcoeffs["epsiloncoeff1"] = -0.18229166666666666 - (5 * dm) / (
709
+ 64.0 * m2
710
+ )
711
+ angcoeffs["epsiloncoeff2"] = (-15 * dm * m2 * chil) / (
712
+ 128.0 * mtot2 * eta
713
+ ) - (35 * m2_2 * chil) / (128.0 * mtot2 * eta)
714
+ angcoeffs["epsiloncoeff3"] = (
715
+ -1.7952473958333333
716
+ - (4555 * dm) / (7168.0 * m2)
717
+ - (515 * eta) / 384.0
718
+ - (15 * dm2 * eta) / (256.0 * m2_2)
719
+ - (175 * dm * eta) / (256.0 * m2)
720
+ )
721
+ angcoeffs["epsiloncoeff4"] = (
722
+ -(35 * PI) / 48.0
723
+ - (5 * dm * PI) / (16.0 * m2)
724
+ + (5 * dm2 * chil) / (16.0 * mtot2)
725
+ + (5 * dm * m2 * chil) / (3.0 * mtot2)
726
+ + (2545 * m2_2 * chil) / (1152.0 * mtot2)
727
+ + (2035 * dm * m2 * chil) / (21504.0 * mtot2 * eta)
728
+ + (2995 * m2_2 * chil) / (9216.0 * mtot2 * eta)
729
+ )
730
+ angcoeffs["epsiloncoeff5"] = (
731
+ 4.318908476114694
732
+ + (27895885 * dm) / (2.1676032e7 * m2)
733
+ + (39695 * eta) / 86016.0
734
+ + (1615 * dm2 * eta) / (28672.0 * m2_2)
735
+ - (265 * dm * eta) / (14336.0 * m2)
736
+ + (955 * eta2) / 576.0
737
+ + (15 * dm3 * eta2) / (1024.0 * m2_3)
738
+ + (35 * dm2 * eta2) / (256.0 * m2_2)
739
+ + (2725 * dm * eta2) / (3072.0 * m2)
740
+ - (15 * dm * m2 * PI * chil) / (16.0 * mtot2 * eta)
741
+ - (35 * m2_2 * PI * chil) / (16.0 * mtot2 * eta)
742
+ + (375 * dm2 * m2_2 * chil2) / (256.0 * mtot4 * eta)
743
+ + (1815 * dm * m2_3 * chil2) / (256.0 * mtot4 * eta)
744
+ + (1645 * m2_4 * chil2) / (192.0 * mtot4 * eta)
745
+ )
746
+ return angcoeffs
747
+
748
+ def FinalSpin_inplane(
749
+ self,
750
+ m1: BatchTensor,
751
+ m2: BatchTensor,
752
+ chi1_l: BatchTensor,
753
+ chi2_l: BatchTensor,
754
+ chip: BatchTensor,
755
+ ) -> BatchTensor:
756
+ M = m1 + m2
757
+ eta = m1 * m2 / (M * M)
758
+ eta2 = eta * eta
759
+ # m1 > m2, the convention used in phenomD
760
+ # (not the convention of internal phenomP)
761
+ mass_ratio = m1 / m2
762
+ af_parallel = self.FinalSpin0815(eta, eta2, chi1_l, chi2_l)
763
+ Sperp = chip * mass_ratio * mass_ratio
764
+ af = torch.copysign(
765
+ torch.ones_like(af_parallel), af_parallel
766
+ ) * torch.sqrt(Sperp * Sperp + af_parallel * af_parallel)
767
+ return af
768
+
769
+ def phP_get_fRD_fdamp(
770
+ self, m1, m2, chi1_l, chi2_l, chip
771
+ ) -> Tuple[BatchTensor, BatchTensor]:
772
+ # m1 > m2 should hold here
773
+ finspin = self.FinalSpin_inplane(m1, m2, chi1_l, chi2_l, chip)
774
+ m1_s = m1 * MTSUN_SI
775
+ m2_s = m2 * MTSUN_SI
776
+ M_s = m1_s + m2_s
777
+ eta_s = m1_s * m2_s / (M_s**2.0)
778
+ eta_s2 = eta_s * eta_s
779
+ Erad = self.PhenomInternal_EradRational0815(
780
+ eta_s, eta_s2, chi1_l, chi2_l
781
+ )
782
+ fRD = self.interpolate(finspin, self.qnmdata_a, self.qnmdata_fring) / (
783
+ 1.0 - Erad
784
+ )
785
+ fdamp = self.interpolate(
786
+ finspin, self.qnmdata_a, self.qnmdata_fdamp
787
+ ) / (1.0 - Erad)
788
+ return fRD / M_s, fdamp / M_s
789
+
790
+ def get_Amp0(self, fM_s: BatchTensor, eta: BatchTensor) -> BatchTensor:
791
+ Amp0 = (
792
+ (2.0 / 3.0 * eta.unsqueeze(1)) ** (1.0 / 2.0)
793
+ * (fM_s) ** (-7.0 / 6.0)
794
+ * PI ** (-1.0 / 6.0)
795
+ )
796
+ return Amp0