ml4gw 0.7.6__tar.gz → 0.7.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

Files changed (122) hide show
  1. {ml4gw-0.7.6 → ml4gw-0.7.8}/PKG-INFO +33 -12
  2. {ml4gw-0.7.6 → ml4gw-0.7.8}/README.md +21 -2
  3. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/augmentations.py +5 -0
  4. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/dataloading/__init__.py +5 -0
  5. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/dataloading/chunked_dataset.py +2 -4
  6. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/dataloading/hdf5_dataset.py +12 -10
  7. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/dataloading/in_memory_dataset.py +12 -12
  8. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/distributions.py +3 -3
  9. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/gw.py +18 -21
  10. ml4gw-0.7.8/ml4gw/nn/__init__.py +6 -0
  11. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/autoencoder/base.py +5 -9
  12. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/autoencoder/convolutional.py +7 -10
  13. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/autoencoder/skip_connection.py +3 -5
  14. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/norm.py +4 -4
  15. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/resnet/resnet_1d.py +12 -13
  16. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/resnet/resnet_2d.py +13 -14
  17. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/streaming/online_average.py +3 -5
  18. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/streaming/snapshotter.py +10 -14
  19. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/spectral.py +20 -23
  20. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/__init__.py +7 -1
  21. ml4gw-0.7.8/ml4gw/transforms/decimator.py +183 -0
  22. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/iirfilter.py +3 -5
  23. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/pearson.py +3 -4
  24. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/qtransform.py +20 -26
  25. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/scaler.py +3 -5
  26. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/snr_rescaler.py +7 -11
  27. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/spectral.py +6 -13
  28. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/spectrogram.py +6 -3
  29. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/spline_interpolation.py +312 -143
  30. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/transform.py +4 -6
  31. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/waveforms.py +8 -15
  32. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/transforms/whitening.py +11 -16
  33. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/types.py +8 -5
  34. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/utils/interferometer.py +20 -3
  35. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/utils/slicing.py +26 -30
  36. ml4gw-0.7.8/ml4gw/waveforms/__init__.py +8 -0
  37. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/cbc/phenom_p.py +7 -9
  38. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/conversion.py +2 -4
  39. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/generator.py +3 -3
  40. ml4gw-0.7.8/ml4gw.egg-info/PKG-INFO +77 -0
  41. ml4gw-0.7.8/ml4gw.egg-info/SOURCES.txt +64 -0
  42. ml4gw-0.7.8/ml4gw.egg-info/dependency_links.txt +1 -0
  43. ml4gw-0.7.8/ml4gw.egg-info/requires.txt +5 -0
  44. ml4gw-0.7.8/ml4gw.egg-info/top_level.txt +1 -0
  45. {ml4gw-0.7.6 → ml4gw-0.7.8}/pyproject.toml +21 -11
  46. ml4gw-0.7.8/setup.cfg +4 -0
  47. {ml4gw-0.7.6 → ml4gw-0.7.8}/tests/test_augmentations.py +2 -2
  48. {ml4gw-0.7.6 → ml4gw-0.7.8}/tests/test_gw.py +7 -3
  49. {ml4gw-0.7.6 → ml4gw-0.7.8}/tests/test_spectral.py +1 -1
  50. ml4gw-0.7.6/.coverage +0 -0
  51. ml4gw-0.7.6/.gitattributes +0 -2
  52. ml4gw-0.7.6/.github/workflows/coverage.yaml +0 -31
  53. ml4gw-0.7.6/.github/workflows/docs.yaml +0 -29
  54. ml4gw-0.7.6/.github/workflows/pre-commit.yaml +0 -17
  55. ml4gw-0.7.6/.github/workflows/publish.yaml +0 -27
  56. ml4gw-0.7.6/.github/workflows/unit-tests.yaml +0 -80
  57. ml4gw-0.7.6/.gitignore +0 -3
  58. ml4gw-0.7.6/.pre-commit-config.yaml +0 -23
  59. ml4gw-0.7.6/.readthedocs.yaml +0 -36
  60. ml4gw-0.7.6/CITATION.cff +0 -37
  61. ml4gw-0.7.6/docs/.gitignore +0 -2
  62. ml4gw-0.7.6/docs/Makefile +0 -20
  63. ml4gw-0.7.6/docs/conf.py +0 -78
  64. ml4gw-0.7.6/docs/examples/augmentations.rst +0 -20
  65. ml4gw-0.7.6/docs/examples/distributions.rst +0 -50
  66. ml4gw-0.7.6/docs/examples/gw.rst +0 -65
  67. ml4gw-0.7.6/docs/examples/transforms.qtransform.rst +0 -40
  68. ml4gw-0.7.6/docs/examples/transforms.spectral.rst +0 -32
  69. ml4gw-0.7.6/docs/examples/transforms.whitening.rst +0 -29
  70. ml4gw-0.7.6/docs/images/distribution_samples.png +0 -0
  71. ml4gw-0.7.6/docs/images/qscan_spectrogram.png +0 -0
  72. ml4gw-0.7.6/docs/index.rst +0 -62
  73. ml4gw-0.7.6/docs/installation.rst +0 -19
  74. ml4gw-0.7.6/docs/make.bat +0 -35
  75. ml4gw-0.7.6/docs/requirements.txt +0 -4
  76. ml4gw-0.7.6/docs/tutorials/ml4gw_tutorial.ipynb +0 -1744
  77. ml4gw-0.7.6/docs/usage.rst +0 -17
  78. ml4gw-0.7.6/ml4gw/nn/__init__.py +0 -0
  79. ml4gw-0.7.6/ml4gw/waveforms/__init__.py +0 -2
  80. ml4gw-0.7.6/tests/conftest.py +0 -200
  81. ml4gw-0.7.6/tests/dataloading/test_chunked_dataset.py +0 -82
  82. ml4gw-0.7.6/tests/dataloading/test_hdf5_dataset.py +0 -188
  83. ml4gw-0.7.6/tests/dataloading/test_in_memory_dataset.py +0 -357
  84. ml4gw-0.7.6/tests/nn/resnet/test_resnet_1d.py +0 -166
  85. ml4gw-0.7.6/tests/nn/resnet/test_resnet_2d.py +0 -160
  86. ml4gw-0.7.6/tests/nn/streaming/test_online_average.py +0 -88
  87. ml4gw-0.7.6/tests/nn/streaming/test_snapshotter.py +0 -120
  88. ml4gw-0.7.6/tests/nn/test_norm.py +0 -75
  89. ml4gw-0.7.6/tests/transforms/test_iirfilter.py +0 -321
  90. ml4gw-0.7.6/tests/transforms/test_pearson.py +0 -81
  91. ml4gw-0.7.6/tests/transforms/test_qtransform.py +0 -184
  92. ml4gw-0.7.6/tests/transforms/test_scaler.py +0 -123
  93. ml4gw-0.7.6/tests/transforms/test_snr_rescaler.py +0 -86
  94. ml4gw-0.7.6/tests/transforms/test_spectral_transform.py +0 -300
  95. ml4gw-0.7.6/tests/transforms/test_spectrogram.py +0 -109
  96. ml4gw-0.7.6/tests/transforms/test_spline_interpolation.py +0 -119
  97. ml4gw-0.7.6/tests/transforms/test_waveforms.py +0 -105
  98. ml4gw-0.7.6/tests/transforms/test_whitening.py +0 -194
  99. ml4gw-0.7.6/tests/utils/test_slicing.py +0 -360
  100. ml4gw-0.7.6/tests/waveforms/adhoc/test_sine_gaussian.py +0 -100
  101. ml4gw-0.7.6/tests/waveforms/cbc/test_cbc_waveforms.py +0 -534
  102. ml4gw-0.7.6/tests/waveforms/cbc/test_utils.py +0 -115
  103. ml4gw-0.7.6/tests/waveforms/test_conversion.py +0 -82
  104. ml4gw-0.7.6/tests/waveforms/test_generator.py +0 -220
  105. ml4gw-0.7.6/uv.lock +0 -4013
  106. {ml4gw-0.7.6 → ml4gw-0.7.8}/LICENSE +0 -0
  107. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/__init__.py +0 -0
  108. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/constants.py +0 -0
  109. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/autoencoder/__init__.py +0 -0
  110. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/autoencoder/utils.py +0 -0
  111. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/resnet/__init__.py +0 -0
  112. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/nn/streaming/__init__.py +0 -0
  113. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/adhoc/__init__.py +0 -0
  114. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/adhoc/ringdown.py +0 -0
  115. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/adhoc/sine_gaussian.py +0 -0
  116. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/cbc/__init__.py +0 -0
  117. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/cbc/coefficients.py +0 -0
  118. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/cbc/phenom_d.py +0 -0
  119. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/cbc/phenom_d_data.py +0 -0
  120. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/cbc/taylorf2.py +0 -0
  121. {ml4gw-0.7.6 → ml4gw-0.7.8}/ml4gw/waveforms/cbc/utils.py +0 -0
  122. {ml4gw-0.7.6 → ml4gw-0.7.8}/tests/test_distributions.py +1 -1
