ml4gw 0.2.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

Files changed (46) hide show
  1. ml4gw-0.2.0/README.md → ml4gw-0.4.0/PKG-INFO +28 -3
  2. ml4gw-0.2.0/PKG-INFO → ml4gw-0.4.0/README.md +7 -19
  3. ml4gw-0.4.0/ml4gw/augmentations.py +43 -0
  4. ml4gw-0.4.0/ml4gw/dataloading/__init__.py +3 -0
  5. ml4gw-0.4.0/ml4gw/dataloading/chunked_dataset.py +134 -0
  6. ml4gw-0.4.0/ml4gw/dataloading/hdf5_dataset.py +176 -0
  7. ml4gw-0.4.0/ml4gw/nn/__init__.py +0 -0
  8. ml4gw-0.4.0/ml4gw/nn/autoencoder/__init__.py +3 -0
  9. ml4gw-0.4.0/ml4gw/nn/autoencoder/base.py +89 -0
  10. ml4gw-0.4.0/ml4gw/nn/autoencoder/convolutional.py +156 -0
  11. ml4gw-0.4.0/ml4gw/nn/autoencoder/skip_connection.py +46 -0
  12. ml4gw-0.4.0/ml4gw/nn/autoencoder/utils.py +14 -0
  13. ml4gw-0.4.0/ml4gw/nn/norm.py +97 -0
  14. ml4gw-0.4.0/ml4gw/nn/resnet/__init__.py +2 -0
  15. ml4gw-0.4.0/ml4gw/nn/resnet/resnet_1d.py +413 -0
  16. ml4gw-0.4.0/ml4gw/nn/resnet/resnet_2d.py +413 -0
  17. ml4gw-0.4.0/ml4gw/nn/streaming/__init__.py +2 -0
  18. ml4gw-0.4.0/ml4gw/nn/streaming/online_average.py +121 -0
  19. ml4gw-0.4.0/ml4gw/nn/streaming/snapshotter.py +121 -0
  20. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/__init__.py +2 -0
  21. ml4gw-0.4.0/ml4gw/transforms/pearson.py +87 -0
  22. ml4gw-0.4.0/ml4gw/transforms/spectrogram.py +162 -0
  23. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/whitening.py +1 -1
  24. ml4gw-0.4.0/ml4gw/waveforms/__init__.py +3 -0
  25. ml4gw-0.4.0/ml4gw/waveforms/phenom_d.py +1359 -0
  26. ml4gw-0.4.0/ml4gw/waveforms/phenom_d_data.py +3026 -0
  27. ml4gw-0.4.0/ml4gw/waveforms/taylorf2.py +306 -0
  28. {ml4gw-0.2.0 → ml4gw-0.4.0}/pyproject.toml +15 -4
  29. ml4gw-0.2.0/ml4gw/dataloading/__init__.py +0 -2
  30. ml4gw-0.2.0/ml4gw/dataloading/chunked_dataset.py +0 -280
  31. ml4gw-0.2.0/ml4gw/waveforms/__init__.py +0 -1
  32. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/__init__.py +0 -0
  33. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/dataloading/in_memory_dataset.py +0 -0
  34. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/distributions.py +0 -0
  35. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/gw.py +0 -0
  36. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/spectral.py +0 -0
  37. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/scaler.py +0 -0
  38. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/snr_rescaler.py +0 -0
  39. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/spectral.py +0 -0
  40. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/transform.py +0 -0
  41. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/transforms/waveforms.py +0 -0
  42. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/types.py +0 -0
  43. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/utils/interferometer.py +0 -0
  44. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/utils/slicing.py +0 -0
  45. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/waveforms/generator.py +0 -0
  46. {ml4gw-0.2.0 → ml4gw-0.4.0}/ml4gw/waveforms/sine_gaussian.py +0 -0
