ml4gw 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/__init__.py +1 -0
- ml4gw/dataloading/chunked_dataset.py +1 -1
- ml4gw/dataloading/hdf5_dataset.py +36 -6
- ml4gw/dataloading/in_memory_dataset.py +1 -1
- ml4gw/gw.py +4 -3
- ml4gw/nn/autoencoder/base.py +1 -1
- ml4gw/nn/autoencoder/convolutional.py +3 -3
- ml4gw/nn/autoencoder/skip_connection.py +1 -1
- ml4gw/nn/resnet/resnet_1d.py +1 -1
- ml4gw/nn/resnet/resnet_2d.py +1 -1
- ml4gw/nn/streaming/online_average.py +1 -1
- ml4gw/nn/streaming/snapshotter.py +1 -1
- ml4gw/spectral.py +24 -6
- ml4gw/transforms/__init__.py +1 -0
- ml4gw/transforms/iirfilter.py +100 -0
- ml4gw/transforms/pearson.py +2 -2
- ml4gw/transforms/qtransform.py +2 -2
- ml4gw/transforms/scaler.py +1 -1
- ml4gw/transforms/snr_rescaler.py +3 -3
- ml4gw/transforms/spectral.py +2 -2
- ml4gw/transforms/spectrogram.py +1 -1
- ml4gw/transforms/transform.py +2 -2
- ml4gw/transforms/waveforms.py +2 -2
- ml4gw/transforms/whitening.py +19 -4
- ml4gw/utils/slicing.py +1 -6
- ml4gw/waveforms/cbc/coefficients.py +35 -0
- ml4gw/waveforms/cbc/phenom_d.py +3 -3
- ml4gw/waveforms/cbc/phenom_p.py +1 -0
- ml4gw/waveforms/cbc/taylorf2.py +5 -4
- ml4gw/waveforms/cbc/utils.py +111 -0
- ml4gw/waveforms/conversion.py +2 -2
- ml4gw/waveforms/generator.py +289 -26
- ml4gw-0.7.1.dist-info/LICENSE +674 -0
- ml4gw-0.7.1.dist-info/METADATA +78 -0
- ml4gw-0.7.1.dist-info/RECORD +55 -0
- {ml4gw-0.6.3.dist-info → ml4gw-0.7.1.dist-info}/WHEEL +1 -1
- ml4gw-0.6.3.dist-info/METADATA +0 -154
- ml4gw-0.6.3.dist-info/RECORD +0 -51
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: ml4gw
|
|
3
|
+
Version: 0.7.1
|
|
4
|
+
Summary: Tools for training torch models on gravitational wave data
|
|
5
|
+
Author: Alec Gunny
|
|
6
|
+
Author-email: alec.gunny@ligo.org
|
|
7
|
+
Requires-Python: >=3.9,<3.13
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Requires-Dist: jaxtyping (>=0.2,<0.3)
|
|
14
|
+
Requires-Dist: numpy (<2.0.0)
|
|
15
|
+
Requires-Dist: torch (>=2.0,<3.0)
|
|
16
|
+
Requires-Dist: torchaudio (>=2.0,<3.0)
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
|
|
19
|
+
# ML4GW
|
|
20
|
+

|
|
21
|
+

|
|
22
|
+

|
|
23
|
+

|
|
24
|
+