@@ -1,22 +1,24 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ml4gw
3
- Version: 0.7.6
3
+ Version: 0.7.8
4
4
  Summary: Tools for training torch models on gravitational wave data
5
- Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>
6
- License-File: LICENSE
7
- Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
8
- Classifier: Programming Language :: Python :: 3.9
5
+ Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>, Ravi Kumar <ravi.kumar@ligo.org>
6
+ Maintainer-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>
7
+ License-Expression: GPL-3.0-or-later
9
8
  Classifier: Programming Language :: Python :: 3.10
10
9
  Classifier: Programming Language :: Python :: 3.11
11
10
  Classifier: Programming Language :: Python :: 3.12
12
- Classifier: Programming Language :: Python :: 3.13
13
- Requires-Python: <3.13,>=3.9
11
+ Classifier: Topic :: Scientific/Engineering :: Astronomy
12
+ Classifier: Topic :: Scientific/Engineering :: Physics
13
+ Requires-Python: <3.13,>=3.10
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
14
16
  Requires-Dist: jaxtyping<0.3,>=0.2
17
+ Requires-Dist: torch~=2.0
18
+ Requires-Dist: torchaudio~=2.0
15
19
  Requires-Dist: numpy<2.0.0
16
20
  Requires-Dist: scipy<1.15,>=1.9.0
