ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +43 -0
- ml4gw/dataloading/__init__.py +2 -1
- ml4gw/dataloading/chunked_dataset.py +66 -212
- ml4gw/dataloading/hdf5_dataset.py +176 -0
- ml4gw/nn/__init__.py +0 -0
- ml4gw/nn/autoencoder/__init__.py +3 -0
- ml4gw/nn/autoencoder/base.py +89 -0
- ml4gw/nn/autoencoder/convolutional.py +156 -0
- ml4gw/nn/autoencoder/skip_connection.py +46 -0
- ml4gw/nn/autoencoder/utils.py +14 -0
- ml4gw/nn/norm.py +97 -0
- ml4gw/nn/resnet/__init__.py +2 -0
- ml4gw/nn/resnet/resnet_1d.py +413 -0
- ml4gw/nn/resnet/resnet_2d.py +413 -0
- ml4gw/nn/streaming/__init__.py +2 -0
- ml4gw/nn/streaming/online_average.py +121 -0
- ml4gw/nn/streaming/snapshotter.py +121 -0
- ml4gw/transforms/__init__.py +2 -0
- ml4gw/transforms/pearson.py +87 -0
- ml4gw/transforms/spectrogram.py +162 -0
- ml4gw/transforms/whitening.py +1 -1
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/phenom_d.py +1359 -0
- ml4gw/waveforms/phenom_d_data.py +3026 -0
- ml4gw/waveforms/taylorf2.py +306 -0
- {ml4gw-0.2.0.dist-info → ml4gw-0.4.0.dist-info}/METADATA +14 -6
- ml4gw-0.4.0.dist-info/RECORD +43 -0
- {ml4gw-0.2.0.dist-info → ml4gw-0.4.0.dist-info}/WHEEL +1 -1
- ml4gw-0.2.0.dist-info/RECORD +0 -23
ml4gw/augmentations.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SignalInverter(torch.nn.Module):
|
|
5
|
+
"""
|
|
6
|
+
Takes a tensor of timeseries of arbitrary dimension
|
|
7
|
+
and randomly inverts (i.e. h(t) -> -h(t))
|
|
8
|
+
each timeseries with probability `prob`.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
prob:
|
|
12
|
+
Probability that a timeseries is inverted
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, prob: float = 0.5):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.prob = prob
|
|
18
|
+
|
|
19
|
+
def forward(self, X):
|
|
20
|
+
mask = torch.rand(size=X.shape[:-1]) < self.prob
|
|
21
|
+
X[mask] *= -1
|
|
22
|
+
return X
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SignalReverser(torch.nn.Module):
|
|
26
|
+
"""
|
|
27
|
+
Takes a tensor of timeseries of arbitrary dimension
|
|
28
|
+
and randomly reverses (i.e. h(t) -> h(-t))
|
|
29
|
+
each timeseries with probability `prob`.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
prob:
|
|
33
|
+
Probability that a kernel is reversed
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, prob: float = 0.5):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.prob = prob
|
|
39
|
+
|
|
40
|
+
def forward(self, X):
|
|
41
|
+
mask = torch.rand(size=X.shape[:-1]) < self.prob
|
|
42
|
+
X[mask] = X[mask].flip(-1)
|
|
43
|
+
return X
|
ml4gw/dataloading/__init__.py
CHANGED
|
@@ -1,262 +1,113 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Iterable
|
|
2
2
|
|
|
3
|
-
import h5py
|
|
4
|
-
import numpy as np
|
|
5
3
|
import torch
|
|
6
4
|
|
|
7
5
|
|
|
8
|
-
class
|
|
9
|
-
def __init__(
|
|
10
|
-
self,
|
|
11
|
-
fnames: List[str],
|
|
12
|
-
channels: List[str],
|
|
13
|
-
chunk_size: int,
|
|
14
|
-
reads_per_chunk: int,
|
|
15
|
-
chunks_per_epoch: int,
|
|
16
|
-
coincident: bool = True,
|
|
17
|
-
) -> None:
|
|
18
|
-
self.fnames = fnames
|
|
19
|
-
self.channels = channels
|
|
20
|
-
self.chunk_size = chunk_size
|
|
21
|
-
self.reads_per_chunk = reads_per_chunk
|
|
22
|
-
self.chunks_per_epoch = chunks_per_epoch
|
|
23
|
-
self.coincident = coincident
|
|
24
|
-
|
|
25
|
-
sizes = []
|
|
26
|
-
for f in self.fnames:
|
|
27
|
-
with h5py.File(f, "r") as f:
|
|
28
|
-
size = len(f[self.channels[0]])
|
|
29
|
-
sizes.append(size)
|
|
30
|
-
total = sum(sizes)
|
|
31
|
-
self.probs = np.array([i / total for i in sizes])
|
|
32
|
-
|
|
33
|
-
def sample_fnames(self):
|
|
34
|
-
return np.random.choice(
|
|
35
|
-
self.fnames,
|
|
36
|
-
p=self.probs,
|
|
37
|
-
size=(self.reads_per_chunk,),
|
|
38
|
-
replace=True,
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
def load_coincident(self):
|
|
42
|
-
fnames = self.sample_fnames()
|
|
43
|
-
chunks = []
|
|
44
|
-
for fname in fnames:
|
|
45
|
-
with h5py.File(fname, "r") as f:
|
|
46
|
-
chunk, idx = [], None
|
|
47
|
-
for channel in self.channels:
|
|
48
|
-
if idx is None:
|
|
49
|
-
end = len(f[channel]) - self.chunk_size
|
|
50
|
-
idx = np.random.randint(0, end)
|
|
51
|
-
x = f[channel][idx : idx + self.chunk_size]
|
|
52
|
-
chunk.append(x)
|
|
53
|
-
chunks.append(np.stack(chunk))
|
|
54
|
-
return np.stack(chunks)
|
|
55
|
-
|
|
56
|
-
def load_noncoincident(self):
|
|
57
|
-
chunks = []
|
|
58
|
-
for channel in self.channels:
|
|
59
|
-
fnames = self.sample_fnames()
|
|
60
|
-
chunk = []
|
|
61
|
-
for fname in fnames:
|
|
62
|
-
with h5py.File(fname, "r") as f:
|
|
63
|
-
end = len(f[channel]) - self.chunk_size
|
|
64
|
-
idx = np.random.randint(0, end)
|
|
65
|
-
x = f[channel][idx : idx + self.chunk_size]
|
|
66
|
-
chunk.append(x)
|
|
67
|
-
chunks.append(np.stack(chunk))
|
|
68
|
-
return np.stack(chunks, axis=1)
|
|
69
|
-
|
|
70
|
-
def iter_epoch(self):
|
|
71
|
-
for _ in range(self.chunks_per_epoch):
|
|
72
|
-
if self.coincident:
|
|
73
|
-
yield torch.Tensor(self.load_coincident())
|
|
74
|
-
else:
|
|
75
|
-
yield torch.Tensor(self.load_noncoincident())
|
|
76
|
-
|
|
77
|
-
def collate(self, xs):
|
|
78
|
-
return torch.cat(xs, axis=0)
|
|
79
|
-
|
|
80
|
-
def __iter__(self):
|
|
81
|
-
return self.iter_epoch()
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
class ChunkedDataset(torch.utils.data.IterableDataset):
|
|
6
|
+
class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
85
7
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
background from `reads_per_chunk` randomly sampled
|
|
90
|
-
files up front, then samples `batches_per_chunk`
|
|
91
|
-
batches of kernels from this chunk before loading
|
|
92
|
-
in the next one. Terminates after `chunks_per_epoch`
|
|
93
|
-
chunks have been exhausted, which amounts to
|
|
94
|
-
`chunks_per_epoch * batches_per_chunk` batches.
|
|
95
|
-
|
|
96
|
-
Note that filenames are not sampled uniformly
|
|
97
|
-
at chunk-loading time, but are weighted according
|
|
98
|
-
to the amount of data each file contains. This ensures
|
|
99
|
-
a uniform sampling over time across the whole dataset.
|
|
100
|
-
|
|
101
|
-
To load chunks asynchronously in the background,
|
|
102
|
-
specify `num_workers > 0`. Note that if the
|
|
103
|
-
number of workers is not an even multiple of
|
|
104
|
-
`chunks_per_epoch`, the last chunks of an epoch
|
|
105
|
-
will be composed of fewer than `reads_per_chunk`
|
|
106
|
-
individual segments.
|
|
8
|
+
Wrapper dataset that will loop through chunks of timeseries
|
|
9
|
+
data produced by another iterable and sample windows from
|
|
10
|
+
these chunks.
|
|
107
11
|
|
|
108
12
|
Args:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
has been sampled.
|
|
13
|
+
chunk_it:
|
|
14
|
+
Iterator which will produce chunks of timeseries
|
|
15
|
+
data to sample windows from. Should have shape
|
|
16
|
+
`(N, C, T)`, where `N` is the number of chunks
|
|
17
|
+
to sample from, `C` is the number of channels,
|
|
18
|
+
and `T` is the number of samples along the
|
|
19
|
+
time dimension for each chunk.
|
|
20
|
+
kernel_size:
|
|
21
|
+
Size of windows to be sampled from each chunk.
|
|
22
|
+
Should be less than the size of each chunk
|
|
23
|
+
along the time dimension.
|
|
121
24
|
batch_size:
|
|
122
|
-
Number of
|
|
123
|
-
reads_per_chunk:
|
|
124
|
-
Number of file reads to perform when generating
|
|
125
|
-
each chunk
|
|
126
|
-
chunk_length:
|
|
127
|
-
Amount of data to read for each segment loaded
|
|
128
|
-
into each chunk, in seconds
|
|
25
|
+
Number of windows to sample at each iteration
|
|
129
26
|
batches_per_chunk:
|
|
130
|
-
Number of batches to sample from
|
|
131
|
-
before
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
27
|
+
Number of batches of windows to sample from
|
|
28
|
+
each chunk before moving on to the next one.
|
|
29
|
+
Sampling fewer batches from each chunk means
|
|
30
|
+
a lower likelihood of sampling duplicate windows,
|
|
31
|
+
but an increase in chunk-loading overhead.
|
|
135
32
|
coincident:
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
asynchronously. If left as 0, chunk loading will
|
|
142
|
-
be performed in serial with batch sampling.
|
|
33
|
+
Whether the windows sampled from individual
|
|
34
|
+
channels in each batch element should be
|
|
35
|
+
sampled coincidentally, i.e. consisting of
|
|
36
|
+
the same timesteps, or whether each window
|
|
37
|
+
should be sample independently from the others.
|
|
143
38
|
device:
|
|
144
|
-
|
|
39
|
+
Which device chunks should be moved to upon loading.
|
|
145
40
|
"""
|
|
146
41
|
|
|
147
42
|
def __init__(
|
|
148
43
|
self,
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
kernel_length: float,
|
|
152
|
-
sample_rate: float,
|
|
44
|
+
chunk_it: Iterable,
|
|
45
|
+
kernel_size: float,
|
|
153
46
|
batch_size: int,
|
|
154
|
-
reads_per_chunk: int,
|
|
155
|
-
chunk_length: float,
|
|
156
47
|
batches_per_chunk: int,
|
|
157
|
-
chunks_per_epoch: int,
|
|
158
48
|
coincident: bool = True,
|
|
159
|
-
num_workers: int = 0,
|
|
160
49
|
device: str = "cpu",
|
|
161
|
-
pin_memory: bool = False,
|
|
162
50
|
) -> None:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
elif reads_per_chunk < num_workers:
|
|
166
|
-
raise ValueError(
|
|
167
|
-
"Too many workers {} for number of reads_per_chunk {}".format(
|
|
168
|
-
num_workers, reads_per_chunk
|
|
169
|
-
)
|
|
170
|
-
)
|
|
171
|
-
else:
|
|
172
|
-
reads_per_worker = int(reads_per_chunk // num_workers)
|
|
173
|
-
|
|
174
|
-
if kernel_length > chunk_length:
|
|
175
|
-
raise ValueError(
|
|
176
|
-
"Kernel length {} must be shorter than "
|
|
177
|
-
"chunk length {}".format(kernel_length, chunk_length)
|
|
178
|
-
)
|
|
179
|
-
self.kernel_size = int(kernel_length * sample_rate)
|
|
180
|
-
self.chunk_size = int(chunk_length * sample_rate)
|
|
181
|
-
|
|
182
|
-
chunk_loader = ChunkLoader(
|
|
183
|
-
fnames,
|
|
184
|
-
channels,
|
|
185
|
-
self.chunk_size,
|
|
186
|
-
reads_per_worker,
|
|
187
|
-
chunks_per_epoch,
|
|
188
|
-
coincident=coincident,
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
if not num_workers:
|
|
192
|
-
self.chunk_loader = chunk_loader
|
|
193
|
-
else:
|
|
194
|
-
self.chunk_loader = torch.utils.data.DataLoader(
|
|
195
|
-
chunk_loader,
|
|
196
|
-
batch_size=num_workers,
|
|
197
|
-
num_workers=num_workers,
|
|
198
|
-
pin_memory=pin_memory,
|
|
199
|
-
collate_fn=chunk_loader.collate,
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
self.device = device
|
|
203
|
-
self.num_channels = len(channels)
|
|
204
|
-
self.coincident = coincident
|
|
205
|
-
|
|
51
|
+
self.chunk_it = chunk_it
|
|
52
|
+
self.kernel_size = kernel_size
|
|
206
53
|
self.batch_size = batch_size
|
|
207
54
|
self.batches_per_chunk = batches_per_chunk
|
|
208
|
-
self.
|
|
209
|
-
self.
|
|
55
|
+
self.coincident = coincident
|
|
56
|
+
self.device = device
|
|
210
57
|
|
|
211
58
|
def __len__(self):
|
|
212
|
-
|
|
213
|
-
return self.chunks_per_epoch * self.batches_per_chunk
|
|
59
|
+
return len(self.chunk_it) * self.batches_per_chunk
|
|
214
60
|
|
|
215
|
-
|
|
216
|
-
|
|
61
|
+
def __iter__(self):
|
|
62
|
+
it = iter(self.chunk_it)
|
|
63
|
+
chunk = next(it)
|
|
64
|
+
num_chunks, num_channels, chunk_size = chunk.shape
|
|
65
|
+
|
|
66
|
+
# if we're sampling coincidentally, we only need
|
|
67
|
+
# to sample indices on a per-batch-element basis.
|
|
68
|
+
# Otherwise, we'll need indices for both each
|
|
69
|
+
# batch sample _and_ each channel with each sample
|
|
70
|
+
if self.coincident:
|
|
71
|
+
sample_size = (self.batch_size,)
|
|
72
|
+
else:
|
|
73
|
+
sample_size = (self.batch_size, num_channels)
|
|
217
74
|
|
|
218
|
-
def iter_epoch(self):
|
|
219
75
|
# slice kernels out a flattened chunk tensor
|
|
220
76
|
# index-for-index. We'll account for batch/
|
|
221
77
|
# channel indices by introducing offsets later on
|
|
222
78
|
idx = torch.arange(self.kernel_size, device=self.device)
|
|
223
79
|
idx = idx.view(1, 1, -1)
|
|
224
|
-
idx = idx.repeat(self.batch_size,
|
|
80
|
+
idx = idx.repeat(self.batch_size, num_channels, 1)
|
|
225
81
|
|
|
226
82
|
# this will just be a set of aranged channel indices
|
|
227
83
|
# repeated to offset the kernel indices in the
|
|
228
84
|
# flattened chunk tensor
|
|
229
|
-
channel_idx = torch.arange(
|
|
85
|
+
channel_idx = torch.arange(num_channels, device=self.device)
|
|
230
86
|
channel_idx = channel_idx.view(1, -1, 1)
|
|
231
87
|
channel_idx = channel_idx.repeat(self.batch_size, 1, self.kernel_size)
|
|
232
|
-
idx += channel_idx *
|
|
88
|
+
idx += channel_idx * chunk_size
|
|
233
89
|
|
|
234
|
-
|
|
90
|
+
while True:
|
|
235
91
|
# record the number of rows in the chunk, then
|
|
236
92
|
# flatten it to make it easier to slice
|
|
237
|
-
|
|
238
|
-
|
|
93
|
+
if chunk_size < self.kernel_size:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"Can't sample kernels of size {} from chunk "
|
|
96
|
+
"with size {}".format(self.kernel_size, chunk_size)
|
|
97
|
+
)
|
|
98
|
+
chunk = chunk.reshape(-1)
|
|
239
99
|
|
|
240
100
|
# generate batches from the current chunk
|
|
241
101
|
for _ in range(self.batches_per_chunk):
|
|
242
|
-
# if we're sampling coincidentally, we only need
|
|
243
|
-
# to sample indices on a per-batch-element basis.
|
|
244
|
-
# Otherwise, we'll need indices for both each
|
|
245
|
-
# batch sample _and_ each channel with each sample
|
|
246
|
-
if self.coincident:
|
|
247
|
-
size = (self.batch_size,)
|
|
248
|
-
else:
|
|
249
|
-
size = (self.batch_size, self.num_channels)
|
|
250
|
-
|
|
251
102
|
# first sample the indices of which chunk elements
|
|
252
103
|
# we're going to read batch elements from
|
|
253
104
|
chunk_idx = torch.randint(
|
|
254
|
-
0, num_chunks, size=
|
|
105
|
+
0, num_chunks, size=sample_size, device=self.device
|
|
255
106
|
)
|
|
256
107
|
|
|
257
108
|
# account for the offset this batch element
|
|
258
109
|
# introduces in the flattened array
|
|
259
|
-
chunk_idx *=
|
|
110
|
+
chunk_idx *= num_channels * chunk_size
|
|
260
111
|
chunk_idx = chunk_idx.view(self.batch_size, -1, 1)
|
|
261
112
|
chunk_idx = chunk_idx + idx
|
|
262
113
|
|
|
@@ -265,7 +116,7 @@ class ChunkedDataset(torch.utils.data.IterableDataset):
|
|
|
265
116
|
time_idx = torch.randint(
|
|
266
117
|
0,
|
|
267
118
|
chunk_size - self.kernel_size,
|
|
268
|
-
size=
|
|
119
|
+
size=sample_size,
|
|
269
120
|
device=self.device,
|
|
270
121
|
)
|
|
271
122
|
time_idx = time_idx.view(self.batch_size, -1, 1)
|
|
@@ -276,5 +127,8 @@ class ChunkedDataset(torch.utils.data.IterableDataset):
|
|
|
276
127
|
# now slice this 3D tensor from our flattened chunk
|
|
277
128
|
yield chunk[chunk_idx]
|
|
278
129
|
|
|
279
|
-
|
|
280
|
-
|
|
130
|
+
try:
|
|
131
|
+
chunk = next(it)
|
|
132
|
+
except StopIteration:
|
|
133
|
+
break
|
|
134
|
+
num_chunks, num_channels, chunk_size = chunk.shape
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Sequence, Union
|
|
3
|
+
|
|
4
|
+
import h5py
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ml4gw.types import WaveformTensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContiguousHdf5Warning(Warning):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
16
|
+
"""
|
|
17
|
+
Iterable dataset that samples and loads windows of
|
|
18
|
+
timeseries data uniformly from a set of HDF5 files.
|
|
19
|
+
It is _strongly_ recommended that these files have been
|
|
20
|
+
written using [chunked storage]
|
|
21
|
+
(https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage).
|
|
22
|
+
This has shown to produce increases in read-time speeds
|
|
23
|
+
of over an order of magnitude.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
fnames:
|
|
27
|
+
Paths to HDF5 files from which to sample data.
|
|
28
|
+
channels:
|
|
29
|
+
Datasets to read from the indicated files, which
|
|
30
|
+
will be stacked along dim 1 of the generated batches
|
|
31
|
+
during iteration.
|
|
32
|
+
kernel_size:
|
|
33
|
+
Size of the windows to read, in number of samples.
|
|
34
|
+
This will be the size of the last dimension of the
|
|
35
|
+
generated batches.
|
|
36
|
+
batch_size:
|
|
37
|
+
Number of windows to sample at each iteration.
|
|
38
|
+
batches_per_epoch:
|
|
39
|
+
Number of batches to generate during each call
|
|
40
|
+
to `__iter__`.
|
|
41
|
+
coincident:
|
|
42
|
+
Whether windows for each channel in a given batch
|
|
43
|
+
element should be sampled coincidentally, i.e.
|
|
44
|
+
corresponding to the same time indices from the
|
|
45
|
+
same files, or should be sampled independently.
|
|
46
|
+
For the latter case, users can either specify
|
|
47
|
+
`False`, which will sample filenames independently
|
|
48
|
+
for each channel, or `"files"`, which will sample
|
|
49
|
+
windows independently within a given file for each
|
|
50
|
+
channel. The latter setting limits the amount of
|
|
51
|
+
entropy in the effective dataset, but can provide
|
|
52
|
+
over 2x improvement in total throughput.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
fnames: Sequence[str],
|
|
58
|
+
channels: Sequence[str],
|
|
59
|
+
kernel_size: int,
|
|
60
|
+
batch_size: int,
|
|
61
|
+
batches_per_epoch: int,
|
|
62
|
+
coincident: Union[bool, str],
|
|
63
|
+
) -> None:
|
|
64
|
+
if not isinstance(coincident, bool) and coincident != "files":
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"coincident must be either a boolean or 'files', "
|
|
67
|
+
"got unrecognized value {}".format(coincident)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self.fnames = fnames
|
|
71
|
+
self.channels = channels
|
|
72
|
+
self.num_channels = len(channels)
|
|
73
|
+
self.kernel_size = kernel_size
|
|
74
|
+
self.batch_size = batch_size
|
|
75
|
+
self.batches_per_epoch = batches_per_epoch
|
|
76
|
+
self.coincident = coincident
|
|
77
|
+
|
|
78
|
+
self.sizes = {}
|
|
79
|
+
for fname in self.fnames:
|
|
80
|
+
with h5py.File(fname, "r") as f:
|
|
81
|
+
dset = f[channels[0]]
|
|
82
|
+
if dset.chunks is None:
|
|
83
|
+
warnings.warn(
|
|
84
|
+
"File {} contains datasets that were generated "
|
|
85
|
+
"without using chunked storage. This can have "
|
|
86
|
+
"severe performance impacts at data loading time. "
|
|
87
|
+
"If you need faster loading, try re-generating "
|
|
88
|
+
"your datset with chunked storage turned on.".format(
|
|
89
|
+
fname
|
|
90
|
+
),
|
|
91
|
+
category=ContiguousHdf5Warning,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.sizes[fname] = len(dset)
|
|
95
|
+
total = sum(self.sizes.values())
|
|
96
|
+
self.probs = np.array([i / total for i in self.sizes.values()])
|
|
97
|
+
|
|
98
|
+
def __len__(self) -> int:
|
|
99
|
+
return self.batches_per_epoch
|
|
100
|
+
|
|
101
|
+
def sample_fnames(self, size) -> np.ndarray:
|
|
102
|
+
return np.random.choice(
|
|
103
|
+
self.fnames,
|
|
104
|
+
p=self.probs,
|
|
105
|
+
size=size,
|
|
106
|
+
replace=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def sample_batch(self) -> WaveformTensor:
|
|
110
|
+
"""
|
|
111
|
+
Sample a single batch of multichannel timeseries
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
# allocate memory up front
|
|
115
|
+
x = np.zeros((self.batch_size, len(self.channels), self.kernel_size))
|
|
116
|
+
|
|
117
|
+
# sample filenames, but only loop through each unique
|
|
118
|
+
# filename once to avoid unnecessary I/O overhead
|
|
119
|
+
if self.coincident is not False:
|
|
120
|
+
size = (self.batch_size,)
|
|
121
|
+
else:
|
|
122
|
+
size = (self.batch_size, self.num_channels)
|
|
123
|
+
fnames = self.sample_fnames(size)
|
|
124
|
+
|
|
125
|
+
unique_fnames, inv, counts = np.unique(
|
|
126
|
+
fnames, return_inverse=True, return_counts=True
|
|
127
|
+
)
|
|
128
|
+
for i, (fname, count) in enumerate(zip(unique_fnames, counts)):
|
|
129
|
+
size = self.sizes[fname]
|
|
130
|
+
max_idx = size - self.kernel_size
|
|
131
|
+
|
|
132
|
+
# figure out which batch indices should be
|
|
133
|
+
# sampled from the current filename
|
|
134
|
+
indices = np.where(inv == i)[0]
|
|
135
|
+
|
|
136
|
+
# when sampling coincidentally either fully
|
|
137
|
+
# or at the file level, all channels will
|
|
138
|
+
# correspond to the same file
|
|
139
|
+
if self.coincident is not False:
|
|
140
|
+
batch_indices = np.repeat(indices, self.num_channels)
|
|
141
|
+
channel_indices = np.arange(self.num_channels)
|
|
142
|
+
channel_indices = np.concatenate([channel_indices] * count)
|
|
143
|
+
else:
|
|
144
|
+
batch_indices = indices // self.num_channels
|
|
145
|
+
channel_indices = indices % self.num_channels
|
|
146
|
+
|
|
147
|
+
# if we're sampling fully coincidentally, each
|
|
148
|
+
# channel will be the same in each file
|
|
149
|
+
if self.coincident is True:
|
|
150
|
+
idx = np.random.randint(max_idx, size=count)
|
|
151
|
+
idx = np.repeat(idx, self.num_channels)
|
|
152
|
+
else:
|
|
153
|
+
# otherwise, every channel will be different
|
|
154
|
+
# for the given file
|
|
155
|
+
idx = np.random.randint(max_idx, size=len(batch_indices))
|
|
156
|
+
|
|
157
|
+
# open the file and sample a different set of
|
|
158
|
+
# kernels for each batch element it occupies
|
|
159
|
+
with h5py.File(fname, "r") as f:
|
|
160
|
+
for b, c, i in zip(batch_indices, channel_indices, idx):
|
|
161
|
+
x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
|
|
162
|
+
return torch.Tensor(x)
|
|
163
|
+
|
|
164
|
+
def __iter__(self) -> torch.Tensor:
|
|
165
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
166
|
+
if worker_info is None:
|
|
167
|
+
num_batches = self.batches_per_epoch
|
|
168
|
+
else:
|
|
169
|
+
num_batches, remainder = divmod(
|
|
170
|
+
self.batches_per_epoch, worker_info.num_workers
|
|
171
|
+
)
|
|
172
|
+
if worker_info.id < remainder:
|
|
173
|
+
num_batches += 1
|
|
174
|
+
|
|
175
|
+
for _ in range(num_batches):
|
|
176
|
+
yield self.sample_batch()
|
ml4gw/nn/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Autoencoder(torch.nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
Base autoencoder class that defines some of the
|
|
12
|
+
basic methods and functionality. Autoencoders are
|
|
13
|
+
defined here as a set of sequential blocks that
|
|
14
|
+
have an `encode` method, which acts on the input
|
|
15
|
+
data to the autoencoder, and a `decode` method, which
|
|
16
|
+
acts on the encoded vector generated by the `encode`
|
|
17
|
+
method. `forward` just runs these steps one after the
|
|
18
|
+
other. Although it isn't explicitly enforced, a good
|
|
19
|
+
rule of thumb is that the ouput of a block's `decode`
|
|
20
|
+
method should have the same shape as the _input_ of its
|
|
21
|
+
`encode` method.
|
|
22
|
+
|
|
23
|
+
Accepts a `skip_connection` argument that defines how to
|
|
24
|
+
combine information from the input of one block's `encode`
|
|
25
|
+
layer with the output to its `decode`layer. See `skip_connections.py`
|
|
26
|
+
for more info about what these classes are expected to contain
|
|
27
|
+
and how they operate.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, skip_connection: Optional[SkipConnection] = None):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.skip_connection = skip_connection
|
|
33
|
+
self.blocks = torch.nn.ModuleList()
|
|
34
|
+
|
|
35
|
+
def encode(self, *X: torch.Tensor, return_states: bool = False):
|
|
36
|
+
states = []
|
|
37
|
+
for block in self.blocks:
|
|
38
|
+
if isinstance(X, tuple):
|
|
39
|
+
X = block.encode(*X)
|
|
40
|
+
else:
|
|
41
|
+
X = block.encode(X)
|
|
42
|
+
states.append(X)
|
|
43
|
+
|
|
44
|
+
# don't need to return the last
|
|
45
|
+
# state, since that's just equal
|
|
46
|
+
# to the output of this layer
|
|
47
|
+
if return_states:
|
|
48
|
+
return X, states[:-1]
|
|
49
|
+
return X
|
|
50
|
+
|
|
51
|
+
def decode(self, *X, states: Optional[Sequence[torch.Tensor]] = None):
|
|
52
|
+
if self.skip_connection is not None and states is None:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
"Must pass intermediate states when autoencoder "
|
|
55
|
+
"has a skip connection function specified"
|
|
56
|
+
)
|
|
57
|
+
elif states is not None:
|
|
58
|
+
if len(states) != len(self.blocks) - 1:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Passed {} intermediate states, expected {}".format(
|
|
61
|
+
len(states), len(self.blocks) - 1
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Don't skip connect the output layer
|
|
66
|
+
states = states[::-1] + [None]
|
|
67
|
+
|
|
68
|
+
for i, block in enumerate(self.blocks[::-1]):
|
|
69
|
+
if isinstance(X, tuple):
|
|
70
|
+
X = block.decode(*X)
|
|
71
|
+
else:
|
|
72
|
+
X = block.decode(X)
|
|
73
|
+
|
|
74
|
+
state = states[-i - 1]
|
|
75
|
+
if state is not None:
|
|
76
|
+
X = self.skip_connection(X, state)
|
|
77
|
+
return X
|
|
78
|
+
|
|
79
|
+
def forward(self, *X):
|
|
80
|
+
return_states = self.skip_connection is not None
|
|
81
|
+
X = self.encode(*X, return_states=return_states)
|
|
82
|
+
if return_states:
|
|
83
|
+
*X, states = X
|
|
84
|
+
else:
|
|
85
|
+
states = None
|
|
86
|
+
|
|
87
|
+
if isinstance(X, torch.Tensor):
|
|
88
|
+
X = (X,)
|
|
89
|
+
return self.decode(*X, states=states)
|