|
|
25
|
+
|
|
26
|
+
Torch utilities for training neural networks in gravitational wave physics applications.
|
|
27
|
+
|
|
28
|
+
## Documentation
|
|
29
|
+
Please visit our [documentation page](https://ml4gw.github.io/ml4gw/) to see descriptions and examples of the functions and modules available in `ml4gw`.
|
|
30
|
+
We also have an interactive Jupyter notebook that demonstrates much of the core functionality available in the `examples` directory.
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
### Pip installation
|
|
34
|
+
You can install `ml4gw` with pip:
|
|
35
|
+
|
|
36
|
+
```console
|
|
37
|
+
pip install ml4gw
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
To build with a specific version of PyTorch/CUDA, please see the PyTorch installation instructions [here](https://pytorch.org/get-started/previous-versions/) to see how to specify the desired torch version and `--extra-index-url` flag. For example, to install with torch 2.5.1 and CUDA 11.8 support, you would run
|
|
41
|
+
|
|
42
|
+
```console
|
|
43
|
+
pip install ml4gw torch==2.5.1--extra-index-url=https://download.pytorch.org/whl/cu118
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
### Poetry installation
|
|
47
|
+
`ml4gw` is also fully compatible with use in Poetry, with your `pyproject.toml` set up like
|
|
48
|
+
|
|
49
|
+
```toml
|
|
50
|
+
[tool.poetry.dependencies]
|
|
51
|
+
python = "^3.9" # python versions 3.9-3.12 are supported
|
|
52
|
+
ml4gw = "^0.6"
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
|
|
56
|
+
|
|
57
|
+
```toml
|
|
58
|
+
[tool.poetry.dependencies]
|
|
59
|
+
python = "^3.9"
|
|
60
|
+
ml4gw = "^0.6"
|
|
61
|
+
torch = {version = "^2.0", source = "torch"}
|
|
62
|
+
|
|
63
|
+
[[tool.poetry.source]]
|
|
64
|
+
name = "torch"
|
|
65
|
+
url = "https://download.pytorch.org/whl/cu118"
|
|
66
|
+
priority = "explicit"
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## Contributing
|
|
70
|
+
If you come across errors in the code, have difficulties using this software, or simply find that the current version doesn't cover your use case, please file an issue on our GitHub page, and we'll be happy to offer support.
|
|
71
|
+
We encourage users who encounter these difficulties to file issues on GitHub, and we'll be happy to offer support to extend our coverage to new or improved functionality.
|
|
72
|
+
We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
|
|
73
|
+
For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu).
|
|
74
|
+
By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool that makes deep learning more accessible for gravitational wave physicists everywhere.
|
|
75
|
+
|
|
76
|
+
## Funding
|
|
77
|
+
We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
|
|
78
|
+
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
ml4gw/__init__.py,sha256=81quoggCuIypZjZs3bbf1Ty70KHdva5RGEJxi0oC57E,25
|
|
2
|
+
ml4gw/augmentations.py,sha256=pZH9tjEpXV0AIqvHHDkpUE-BorG02beOz2pmSipw2EY,1232
|
|
3
|
+
ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
|
|
4
|
+
ml4gw/dataloading/__init__.py,sha256=EHBBqU7y2-Np5iQ_xyufxamUEM1pPEquqFo7oaJnaJE,149
|
|
5
|
+
ml4gw/dataloading/chunked_dataset.py,sha256=j96Rd67cRpsvotR_dzgfbrqxcoGDWnTV5cmfN038cb8,5256
|
|
6
|
+
ml4gw/dataloading/hdf5_dataset.py,sha256=bVcXzS1LHVj7zMeMtRkxx1Q76MQS6wEApJJlUAI6iC8,7879
|
|
7
|
+
ml4gw/dataloading/in_memory_dataset.py,sha256=1oUchfNBq3rx1NgNqrcg6AGdJ-dvm56o-TGFwPn5wm8,9546
|
|
8
|
+
ml4gw/distributions.py,sha256=tUuaOiX5enjKLYWD7uiN8rdRVQcrIKps64xBkTl8fMs,4991
|
|
9
|
+
ml4gw/gw.py,sha256=aUPSXgwyqJUBGGaKtUa-O3qkSbRYZwhhXIlkhvjgJgI,17684
|
|
10
|
+
ml4gw/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
ml4gw/nn/autoencoder/__init__.py,sha256=ZaT1XhJTHpMuPQqu5E__Jezeh9uwtjcXlT7IZ18byq4,161
|
|
12
|
+
ml4gw/nn/autoencoder/base.py,sha256=eSWrDdpblI609oqa7RDSvZiY3YcV8WfhTioWKFn_7eE,3205
|
|
13
|
+
ml4gw/nn/autoencoder/convolutional.py,sha256=VCemfNtzleIaKZHtzfxFDhYBLHdHEODy4-LRA9GiDhY,5359
|
|
14
|
+
ml4gw/nn/autoencoder/skip_connection.py,sha256=9PKoCCvCUj5di9tuFM0Cl1v6gtcOK1bDeE_fS_R__FE,1391
|
|
15
|
+
ml4gw/nn/autoencoder/utils.py,sha256=m_ivYGNwdrhA7cFxJVD4gqM8AHiWIGmlQI3pFNRklXQ,355
|
|
16
|
+
ml4gw/nn/norm.py,sha256=JIOMXQbUtoWlrhncGsqW6f1-DiGDx9zQH2O3CvQml3U,3594
|
|
17
|
+
ml4gw/nn/resnet/__init__.py,sha256=vBI0IftVP_EYAeDlqomtkGqUYE-RE_S4WNioUhniw9s,64
|
|
18
|
+
ml4gw/nn/resnet/resnet_1d.py,sha256=C0H-GuY3-bradnSvUNtkD-o8j3-3uQDhUK4DKbOOrzk,13211
|
|
19
|
+
ml4gw/nn/resnet/resnet_2d.py,sha256=fVzYRuO0xR9yGjjQExv30mouokvupOAW-Kfdbs5WYDA,13294
|
|
20
|
+
ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
|
|
21
|
+
ml4gw/nn/streaming/online_average.py,sha256=_nrul4ygTC_ln4wpSWGRWTgWlfGeOUGXxeGrhU4oJms,4716
|
|
22
|
+
ml4gw/nn/streaming/snapshotter.py,sha256=1vWDpebRQBZIUVeksbXoqngqMnlSzQFkcsgYNrHB9tc,4473
|
|
23
|
+
ml4gw/spectral.py,sha256=lLpnho02i-0zPSi96b0xOPEIgQMnBrmO8JiV1KvPGEw,19811
|
|
24
|
+
ml4gw/transforms/__init__.py,sha256=OaTQJD4GFkDkcxt0DIwt2AzeEcv9t21ciKXxQnqDiuI,447
|
|
25
|
+
ml4gw/transforms/iirfilter.py,sha256=RwgC3DWgYmBnHe7bYjvr9njM1WrRZ9EjBJsZNmtOY8s,3186
|
|
26
|
+
ml4gw/transforms/pearson.py,sha256=CM9FTRxI4384-36FIaJFOcMZwsA7BkgberToJkMU1PA,3227
|
|
27
|
+
ml4gw/transforms/qtransform.py,sha256=5S9y3PxkOmqMAarQmme0Tiy58vRvberpqhg6IeyDJLI,20675
|
|
28
|
+
ml4gw/transforms/scaler.py,sha256=K5mp4w2zGZbpH1AcBUfpQS4n3aVSNzkaGWXedwk2LXs,2508
|
|
29
|
+
ml4gw/transforms/snr_rescaler.py,sha256=XHKTeJXM3F_VOmjWOZetQuVZ6PMum8pEBPaOVbS16-w,2327
|
|
30
|
+
ml4gw/transforms/spectral.py,sha256=4uCLNEcDff4kLheUA5v64L0y_MSOvUTJ92IH4TVcEys,4385
|
|
31
|
+
ml4gw/transforms/spectrogram.py,sha256=8HDStoup7vlwpw9qTKshAuEpa85-lw5_SwYxjxxu1sQ,6158
|
|
32
|
+
ml4gw/transforms/spline_interpolation.py,sha256=oXih-gLMbIbI36DPKLTk39WcjiNUJN_rcQia_k3OrMY,13527
|
|
33
|
+
ml4gw/transforms/transform.py,sha256=lu5ukcOCOYYZDZCM_0amS9AY2bJgkbLpXmZ9DpnSK9I,2504
|
|
34
|
+
ml4gw/transforms/waveforms.py,sha256=koWOuHuUpQWmTT1yawSWa_MOuLfDBuugy91KIyuklOo,3189
|
|
35
|
+
ml4gw/transforms/whitening.py,sha256=8ADmM52lrHt_2yvjX51x0bFxAloKbS7s2owJgrVD5uc,10294
|
|
36
|
+
ml4gw/types.py,sha256=CcctqDcNajR7khGT6BD-WYsfRKpiP0udoSAB0k1qcFw,863
|
|
37
|
+
ml4gw/utils/interferometer.py,sha256=lRS0N3SwUTknhYXX57VACJ99jK1P9M19oUWN_i_nQN0,1814
|
|
38
|
+
ml4gw/utils/slicing.py,sha256=V9tbEzHnukg16-e8jdIFsZIZ1oICF9zBE2sjUsBXW-s,13538
|
|
39
|
+
ml4gw/waveforms/__init__.py,sha256=QVUzBx_y8A9_AsRuTJruPvL9mqGnBt11Iw1MOYjXyE4,40
|
|
40
|
+
ml4gw/waveforms/adhoc/__init__.py,sha256=XVwP4t8TMUj87WY3yMGRTkXsv7_lVr1w8p8iKBW8iKE,71
|
|
41
|
+
ml4gw/waveforms/adhoc/ringdown.py,sha256=m8IBQTxKBBGFqBtWGEO4KG3DEYR8TTnNyGVdVLaMKa8,3316
|
|
42
|
+
ml4gw/waveforms/adhoc/sine_gaussian.py,sha256=-MtrI7ydwBTk4K0O4tdkC8-w5OifQszdnWN9__I4XzY,3569
|
|
43
|
+
ml4gw/waveforms/cbc/__init__.py,sha256=hGbPsFNAIveYJnff8qKY8RWeBPFtZoYcnGHxraPWtWI,99
|
|
44
|
+
ml4gw/waveforms/cbc/coefficients.py,sha256=PMr0IBALEQ38eAvZqYg-w_FE_sS1mH2FWr9soQ5MRfU,1106
|
|
45
|
+
ml4gw/waveforms/cbc/phenom_d.py,sha256=b586PbpBGAA1DO55X0D35_dAJXIGVwUBrNhmPgCBbwU,48983
|
|
46
|
+
ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9n25he9U,53447
|
|
47
|
+
ml4gw/waveforms/cbc/phenom_p.py,sha256=LBtGVUjBjROcYBPLldFnF6T1jZV6ZyuZEnkn9-oTKpQ,27620
|
|
48
|
+
ml4gw/waveforms/cbc/taylorf2.py,sha256=2ga_lG_xkYOsF-BdxgjbU0pgLDjeAO0p5IWuCPvibvQ,10504
|
|
49
|
+
ml4gw/waveforms/cbc/utils.py,sha256=CvZ79PQygb-zwulMV-wRuBcGEsHbVOtJz60UnOJFKoM,3051
|
|
50
|
+
ml4gw/waveforms/conversion.py,sha256=MyADWEZVoEkRkKaHk1ZuQDsGfPYx5xUTtyApj5P3ueQ,6895
|
|
51
|
+
ml4gw/waveforms/generator.py,sha256=i2lgaJzH5eA6gzc-bLQZYYEgEQ8OBLJgE9yNXU3FsKM,12005
|
|
52
|
+
ml4gw-0.7.1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
53
|
+
ml4gw-0.7.1.dist-info/METADATA,sha256=xEVSE7PX32I8b4YIneUVVvTAHLS4WemuQ8bpCKskIXE,3904
|
|
54
|
+
ml4gw-0.7.1.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
|
|
55
|
+
ml4gw-0.7.1.dist-info/RECORD,,
|
ml4gw-0.6.3.dist-info/METADATA
DELETED
|
@@ -1,154 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.3
|
|
2
|
-
Name: ml4gw
|
|
3
|
-
Version: 0.6.3
|
|
4
|
-
Summary: Tools for training torch models on gravitational wave data
|
|
5
|
-
Author: Alec Gunny
|
|
6
|
-
Author-email: alec.gunny@ligo.org
|
|
7
|
-
Requires-Python: >=3.9,<3.13
|
|
8
|
-
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
10
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
-
Requires-Dist: jaxtyping (>=0.2,<0.3)
|
|
14
|
-
Requires-Dist: numpy (<2.0.0)
|
|
15
|
-
Requires-Dist: torch (>=2.0,<3.0)
|
|
16
|
-
Requires-Dist: torchaudio (>=2.0,<3.0)
|
|
17
|
-
Description-Content-Type: text/markdown
|
|
18
|
-
|
|
19
|
-
# ML4GW
|
|
20
|
-
|
|
21
|
-
Torch utilities for training neural networks in gravitational wave physics applications.
|
|
22
|
-
|
|
23
|
-
## Installation
|
|
24
|
-
### Pip installation
|
|
25
|
-
You can install `ml4gw` with pip:
|
|
26
|
-
|
|
27
|
-
```console
|
|
28
|
-
pip install ml4gw
|
|
29
|
-
```
|
|
30
|
-
|
|
31
|
-
To build with a specific version of PyTorch/CUDA, please see the PyTorch installation instructions [here](https://pytorch.org/get-started/previous-versions/) to see how to specify the desired torch version and `--extra-index-url` flag. For example, to install with torch 1.12 and CUDA 11.6 support, you would run
|
|
32
|
-
|
|
33
|
-
```console
|
|
34
|
-
pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/whl/cu116
|
|
35
|
-
```
|
|
36
|
-
|
|
37
|
-
### Poetry installation
|
|
38
|
-
`ml4gw` is also fully compatible with use in Poetry, with your `pyproject.toml` set up like
|
|
39
|
-
|
|
40
|
-
```toml
|
|
41
|
-
[tool.poetry.dependencies]
|
|
42
|
-
python = "^3.8" # python versions 3.8-3.11 are supported
|
|
43
|
-
ml4gw = "^0.3.0"
|
|
44
|
-
```
|
|
45
|
-
|
|
46
|
-
To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
|
|
47
|
-
|
|
48
|
-
```toml
|
|
49
|
-
[tool.poetry.dependencies]
|
|
50
|
-
python = "^3.8"
|
|
51
|
-
ml4gw = "^0.3.0"
|
|
52
|
-
torch = {version = "^1.12", source = "torch"}
|
|
53
|
-
|
|
54
|
-
[[tool.poetry.source]]
|
|
55
|
-
name = "torch"
|
|
56
|
-
url = "https://download.pytorch.org/whl/cu116"
|
|
57
|
-
secondary = true
|
|
58
|
-
default = false
|
|
59
|
-
```
|
|
60
|
-
|
|
61
|
-
Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
|
|
62
|
-
|
|
63
|
-
## Use cases
|
|
64
|
-
This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
|
|
65
|
-
|
|
66
|
-
For example, to train a simple autoencoder using a cost function in frequency space, you might do something like:
|
|
67
|
-
|
|
68
|
-
```python
|
|
69
|
-
import numpy as np
|
|
70
|
-
import torch
|
|
71
|
-
from ml4gw.dataloading import InMemoryDataset
|
|
72
|
-
from ml4gw.transforms import SpectralDensity
|
|
73
|
-
|
|
74
|
-
SAMPLE_RATE = 2048
|
|
75
|
-
NUM_IFOS = 2
|
|
76
|
-
DATA_LENGTH = 128
|
|
77
|
-
KERNEL_LENGTH = 4
|
|
78
|
-
DEVICE = "cuda" # or "cpu", wherever you want to run
|
|
79
|
-
|
|
80
|
-
BATCH_SIZE = 32
|
|
81
|
-
LEARNING_RATE = 1e-3
|
|
82
|
-
NUM_EPOCHS = 10
|
|
83
|
-
|
|
84
|
-
dummy_data = np.random.randn(NUM_IFOS, DATA_LENGTH * SAMPLE_RATE)
|
|
85
|
-
|
|
86
|
-
# this will create a dataloader that iterates through your
|
|
87
|
-
# timeseries data sampling 4s long windows of data randomly
|
|
88
|
-
# and non-coincidentally: i.e. the background from each IFO
|
|
89
|
-
# will be sampled independently
|
|
90
|
-
dataset = InMemoryDataset(
|
|
91
|
-
dummy_data,
|
|
92
|
-
kernel_size=KERNEL_LENGTH * SAMPLE_RATE,
|
|
93
|
-
batch_size=BATCH_SIZE,
|
|
94
|
-
batches_per_epoch=50,
|
|
95
|
-
coincident=False,
|
|
96
|
-
shuffle=True,
|
|
97
|
-
device=DEVICE # this will move your dataset to GPU up-front if "cuda"
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
nn = torch.nn.Sequential(
|
|
102
|
-
torch.nn.Conv1d(
|
|
103
|
-
in_channels=2,
|
|
104
|
-
out_channels=8,
|
|
105
|
-
kernel_size=7
|
|
106
|
-
),
|
|
107
|
-
torch.nn.ConvTranspose1d(
|
|
108
|
-
in_channels=8,
|
|
109
|
-
out_channels=2,
|
|
110
|
-
kernel_size=7
|
|
111
|
-
)
|
|
112
|
-
).to(DEVICE)
|
|
113
|
-
|
|
114
|
-
optimizer = torch.optim.Adam(nn.parameters(), lr=LEARNING_RATE)
|
|
115
|
-
|
|
116
|
-
spectral_density = SpectralDensity(SAMPLE_RATE, fftlength=2).to(DEVICE)
|
|
117
|
-
|
|
118
|
-
def loss_function(X, y):
|
|
119
|
-
"""
|
|
120
|
-
MSE in frequency domain. Obviously this doesn't
|
|
121
|
-
give you much on its own, but you can imagine doing
|
|
122
|
-
something like masking to just the bins you care about.
|
|
123
|
-
"""
|
|
124
|
-
X = spectral_density(X)
|
|
125
|
-
y = spectral_density(y)
|
|
126
|
-
return ((X - y)**2).mean()
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
for i in range(NUM_EPOCHS):
|
|
130
|
-
epoch_loss = 0
|
|
131
|
-
for X in dataset:
|
|
132
|
-
optimizer.zero_grad(set_to_none=True)
|
|
133
|
-
assert X.shape == (32, NUM_IFOS, KERNEL_LENGTH * SAMPLE_RATE)
|
|
134
|
-
y = nn(X)
|
|
135
|
-
|
|
136
|
-
loss = loss_function(X, y)
|
|
137
|
-
loss.backward()
|
|
138
|
-
optimizer.step()
|
|
139
|
-
|
|
140
|
-
epoch_loss += loss.item()
|
|
141
|
-
epoch_loss /= len(dataset)
|
|
142
|
-
print(f"Epoch {i + 1}/{NUM_EPOCHS} Loss: {epoch_loss:0.3e}")
|
|
143
|
-
```
|
|
144
|
-
|
|
145
|
-
## Development
|
|
146
|
-
As this library is still very much a work in progress, we anticipate that novel use cases will encounter errors stemming from a lack of robustness.
|
|
147
|
-
We encourage users who encounter these difficulties to file issues on GitHub, and we'll be happy to offer support to extend our coverage to new or improved functionality.
|
|
148
|
-
We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
|
|
149
|
-
For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
|
|
150
|
-
By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
|
|
151
|
-
|
|
152
|
-
## Funding
|
|
153
|
-
We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
|
|
154
|
-
|
ml4gw-0.6.3.dist-info/RECORD
DELETED
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
ml4gw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
ml4gw/augmentations.py,sha256=pZH9tjEpXV0AIqvHHDkpUE-BorG02beOz2pmSipw2EY,1232
|
|
3
|
-
ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
|
|
4
|
-
ml4gw/dataloading/__init__.py,sha256=EHBBqU7y2-Np5iQ_xyufxamUEM1pPEquqFo7oaJnaJE,149
|
|
5
|
-
ml4gw/dataloading/chunked_dataset.py,sha256=FpDc4gFxt-PMyXs5qSWLuTGXMTuS1B-hH8gUOCOGxZk,5260
|
|
6
|
-
ml4gw/dataloading/hdf5_dataset.py,sha256=UB1Eog8l7m4M78Owst7oYQZICb0DRJer9WVLVn4hl_I,6645
|
|
7
|
-
ml4gw/dataloading/in_memory_dataset.py,sha256=kleMA9ABUKA6J0tCdz78tbX9lM6uxVSLhqgHbSa1iWY,9550
|
|
8
|
-
ml4gw/distributions.py,sha256=tUuaOiX5enjKLYWD7uiN8rdRVQcrIKps64xBkTl8fMs,4991
|
|
9
|
-
ml4gw/gw.py,sha256=To_hQz9tUp02ADllGLxFCPsNcfbb-kbvfgGpooxcOII,17693
|
|
10
|
-
ml4gw/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
ml4gw/nn/autoencoder/__init__.py,sha256=ZaT1XhJTHpMuPQqu5E__Jezeh9uwtjcXlT7IZ18byq4,161
|
|
12
|
-
ml4gw/nn/autoencoder/base.py,sha256=4d5Ej30IUzZh3XbldzWlCpp3p0_91YUvKeRID8ZEZGA,3225
|
|
13
|
-
ml4gw/nn/autoencoder/convolutional.py,sha256=2BXDuPWYC-151RO_FL0ogdrqSVTfo4YNrY80lPwrmFA,5419
|
|
14
|
-
ml4gw/nn/autoencoder/skip_connection.py,sha256=fpXxxIIl0CXY4mAfUZQuvI542pEBSwpg90TNG2rbZY8,1411
|
|
15
|
-
ml4gw/nn/autoencoder/utils.py,sha256=m_ivYGNwdrhA7cFxJVD4gqM8AHiWIGmlQI3pFNRklXQ,355
|
|
16
|
-
ml4gw/nn/norm.py,sha256=JIOMXQbUtoWlrhncGsqW6f1-DiGDx9zQH2O3CvQml3U,3594
|
|
17
|
-
ml4gw/nn/resnet/__init__.py,sha256=vBI0IftVP_EYAeDlqomtkGqUYE-RE_S4WNioUhniw9s,64
|
|
18
|
-
ml4gw/nn/resnet/resnet_1d.py,sha256=IQ-EIIzAXd-NWuLwt7JTXLWg5bO3FGJpuFAZwZ78jaI,13218
|
|
19
|
-
ml4gw/nn/resnet/resnet_2d.py,sha256=aK4I0FOZk62JxnYFz0t1O0s5s7J7yRNYSM1flRypvVc,13301
|
|
20
|
-
ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
|
|
21
|
-
ml4gw/nn/streaming/online_average.py,sha256=aI8hkT7I3thXkda9tsXxYrzump9swelSXPdSTwPlJWY,4719
|
|
22
|
-
ml4gw/nn/streaming/snapshotter.py,sha256=B9qtbHxnPszAHQ5WQppWJLRuMnnYIxGk7MRUlgja7Is,4476
|
|
23
|
-
ml4gw/spectral.py,sha256=0UPgbqGay-xP-3uJ7orZCb9fSO4eVbu6JTjzZJOFqj4,19160
|
|
24
|
-
ml4gw/transforms/__init__.py,sha256=-DLdjD4usIi0ttSw61ZV7HieCTgHz1vTwfAlRgzbuDw,414
|
|
25
|
-
ml4gw/transforms/pearson.py,sha256=Ep3mMsY15AF55taRaWNjpHRTvtr1StShUDfqk0dN-qo,3235
|
|
26
|
-
ml4gw/transforms/qtransform.py,sha256=TWQsBeKhRoqJdkc4cPt58pKozgb_6-jZivn8u0AzQyQ,20695
|
|
27
|
-
ml4gw/transforms/scaler.py,sha256=souOt-hOO4M6dqPNXOspfmeU2V9622yGoIMNvju5JZI,2524
|
|
28
|
-
ml4gw/transforms/snr_rescaler.py,sha256=3XXCTaXc2dzzpXRZx7iqRwImvYtRSJLM5fHdBGfpoUs,2351
|
|
29
|
-
ml4gw/transforms/spectral.py,sha256=gTHUeC0gGYbzgBZHb_FxC_4zdhl5H-XCiLg1hrvKB70,4393
|
|
30
|
-
ml4gw/transforms/spectrogram.py,sha256=HS3Rf5iB7JjhlSESRDdFGUwCtIBdvUaJUDulkB4Lmos,6162
|
|
31
|
-
ml4gw/transforms/spline_interpolation.py,sha256=oXih-gLMbIbI36DPKLTk39WcjiNUJN_rcQia_k3OrMY,13527
|
|
32
|
-
ml4gw/transforms/transform.py,sha256=BuzTbPFxp18OEGP9Tu9jBGtvqy3len1cqvqg5X37DiY,2512
|
|
33
|
-
ml4gw/transforms/waveforms.py,sha256=LkYCvxPqYhHa2yYZTvPE6j0E4HFy16b5ndCRQb7WfcA,3196
|
|
34
|
-
ml4gw/transforms/whitening.py,sha256=Aw_ogq93CYCATiHWBqSZ-qsUtaHAMA3k009ZRtQTtHA,9596
|
|
35
|
-
ml4gw/types.py,sha256=CcctqDcNajR7khGT6BD-WYsfRKpiP0udoSAB0k1qcFw,863
|
|
36
|
-
ml4gw/utils/interferometer.py,sha256=lRS0N3SwUTknhYXX57VACJ99jK1P9M19oUWN_i_nQN0,1814
|
|
37
|
-
ml4gw/utils/slicing.py,sha256=ilRz_5sJzwmd5VyBlrj81tvyC3uCnXYjd0TO2fzFMr8,13563
|
|
38
|
-
ml4gw/waveforms/__init__.py,sha256=QVUzBx_y8A9_AsRuTJruPvL9mqGnBt11Iw1MOYjXyE4,40
|
|
39
|
-
ml4gw/waveforms/adhoc/__init__.py,sha256=XVwP4t8TMUj87WY3yMGRTkXsv7_lVr1w8p8iKBW8iKE,71
|
|
40
|
-
ml4gw/waveforms/adhoc/ringdown.py,sha256=m8IBQTxKBBGFqBtWGEO4KG3DEYR8TTnNyGVdVLaMKa8,3316
|
|
41
|
-
ml4gw/waveforms/adhoc/sine_gaussian.py,sha256=-MtrI7ydwBTk4K0O4tdkC8-w5OifQszdnWN9__I4XzY,3569
|
|
42
|
-
ml4gw/waveforms/cbc/__init__.py,sha256=hGbPsFNAIveYJnff8qKY8RWeBPFtZoYcnGHxraPWtWI,99
|
|
43
|
-
ml4gw/waveforms/cbc/phenom_d.py,sha256=0pcVAt7b1cjTbphdClPCjenv2sC8bp6oXGGlEUyW-mY,48973
|
|
44
|
-
ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9n25he9U,53447
|
|
45
|
-
ml4gw/waveforms/cbc/phenom_p.py,sha256=bXh1ohqzVQw-UUqc02uNoIMb9oCaT8-WlEWIrnuab-0,27602
|
|
46
|
-
ml4gw/waveforms/cbc/taylorf2.py,sha256=_s-faE8yWMULMxGd4VvzPI54R3G-O2TF2G4-T2m2rDM,10510
|
|
47
|
-
ml4gw/waveforms/conversion.py,sha256=zPkaGkMVqsdrF0fS3ZscyP-2jX8YK40d4smUoJb4gj4,6903
|
|
48
|
-
ml4gw/waveforms/generator.py,sha256=dO6RQ96EC87p2q0tEkxA62XkkJc1xARFO1SKcGvyDhM,1272
|
|
49
|
-
ml4gw-0.6.3.dist-info/METADATA,sha256=6Kpi5UqMguD4hdv_5FUhx36qc_8hoDkoT4IBo6ydwcg,5735
|
|
50
|
-
ml4gw-0.6.3.dist-info/WHEEL,sha256=RaoafKOydTQ7I_I3JTrPCg6kUmTgtm4BornzOqyEfJ8,88
|
|
51
|
-
ml4gw-0.6.3.dist-info/RECORD,,
|