17
- Requires-Dist: torchaudio~=2.0
18
- Requires-Dist: torch~=2.0
19
- Description-Content-Type: text/markdown
21
+ Dynamic: license-file
20
22
 
21
23
  # ML4GW
22
24
  ![PyPI - Version](https://img.shields.io/pypi/v/ml4gw)
@@ -29,7 +31,10 @@ Torch utilities for training neural networks in gravitational wave physics appli
29
31
 
30
32
  ## Documentation
31
33
  Please visit our [documentation page](https://ml4gw.github.io/ml4gw/) to see descriptions and examples of the functions and modules available in `ml4gw`.
32
- We also have an interactive Jupyter notebook that demonstrates much of the core functionality available in the `examples` directory.
34
+ We also have an interactive Jupyter notebook demonstrating much of the core functionality available [here](https://github.com/ML4GW/ml4gw/blob/main/docs/tutorials/ml4gw_tutorial.ipynb).
35
+ To run this notebook, download it from the above link and follow the instructions within it to install the required packages.
36
+ See also the [documentation page](https://ml4gw.github.io/ml4gw/tutorials/ml4gw_tutorial.html) for the tutorial to look
37
+ through it without running the code.
33
38
 
34
39
  ## Installation
35
40
  ### Pip installation
@@ -45,9 +50,25 @@ To build with a specific version of PyTorch/CUDA, please see the PyTorch install
45
50
  pip install ml4gw torch==2.5.1--extra-index-url=https://download.pytorch.org/whl/cu118
46
51
  ```
47
52
 
53
+ ### uv installation
54
+ If you want to develop `ml4gw`, you can use [uv](https://docs.astral.sh/uv/getting-started/installation/) to install the project in editable mode.
55
+ For example, after cloning the repository, create a virtualenv using
56
+ ```bash
57
+ uv venv --python=3.11
58
+ ```
59
+ Then sync the dependencies from the [uv lock file](/uv.lock) using
60
+ ```bash
61
+ uv sync --all-extras
62
+ ```
63
+ Code changes can be tested using
64
+ ```bash
65
+ uv run pytest
66
+ ```
67
+ See [contribution guide](/CONTRIBUTING.md) for more details.
68
+
48
69
  ## Contributing
49
70
  If you come across errors in the code, have difficulties using this software, or simply find that the current version doesn't cover your use case, please file an issue on our GitHub page, and we'll be happy to offer support.
50
- We encourage users who encounter these difficulties to file issues on GitHub, and we'll be happy to offer support to extend our coverage to new or improved functionality.
71
+ If you want to add feature, please refer to the [contribution guide](/CONTRIBUTING.md) for more details.
51
72
  We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
52
73
  For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu).
53
74
  By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool that makes deep learning more accessible for gravitational wave physicists everywhere.
@@ -9,7 +9,10 @@ Torch utilities for training neural networks in gravitational wave physics appli
9
9
 
10
10
  ## Documentation
11
11
  Please visit our [documentation page](https://ml4gw.github.io/ml4gw/) to see descriptions and examples of the functions and modules available in `ml4gw`.
12
- We also have an interactive Jupyter notebook that demonstrates much of the core functionality available in the `examples` directory.
12
+ We also have an interactive Jupyter notebook demonstrating much of the core functionality available [here](https://github.com/ML4GW/ml4gw/blob/main/docs/tutorials/ml4gw_tutorial.ipynb).
13
+ To run this notebook, download it from the above link and follow the instructions within it to install the required packages.
14
+ See also the [documentation page](https://ml4gw.github.io/ml4gw/tutorials/ml4gw_tutorial.html) for the tutorial to look
15
+ through it without running the code.
13
16
 
14
17
  ## Installation
15
18
  ### Pip installation
@@ -25,9 +28,25 @@ To build with a specific version of PyTorch/CUDA, please see the PyTorch install
25
28
  pip install ml4gw torch==2.5.1--extra-index-url=https://download.pytorch.org/whl/cu118
26
29
  ```
27
30
 
31
+ ### uv installation
32
+ If you want to develop `ml4gw`, you can use [uv](https://docs.astral.sh/uv/getting-started/installation/) to install the project in editable mode.
33
+ For example, after cloning the repository, create a virtualenv using
34
+ ```bash
35
+ uv venv --python=3.11
36
+ ```
37
+ Then sync the dependencies from the [uv lock file](/uv.lock) using
38
+ ```bash
39
+ uv sync --all-extras
40
+ ```
41
+ Code changes can be tested using
42
+ ```bash
43
+ uv run pytest
44
+ ```
45
+ See [contribution guide](/CONTRIBUTING.md) for more details.
46
+
28
47
  ## Contributing
29
48
  If you come across errors in the code, have difficulties using this software, or simply find that the current version doesn't cover your use case, please file an issue on our GitHub page, and we'll be happy to offer support.
30
- We encourage users who encounter these difficulties to file issues on GitHub, and we'll be happy to offer support to extend our coverage to new or improved functionality.
49
+ If you want to add feature, please refer to the [contribution guide](/CONTRIBUTING.md) for more details.
31
50
  We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
32
51
  For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu).
33
52
  By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool that makes deep learning more accessible for gravitational wave physicists everywhere.
@@ -1,3 +1,8 @@
1
+ """
2
+ This module contains transformations that may be useful
3
+ for augmenting timeseries data during training
4
+ """
5
+
1
6
  import torch
2
7
  from jaxtyping import Float
3
8
  from torch import Tensor
@@ -1,3 +1,8 @@
1
+ """
2
+ This module contains tools for efficient in-memory and
3
+ out-of-memory dataloading.
4
+ """
5
+
1
6
  from .chunked_dataset import ChunkedTimeSeriesDataset
2
7
  from .hdf5_dataset import Hdf5TimeSeriesDataset
3
8
  from .in_memory_dataset import InMemoryDataset
@@ -94,10 +94,8 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
94
94
  # flatten it to make it easier to slice
95
95
  if chunk_size < self.kernel_size:
96
96
  raise ValueError(
97
- (
98
- "Can't sample kernels of size {} from chunk "
99
- "with size {}"
100
- ).format(self.kernel_size, chunk_size)
97
+ f"Can't sample kernels of size {self.kernel_size} from "
98
+ f"chunk with size {chunk_size}"
101
99
  )
102
100
  chunk = chunk.reshape(-1)
103
101
 
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from typing import Optional, Sequence, Union
2
+ from collections.abc import Sequence
3
3
 
4
4
  import h5py
5
5
  import numpy as np
@@ -63,13 +63,13 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
63
63
  kernel_size: int,
64
64
  batch_size: int,
65
65
  batches_per_epoch: int,
66
- coincident: Union[bool, str],
67
- num_files_per_batch: Optional[int] = None,
66
+ coincident: bool | str,
67
+ num_files_per_batch: int | None = None,
68
68
  ) -> None:
69
69
  if not isinstance(coincident, bool) and coincident != "files":
70
70
  raise ValueError(
71
71
  "coincident must be either a boolean or 'files', "
72
- "got unrecognized value {}".format(coincident)
72
+ f"got unrecognized value {coincident}"
73
73
  )
74
74
 
75
75
  self.fnames = np.array(fnames)
@@ -94,13 +94,11 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
94
94
  dset = f[channels[0]]
95
95
  if dset.chunks is None:
96
96
  warnings.warn(
97
- "File {} contains datasets that were generated "
97
+ f"File {fname} contains datasets that were generated "
98
98
  "without using chunked storage. This can have "
99
99
  "severe performance impacts at data loading time. "
100
100
  "If you need faster loading, try re-generating "
101
- "your dataset with chunked storage turned on.".format(
102
- fname
103
- ),
101
+ "your dataset with chunked storage turned on.",
104
102
  category=ContiguousHdf5Warning,
105
103
  stacklevel=2,
106
104
  )
@@ -153,7 +151,9 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
153
151
  unique_fnames, inv, counts = np.unique(
154
152
  fnames, return_inverse=True, return_counts=True
155
153
  )
156
- for i, (fname, count) in enumerate(zip(unique_fnames, counts)):
154
+ for i, (fname, count) in enumerate(
155
+ zip(unique_fnames, counts, strict=True)
156
+ ):
157
157
  size = self.sizes[fname]
158
158
  max_idx = size - self.kernel_size
159
159
 
@@ -185,7 +185,9 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
185
185
  # open the file and sample a different set of
186
186
  # kernels for each batch element it occupies
187
187
  with h5py.File(fname, "r") as f:
188
- for b, c, i in zip(batch_indices, channel_indices, idx):
188
+ for b, c, i in zip(
189
+ batch_indices, channel_indices, idx, strict=True
190
+ ):
189
191
  x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
190
192
  return torch.Tensor(x)
191
193
 
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- from typing import Optional, Tuple, Union
3
2
 
4
3
  import torch
5
4
  from jaxtyping import Float
@@ -79,10 +78,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
79
78
  self,
80
79
  X: Float[Tensor, "channels time"],
81
80
  kernel_size: int,
82
- y: Optional[Float[Tensor, " time"]] = None,
81
+ y: Float[Tensor, " time"] | None = None,
83
82
  batch_size: int = 32,
84
83
  stride: int = 1,
85
- batches_per_epoch: Optional[int] = None,
84
+ batches_per_epoch: int | None = None,
86
85
  coincident: bool = True,
87
86
  shuffle: bool = True,
88
87
  device: str = "cpu",
@@ -122,10 +121,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
122
121
  batch_size * batches_per_epoch
123
122
  ):
124
123
  raise ValueError(
125
- "Number of kernels {} in timeseries insufficient "
126
- "to generate {} batches of size {}".format(
127
- self.num_kernels, batch_size, batches_per_epoch
128
- )
124
+ f"Number of kernels {self.num_kernels} in timeseries "
125
+ f"insufficient to generate {batch_size} batches of size "
126
+ f"{batches_per_epoch}"
129
127
  )
130
128
 
131
129
  self.batch_size = batch_size
@@ -191,7 +189,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
191
189
  # indices we'll need rather than having to generate
192
190
  # everything.
193
191
  idx = [range(self.num_kernels) for _ in range(len(self.X))]
194
- idx = zip(range(num_kernels), itertools.product(*idx))
192
+ idx = zip(
193
+ range(num_kernels), itertools.product(*idx), strict=False
194
+ )
195
195
  idx = torch.stack([torch.Tensor(i[1]) for i in idx])
196
196
  idx = idx.type(torch.int64).to(device)
197
197
  elif self.shuffle:
@@ -208,10 +208,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
208
208
 
209
209
  def __iter__(
210
210
  self,
211
- ) -> Union[
212
- Float[Tensor, "batch channel time"],
213
- Tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]],
214
- ]:
211
+ ) -> (
212
+ Float[Tensor, "batch channel time"]
213
+ | tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]]
214
+ ):
215
215
  indices = self.init_indices()
