ml4gw 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/gw.py CHANGED
@@ -16,8 +16,6 @@ import torch
16
16
  from jaxtyping import Float
17
17
  from torch import Tensor
18
18
 
19
- from ml4gw.utils.interferometer import InterferometerGeometry
20
-
21
19
  from .constants import C
22
20
  from .types import (
23
21
  BatchTensor,
@@ -28,6 +26,7 @@ from .types import (
28
26
  VectorGeometry,
29
27
  WaveformTensor,
30
28
  )
29
+ from .utils.interferometer import InterferometerGeometry
31
30
 
32
31
 
33
32
  def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
@@ -285,6 +284,7 @@ def compute_ifo_snr(
285
284
  psd: PSDTensor,
286
285
  sample_rate: float,
287
286
  highpass: Union[float, Float[Tensor, " frequency"], None] = None,
287
+ lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
288
288
  ) -> Float[Tensor, "batch num_ifos"]:
289
289
  r"""Compute the SNRs of a batch of interferometer responses
290
290
 
@@ -300,7 +300,8 @@ def compute_ifo_snr(
300
300
  {S_n^{(j)}(f)}df$$
301
301
 
302
302
  Where $f_{\text{min}}$ is a minimum frequency denoted
303
- by `highpass`, `f_{\text{max}}` is the Nyquist frequency
303
+ by `highpass`, `f_{\text{max}}` is the maximum frequency
304
+ denoted by `lowpass`, which defaults to the Nyquist frequency
304
305
  dictated by `sample_rate`; `\tilde{h_{ij}}` and `\tilde{h_{ij}}*`
305
306
  indicate the fourier transform of the $i$th waveform at
306
307
  the $j$th inteferometer and its complex conjugate, respectively;
@@ -328,8 +329,15 @@ def compute_ifo_snr(
328
329
  If a tensor is provided, it will be assumed to be a
329
330
  pre-computed mask used to 0-out low frequency components.
330
331
  If a float, it will be used to compute such a mask. If
331
- left as `None`, all frequencies up to `sample_rate / 2`
332
+ left as `None`, all frequencies up to `lowpass`
332
333
  will contribute to the SNR calculation.
334
+ lowpass:
335
+ The maximum frequency below which to compute the SNR.
336
+ If a tensor is provided, it will be assumed to be a
337
+ pre-computed mask used to 0-out high frequency components.
338
+ If a float, it will be used to compute such a mask. If
339
+ left as `None`, all frequencies from `highpass` up to
340
+ the Nyquist freqyency will contribute to the SNR calculation.
333
341
  Returns:
334
342
  Batch of SNRs computed for each interferometer
335
343
  """
@@ -346,7 +354,7 @@ def compute_ifo_snr(
346
354
  integrand = fft / (psd**0.5)
347
355
  integrand = integrand.type(torch.float32) ** 2
348
356
 
349
- # mask out low frequency components if a critical
357
+ # mask out frequency components if a critical
350
358
  # frequency or frequency mask was provided
351
359
  if highpass is not None:
352
360
  if not isinstance(highpass, torch.Tensor):
@@ -354,12 +362,24 @@ def compute_ifo_snr(
354
362
  highpass = freqs >= highpass
355
363
  elif len(highpass) != integrand.shape[-1]:
356
364
  raise ValueError(
357
- "Can't apply highpass filter mask with {} frequecy bins"
365
+ "Can't apply highpass filter mask with {} frequency bins"
358
366
  "to signal fft with {} frequency bins".format(
359
367
  len(highpass), integrand.shape[-1]
360
368
  )
361
369
  )
362
370
  integrand *= highpass.to(integrand.device)
371
+ if lowpass is not None:
372
+ if not isinstance(lowpass, torch.Tensor):
373
+ freqs = torch.fft.rfftfreq(responses.shape[-1], 1 / sample_rate)
374
+ lowpass = freqs < lowpass
375
+ elif len(lowpass) != integrand.shape[-1]:
376
+ raise ValueError(
377
+ "Can't apply lowpass filter mask with {} frequency bins"
378
+ "to signal fft with {} frequency bins".format(
379
+ len(lowpass), integrand.shape[-1]
380
+ )
381
+ )
382
+ integrand *= lowpass.to(integrand.device)
363
383
 
364
384
  # sum over the desired frequency range and multiply
365
385
  # by df to turn it into an integration (and get
@@ -386,6 +406,7 @@ def compute_network_snr(
386
406
  psd: PSDTensor,
387
407
  sample_rate: float,
388
408
  highpass: Union[float, Float[Tensor, " frequency"], None] = None,
409
+ lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
389
410
  ) -> BatchTensor:
390
411
  r"""
391
412
  Compute the total SNR from a gravitational waveform
@@ -422,10 +443,17 @@ def compute_network_snr(
422
443
  If a float, it will be used to compute such a mask. If
423
444
  left as `None`, all frequencies up to `sample_rate / 2`
424
445
  will contribute to the SNR calculation.
446
+ lowpass:
447
+ The maximum frequency below which to compute the SNR.
448
+ If a tensor is provided, it will be assumed to be a
449
+ pre-computed mask used to 0-out high frequency components.
450
+ If a float, it will be used to compute such a mask. If
451
+ left as `None`, all frequencies from `highpass` up to
452
+ the Nyquist freqyency will contribute to the SNR calculation.
425
453
  Returns:
426
454
  Batch of SNRs for each waveform across the interferometer network
427
455
  """
428
- snrs = compute_ifo_snr(responses, psd, sample_rate, highpass)
456
+ snrs = compute_ifo_snr(responses, psd, sample_rate, highpass, lowpass)
429
457
  snrs = snrs**2
430
458
  return snrs.sum(axis=-1) ** 0.5
431
459
 
@@ -436,6 +464,7 @@ def reweight_snrs(
436
464
  psd: PSDTensor,
437
465
  sample_rate: float,
438
466
  highpass: Union[float, Float[Tensor, " frequency"], None] = None,
467
+ lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
439
468
  ) -> WaveformTensor:
440
469
  """Scale interferometer responses such that they have a desired SNR
441
470
 
@@ -466,10 +495,17 @@ def reweight_snrs(
466
495
  If a float, it will be used to compute such a mask. If
467
496
  left as `None`, all frequencies up to `sample_rate / 2`
468
497
  will contribute to the SNR calculation.
498
+ lowpass:
499
+ The maximum frequency below which to compute the SNR.
500
+ If a tensor is provided, it will be assumed to be a
501
+ pre-computed mask used to 0-out high frequency components.
502
+ If a float, it will be used to compute such a mask. If
503
+ left as `None`, all frequencies from `highpass` up to
504
+ the Nyquist freqyency will contribute to the SNR calculation.
469
505
  Returns:
470
506
  Rescaled interferometer responses
471
507
  """
472
508
 
473
- snrs = compute_network_snr(responses, psd, sample_rate, highpass)
509
+ snrs = compute_network_snr(responses, psd, sample_rate, highpass, lowpass)
474
510
  weights = target_snrs / snrs
475
511
  return responses * weights[:, None, None]
ml4gw/spectral.py CHANGED
@@ -382,7 +382,7 @@ def truncate_inverse_power_spectrum(
382
382
  as `None`, no lowpass filtering will be applied.
383
383
  Returns:
384
384
  The PSD with its time domain response truncated
385
- to `fduration` and any highpassed frequencies
385
+ to `fduration` and any filtered frequencies
386
386
  tapered.
387
387
  """
388
388
 
@@ -14,10 +14,10 @@ class SnrRescaler(FittableSpectralTransform):
14
14
  sample_rate: float,
15
15
  waveform_duration: float,
16
16
  highpass: Optional[float] = None,
17
+ lowpass: Optional[float] = None,
17
18
  dtype: torch.dtype = torch.float32,
18
19
  ) -> None:
19
20
  super().__init__()
20
- self.highpass = highpass
21
21
  self.sample_rate = sample_rate
22
22
  self.num_channels = num_channels
23
23
 
@@ -29,9 +29,18 @@ class SnrRescaler(FittableSpectralTransform):
29
29
 
30
30
  if highpass is not None:
31
31
  freqs = torch.fft.rfftfreq(waveform_size, 1 / sample_rate)
32
- self.register_buffer("mask", freqs >= highpass, persistent=False)
32
+ self.register_buffer(
33
+ "highpass_mask", freqs >= highpass, persistent=False
34
+ )
35
+ else:
36
+ self.highpass_mask = None
37
+ if lowpass is not None:
38
+ freqs = torch.fft.rfftfreq(waveform_size, 1 / sample_rate)
39
+ self.register_buffer(
40
+ "lowpass_mask", freqs < lowpass, persistent=False
41
+ )
33
42
  else:
34
- self.mask = None
43
+ self.lowpass_mask = None
35
44
 
36
45
  def fit(
37
46
  self,
@@ -63,7 +72,11 @@ class SnrRescaler(FittableSpectralTransform):
63
72
  target_snrs: Optional[BatchTensor] = None,
64
73
  ):
65
74
  snrs = compute_network_snr(
66
- responses, self.background, self.sample_rate, self.mask
75
+ responses,
76
+ self.background,
77
+ self.sample_rate,
78
+ self.highpass_mask,
79
+ self.lowpass_mask,
67
80
  )
68
81
  if target_snrs is None:
69
82
  idx = torch.randperm(len(snrs))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ml4gw
3
- Version: 0.7.1
3
+ Version: 0.7.3
4
4
  Summary: Tools for training torch models on gravitational wave data
5
5
  Author: Alec Gunny
6
6
  Author-email: alec.gunny@ligo.org
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.11
12
12
  Classifier: Programming Language :: Python :: 3.12
13
13
  Requires-Dist: jaxtyping (>=0.2,<0.3)
14
14
  Requires-Dist: numpy (<2.0.0)
15
+ Requires-Dist: scipy (>=1.9.0,<1.15)
15
16
  Requires-Dist: torch (>=2.0,<3.0)
16
17
  Requires-Dist: torchaudio (>=2.0,<3.0)
17
18
  Description-Content-Type: text/markdown
@@ -6,7 +6,7 @@ ml4gw/dataloading/chunked_dataset.py,sha256=j96Rd67cRpsvotR_dzgfbrqxcoGDWnTV5cmf
6
6
  ml4gw/dataloading/hdf5_dataset.py,sha256=bVcXzS1LHVj7zMeMtRkxx1Q76MQS6wEApJJlUAI6iC8,7879
7
7
  ml4gw/dataloading/in_memory_dataset.py,sha256=1oUchfNBq3rx1NgNqrcg6AGdJ-dvm56o-TGFwPn5wm8,9546
8
8
  ml4gw/distributions.py,sha256=tUuaOiX5enjKLYWD7uiN8rdRVQcrIKps64xBkTl8fMs,4991
9
- ml4gw/gw.py,sha256=aUPSXgwyqJUBGGaKtUa-O3qkSbRYZwhhXIlkhvjgJgI,17684
9
+ ml4gw/gw.py,sha256=0I9MhoHWksWG9a5EUI0GkHD1skuOXiaQgSgxNKYXCxE,19778
10
10
  ml4gw/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  ml4gw/nn/autoencoder/__init__.py,sha256=ZaT1XhJTHpMuPQqu5E__Jezeh9uwtjcXlT7IZ18byq4,161
12
12
  ml4gw/nn/autoencoder/base.py,sha256=eSWrDdpblI609oqa7RDSvZiY3YcV8WfhTioWKFn_7eE,3205
@@ -20,13 +20,13 @@ ml4gw/nn/resnet/resnet_2d.py,sha256=fVzYRuO0xR9yGjjQExv30mouokvupOAW-Kfdbs5WYDA,
20
20
  ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
21
21
  ml4gw/nn/streaming/online_average.py,sha256=_nrul4ygTC_ln4wpSWGRWTgWlfGeOUGXxeGrhU4oJms,4716
22
22
  ml4gw/nn/streaming/snapshotter.py,sha256=1vWDpebRQBZIUVeksbXoqngqMnlSzQFkcsgYNrHB9tc,4473
23
- ml4gw/spectral.py,sha256=lLpnho02i-0zPSi96b0xOPEIgQMnBrmO8JiV1KvPGEw,19811
23
+ ml4gw/spectral.py,sha256=rnxd1ObPjyQMAu3D83_sw2lEEHZF7f87YQBV_pxHLxM,19809
24
24
  ml4gw/transforms/__init__.py,sha256=OaTQJD4GFkDkcxt0DIwt2AzeEcv9t21ciKXxQnqDiuI,447
25
25
  ml4gw/transforms/iirfilter.py,sha256=RwgC3DWgYmBnHe7bYjvr9njM1WrRZ9EjBJsZNmtOY8s,3186
26
26
  ml4gw/transforms/pearson.py,sha256=CM9FTRxI4384-36FIaJFOcMZwsA7BkgberToJkMU1PA,3227
27
27
  ml4gw/transforms/qtransform.py,sha256=5S9y3PxkOmqMAarQmme0Tiy58vRvberpqhg6IeyDJLI,20675
28
28
  ml4gw/transforms/scaler.py,sha256=K5mp4w2zGZbpH1AcBUfpQS4n3aVSNzkaGWXedwk2LXs,2508
29
- ml4gw/transforms/snr_rescaler.py,sha256=XHKTeJXM3F_VOmjWOZetQuVZ6PMum8pEBPaOVbS16-w,2327
29
+ ml4gw/transforms/snr_rescaler.py,sha256=lfuwdwMY117gB-emmn0_22gsK_A9xnkHJv2-76HFWc4,2728
30
30
  ml4gw/transforms/spectral.py,sha256=4uCLNEcDff4kLheUA5v64L0y_MSOvUTJ92IH4TVcEys,4385
31
31
  ml4gw/transforms/spectrogram.py,sha256=8HDStoup7vlwpw9qTKshAuEpa85-lw5_SwYxjxxu1sQ,6158
32
32
  ml4gw/transforms/spline_interpolation.py,sha256=oXih-gLMbIbI36DPKLTk39WcjiNUJN_rcQia_k3OrMY,13527
@@ -49,7 +49,7 @@ ml4gw/waveforms/cbc/taylorf2.py,sha256=2ga_lG_xkYOsF-BdxgjbU0pgLDjeAO0p5IWuCPvib
49
49
  ml4gw/waveforms/cbc/utils.py,sha256=CvZ79PQygb-zwulMV-wRuBcGEsHbVOtJz60UnOJFKoM,3051
50
50
  ml4gw/waveforms/conversion.py,sha256=MyADWEZVoEkRkKaHk1ZuQDsGfPYx5xUTtyApj5P3ueQ,6895
51
51
  ml4gw/waveforms/generator.py,sha256=i2lgaJzH5eA6gzc-bLQZYYEgEQ8OBLJgE9yNXU3FsKM,12005
52
- ml4gw-0.7.1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
53
- ml4gw-0.7.1.dist-info/METADATA,sha256=xEVSE7PX32I8b4YIneUVVvTAHLS4WemuQ8bpCKskIXE,3904
54
- ml4gw-0.7.1.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
55
- ml4gw-0.7.1.dist-info/RECORD,,
52
+ ml4gw-0.7.3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
53
+ ml4gw-0.7.3.dist-info/METADATA,sha256=LSSAqGyEFUoINHjSh1iCc3TRRY-ytTU6pPZ6TVeZdLI,3941
54
+ ml4gw-0.7.3.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
55
+ ml4gw-0.7.3.dist-info/RECORD,,
File without changes
File without changes