ml4gw 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw-0.2.0/README.md → ml4gw-0.4.0/PKG-INFO +28 -3
- ml4gw-0.2.0/PKG-INFO → ml4gw-0.4.0/README.md +7 -19
- ml4gw-0.4.0/ml4gw/augmentations.py +43 -0
- ml4gw-0.4.0/ml4gw/dataloading/__init__.py +3 -0
- ml4gw-0.4.0/ml4gw/dataloading/chunked_dataset.py +134 -0
- ml4gw-0.4.0/ml4gw/dataloading/hdf5_dataset.py +176 -0
- ml4gw-0.4.0/ml4gw/nn/__init__.py +0 -0
- ml4gw-0.4.0/ml4gw/nn/autoencoder/__init__.py +3 -0
- ml4gw-0.4.0/ml4gw/nn/autoencoder/base.py +89 -0
- ml4gw-0.4.0/ml4gw/nn/autoencoder/convolutional.py +156 -0
- ml4gw-0.4.0/ml4gw/nn/autoencoder/skip_connection.py +46 -0
- ml4gw-0.4.0/ml4gw/nn/autoencoder/utils.py +14 -0
- ml4gw-0.4.0/ml4gw/nn/norm.py +97 -0
- ml4gw-0.4.0/ml4gw/nn/resnet/__init__.py +2 -0
- ml4gw-0.4.0/ml4gw/nn/resnet/resnet_1d.py +413 -0
- ml4gw-0.4.0/ml4gw/nn/resnet/resnet_2d.py +413 -0
- ml4gw-0.4.0/ml4gw/nn/streaming/__init__.py +2 -0
- ml4gw-0.4.0/ml4gw/nn/streaming/online_average.py +121 -0
- ml4gw-0.4.0/ml4gw/nn/streaming/snapshotter.py +121 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/__init__.py +2 -0
- ml4gw-0.4.0/ml4gw/transforms/pearson.py +87 -0
- ml4gw-0.4.0/ml4gw/transforms/spectrogram.py +162 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/whitening.py +1 -1
- ml4gw-0.4.0/ml4gw/waveforms/__init__.py +3 -0
- ml4gw-0.4.0/ml4gw/waveforms/phenom_d.py +1359 -0
- ml4gw-0.4.0/ml4gw/waveforms/phenom_d_data.py +3026 -0
- ml4gw-0.4.0/ml4gw/waveforms/taylorf2.py +306 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/pyproject.toml +15 -4
- ml4gw-0.2.0/ml4gw/dataloading/__init__.py +0 -2
- ml4gw-0.2.0/ml4gw/dataloading/chunked_dataset.py +0 -280
- ml4gw-0.2.0/ml4gw/waveforms/__init__.py +0 -1
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/__init__.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/dataloading/in_memory_dataset.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/distributions.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/gw.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/spectral.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/scaler.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/snr_rescaler.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/spectral.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/transform.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/waveforms.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/types.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/utils/interferometer.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/utils/slicing.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/waveforms/generator.py +0 -0
- {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/waveforms/sine_gaussian.py +0 -0
|
@@ -1,3 +1,22 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: ml4gw
|
|
3
|
+
Version: 0.4.0
|
|
4
|
+
Summary: Tools for training torch models on gravitational wave data
|
|
5
|
+
Author: Alec Gunny
|
|
6
|
+
Author-email: alec.gunny@ligo.org
|
|
7
|
+
Requires-Python: >=3.8,<3.12
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Requires-Dist: torch (>=1.10,<2.0) ; python_version >= "3.8" and python_version < "3.11"
|
|
14
|
+
Requires-Dist: torch (>=2.0,<3.0) ; python_version >= "3.11"
|
|
15
|
+
Requires-Dist: torchaudio (>=0.13,<0.14) ; python_version >= "3.8" and python_version < "3.11"
|
|
16
|
+
Requires-Dist: torchaudio (>=2.0,<3.0) ; python_version >= "3.11"
|
|
17
|
+
Requires-Dist: torchtyping (>=0.1,<0.2)
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
|
|
1
20
|
# ML4GW
|
|
2
21
|
|
|
3
22
|
Torch utilities for training neural networks in gravitational wave physics applications.
|
|
@@ -21,8 +40,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
|
|
|
21
40
|
|
|
22
41
|
```toml
|
|
23
42
|
[tool.poetry.dependencies]
|
|
24
|
-
python = "^3.8" # python versions 3.8-3.
|
|
25
|
-
ml4gw = "^0.
|
|
43
|
+
python = "^3.8" # python versions 3.8-3.11 are supported
|
|
44
|
+
ml4gw = "^0.3.0"
|
|
26
45
|
```
|
|
27
46
|
|
|
28
47
|
To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
|
|
@@ -30,7 +49,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
|
|
|
30
49
|
```toml
|
|
31
50
|
[tool.poetry.dependencies]
|
|
32
51
|
python = "^3.8"
|
|
33
|
-
ml4gw = "^0.
|
|
52
|
+
ml4gw = "^0.3.0"
|
|
34
53
|
torch = {version = "^1.12", source = "torch"}
|
|
35
54
|
|
|
36
55
|
[[tool.poetry.source]]
|
|
@@ -40,6 +59,8 @@ secondary = true
|
|
|
40
59
|
default = false
|
|
41
60
|
```
|
|
42
61
|
|
|
62
|
+
Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
|
|
63
|
+
|
|
43
64
|
## Use cases
|
|
44
65
|
This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
|
|
45
66
|
|
|
@@ -128,3 +149,7 @@ We encourage users who encounter these difficulties to file issues on GitHub, an
|
|
|
128
149
|
We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
|
|
129
150
|
For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
|
|
130
151
|
By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
|
|
152
|
+
|
|
153
|
+
## Funding
|
|
154
|
+
We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
|
|
155
|
+
|
|
@@ -1,19 +1,3 @@
|
|
|
1
|
-
Metadata-Version: 2.1
|
|
2
|
-
Name: ml4gw
|
|
3
|
-
Version: 0.2.0
|
|
4
|
-
Summary: Tools for training torch models on gravitational wave data
|
|
5
|
-
Author: Alec Gunny
|
|
6
|
-
Author-email: alec.gunny@ligo.org
|
|
7
|
-
Requires-Python: >=3.8,<4.0
|
|
8
|
-
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
10
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
-
Requires-Dist: torch (>=1.10,<2.0)
|
|
14
|
-
Requires-Dist: torchtyping (>=0.1,<0.2)
|
|
15
|
-
Description-Content-Type: text/markdown
|
|
16
|
-
|
|
17
1
|
# ML4GW
|
|
18
2
|
|
|
19
3
|
Torch utilities for training neural networks in gravitational wave physics applications.
|
|
@@ -37,8 +21,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
|
|
|
37
21
|
|
|
38
22
|
```toml
|
|
39
23
|
[tool.poetry.dependencies]
|
|
40
|
-
python = "^3.8" # python versions 3.8-3.
|
|
41
|
-
ml4gw = "^0.
|
|
24
|
+
python = "^3.8" # python versions 3.8-3.11 are supported
|
|
25
|
+
ml4gw = "^0.3.0"
|
|
42
26
|
```
|
|
43
27
|
|
|
44
28
|
To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
|
|
@@ -46,7 +30,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
|
|
|
46
30
|
```toml
|
|
47
31
|
[tool.poetry.dependencies]
|
|
48
32
|
python = "^3.8"
|
|
49
|
-
ml4gw = "^0.
|
|
33
|
+
ml4gw = "^0.3.0"
|
|
50
34
|
torch = {version = "^1.12", source = "torch"}
|
|
51
35
|
|
|
52
36
|
[[tool.poetry.source]]
|
|
@@ -56,6 +40,8 @@ secondary = true
|
|
|
56
40
|
default = false
|
|
57
41
|
```
|
|
58
42
|
|
|
43
|
+
Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
|
|
44
|
+
|
|
59
45
|
## Use cases
|
|
60
46
|
This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
|
|
61
47
|
|
|
@@ -145,3 +131,5 @@ We also strongly encourage ML users in the GW physics space to try their hand at
|
|
|
145
131
|
For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
|
|
146
132
|
By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
|
|
147
133
|
|
|
134
|
+
## Funding
|
|
135
|
+
We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SignalInverter(torch.nn.Module):
|
|
5
|
+
"""
|
|
6
|
+
Takes a tensor of timeseries of arbitrary dimension
|
|
7
|
+
and randomly inverts (i.e. h(t) -> -h(t))
|
|
8
|
+
each timeseries with probability `prob`.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
prob:
|
|
12
|
+
Probability that a timeseries is inverted
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, prob: float = 0.5):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.prob = prob
|
|
18
|
+
|
|
19
|
+
def forward(self, X):
|
|
20
|
+
mask = torch.rand(size=X.shape[:-1]) < self.prob
|
|
21
|
+
X[mask] *= -1
|
|
22
|
+
return X
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SignalReverser(torch.nn.Module):
|
|
26
|
+
"""
|
|
27
|
+
Takes a tensor of timeseries of arbitrary dimension
|
|
28
|
+
and randomly reverses (i.e. h(t) -> h(-t))
|
|
29
|
+
each timeseries with probability `prob`.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
prob:
|
|
33
|
+
Probability that a kernel is reversed
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, prob: float = 0.5):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.prob = prob
|
|
39
|
+
|
|
40
|
+
def forward(self, X):
|
|
41
|
+
mask = torch.rand(size=X.shape[:-1]) < self.prob
|
|
42
|
+
X[mask] = X[mask].flip(-1)
|
|
43
|
+
return X
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
7
|
+
"""
|
|
8
|
+
Wrapper dataset that will loop through chunks of timeseries
|
|
9
|
+
data produced by another iterable and sample windows from
|
|
10
|
+
these chunks.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
chunk_it:
|
|
14
|
+
Iterator which will produce chunks of timeseries
|
|
15
|
+
data to sample windows from. Should have shape
|
|
16
|
+
`(N, C, T)`, where `N` is the number of chunks
|
|
17
|
+
to sample from, `C` is the number of channels,
|
|
18
|
+
and `T` is the number of samples along the
|
|
19
|
+
time dimension for each chunk.
|
|
20
|
+
kernel_size:
|
|
21
|
+
Size of windows to be sampled from each chunk.
|
|
22
|
+
Should be less than the size of each chunk
|
|
23
|
+
along the time dimension.
|
|
24
|
+
batch_size:
|
|
25
|
+
Number of windows to sample at each iteration
|
|
26
|
+
batches_per_chunk:
|
|
27
|
+
Number of batches of windows to sample from
|
|
28
|
+
each chunk before moving on to the next one.
|
|
29
|
+
Sampling fewer batches from each chunk means
|
|
30
|
+
a lower likelihood of sampling duplicate windows,
|
|
31
|
+
but an increase in chunk-loading overhead.
|
|
32
|
+
coincident:
|
|
33
|
+
Whether the windows sampled from individual
|
|
34
|
+
channels in each batch element should be
|
|
35
|
+
sampled coincidentally, i.e. consisting of
|
|
36
|
+
the same timesteps, or whether each window
|
|
37
|
+
should be sample independently from the others.
|
|
38
|
+
device:
|
|
39
|
+
Which device chunks should be moved to upon loading.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
chunk_it: Iterable,
|
|
45
|
+
kernel_size: float,
|
|
46
|
+
batch_size: int,
|
|
47
|
+
batches_per_chunk: int,
|
|
48
|
+
coincident: bool = True,
|
|
49
|
+
device: str = "cpu",
|
|
50
|
+
) -> None:
|
|
51
|
+
self.chunk_it = chunk_it
|
|
52
|
+
self.kernel_size = kernel_size
|
|
53
|
+
self.batch_size = batch_size
|
|
54
|
+
self.batches_per_chunk = batches_per_chunk
|
|
55
|
+
self.coincident = coincident
|
|
56
|
+
self.device = device
|
|
57
|
+
|
|
58
|
+
def __len__(self):
|
|
59
|
+
return len(self.chunk_it) * self.batches_per_chunk
|
|
60
|
+
|
|
61
|
+
def __iter__(self):
|
|
62
|
+
it = iter(self.chunk_it)
|
|
63
|
+
chunk = next(it)
|
|
64
|
+
num_chunks, num_channels, chunk_size = chunk.shape
|
|
65
|
+
|
|
66
|
+
# if we're sampling coincidentally, we only need
|
|
67
|
+
# to sample indices on a per-batch-element basis.
|
|
68
|
+
# Otherwise, we'll need indices for both each
|
|
69
|
+
# batch sample _and_ each channel with each sample
|
|
70
|
+
if self.coincident:
|
|
71
|
+
sample_size = (self.batch_size,)
|
|
72
|
+
else:
|
|
73
|
+
sample_size = (self.batch_size, num_channels)
|
|
74
|
+
|
|
75
|
+
# slice kernels out a flattened chunk tensor
|
|
76
|
+
# index-for-index. We'll account for batch/
|
|
77
|
+
# channel indices by introducing offsets later on
|
|
78
|
+
idx = torch.arange(self.kernel_size, device=self.device)
|
|
79
|
+
idx = idx.view(1, 1, -1)
|
|
80
|
+
idx = idx.repeat(self.batch_size, num_channels, 1)
|
|
81
|
+
|
|
82
|
+
# this will just be a set of aranged channel indices
|
|
83
|
+
# repeated to offset the kernel indices in the
|
|
84
|
+
# flattened chunk tensor
|
|
85
|
+
channel_idx = torch.arange(num_channels, device=self.device)
|
|
86
|
+
channel_idx = channel_idx.view(1, -1, 1)
|
|
87
|
+
channel_idx = channel_idx.repeat(self.batch_size, 1, self.kernel_size)
|
|
88
|
+
idx += channel_idx * chunk_size
|
|
89
|
+
|
|
90
|
+
while True:
|
|
91
|
+
# record the number of rows in the chunk, then
|
|
92
|
+
# flatten it to make it easier to slice
|
|
93
|
+
if chunk_size < self.kernel_size:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"Can't sample kernels of size {} from chunk "
|
|
96
|
+
"with size {}".format(self.kernel_size, chunk_size)
|
|
97
|
+
)
|
|
98
|
+
chunk = chunk.reshape(-1)
|
|
99
|
+
|
|
100
|
+
# generate batches from the current chunk
|
|
101
|
+
for _ in range(self.batches_per_chunk):
|
|
102
|
+
# first sample the indices of which chunk elements
|
|
103
|
+
# we're going to read batch elements from
|
|
104
|
+
chunk_idx = torch.randint(
|
|
105
|
+
0, num_chunks, size=sample_size, device=self.device
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# account for the offset this batch element
|
|
109
|
+
# introduces in the flattened array
|
|
110
|
+
chunk_idx *= num_channels * chunk_size
|
|
111
|
+
chunk_idx = chunk_idx.view(self.batch_size, -1, 1)
|
|
112
|
+
chunk_idx = chunk_idx + idx
|
|
113
|
+
|
|
114
|
+
# now sample the start index within each chunk
|
|
115
|
+
# element we're going to grab our time windows from
|
|
116
|
+
time_idx = torch.randint(
|
|
117
|
+
0,
|
|
118
|
+
chunk_size - self.kernel_size,
|
|
119
|
+
size=sample_size,
|
|
120
|
+
device=self.device,
|
|
121
|
+
)
|
|
122
|
+
time_idx = time_idx.view(self.batch_size, -1, 1)
|
|
123
|
+
|
|
124
|
+
# there's no additional offset factor to account for here
|
|
125
|
+
chunk_idx += time_idx
|
|
126
|
+
|
|
127
|
+
# now slice this 3D tensor from our flattened chunk
|
|
128
|
+
yield chunk[chunk_idx]
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
chunk = next(it)
|
|
132
|
+
except StopIteration:
|
|
133
|
+
break
|
|
134
|
+
num_chunks, num_channels, chunk_size = chunk.shape
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Sequence, Union
|
|
3
|
+
|
|
4
|
+
import h5py
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ml4gw.types import WaveformTensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContiguousHdf5Warning(Warning):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
16
|
+
"""
|
|
17
|
+
Iterable dataset that samples and loads windows of
|
|
18
|
+
timeseries data uniformly from a set of HDF5 files.
|
|
19
|
+
It is _strongly_ recommended that these files have been
|
|
20
|
+
written using [chunked storage]
|
|
21
|
+
(https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage).
|
|
22
|
+
This has shown to produce increases in read-time speeds
|
|
23
|
+
of over an order of magnitude.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
fnames:
|
|
27
|
+
Paths to HDF5 files from which to sample data.
|
|
28
|
+
channels:
|
|
29
|
+
Datasets to read from the indicated files, which
|
|
30
|
+
will be stacked along dim 1 of the generated batches
|
|
31
|
+
during iteration.
|
|
32
|
+
kernel_size:
|
|
33
|
+
Size of the windows to read, in number of samples.
|
|
34
|
+
This will be the size of the last dimension of the
|
|
35
|
+
generated batches.
|
|
36
|
+
batch_size:
|
|
37
|
+
Number of windows to sample at each iteration.
|
|
38
|
+
batches_per_epoch:
|
|
39
|
+
Number of batches to generate during each call
|
|
40
|
+
to `__iter__`.
|
|
41
|
+
coincident:
|
|
42
|
+
Whether windows for each channel in a given batch
|
|
43
|
+
element should be sampled coincidentally, i.e.
|
|
44
|
+
corresponding to the same time indices from the
|
|
45
|
+
same files, or should be sampled independently.
|
|
46
|
+
For the latter case, users can either specify
|
|
47
|
+
`False`, which will sample filenames independently
|
|
48
|
+
for each channel, or `"files"`, which will sample
|
|
49
|
+
windows independently within a given file for each
|
|
50
|
+
channel. The latter setting limits the amount of
|
|
51
|
+
entropy in the effective dataset, but can provide
|
|
52
|
+
over 2x improvement in total throughput.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
fnames: Sequence[str],
|
|
58
|
+
channels: Sequence[str],
|
|
59
|
+
kernel_size: int,
|
|
60
|
+
batch_size: int,
|
|
61
|
+
batches_per_epoch: int,
|
|
62
|
+
coincident: Union[bool, str],
|
|
63
|
+
) -> None:
|
|
64
|
+
if not isinstance(coincident, bool) and coincident != "files":
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"coincident must be either a boolean or 'files', "
|
|
67
|
+
"got unrecognized value {}".format(coincident)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self.fnames = fnames
|
|
71
|
+
self.channels = channels
|
|
72
|
+
self.num_channels = len(channels)
|
|
73
|
+
self.kernel_size = kernel_size
|
|
74
|
+
self.batch_size = batch_size
|
|
75
|
+
self.batches_per_epoch = batches_per_epoch
|
|
76
|
+
self.coincident = coincident
|
|
77
|
+
|
|
78
|
+
self.sizes = {}
|
|
79
|
+
for fname in self.fnames:
|
|
80
|
+
with h5py.File(fname, "r") as f:
|
|
81
|
+
dset = f[channels[0]]
|
|
82
|
+
if dset.chunks is None:
|
|
83
|
+
warnings.warn(
|
|
84
|
+
"File {} contains datasets that were generated "
|
|
85
|
+
"without using chunked storage. This can have "
|
|
86
|
+
"severe performance impacts at data loading time. "
|
|
87
|
+
"If you need faster loading, try re-generating "
|
|
88
|
+
"your datset with chunked storage turned on.".format(
|
|
89
|
+
fname
|
|
90
|
+
),
|
|
91
|
+
category=ContiguousHdf5Warning,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.sizes[fname] = len(dset)
|
|
95
|
+
total = sum(self.sizes.values())
|
|
96
|
+
self.probs = np.array([i / total for i in self.sizes.values()])
|
|
97
|
+
|
|
98
|
+
def __len__(self) -> int:
|
|
99
|
+
return self.batches_per_epoch
|
|
100
|
+
|
|
101
|
+
def sample_fnames(self, size) -> np.ndarray:
|
|
102
|
+
return np.random.choice(
|
|
103
|
+
self.fnames,
|
|
104
|
+
p=self.probs,
|
|
105
|
+
size=size,
|
|
106
|
+
replace=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def sample_batch(self) -> WaveformTensor:
|
|
110
|
+
"""
|
|
111
|
+
Sample a single batch of multichannel timeseries
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
# allocate memory up front
|
|
115
|
+
x = np.zeros((self.batch_size, len(self.channels), self.kernel_size))
|
|
116
|
+
|
|
117
|
+
# sample filenames, but only loop through each unique
|
|
118
|
+
# filename once to avoid unnecessary I/O overhead
|
|
119
|
+
if self.coincident is not False:
|
|
120
|
+
size = (self.batch_size,)
|
|
121
|
+
else:
|
|
122
|
+
size = (self.batch_size, self.num_channels)
|
|
123
|
+
fnames = self.sample_fnames(size)
|
|
124
|
+
|
|
125
|
+
unique_fnames, inv, counts = np.unique(
|
|
126
|
+
fnames, return_inverse=True, return_counts=True
|
|
127
|
+
)
|
|
128
|
+
for i, (fname, count) in enumerate(zip(unique_fnames, counts)):
|
|
129
|
+
size = self.sizes[fname]
|
|
130
|
+
max_idx = size - self.kernel_size
|
|
131
|
+
|
|
132
|
+
# figure out which batch indices should be
|
|
133
|
+
# sampled from the current filename
|
|
134
|
+
indices = np.where(inv == i)[0]
|
|
135
|
+
|
|
136
|
+
# when sampling coincidentally either fully
|
|
137
|
+
# or at the file level, all channels will
|
|
138
|
+
# correspond to the same file
|
|
139
|
+
if self.coincident is not False:
|
|
140
|
+
batch_indices = np.repeat(indices, self.num_channels)
|
|
141
|
+
channel_indices = np.arange(self.num_channels)
|
|
142
|
+
channel_indices = np.concatenate([channel_indices] * count)
|
|
143
|
+
else:
|
|
144
|
+
batch_indices = indices // self.num_channels
|
|
145
|
+
channel_indices = indices % self.num_channels
|
|
146
|
+
|
|
147
|
+
# if we're sampling fully coincidentally, each
|
|
148
|
+
# channel will be the same in each file
|
|
149
|
+
if self.coincident is True:
|
|
150
|
+
idx = np.random.randint(max_idx, size=count)
|
|
151
|
+
idx = np.repeat(idx, self.num_channels)
|
|
152
|
+
else:
|
|
153
|
+
# otherwise, every channel will be different
|
|
154
|
+
# for the given file
|
|
155
|
+
idx = np.random.randint(max_idx, size=len(batch_indices))
|
|
156
|
+
|
|
157
|
+
# open the file and sample a different set of
|
|
158
|
+
# kernels for each batch element it occupies
|
|
159
|
+
with h5py.File(fname, "r") as f:
|
|
160
|
+
for b, c, i in zip(batch_indices, channel_indices, idx):
|
|
161
|
+
x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
|
|
162
|
+
return torch.Tensor(x)
|
|
163
|
+
|
|
164
|
+
def __iter__(self) -> torch.Tensor:
|
|
165
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
166
|
+
if worker_info is None:
|
|
167
|
+
num_batches = self.batches_per_epoch
|
|
168
|
+
else:
|
|
169
|
+
num_batches, remainder = divmod(
|
|
170
|
+
self.batches_per_epoch, worker_info.num_workers
|
|
171
|
+
)
|
|
172
|
+
if worker_info.id < remainder:
|
|
173
|
+
num_batches += 1
|
|
174
|
+
|
|
175
|
+
for _ in range(num_batches):
|
|
176
|
+
yield self.sample_batch()
|
|
File without changes
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Autoencoder(torch.nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
Base autoencoder class that defines some of the
|
|
12
|
+
basic methods and functionality. Autoencoders are
|
|
13
|
+
defined here as a set of sequential blocks that
|
|
14
|
+
have an `encode` method, which acts on the input
|
|
15
|
+
data to the autoencoder, and a `decode` method, which
|
|
16
|
+
acts on the encoded vector generated by the `encode`
|
|
17
|
+
method. `forward` just runs these steps one after the
|
|
18
|
+
other. Although it isn't explicitly enforced, a good
|
|
19
|
+
rule of thumb is that the ouput of a block's `decode`
|
|
20
|
+
method should have the same shape as the _input_ of its
|
|
21
|
+
`encode` method.
|
|
22
|
+
|
|
23
|
+
Accepts a `skip_connection` argument that defines how to
|
|
24
|
+
combine information from the input of one block's `encode`
|
|
25
|
+
layer with the output to its `decode`layer. See `skip_connections.py`
|
|
26
|
+
for more info about what these classes are expected to contain
|
|
27
|
+
and how they operate.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, skip_connection: Optional[SkipConnection] = None):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.skip_connection = skip_connection
|
|
33
|
+
self.blocks = torch.nn.ModuleList()
|
|
34
|
+
|
|
35
|
+
def encode(self, *X: torch.Tensor, return_states: bool = False):
|
|
36
|
+
states = []
|
|
37
|
+
for block in self.blocks:
|
|
38
|
+
if isinstance(X, tuple):
|
|
39
|
+
X = block.encode(*X)
|
|
40
|
+
else:
|
|
41
|
+
X = block.encode(X)
|
|
42
|
+
states.append(X)
|
|
43
|
+
|
|
44
|
+
# don't need to return the last
|
|
45
|
+
# state, since that's just equal
|
|
46
|
+
# to the output of this layer
|
|
47
|
+
if return_states:
|
|
48
|
+
return X, states[:-1]
|
|
49
|
+
return X
|
|
50
|
+
|
|
51
|
+
def decode(self, *X, states: Optional[Sequence[torch.Tensor]] = None):
|
|
52
|
+
if self.skip_connection is not None and states is None:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
"Must pass intermediate states when autoencoder "
|
|
55
|
+
"has a skip connection function specified"
|
|
56
|
+
)
|
|
57
|
+
elif states is not None:
|
|
58
|
+
if len(states) != len(self.blocks) - 1:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Passed {} intermediate states, expected {}".format(
|
|
61
|
+
len(states), len(self.blocks) - 1
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Don't skip connect the output layer
|
|
66
|
+
states = states[::-1] + [None]
|
|
67
|
+
|
|
68
|
+
for i, block in enumerate(self.blocks[::-1]):
|
|
69
|
+
if isinstance(X, tuple):
|
|
70
|
+
X = block.decode(*X)
|
|
71
|
+
else:
|
|
72
|
+
X = block.decode(X)
|
|
73
|
+
|
|
74
|
+
state = states[-i - 1]
|
|
75
|
+
if state is not None:
|
|
76
|
+
X = self.skip_connection(X, state)
|
|
77
|
+
return X
|
|
78
|
+
|
|
79
|
+
def forward(self, *X):
|
|
80
|
+
return_states = self.skip_connection is not None
|
|
81
|
+
X = self.encode(*X, return_states=return_states)
|
|
82
|
+
if return_states:
|
|
83
|
+
*X, states = X
|
|
84
|
+
else:
|
|
85
|
+
states = None
|
|
86
|
+
|
|
87
|
+
if isinstance(X, torch.Tensor):
|
|
88
|
+
X = (X,)
|
|
89
|
+
return self.decode(*X, states=states)
|