216
216
  for i in range(len(self)):
217
217
  # slice the array of _indices_ we'll be using to
@@ -6,7 +6,7 @@ from the corresponding distribution.
6
6
  """
7
7
 
8
8
  import math
9
- from typing import Callable, Optional
9
+ from collections.abc import Callable
10
10
 
11
11
  import torch
12
12
  import torch.distributions as dist
@@ -104,7 +104,7 @@ class LogNormal(dist.LogNormal):
104
104
  self,
105
105
  mean: float,
106
106
  std: float,
107
- low: Optional[float] = None,
107
+ low: float | None = None,
108
108
  validate_args=None,
109
109
  ):
110
110
  self.low = low
@@ -137,7 +137,7 @@ class PowerLaw(dist.TransformedDistribution):
137
137
  support = dist.constraints.nonnegative
138
138
 
139
139
  def __init__(
140
- self, minimum: float, maximum: float, index: int, validate_args=None
140
+ self, minimum: float, maximum: float, index: float, validate_args=None
141
141
  ):
142
142
  if index == 0:
143
143
  raise ValueError("Index of 0 is the same as Uniform")
@@ -1,6 +1,7 @@
1
1
  """
2
- Tools for manipulating raw gravitational waveforms
3
- and projecting them onto interferometer responses.
2
+ Tools for manipulating raw gravitational waveforms,
3
+ projecting them onto interferometer responses, and
4
+ calculating SNRs.
4
5
  Much of the projection code is an extension of the
