ml4gw 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/gw.py +44 -8
- ml4gw/spectral.py +1 -1
- ml4gw/transforms/snr_rescaler.py +17 -4
- ml4gw/waveforms/generator.py +2 -2
- {ml4gw-0.7.0.dist-info → ml4gw-0.7.2.dist-info}/METADATA +1 -1
- {ml4gw-0.7.0.dist-info → ml4gw-0.7.2.dist-info}/RECORD +8 -8
- {ml4gw-0.7.0.dist-info → ml4gw-0.7.2.dist-info}/LICENSE +0 -0
- {ml4gw-0.7.0.dist-info → ml4gw-0.7.2.dist-info}/WHEEL +0 -0
ml4gw/gw.py
CHANGED
|
@@ -16,8 +16,6 @@ import torch
|
|
|
16
16
|
from jaxtyping import Float
|
|
17
17
|
from torch import Tensor
|
|
18
18
|
|
|
19
|
-
from ml4gw.utils.interferometer import InterferometerGeometry
|
|
20
|
-
|
|
21
19
|
from .constants import C
|
|
22
20
|
from .types import (
|
|
23
21
|
BatchTensor,
|
|
@@ -28,6 +26,7 @@ from .types import (
|
|
|
28
26
|
VectorGeometry,
|
|
29
27
|
WaveformTensor,
|
|
30
28
|
)
|
|
29
|
+
from .utils.interferometer import InterferometerGeometry
|
|
31
30
|
|
|
32
31
|
|
|
33
32
|
def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
|
|
@@ -285,6 +284,7 @@ def compute_ifo_snr(
|
|
|
285
284
|
psd: PSDTensor,
|
|
286
285
|
sample_rate: float,
|
|
287
286
|
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
287
|
+
lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
288
288
|
) -> Float[Tensor, "batch num_ifos"]:
|
|
289
289
|
r"""Compute the SNRs of a batch of interferometer responses
|
|
290
290
|
|
|
@@ -300,7 +300,8 @@ def compute_ifo_snr(
|
|
|
300
300
|
{S_n^{(j)}(f)}df$$
|
|
301
301
|
|
|
302
302
|
Where $f_{\text{min}}$ is a minimum frequency denoted
|
|
303
|
-
by `highpass`, `f_{\text{max}}` is the
|
|
303
|
+
by `highpass`, `f_{\text{max}}` is the maximum frequency
|
|
304
|
+
denoted by `lowpass`, which defaults to the Nyquist frequency
|
|
304
305
|
dictated by `sample_rate`; `\tilde{h_{ij}}` and `\tilde{h_{ij}}*`
|
|
305
306
|
indicate the fourier transform of the $i$th waveform at
|
|
306
307
|
the $j$th inteferometer and its complex conjugate, respectively;
|
|
@@ -328,8 +329,15 @@ def compute_ifo_snr(
|
|
|
328
329
|
If a tensor is provided, it will be assumed to be a
|
|
329
330
|
pre-computed mask used to 0-out low frequency components.
|
|
330
331
|
If a float, it will be used to compute such a mask. If
|
|
331
|
-
left as `None`, all frequencies up to `
|
|
332
|
+
left as `None`, all frequencies up to `lowpass`
|
|
332
333
|
will contribute to the SNR calculation.
|
|
334
|
+
lowpass:
|
|
335
|
+
The maximum frequency below which to compute the SNR.
|
|
336
|
+
If a tensor is provided, it will be assumed to be a
|
|
337
|
+
pre-computed mask used to 0-out high frequency components.
|
|
338
|
+
If a float, it will be used to compute such a mask. If
|
|
339
|
+
left as `None`, all frequencies from `highpass` up to
|
|
340
|
+
the Nyquist freqyency will contribute to the SNR calculation.
|
|
333
341
|
Returns:
|
|
334
342
|
Batch of SNRs computed for each interferometer
|
|
335
343
|
"""
|
|
@@ -346,7 +354,7 @@ def compute_ifo_snr(
|
|
|
346
354
|
integrand = fft / (psd**0.5)
|
|
347
355
|
integrand = integrand.type(torch.float32) ** 2
|
|
348
356
|
|
|
349
|
-
# mask out
|
|
357
|
+
# mask out frequency components if a critical
|
|
350
358
|
# frequency or frequency mask was provided
|
|
351
359
|
if highpass is not None:
|
|
352
360
|
if not isinstance(highpass, torch.Tensor):
|
|
@@ -354,12 +362,24 @@ def compute_ifo_snr(
|
|
|
354
362
|
highpass = freqs >= highpass
|
|
355
363
|
elif len(highpass) != integrand.shape[-1]:
|
|
356
364
|
raise ValueError(
|
|
357
|
-
"Can't apply highpass filter mask with {}
|
|
365
|
+
"Can't apply highpass filter mask with {} frequency bins"
|
|
358
366
|
"to signal fft with {} frequency bins".format(
|
|
359
367
|
len(highpass), integrand.shape[-1]
|
|
360
368
|
)
|
|
361
369
|
)
|
|
362
370
|
integrand *= highpass.to(integrand.device)
|
|
371
|
+
if lowpass is not None:
|
|
372
|
+
if not isinstance(lowpass, torch.Tensor):
|
|
373
|
+
freqs = torch.fft.rfftfreq(responses.shape[-1], 1 / sample_rate)
|
|
374
|
+
lowpass = freqs < lowpass
|
|
375
|
+
elif len(lowpass) != integrand.shape[-1]:
|
|
376
|
+
raise ValueError(
|
|
377
|
+
"Can't apply lowpass filter mask with {} frequency bins"
|
|
378
|
+
"to signal fft with {} frequency bins".format(
|
|
379
|
+
len(lowpass), integrand.shape[-1]
|
|
380
|
+
)
|
|
381
|
+
)
|
|
382
|
+
integrand *= lowpass.to(integrand.device)
|
|
363
383
|
|
|
364
384
|
# sum over the desired frequency range and multiply
|
|
365
385
|
# by df to turn it into an integration (and get
|
|
@@ -386,6 +406,7 @@ def compute_network_snr(
|
|
|
386
406
|
psd: PSDTensor,
|
|
387
407
|
sample_rate: float,
|
|
388
408
|
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
409
|
+
lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
389
410
|
) -> BatchTensor:
|
|
390
411
|
r"""
|
|
391
412
|
Compute the total SNR from a gravitational waveform
|
|
@@ -422,10 +443,17 @@ def compute_network_snr(
|
|
|
422
443
|
If a float, it will be used to compute such a mask. If
|
|
423
444
|
left as `None`, all frequencies up to `sample_rate / 2`
|
|
424
445
|
will contribute to the SNR calculation.
|
|
446
|
+
lowpass:
|
|
447
|
+
The maximum frequency below which to compute the SNR.
|
|
448
|
+
If a tensor is provided, it will be assumed to be a
|
|
449
|
+
pre-computed mask used to 0-out high frequency components.
|
|
450
|
+
If a float, it will be used to compute such a mask. If
|
|
451
|
+
left as `None`, all frequencies from `highpass` up to
|
|
452
|
+
the Nyquist freqyency will contribute to the SNR calculation.
|
|
425
453
|
Returns:
|
|
426
454
|
Batch of SNRs for each waveform across the interferometer network
|
|
427
455
|
"""
|
|
428
|
-
snrs = compute_ifo_snr(responses, psd, sample_rate, highpass)
|
|
456
|
+
snrs = compute_ifo_snr(responses, psd, sample_rate, highpass, lowpass)
|
|
429
457
|
snrs = snrs**2
|
|
430
458
|
return snrs.sum(axis=-1) ** 0.5
|
|
431
459
|
|
|
@@ -436,6 +464,7 @@ def reweight_snrs(
|
|
|
436
464
|
psd: PSDTensor,
|
|
437
465
|
sample_rate: float,
|
|
438
466
|
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
467
|
+
lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
439
468
|
) -> WaveformTensor:
|
|
440
469
|
"""Scale interferometer responses such that they have a desired SNR
|
|
441
470
|
|
|
@@ -466,10 +495,17 @@ def reweight_snrs(
|
|
|
466
495
|
If a float, it will be used to compute such a mask. If
|
|
467
496
|
left as `None`, all frequencies up to `sample_rate / 2`
|
|
468
497
|
will contribute to the SNR calculation.
|
|
498
|
+
lowpass:
|
|
499
|
+
The maximum frequency below which to compute the SNR.
|
|
500
|
+
If a tensor is provided, it will be assumed to be a
|
|
501
|
+
pre-computed mask used to 0-out high frequency components.
|
|
502
|
+
If a float, it will be used to compute such a mask. If
|
|
503
|
+
left as `None`, all frequencies from `highpass` up to
|
|
504
|
+
the Nyquist freqyency will contribute to the SNR calculation.
|
|
469
505
|
Returns:
|
|
470
506
|
Rescaled interferometer responses
|
|
471
507
|
"""
|
|
472
508
|
|
|
473
|
-
snrs = compute_network_snr(responses, psd, sample_rate, highpass)
|
|
509
|
+
snrs = compute_network_snr(responses, psd, sample_rate, highpass, lowpass)
|
|
474
510
|
weights = target_snrs / snrs
|
|
475
511
|
return responses * weights[:, None, None]
|
ml4gw/spectral.py
CHANGED
|
@@ -382,7 +382,7 @@ def truncate_inverse_power_spectrum(
|
|
|
382
382
|
as `None`, no lowpass filtering will be applied.
|
|
383
383
|
Returns:
|
|
384
384
|
The PSD with its time domain response truncated
|
|
385
|
-
to `fduration` and any
|
|
385
|
+
to `fduration` and any filtered frequencies
|
|
386
386
|
tapered.
|
|
387
387
|
"""
|
|
388
388
|
|
ml4gw/transforms/snr_rescaler.py
CHANGED
|
@@ -14,10 +14,10 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
14
14
|
sample_rate: float,
|
|
15
15
|
waveform_duration: float,
|
|
16
16
|
highpass: Optional[float] = None,
|
|
17
|
+
lowpass: Optional[float] = None,
|
|
17
18
|
dtype: torch.dtype = torch.float32,
|
|
18
19
|
) -> None:
|
|
19
20
|
super().__init__()
|
|
20
|
-
self.highpass = highpass
|
|
21
21
|
self.sample_rate = sample_rate
|
|
22
22
|
self.num_channels = num_channels
|
|
23
23
|
|
|
@@ -29,9 +29,18 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
29
29
|
|
|
30
30
|
if highpass is not None:
|
|
31
31
|
freqs = torch.fft.rfftfreq(waveform_size, 1 / sample_rate)
|
|
32
|
-
self.register_buffer(
|
|
32
|
+
self.register_buffer(
|
|
33
|
+
"highpass_mask", freqs >= highpass, persistent=False
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
self.highpass_mask = None
|
|
37
|
+
if lowpass is not None:
|
|
38
|
+
freqs = torch.fft.rfftfreq(waveform_size, 1 / sample_rate)
|
|
39
|
+
self.register_buffer(
|
|
40
|
+
"lowpass_mask", freqs < lowpass, persistent=False
|
|
41
|
+
)
|
|
33
42
|
else:
|
|
34
|
-
self.
|
|
43
|
+
self.lowpass_mask = None
|
|
35
44
|
|
|
36
45
|
def fit(
|
|
37
46
|
self,
|
|
@@ -63,7 +72,11 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
63
72
|
target_snrs: Optional[BatchTensor] = None,
|
|
64
73
|
):
|
|
65
74
|
snrs = compute_network_snr(
|
|
66
|
-
responses,
|
|
75
|
+
responses,
|
|
76
|
+
self.background,
|
|
77
|
+
self.sample_rate,
|
|
78
|
+
self.highpass_mask,
|
|
79
|
+
self.lowpass_mask,
|
|
67
80
|
)
|
|
68
81
|
if target_snrs is None:
|
|
69
82
|
idx = torch.randperm(len(snrs))
|
ml4gw/waveforms/generator.py
CHANGED
|
@@ -224,7 +224,7 @@ class TimeDomainCBCWaveformGenerator(torch.nn.Module):
|
|
|
224
224
|
k1s = torch.round(f_min / df)
|
|
225
225
|
|
|
226
226
|
num_freqs = frequencies.size(0)
|
|
227
|
-
frequency_indices = torch.arange(num_freqs)
|
|
227
|
+
frequency_indices = torch.arange(num_freqs, device=device)
|
|
228
228
|
taper_mask = frequency_indices <= k1s[:, None]
|
|
229
229
|
taper_mask &= frequency_indices >= k0s[:, None]
|
|
230
230
|
|
|
@@ -253,7 +253,7 @@ class TimeDomainCBCWaveformGenerator(torch.nn.Module):
|
|
|
253
253
|
# that will translate the coalescense time such that it is `right_pad`
|
|
254
254
|
# seconds from the right edge of the window
|
|
255
255
|
tshift = round(self.right_pad * self.sample_rate) / self.sample_rate
|
|
256
|
-
kvals = torch.arange(num_freqs)
|
|
256
|
+
kvals = torch.arange(num_freqs, device=device)
|
|
257
257
|
phase_shift = torch.exp(1j * 2 * torch.pi * df * tshift * kvals)
|
|
258
258
|
|
|
259
259
|
hc_spectrum *= phase_shift
|
|
@@ -6,7 +6,7 @@ ml4gw/dataloading/chunked_dataset.py,sha256=j96Rd67cRpsvotR_dzgfbrqxcoGDWnTV5cmf
|
|
|
6
6
|
ml4gw/dataloading/hdf5_dataset.py,sha256=bVcXzS1LHVj7zMeMtRkxx1Q76MQS6wEApJJlUAI6iC8,7879
|
|
7
7
|
ml4gw/dataloading/in_memory_dataset.py,sha256=1oUchfNBq3rx1NgNqrcg6AGdJ-dvm56o-TGFwPn5wm8,9546
|
|
8
8
|
ml4gw/distributions.py,sha256=tUuaOiX5enjKLYWD7uiN8rdRVQcrIKps64xBkTl8fMs,4991
|
|
9
|
-
ml4gw/gw.py,sha256=
|
|
9
|
+
ml4gw/gw.py,sha256=0I9MhoHWksWG9a5EUI0GkHD1skuOXiaQgSgxNKYXCxE,19778
|
|
10
10
|
ml4gw/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
ml4gw/nn/autoencoder/__init__.py,sha256=ZaT1XhJTHpMuPQqu5E__Jezeh9uwtjcXlT7IZ18byq4,161
|
|
12
12
|
ml4gw/nn/autoencoder/base.py,sha256=eSWrDdpblI609oqa7RDSvZiY3YcV8WfhTioWKFn_7eE,3205
|
|
@@ -20,13 +20,13 @@ ml4gw/nn/resnet/resnet_2d.py,sha256=fVzYRuO0xR9yGjjQExv30mouokvupOAW-Kfdbs5WYDA,
|
|
|
20
20
|
ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
|
|
21
21
|
ml4gw/nn/streaming/online_average.py,sha256=_nrul4ygTC_ln4wpSWGRWTgWlfGeOUGXxeGrhU4oJms,4716
|
|
22
22
|
ml4gw/nn/streaming/snapshotter.py,sha256=1vWDpebRQBZIUVeksbXoqngqMnlSzQFkcsgYNrHB9tc,4473
|
|
23
|
-
ml4gw/spectral.py,sha256=
|
|
23
|
+
ml4gw/spectral.py,sha256=rnxd1ObPjyQMAu3D83_sw2lEEHZF7f87YQBV_pxHLxM,19809
|
|
24
24
|
ml4gw/transforms/__init__.py,sha256=OaTQJD4GFkDkcxt0DIwt2AzeEcv9t21ciKXxQnqDiuI,447
|
|
25
25
|
ml4gw/transforms/iirfilter.py,sha256=RwgC3DWgYmBnHe7bYjvr9njM1WrRZ9EjBJsZNmtOY8s,3186
|
|
26
26
|
ml4gw/transforms/pearson.py,sha256=CM9FTRxI4384-36FIaJFOcMZwsA7BkgberToJkMU1PA,3227
|
|
27
27
|
ml4gw/transforms/qtransform.py,sha256=5S9y3PxkOmqMAarQmme0Tiy58vRvberpqhg6IeyDJLI,20675
|
|
28
28
|
ml4gw/transforms/scaler.py,sha256=K5mp4w2zGZbpH1AcBUfpQS4n3aVSNzkaGWXedwk2LXs,2508
|
|
29
|
-
ml4gw/transforms/snr_rescaler.py,sha256=
|
|
29
|
+
ml4gw/transforms/snr_rescaler.py,sha256=lfuwdwMY117gB-emmn0_22gsK_A9xnkHJv2-76HFWc4,2728
|
|
30
30
|
ml4gw/transforms/spectral.py,sha256=4uCLNEcDff4kLheUA5v64L0y_MSOvUTJ92IH4TVcEys,4385
|
|
31
31
|
ml4gw/transforms/spectrogram.py,sha256=8HDStoup7vlwpw9qTKshAuEpa85-lw5_SwYxjxxu1sQ,6158
|
|
32
32
|
ml4gw/transforms/spline_interpolation.py,sha256=oXih-gLMbIbI36DPKLTk39WcjiNUJN_rcQia_k3OrMY,13527
|
|
@@ -48,8 +48,8 @@ ml4gw/waveforms/cbc/phenom_p.py,sha256=LBtGVUjBjROcYBPLldFnF6T1jZV6ZyuZEnkn9-oTK
|
|
|
48
48
|
ml4gw/waveforms/cbc/taylorf2.py,sha256=2ga_lG_xkYOsF-BdxgjbU0pgLDjeAO0p5IWuCPvibvQ,10504
|
|
49
49
|
ml4gw/waveforms/cbc/utils.py,sha256=CvZ79PQygb-zwulMV-wRuBcGEsHbVOtJz60UnOJFKoM,3051
|
|
50
50
|
ml4gw/waveforms/conversion.py,sha256=MyADWEZVoEkRkKaHk1ZuQDsGfPYx5xUTtyApj5P3ueQ,6895
|
|
51
|
-
ml4gw/waveforms/generator.py,sha256=
|
|
52
|
-
ml4gw-0.7.
|
|
53
|
-
ml4gw-0.7.
|
|
54
|
-
ml4gw-0.7.
|
|
55
|
-
ml4gw-0.7.
|
|
51
|
+
ml4gw/waveforms/generator.py,sha256=i2lgaJzH5eA6gzc-bLQZYYEgEQ8OBLJgE9yNXU3FsKM,12005
|
|
52
|
+
ml4gw-0.7.2.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
53
|
+
ml4gw-0.7.2.dist-info/METADATA,sha256=enY2IXiMhiIKn4RBDS6JLEdq1xT9g7IDjV6pvHxBpes,3904
|
|
54
|
+
ml4gw-0.7.2.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
|
|
55
|
+
ml4gw-0.7.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|