@@ -1,3 +1,22 @@
1
+ Metadata-Version: 2.1
2
+ Name: ml4gw
3
+ Version: 0.4.0
4
+ Summary: Tools for training torch models on gravitational wave data
5
+ Author: Alec Gunny
6
+ Author-email: alec.gunny@ligo.org
7
+ Requires-Python: >=3.8,<3.12
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.8
10
+ Classifier: Programming Language :: Python :: 3.9
11
+ Classifier: Programming Language :: Python :: 3.10
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Requires-Dist: torch (>=1.10,<2.0) ; python_version >= "3.8" and python_version < "3.11"
14
+ Requires-Dist: torch (>=2.0,<3.0) ; python_version >= "3.11"
15
+ Requires-Dist: torchaudio (>=0.13,<0.14) ; python_version >= "3.8" and python_version < "3.11"
16
+ Requires-Dist: torchaudio (>=2.0,<3.0) ; python_version >= "3.11"
17
+ Requires-Dist: torchtyping (>=0.1,<0.2)
18
+ Description-Content-Type: text/markdown
19
+
1
20
  # ML4GW
2
21
 
3
22
  Torch utilities for training neural networks in gravitational wave physics applications.
@@ -21,8 +40,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
21
40
 
22
41
  ```toml
23
42
  [tool.poetry.dependencies]
24
- python = "^3.8" # python versions 3.8-3.10 are supported
25
- ml4gw = "^0.1.0"
43
+ python = "^3.8" # python versions 3.8-3.11 are supported
44
+ ml4gw = "^0.3.0"
26
45
  ```
27
46
 
28
47
  To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
@@ -30,7 +49,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
30
49
  ```toml
31
50
  [tool.poetry.dependencies]
32
51
  python = "^3.8"
33
- ml4gw = "^0.1.0"
52
+ ml4gw = "^0.3.0"
34
53
  torch = {version = "^1.12", source = "torch"}
35
54
 
36
55
  [[tool.poetry.source]]
@@ -40,6 +59,8 @@ secondary = true
40
59
  default = false
41
60
  ```
42
61
 
62
+ Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
63
+
43
64
  ## Use cases
44
65
  This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
45
66
 
@@ -128,3 +149,7 @@ We encourage users who encounter these difficulties to file issues on GitHub, an
128
149
  We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
129
150
  For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
130
151
  By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
152
+
153
+ ## Funding
154
+ We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
155
+
@@ -1,19 +1,3 @@
1
- Metadata-Version: 2.1
2
- Name: ml4gw
3
- Version: 0.2.0
4
- Summary: Tools for training torch models on gravitational wave data
5
- Author: Alec Gunny
6
- Author-email: alec.gunny@ligo.org
7
- Requires-Python: >=3.8,<4.0
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
- Classifier: Programming Language :: Python :: 3.9
11
- Classifier: Programming Language :: Python :: 3.10
12
- Classifier: Programming Language :: Python :: 3.11
13
- Requires-Dist: torch (>=1.10,<2.0)
14
- Requires-Dist: torchtyping (>=0.1,<0.2)
15
- Description-Content-Type: text/markdown
16
-
17
1
  # ML4GW
18
2
 
19
3
  Torch utilities for training neural networks in gravitational wave physics applications.
@@ -37,8 +21,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
37
21
 
38
22
  ```toml
39
23
  [tool.poetry.dependencies]
40
- python = "^3.8" # python versions 3.8-3.10 are supported
41
- ml4gw = "^0.1.0"
24
+ python = "^3.8" # python versions 3.8-3.11 are supported
25
+ ml4gw = "^0.3.0"
42
26
  ```
43
27
 
44
28
  To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
@@ -46,7 +30,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
46
30
  ```toml
47
31
  [tool.poetry.dependencies]
48
32
  python = "^3.8"
49
- ml4gw = "^0.1.0"
33
+ ml4gw = "^0.3.0"
50
34
  torch = {version = "^1.12", source = "torch"}
51
35
 
52
36
  [[tool.poetry.source]]
@@ -56,6 +40,8 @@ secondary = true
56
40
  default = false
57
41
  ```
58
42
 
43
+ Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
44
+
59
45
  ## Use cases
60
46
  This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
61
47
 
@@ -145,3 +131,5 @@ We also strongly encourage ML users in the GW physics space to try their hand at
145
131
  For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
146
132
  By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
147
133
 