5
6
  implementation made available in
6
7
  `bilby <https://arxiv.org/abs/1811.02042>`_.
@@ -8,8 +9,6 @@ Specifically code from
8
9
  `this module <https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py>`_.
9
10
  """ # noqa E501
10
11
 
11
- from typing import List, Tuple, Union
12
-
13
12
  import torch
14
13
  from jaxtyping import Float
15
14
  from torch import Tensor
@@ -58,7 +57,7 @@ def compute_antenna_responses(
58
57
  psi: BatchTensor,
59
58
  phi: BatchTensor,
60
59
  detector_tensors: NetworkDetectorTensors,
61
- modes: List[str],
60
+ modes: list[str],
62
61
  ) -> Float[Tensor, "batch polarizations num_ifos"]:
63
62
  """
64
63
  Compute the antenna pattern factors of a batch of
@@ -257,7 +256,7 @@ def compute_observed_strain(
257
256
 
258
257
  def get_ifo_geometry(
259
258
  *ifos: str,
260
- ) -> Tuple[NetworkDetectorTensors, NetworkVertices]:
259
+ ) -> tuple[NetworkDetectorTensors, NetworkVertices]:
261
260
  """
262
261
  For a given list of interferometer names, retrieve and
263
262
  concatenate the associated detector tensors and vertices
@@ -286,8 +285,8 @@ def compute_ifo_snr(
286
285
  responses: WaveformTensor,
287
286
  psd: PSDTensor,
288
287
  sample_rate: float,
289
- highpass: Union[float, Float[Tensor, " frequency"], None] = None,
290
- lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
288
+ highpass: float | Float[Tensor, " frequency"] | None = None,
289
+ lowpass: float | Float[Tensor, " frequency"] | None = None,
291
290
  ) -> Float[Tensor, "batch num_ifos"]:
292
291
  """Compute the SNRs of a batch of interferometer responses
