ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +8 -2
- ml4gw/constants.py +45 -0
- ml4gw/dataloading/chunked_dataset.py +4 -2
- ml4gw/dataloading/hdf5_dataset.py +1 -1
- ml4gw/dataloading/in_memory_dataset.py +8 -4
- ml4gw/distributions.py +18 -12
- ml4gw/gw.py +21 -27
- ml4gw/nn/autoencoder/base.py +11 -6
- ml4gw/nn/autoencoder/convolutional.py +7 -4
- ml4gw/nn/autoencoder/skip_connection.py +7 -6
- ml4gw/nn/autoencoder/utils.py +2 -1
- ml4gw/nn/norm.py +11 -1
- ml4gw/nn/streaming/online_average.py +7 -5
- ml4gw/nn/streaming/snapshotter.py +7 -5
- ml4gw/spectral.py +40 -36
- ml4gw/transforms/pearson.py +7 -3
- ml4gw/transforms/qtransform.py +20 -14
- ml4gw/transforms/scaler.py +6 -2
- ml4gw/transforms/snr_rescaler.py +6 -5
- ml4gw/transforms/spectral.py +25 -6
- ml4gw/transforms/spectrogram.py +7 -1
- ml4gw/transforms/transform.py +4 -3
- ml4gw/transforms/waveforms.py +10 -7
- ml4gw/transforms/whitening.py +12 -4
- ml4gw/types.py +25 -10
- ml4gw/utils/interferometer.py +7 -1
- ml4gw/utils/slicing.py +24 -16
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/generator.py +9 -5
- ml4gw/waveforms/phenom_d.py +1338 -1256
- ml4gw/waveforms/phenom_p.py +796 -0
- ml4gw/waveforms/ringdown.py +109 -0
- ml4gw/waveforms/sine_gaussian.py +10 -11
- ml4gw/waveforms/taylorf2.py +304 -279
- {ml4gw-0.4.2.dist-info → ml4gw-0.5.1.dist-info}/METADATA +5 -3
- ml4gw-0.5.1.dist-info/RECORD +47 -0
- ml4gw-0.4.2.dist-info/RECORD +0 -44
- {ml4gw-0.4.2.dist-info → ml4gw-0.5.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ml4gw.constants import PI, C, G, m_per_Mpc
|
|
5
|
+
from ml4gw.types import BatchTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Ringdown(torch.nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Callable class for generating ringdown waveforms.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
sample_rate: Sample rate of waveform
|
|
14
|
+
duration: Duration of waveform
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, sample_rate: float, duration: float):
|
|
18
|
+
super().__init__()
|
|
19
|
+
# determine times based on requested duration and sample rate
|
|
20
|
+
# and shift so that the waveform is centered at t=0
|
|
21
|
+
|
|
22
|
+
num = int(duration * sample_rate)
|
|
23
|
+
times = torch.arange(num, dtype=torch.float64) / sample_rate
|
|
24
|
+
|
|
25
|
+
self.register_buffer("times", times)
|
|
26
|
+
|
|
27
|
+
def forward(
|
|
28
|
+
self,
|
|
29
|
+
frequency: BatchTensor,
|
|
30
|
+
quality: BatchTensor,
|
|
31
|
+
epsilon: BatchTensor,
|
|
32
|
+
phase: BatchTensor,
|
|
33
|
+
inclination: BatchTensor,
|
|
34
|
+
distance: BatchTensor,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Generate ringdown waveform based on the damped sinusoid equation.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
frequency:
|
|
41
|
+
Central frequency of the ringdown waveform in Hz
|
|
42
|
+
quality:
|
|
43
|
+
Quality factor of the ringdown waveform
|
|
44
|
+
epsilon:
|
|
45
|
+
Fraction of black hole's mass radiated as gravitational waves
|
|
46
|
+
phase:
|
|
47
|
+
Initial phase of the ringdown waveform in rad
|
|
48
|
+
inclination:
|
|
49
|
+
Inclination angle of the source in rad
|
|
50
|
+
distance:
|
|
51
|
+
Distance to the source in Mpc
|
|
52
|
+
Returns:
|
|
53
|
+
Tensors of cross and plus polarizations
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# add dimension for calculating waveforms in batch
|
|
57
|
+
frequency = frequency.view(-1, 1)
|
|
58
|
+
quality = quality.view(-1, 1)
|
|
59
|
+
epsilon = epsilon.view(-1, 1)
|
|
60
|
+
phase = phase.view(-1, 1)
|
|
61
|
+
inclination = inclination.view(-1, 1)
|
|
62
|
+
distance = distance.view(-1, 1)
|
|
63
|
+
|
|
64
|
+
# convert Mpc to m
|
|
65
|
+
distance = distance * m_per_Mpc
|
|
66
|
+
|
|
67
|
+
# ensure all inputs are on the same device
|
|
68
|
+
pi = torch.tensor([PI], device=frequency.device)
|
|
69
|
+
|
|
70
|
+
# Calculate spin and mass
|
|
71
|
+
spin = 1 - (2 / quality) ** (20 / 9)
|
|
72
|
+
mass = (
|
|
73
|
+
(1 / (2 * pi))
|
|
74
|
+
* (C**3 / (G * frequency))
|
|
75
|
+
* (1 - 0.63 * (2 / quality) ** (2 / 3))
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Calculate amplitude
|
|
79
|
+
F_Q = 1 + ((7 / 24) / quality**2)
|
|
80
|
+
g_a = 1 - 0.63 * (1 - spin) ** (3 / 10)
|
|
81
|
+
amplitude = (
|
|
82
|
+
np.sqrt(5 / 2)
|
|
83
|
+
* epsilon
|
|
84
|
+
* (G * mass / (C) ** 2)
|
|
85
|
+
* quality ** (-0.5)
|
|
86
|
+
* F_Q ** (-0.5)
|
|
87
|
+
* g_a ** (-0.5)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# calculate cosines with inclination
|
|
91
|
+
cos_i = torch.cos(inclination)
|
|
92
|
+
cos_i2 = cos_i**2
|
|
93
|
+
sin_i = torch.sin(inclination)
|
|
94
|
+
|
|
95
|
+
# Precompute exponent and phase terms
|
|
96
|
+
exp_term = torch.exp(-pi * frequency * self.times / quality)
|
|
97
|
+
phase_term = 2 * pi * frequency * self.times + phase
|
|
98
|
+
|
|
99
|
+
a_plus = (amplitude / distance) * (1 + cos_i2) * exp_term
|
|
100
|
+
a_cross = (amplitude / distance) * (2 * sin_i) * exp_term
|
|
101
|
+
|
|
102
|
+
h_plus = a_plus * torch.cos(phase_term)
|
|
103
|
+
h_cross = a_cross * torch.sin(phase_term)
|
|
104
|
+
|
|
105
|
+
# ensure the dtype is double
|
|
106
|
+
h_plus = h_plus.double()
|
|
107
|
+
h_cross = h_cross.double()
|
|
108
|
+
|
|
109
|
+
return h_cross, h_plus
|
ml4gw/waveforms/sine_gaussian.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
|
-
from ml4gw.types import
|
|
4
|
+
from ml4gw.types import BatchTensor
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def semi_major_minor_from_e(e: Tensor):
|
|
@@ -30,13 +30,13 @@ class SineGaussian(torch.nn.Module):
|
|
|
30
30
|
|
|
31
31
|
self.register_buffer("times", times)
|
|
32
32
|
|
|
33
|
-
def
|
|
33
|
+
def forward(
|
|
34
34
|
self,
|
|
35
|
-
quality:
|
|
36
|
-
frequency:
|
|
37
|
-
hrss:
|
|
38
|
-
phase:
|
|
39
|
-
eccentricity:
|
|
35
|
+
quality: BatchTensor,
|
|
36
|
+
frequency: BatchTensor,
|
|
37
|
+
hrss: BatchTensor,
|
|
38
|
+
phase: BatchTensor,
|
|
39
|
+
eccentricity: BatchTensor,
|
|
40
40
|
):
|
|
41
41
|
"""
|
|
42
42
|
Generate lalinference implementation of a sine-Gaussian waveform.
|
|
@@ -60,7 +60,7 @@ class SineGaussian(torch.nn.Module):
|
|
|
60
60
|
Returns:
|
|
61
61
|
Tensors of cross and plus polarizations
|
|
62
62
|
"""
|
|
63
|
-
|
|
63
|
+
dtype = frequency.dtype
|
|
64
64
|
# add dimension for calculating waveforms in batch
|
|
65
65
|
frequency = frequency.view(-1, 1)
|
|
66
66
|
quality = quality.view(-1, 1)
|
|
@@ -105,8 +105,7 @@ class SineGaussian(torch.nn.Module):
|
|
|
105
105
|
cross = fac.imag * h0_cross
|
|
106
106
|
plus = fac.real * h0_plus
|
|
107
107
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
plus = plus.double()
|
|
108
|
+
cross = cross.to(dtype)
|
|
109
|
+
plus = plus.to(dtype)
|
|
111
110
|
|
|
112
111
|
return cross, plus
|