134
+ ## Funding
135
+ We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
@@ -0,0 +1,43 @@
1
+ import torch
2
+
3
+
4
+ class SignalInverter(torch.nn.Module):
5
+ """
6
+ Takes a tensor of timeseries of arbitrary dimension
7
+ and randomly inverts (i.e. h(t) -> -h(t))
8
+ each timeseries with probability `prob`.
9
+
10
+ Args:
11
+ prob:
12
+ Probability that a timeseries is inverted
13
+ """
14
+
15
+ def __init__(self, prob: float = 0.5):
16
+ super().__init__()
17
+ self.prob = prob
18
+
19
+ def forward(self, X):
20
+ mask = torch.rand(size=X.shape[:-1]) < self.prob
21
+ X[mask] *= -1
22
+ return X
23
+
24
+
25
+ class SignalReverser(torch.nn.Module):
26
+ """
27
+ Takes a tensor of timeseries of arbitrary dimension
28
+ and randomly reverses (i.e. h(t) -> h(-t))
29
+ each timeseries with probability `prob`.
30
+
31
+ Args:
32
+ prob:
33
+ Probability that a kernel is reversed
34
+ """
35
+
36
+ def __init__(self, prob: float = 0.5):
37
+ super().__init__()
38
+ self.prob = prob
39
+
40
+ def forward(self, X):
41
+ mask = torch.rand(size=X.shape[:-1]) < self.prob
42
+ X[mask] = X[mask].flip(-1)
43
+ return X
@@ -0,0 +1,3 @@
1
+ from .chunked_dataset import ChunkedTimeSeriesDataset
2
+ from .hdf5_dataset import Hdf5TimeSeriesDataset
3
+ from .in_memory_dataset import InMemoryDataset
@@ -0,0 +1,134 @@
1
+ from collections.abc import Iterable
2
+
3
+ import torch
4
+
5
+
6
+ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
7
+ """
8
+ Wrapper dataset that will loop through chunks of timeseries
9
+ data produced by another iterable and sample windows from
10
+ these chunks.
11
+
12
+ Args:
13
+ chunk_it:
14
+ Iterator which will produce chunks of timeseries
15
+ data to sample windows from. Should have shape
16
+ `(N, C, T)`, where `N` is the number of chunks
17
+ to sample from, `C` is the number of channels,
18
+ and `T` is the number of samples along the
19
+ time dimension for each chunk.
20
+ kernel_size:
21
+ Size of windows to be sampled from each chunk.
22
+ Should be less than the size of each chunk
23
+ along the time dimension.
24
+ batch_size:
25
+ Number of windows to sample at each iteration
26
+ batches_per_chunk:
27
+ Number of batches of windows to sample from
28
+ each chunk before moving on to the next one.
29
+ Sampling fewer batches from each chunk means
30
+ a lower likelihood of sampling duplicate windows,
31
+ but an increase in chunk-loading overhead.
32
+ coincident:
33
+ Whether the windows sampled from individual
34
+ channels in each batch element should be
35
+ sampled coincidentally, i.e. consisting of
36
+ the same timesteps, or whether each window
37
+ should be sample independently from the others.
38
+ device:
39
+ Which device chunks should be moved to upon loading.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ chunk_it: Iterable,
45
+ kernel_size: float,
46
+ batch_size: int,
47
+ batches_per_chunk: int,
48
+ coincident: bool = True,
49
+ device: str = "cpu",
50
+ ) -> None:
51
+ self.chunk_it = chunk_it
52
+ self.kernel_size = kernel_size
53
+ self.batch_size = batch_size
54
+ self.batches_per_chunk = batches_per_chunk
55
+ self.coincident = coincident
56
+ self.device = device
57
+
58
+ def __len__(self):
59
+ return len(self.chunk_it) * self.batches_per_chunk
60
+
61
+ def __iter__(self):
62
+ it = iter(self.chunk_it)
63
+ chunk = next(it)
64
+ num_chunks, num_channels, chunk_size = chunk.shape
65
+
66
+ # if we're sampling coincidentally, we only need
67
+ # to sample indices on a per-batch-element basis.
68
+ # Otherwise, we'll need indices for both each
69
+ # batch sample _and_ each channel with each sample
70
+ if self.coincident:
71
+ sample_size = (self.batch_size,)
72
+ else:
73
+ sample_size = (self.batch_size, num_channels)
74
+
75
+ # slice kernels out a flattened chunk tensor
76
+ # index-for-index. We'll account for batch/
77
+ # channel indices by introducing offsets later on
78
+ idx = torch.arange(self.kernel_size, device=self.device)
79
+ idx = idx.view(1, 1, -1)
80
+ idx = idx.repeat(self.batch_size, num_channels, 1)
81
+
82
+ # this will just be a set of aranged channel indices
83
+ # repeated to offset the kernel indices in the
84
+ # flattened chunk tensor
85
+ channel_idx = torch.arange(num_channels, device=self.device)
86
+ channel_idx = channel_idx.view(1, -1, 1)
87
+ channel_idx = channel_idx.repeat(self.batch_size, 1, self.kernel_size)
88
+ idx += channel_idx * chunk_size
89
+
90
+ while True:
91
+ # record the number of rows in the chunk, then
92
+ # flatten it to make it easier to slice
93
+ if chunk_size < self.kernel_size:
94
+ raise ValueError(
95
+ "Can't sample kernels of size {} from chunk "
96
+ "with size {}".format(self.kernel_size, chunk_size)
97
+ )
98
+ chunk = chunk.reshape(-1)
99
+
100
+ # generate batches from the current chunk
101
+ for _ in range(self.batches_per_chunk):
102
+ # first sample the indices of which chunk elements
103
+ # we're going to read batch elements from
104
+ chunk_idx = torch.randint(
105
+ 0, num_chunks, size=sample_size, device=self.device
106
+ )
107
+
108
+ # account for the offset this batch element
109
+ # introduces in the flattened array
110
+ chunk_idx *= num_channels * chunk_size
111
+ chunk_idx = chunk_idx.view(self.batch_size, -1, 1)
112
+ chunk_idx = chunk_idx + idx
113
+
114
+ # now sample the start index within each chunk
115
+ # element we're going to grab our time windows from
116
+ time_idx = torch.randint(
117
+ 0,
118
+ chunk_size - self.kernel_size,
119
+ size=sample_size,
120
+ device=self.device,
121
+ )
122
+ time_idx = time_idx.view(self.batch_size, -1, 1)
123
+
124
+ # there's no additional offset factor to account for here
125
+ chunk_idx += time_idx
126
+
127
+ # now slice this 3D tensor from our flattened chunk
128
+ yield chunk[chunk_idx]
129
+
130
+ try:
131
+ chunk = next(it)
132
+ except StopIteration:
133
+ break
134
+ num_chunks, num_channels, chunk_size = chunk.shape
@@ -0,0 +1,176 @@
1
+ import warnings
2
+ from typing import Sequence, Union
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ml4gw.types import WaveformTensor
9
+
10
+
11
+ class ContiguousHdf5Warning(Warning):
12
+ pass
13
+
14
+
15
+ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
16
+ """
17
+ Iterable dataset that samples and loads windows of
18
+ timeseries data uniformly from a set of HDF5 files.
19
+ It is _strongly_ recommended that these files have been
20
+ written using [chunked storage]
21
+ (https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage).
22
+ This has shown to produce increases in read-time speeds
23
+ of over an order of magnitude.
24
+
25
+ Args:
26
+ fnames:
27
+ Paths to HDF5 files from which to sample data.
28
+ channels:
29
+ Datasets to read from the indicated files, which
30
+ will be stacked along dim 1 of the generated batches
31
+ during iteration.
32
+ kernel_size:
33
+ Size of the windows to read, in number of samples.
34
+ This will be the size of the last dimension of the
35
+ generated batches.
36
+ batch_size:
37
+ Number of windows to sample at each iteration.
38
+ batches_per_epoch:
39
+ Number of batches to generate during each call
40
+ to `__iter__`.
41
+ coincident:
42
+ Whether windows for each channel in a given batch
43
+ element should be sampled coincidentally, i.e.
44
+ corresponding to the same time indices from the
45
+ same files, or should be sampled independently.
46
+ For the latter case, users can either specify
47
+ `False`, which will sample filenames independently
48
+ for each channel, or `"files"`, which will sample
49
+ windows independently within a given file for each
50
+ channel. The latter setting limits the amount of
51
+ entropy in the effective dataset, but can provide
52
+ over 2x improvement in total throughput.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ fnames: Sequence[str],
58
+ channels: Sequence[str],
59
+ kernel_size: int,
60
+ batch_size: int,
61
+ batches_per_epoch: int,
62
+ coincident: Union[bool, str],
63
+ ) -> None:
64
+ if not isinstance(coincident, bool) and coincident != "files":
65
+ raise ValueError(
66
+ "coincident must be either a boolean or 'files', "
67
+ "got unrecognized value {}".format(coincident)
68
+ )
69
+
70
+ self.fnames = fnames
71
+ self.channels = channels
72
+ self.num_channels = len(channels)
73
+ self.kernel_size = kernel_size
74
+ self.batch_size = batch_size
75
+ self.batches_per_epoch = batches_per_epoch
76
+ self.coincident = coincident
77
+
78
+ self.sizes = {}
79
+ for fname in self.fnames:
80
+ with h5py.File(fname, "r") as f:
81
+ dset = f[channels[0]]
82
+ if dset.chunks is None:
83
+ warnings.warn(
84
+ "File {} contains datasets that were generated "
85
+ "without using chunked storage. This can have "
86
+ "severe performance impacts at data loading time. "
87
+ "If you need faster loading, try re-generating "
88
+ "your datset with chunked storage turned on.".format(
89
+ fname
90
+ ),
91
+ category=ContiguousHdf5Warning,
92
+ )
93
+
94
+ self.sizes[fname] = len(dset)
95
+ total = sum(self.sizes.values())
96
+ self.probs = np.array([i / total for i in self.sizes.values()])
97
+
98
+ def __len__(self) -> int:
99
+ return self.batches_per_epoch
100
+
101
+ def sample_fnames(self, size) -> np.ndarray:
102
+ return np.random.choice(
103
+ self.fnames,
104
+ p=self.probs,
105
+ size=size,
106
+ replace=True,
107
+ )
108
+
109
+ def sample_batch(self) -> WaveformTensor:
110
+ """
111
+ Sample a single batch of multichannel timeseries
112
+ """
113
+
114
+ # allocate memory up front
115
+ x = np.zeros((self.batch_size, len(self.channels), self.kernel_size))
116
+
117
+ # sample filenames, but only loop through each unique
118
+ # filename once to avoid unnecessary I/O overhead
119
+ if self.coincident is not False:
120
+ size = (self.batch_size,)
121
+ else:
122
+ size = (self.batch_size, self.num_channels)
123
+ fnames = self.sample_fnames(size)
124
+
125
+ unique_fnames, inv, counts = np.unique(
126
+ fnames, return_inverse=True, return_counts=True
127
+ )
128
+ for i, (fname, count) in enumerate(zip(unique_fnames, counts)):
129
+ size = self.sizes[fname]
130
+ max_idx = size - self.kernel_size
131
+
132
+ # figure out which batch indices should be
133
+ # sampled from the current filename
134
+ indices = np.where(inv == i)[0]
135
+
136
+ # when sampling coincidentally either fully
137
+ # or at the file level, all channels will
138
+ # correspond to the same file
139
+ if self.coincident is not False:
140
+ batch_indices = np.repeat(indices, self.num_channels)
141
+ channel_indices = np.arange(self.num_channels)
142
+ channel_indices = np.concatenate([channel_indices] * count)
143
+ else:
144
+ batch_indices = indices // self.num_channels
145
+ channel_indices = indices % self.num_channels
146
+
147
+ # if we're sampling fully coincidentally, each
148
+ # channel will be the same in each file
149
+ if self.coincident is True:
150
+ idx = np.random.randint(max_idx, size=count)
151
+ idx = np.repeat(idx, self.num_channels)
152
+ else:
153
+ # otherwise, every channel will be different
154
+ # for the given file
155
+ idx = np.random.randint(max_idx, size=len(batch_indices))
156
+
157
+ # open the file and sample a different set of
158
+ # kernels for each batch element it occupies
159
+ with h5py.File(fname, "r") as f:
160
+ for b, c, i in zip(batch_indices, channel_indices, idx):
161
+ x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
162
+ return torch.Tensor(x)
163
+
164
+ def __iter__(self) -> torch.Tensor:
165
+ worker_info = torch.utils.data.get_worker_info()
166
+ if worker_info is None:
167
+ num_batches = self.batches_per_epoch
168
+ else:
169
+ num_batches, remainder = divmod(
170
+ self.batches_per_epoch, worker_info.num_workers
171
+ )
172
+ if worker_info.id < remainder:
173
+ num_batches += 1
174
+
175
+ for _ in range(num_batches):
176
+ yield self.sample_batch()
File without changes
@@ -0,0 +1,3 @@
1
+ from .base import Autoencoder
2
+ from .convolutional import ConvolutionalAutoencoder
3
+ from .skip_connection import AddSkipConnect, ConcatSkipConnect, SkipConnection
@@ -0,0 +1,89 @@
1
+ from collections.abc import Sequence
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ from ml4gw.nn.autoencoder.skip_connection import SkipConnection
7
+
8
+
9
+ class Autoencoder(torch.nn.Module):
10
+ """
11
+ Base autoencoder class that defines some of the
12
+ basic methods and functionality. Autoencoders are
13
+ defined here as a set of sequential blocks that
14
+ have an `encode` method, which acts on the input
15
+ data to the autoencoder, and a `decode` method, which
16
+ acts on the encoded vector generated by the `encode`
17
+ method. `forward` just runs these steps one after the
18
+ other. Although it isn't explicitly enforced, a good
19
+ rule of thumb is that the ouput of a block's `decode`
20
+ method should have the same shape as the _input_ of its
21
+ `encode` method.
22
+
23
+ Accepts a `skip_connection` argument that defines how to
24
+ combine information from the input of one block's `encode`
25
+ layer with the output to its `decode`layer. See `skip_connections.py`
26
+ for more info about what these classes are expected to contain
27
+ and how they operate.
28
+ """
29
+
30
+ def __init__(self, skip_connection: Optional[SkipConnection] = None):
31
+ super().__init__()
32
+ self.skip_connection = skip_connection
33
+ self.blocks = torch.nn.ModuleList()
34
+
35
+ def encode(self, *X: torch.Tensor, return_states: bool = False):
36
+ states = []
37
+ for block in self.blocks:
38
+ if isinstance(X, tuple):
39
+ X = block.encode(*X)
40
+ else:
41
+ X = block.encode(X)
42
+ states.append(X)
43
+
44
+ # don't need to return the last
45
+ # state, since that's just equal
46
+ # to the output of this layer
47
+ if return_states:
48
+ return X, states[:-1]
49
+ return X
50
+
51
+ def decode(self, *X, states: Optional[Sequence[torch.Tensor]] = None):
52
+ if self.skip_connection is not None and states is None:
53
+ raise ValueError(
54
+ "Must pass intermediate states when autoencoder "
55
+ "has a skip connection function specified"
56
+ )
57
+ elif states is not None:
58
+ if len(states) != len(self.blocks) - 1:
59
+ raise ValueError(
60
+ "Passed {} intermediate states, expected {}".format(
61
+ len(states), len(self.blocks) - 1
62
+ )
63
+ )
64
+
65
+ # Don't skip connect the output layer
66
+ states = states[::-1] + [None]
67
+
68
+ for i, block in enumerate(self.blocks[::-1]):
69
+ if isinstance(X, tuple):
70
+ X = block.decode(*X)
71
+ else:
72
+ X = block.decode(X)
73
+
74
+ state = states[-i - 1]
75
+ if state is not None:
76
+ X = self.skip_connection(X, state)
77
+ return X
78
+
79
+ def forward(self, *X):
80
+ return_states = self.skip_connection is not None
81
+ X = self.encode(*X, return_states=return_states)
82
+ if return_states:
83
+ *X, states = X
84
+ else:
85
+ states = None
86
+
87
+ if isinstance(X, torch.Tensor):
88
+ X = (X,)
89
+ return self.decode(*X, states=states)