293
292
 
@@ -367,10 +366,9 @@ def compute_ifo_snr(
367
366
  highpass = freqs >= highpass
368
367
  elif len(highpass) != integrand.shape[-1]:
369
368
  raise ValueError(
370
- "Can't apply highpass filter mask with {} frequency bins"
371
- "to signal fft with {} frequency bins".format(
372
- len(highpass), integrand.shape[-1]
373
- )
369
+ f"Can't apply highpass filter mask with {len(highpass)} "
370
+ f"frequency bins to signal fft with {integrand.shape[-1]} "
371
+ "frequency bins"
374
372
  )
375
373
  integrand *= highpass.to(integrand.device)
376
374
  if lowpass is not None:
@@ -379,10 +377,9 @@ def compute_ifo_snr(
379
377
  lowpass = freqs < lowpass
380
378
  elif len(lowpass) != integrand.shape[-1]:
381
379
  raise ValueError(
382
- "Can't apply lowpass filter mask with {} frequency bins"
383
- "to signal fft with {} frequency bins".format(
384
- len(lowpass), integrand.shape[-1]
385
- )
380
+ f"Can't apply lowpass filter mask with {len(lowpass)} "
381
+ f"frequency bins to signal fft with {integrand.shape[-1]} "
382
+ "frequency bins"
386
383
  )
387
384
  integrand *= lowpass.to(integrand.device)
388
385
 
@@ -410,8 +407,8 @@ def compute_network_snr(
410
407
  responses: WaveformTensor,
411
408
  psd: PSDTensor,
412
409
  sample_rate: float,
413
- highpass: Union[float, Float[Tensor, " frequency"], None] = None,
414
- lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
410
+ highpass: float | Float[Tensor, " frequency"] | None = None,
411
+ lowpass: float | Float[Tensor, " frequency"] | None = None,
415
412
  ) -> BatchTensor:
416
413
  """
417
414
  Compute the total SNR from a gravitational waveform
@@ -467,11 +464,11 @@ def compute_network_snr(
467
464
 
468
465
  def reweight_snrs(
469
466
  responses: WaveformTensor,
470
- target_snrs: Union[float, BatchTensor],
467
+ target_snrs: float | BatchTensor,
471
468
  psd: PSDTensor,
472
469
  sample_rate: float,
473
- highpass: Union[float, Float[Tensor, " frequency"], None] = None,
474
- lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
470
+ highpass: float | Float[Tensor, " frequency"] | None = None,
471
+ lowpass: float | Float[Tensor, " frequency"] | None = None,
475
472
  ) -> WaveformTensor:
476
473
  """Scale interferometer responses such that they have a desired SNR
477
474
 
@@ -0,0 +1,6 @@
1
+ """
2
+ This module contains neural network architectures and
3
+ architecture components. These can be a good place
4
+ to get started, rather than defining your own
5
+ architecture from the start.
6
+ """
@@ -1,5 +1,4 @@
1
1
  from collections.abc import Sequence
2
- from typing import Optional, Tuple, Union
3
2
 
4
3
  import torch
5
4
  from torch import Tensor
@@ -28,16 +27,14 @@ class Autoencoder(torch.nn.Module):
28
27
  and how they operate.
29
28
  """
30
29
 
31
- def __init__(
32
- self, skip_connection: Optional[SkipConnection] = None
33
- ) -> None:
30
+ def __init__(self, skip_connection: SkipConnection | None = None) -> None:
34
31
  super().__init__()
35
32
  self.skip_connection = skip_connection
36
33
  self.blocks = torch.nn.ModuleList()
37
34
 
38
35
  def encode(
39
36
  self, *X: Tensor, return_states: bool = False
40
- ) -> Union[Tensor, Tuple[Tensor, Sequence]]:
37
+ ) -> Tensor | tuple[Tensor, Sequence]:
41
38
  states = []
42
39
  for block in self.blocks:
43
40
  if isinstance(X, tuple):
@@ -53,7 +50,7 @@ class Autoencoder(torch.nn.Module):
53
50
  return X, states[:-1]
54
51
  return X
55
52
 
56
- def decode(self, *X, states: Optional[Sequence[Tensor]] = None) -> Tensor:
53
+ def decode(self, *X, states: Sequence[Tensor] | None = None) -> Tensor:
57
54
  if self.skip_connection is not None and states is None:
58
55
  raise ValueError(
59
56
  "Must pass intermediate states when autoencoder "
@@ -62,9 +59,8 @@ class Autoencoder(torch.nn.Module):
62
59
  elif states is not None:
63
60
  if len(states) != len(self.blocks) - 1:
64
61
  raise ValueError(
65
- "Passed {} intermediate states, expected {}".format(
66
- len(states), len(self.blocks) - 1
67
- )
62
+ f"Passed {len(states)} intermediate states, expected "
63
+ f"{len(self.blocks) - 1}"
68
64
  )
69
65
 
70
66
  # Don't skip connect the output layer
@@ -1,5 +1,4 @@
1
1
  from collections.abc import Callable, Sequence
2
- from typing import Optional
3
2
 
4
3
  import torch
5
4
  from torch import Tensor
@@ -21,9 +20,9 @@ class ConvBlock(Autoencoder):
21
20
  groups: int = 1,
22
21
  activation: torch.nn.Module = torch.nn.ReLU,
23
22
  norm: Module = torch.nn.BatchNorm1d,
24
- decode_channels: Optional[int] = None,
25
- output_activation: Optional[torch.nn.Module] = None,
26
- skip_connection: Optional[SkipConnection] = None,
23
+ decode_channels: int | None = None,
24
+ output_activation: torch.nn.Module | None = None,
25
+ skip_connection: SkipConnection | None = None,
27
26
  ) -> None:
28
27
  super().__init__(skip_connection=None)
29
28
 
@@ -98,10 +97,10 @@ class ConvolutionalAutoencoder(Autoencoder):
98
97
  stride: int = 1,
99
98
  groups: int = 1,
100
99
  activation: torch.nn.Module = torch.nn.ReLU,
101
- output_activation: Optional[torch.nn.Module] = None,
100
+ output_activation: torch.nn.Module | None = None,
102
101
  norm: Module = torch.nn.BatchNorm1d,
103
- decode_channels: Optional[int] = None,
104
- skip_connection: Optional[SkipConnection] = None,
102
+ decode_channels: int | None = None,
103
+ skip_connection: SkipConnection | None = None,
105
104
  ) -> None:
106
105
  # TODO: how to do this dynamically? Maybe the base
107
106
  # architecture looks for overlapping arguments between
@@ -145,9 +144,7 @@ class ConvolutionalAutoencoder(Autoencoder):
145
144
  self.blocks.append(block)
146
145
  in_channels = channels * groups
147
146
 
148
- def decode(
149
- self, *X, states=None, input_size: Optional[int] = None
150
- ) -> Tensor:
147
+ def decode(self, *X, states=None, input_size: int | None = None) -> Tensor:
151
148
  X = super().decode(*X, states=states)
152
149
  if input_size is not None:
153
150
  return match_size(X, input_size)
@@ -35,13 +35,11 @@ class ConcatSkipConnect(SkipConnection):
35
35
  rem = num_channels % self.groups
36
36
  if rem:
37
37
  raise ValueError(
38
- "Number of channels in input tensor {} cannot "
39
- "be divided evenly into {} groups".format(
40
- num_channels, self.groups
41
- )
38
+ f"Number of channels in input tensor {num_channels} cannot "
39
+ f"be divided evenly into {self.groups} groups"
42
40
  )
43
41
 
44
42
  X = torch.split(X, self.groups, dim=1)
45
43
  state = torch.split(state, self.groups, dim=1)
46
- frags = [i for j in zip(X, state) for i in j]
44
+ frags = [i for j in zip(X, state, strict=True) for i in j]
47
45
  return torch.cat(frags, dim=1)
@@ -1,4 +1,4 @@
1
- from typing import Callable, Optional
1
+ from collections.abc import Callable
2
2
 
3
3
  import torch
4
4
  from jaxtyping import Float
@@ -16,7 +16,7 @@ class GroupNorm1D(torch.nn.Module):
16
16
  def __init__(
17
17
  self,
18
18
  num_channels: int,
19
- num_groups: Optional[int] = None,
19
+ num_groups: int | None = None,
20
20
  eps: float = 1e-5,
21
21
  ):
22
22
  super().__init__()
@@ -77,7 +77,7 @@ class GroupNorm1DGetter:
77
77
  for command-line parameterization with jsonargparse.
78
78
  """
79
79
 
80
- def __init__(self, groups: Optional[int] = None) -> None:
80
+ def __init__(self, groups: int | None = None) -> None:
81
81
  self.groups = groups
82
82
 
83
83
  def __call__(self, num_channels: int) -> torch.nn.Module:
@@ -96,7 +96,7 @@ class GroupNorm2DGetter:
96
96
  for command-line parameterization with jsonargparse.
97
97
  """
98
98
 
99
- def __init__(self, groups: Optional[int] = None) -> None:
99
+ def __init__(self, groups: int | None = None) -> None:
100
100
  self.groups = groups
101
101
 
102
102
  def __call__(self, num_channels: int) -> torch.nn.Module:
@@ -7,7 +7,8 @@ where training-time statistics are entirely arbitrary due to
7
7
  simulations.
8
8
  """
9
9
 
10
- from typing import Callable, List, Literal, Optional
10
+ from collections.abc import Callable
11
+ from typing import Literal
11
12
 
12
13
  import torch
13
14
  import torch.nn as nn
@@ -58,11 +59,11 @@ class BasicBlock(nn.Module):
58
59
  planes: int,
59
60
  kernel_size: int = 3,
60
61
  stride: int = 1,
61
- downsample: Optional[nn.Module] = None,
62
+ downsample: nn.Module | None = None,
62
63
  groups: int = 1,
63
64
  base_width: int = 64,
64
65
  dilation: int = 1,
65
- norm_layer: Optional[Callable[..., nn.Module]] = None,
66
+ norm_layer: Callable[..., nn.Module] | None = None,
66
67
  ) -> None:
67
68
  super().__init__()
68
69
  if norm_layer is None:
@@ -123,11 +124,11 @@ class Bottleneck(nn.Module):
123
124
  planes: int,
124
125
  kernel_size: int = 3,
125
126
  stride: int = 1,
126
- downsample: Optional[nn.Module] = None,
127
+ downsample: nn.Module | None = None,
127
128
  groups: int = 1,
128
129
  base_width: int = 64,
129
130
  dilation: int = 1,
130
- norm_layer: Optional[NormLayer] = None,
131
+ norm_layer: NormLayer | None = None,
131
132
  ) -> None:
132
133
  super().__init__()
133
134
  if norm_layer is None:
@@ -231,14 +232,14 @@ class ResNet1D(nn.Module):
231
232
  def __init__(
232
233
  self,
233
234
  in_channels: int,
234
- layers: List[int],
235
+ layers: list[int],
235
236
  classes: int,
236
237
  kernel_size: int = 3,
237
238
  zero_init_residual: bool = False,
238
239
  groups: int = 1,
239
240
  width_per_group: int = 64,
240
- stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
241
- norm_layer: Optional[NormLayer] = None,
241
+ stride_type: list[Literal["stride", "dilation"]] | None = None,
242
+ norm_layer: NormLayer | None = None,
242
243
  ) -> None:
243
244
  super().__init__()
244
245
 
@@ -257,10 +258,8 @@ class ResNet1D(nn.Module):
257
258
  stride_type = ["stride"] * (len(layers) - 1)
258
259
  if len(stride_type) != (len(layers) - 1):
259
260
  raise ValueError(
260
- (
261
- "'stride_type' should be None or a {}-element "
262
- "tuple, got {}"
263
- ).format(len(layers) - 1, stride_type)
261
+ f"'stride_type' should be None or a {len(layers) - 1}-element "
262
+ f"tuple, got {stride_type}"
264
263
  )
265
264
 
266
265
  self.groups = groups
@@ -289,7 +288,7 @@ class ResNet1D(nn.Module):
289
288
  # striding or dilating depending on the stride_type
290
289
  # argument)
291
290
  residual_layers = [self._make_layer(64, layers[0], kernel_size)]
292
- it = zip(layers[1:], stride_type)
291
+ it = zip(layers[1:], stride_type, strict=True)
293
292
  for i, (num_blocks, stride) in enumerate(it):
294
293
  block_size = 64 * 2 ** (i + 1)
295
294
  layer = self._make_layer(