ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/augmentations.py ADDED
@@ -0,0 +1,43 @@
1
+ import torch
2
+
3
+
4
+ class SignalInverter(torch.nn.Module):
5
+ """
6
+ Takes a tensor of timeseries of arbitrary dimension
7
+ and randomly inverts (i.e. h(t) -> -h(t))
8
+ each timeseries with probability `prob`.
9
+
10
+ Args:
11
+ prob:
12
+ Probability that a timeseries is inverted
13
+ """
14
+
15
+ def __init__(self, prob: float = 0.5):
16
+ super().__init__()
17
+ self.prob = prob
18
+
19
+ def forward(self, X):
20
+ mask = torch.rand(size=X.shape[:-1]) < self.prob
21
+ X[mask] *= -1
22
+ return X
23
+
24
+
25
+ class SignalReverser(torch.nn.Module):
26
+ """
27
+ Takes a tensor of timeseries of arbitrary dimension
28
+ and randomly reverses (i.e. h(t) -> h(-t))
29
+ each timeseries with probability `prob`.
30
+
31
+ Args:
32
+ prob:
33
+ Probability that a kernel is reversed
34
+ """
35
+
36
+ def __init__(self, prob: float = 0.5):
37
+ super().__init__()
38
+ self.prob = prob
39
+
40
+ def forward(self, X):
41
+ mask = torch.rand(size=X.shape[:-1]) < self.prob
42
+ X[mask] = X[mask].flip(-1)
43
+ return X
@@ -1,2 +1,3 @@
1
- from .chunked_dataset import ChunkedDataset
1
+ from .chunked_dataset import ChunkedTimeSeriesDataset
2
+ from .hdf5_dataset import Hdf5TimeSeriesDataset
2
3
  from .in_memory_dataset import InMemoryDataset
@@ -1,262 +1,113 @@
1
- from typing import List
1
+ from collections.abc import Iterable
2
2
 
3
- import h5py
4
- import numpy as np
5
3
  import torch
6
4
 
7
5
 
8
- class ChunkLoader(torch.utils.data.IterableDataset):
9
- def __init__(
10
- self,
11
- fnames: List[str],
12
- channels: List[str],
13
- chunk_size: int,
14
- reads_per_chunk: int,
15
- chunks_per_epoch: int,
16
- coincident: bool = True,
17
- ) -> None:
18
- self.fnames = fnames
19
- self.channels = channels
20
- self.chunk_size = chunk_size
21
- self.reads_per_chunk = reads_per_chunk
22
- self.chunks_per_epoch = chunks_per_epoch
23
- self.coincident = coincident
24
-
25
- sizes = []
26
- for f in self.fnames:
27
- with h5py.File(f, "r") as f:
28
- size = len(f[self.channels[0]])
29
- sizes.append(size)
30
- total = sum(sizes)
31
- self.probs = np.array([i / total for i in sizes])
32
-
33
- def sample_fnames(self):
34
- return np.random.choice(
35
- self.fnames,
36
- p=self.probs,
37
- size=(self.reads_per_chunk,),
38
- replace=True,
39
- )
40
-
41
- def load_coincident(self):
42
- fnames = self.sample_fnames()
43
- chunks = []
44
- for fname in fnames:
45
- with h5py.File(fname, "r") as f:
46
- chunk, idx = [], None
47
- for channel in self.channels:
48
- if idx is None:
49
- end = len(f[channel]) - self.chunk_size
50
- idx = np.random.randint(0, end)
51
- x = f[channel][idx : idx + self.chunk_size]
52
- chunk.append(x)
53
- chunks.append(np.stack(chunk))
54
- return np.stack(chunks)
55
-
56
- def load_noncoincident(self):
57
- chunks = []
58
- for channel in self.channels:
59
- fnames = self.sample_fnames()
60
- chunk = []
61
- for fname in fnames:
62
- with h5py.File(fname, "r") as f:
63
- end = len(f[channel]) - self.chunk_size
64
- idx = np.random.randint(0, end)
65
- x = f[channel][idx : idx + self.chunk_size]
66
- chunk.append(x)
67
- chunks.append(np.stack(chunk))
68
- return np.stack(chunks, axis=1)
69
-
70
- def iter_epoch(self):
71
- for _ in range(self.chunks_per_epoch):
72
- if self.coincident:
73
- yield torch.Tensor(self.load_coincident())
74
- else:
75
- yield torch.Tensor(self.load_noncoincident())
76
-
77
- def collate(self, xs):
78
- return torch.cat(xs, axis=0)
79
-
80
- def __iter__(self):
81
- return self.iter_epoch()
82
-
83
-
84
- class ChunkedDataset(torch.utils.data.IterableDataset):
6
+ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
85
7
  """
86
- Iterable dataset for generating batches of background data
87
- loaded on-the-fly from multiple HDF5 files. Loads
88
- `chunk_length`-sized randomly-sampled stretches of
89
- background from `reads_per_chunk` randomly sampled
90
- files up front, then samples `batches_per_chunk`
91
- batches of kernels from this chunk before loading
92
- in the next one. Terminates after `chunks_per_epoch`
93
- chunks have been exhausted, which amounts to
94
- `chunks_per_epoch * batches_per_chunk` batches.
95
-
96
- Note that filenames are not sampled uniformly
97
- at chunk-loading time, but are weighted according
98
- to the amount of data each file contains. This ensures
99
- a uniform sampling over time across the whole dataset.
100
-
101
- To load chunks asynchronously in the background,
102
- specify `num_workers > 0`. Note that if the
103
- number of workers is not an even multiple of
104
- `chunks_per_epoch`, the last chunks of an epoch
105
- will be composed of fewer than `reads_per_chunk`
106
- individual segments.
8
+ Wrapper dataset that will loop through chunks of timeseries
9
+ data produced by another iterable and sample windows from
10
+ these chunks.
107
11
 
108
12
  Args:
109
- fnames:
110
- List of HDF5 archives containing data to read.
111
- Each file should have all of the channels specified
112
- in `channels` as top-level datasets.
113
- channels:
114
- Datasets to load from each filename in `fnames`
115
- kernel_length:
116
- Length of the windows returned at iteration time
117
- in seconds
118
- sample_rate:
119
- Rate at which the data in the specified `fnames`
120
- has been sampled.
13
+ chunk_it:
14
+ Iterator which will produce chunks of timeseries
15
+ data to sample windows from. Should have shape
16
+ `(N, C, T)`, where `N` is the number of chunks
17
+ to sample from, `C` is the number of channels,
18
+ and `T` is the number of samples along the
19
+ time dimension for each chunk.
20
+ kernel_size:
21
+ Size of windows to be sampled from each chunk.
22
+ Should be less than the size of each chunk
23
+ along the time dimension.
121
24
  batch_size:
122
- Number of samples to return at iteration time
123
- reads_per_chunk:
124
- Number of file reads to perform when generating
125
- each chunk
126
- chunk_length:
127
- Amount of data to read for each segment loaded
128
- into each chunk, in seconds
25
+ Number of windows to sample at each iteration
129
26
  batches_per_chunk:
130
- Number of batches to sample from each chunk
131
- before loading the next one
132
- chunks_per_epoch:
133
- Number of chunks to generate before iteration
134
- terminates
27
+ Number of batches of windows to sample from
28
+ each chunk before moving on to the next one.
29
+ Sampling fewer batches from each chunk means
30
+ a lower likelihood of sampling duplicate windows,
31
+ but an increase in chunk-loading overhead.
135
32
  coincident:
136
- Flag indicating whether windows returned at iteration
137
- time should come from the same point in time for
138
- each channel in a given batch sample.
139
- num_workers:
140
- Number of workers for performing chunk loading
141
- asynchronously. If left as 0, chunk loading will
142
- be performed in serial with batch sampling.
33
+ Whether the windows sampled from individual
34
+ channels in each batch element should be
35
+ sampled coincidentally, i.e. consisting of
36
+ the same timesteps, or whether each window
37
+ should be sample independently from the others.
143
38
  device:
144
- Device on which to host loaded chunks
39
+ Which device chunks should be moved to upon loading.
145
40
  """
146
41
 
147
42
  def __init__(
148
43
  self,
149
- fnames: List[str],
150
- channels: List[str],
151
- kernel_length: float,
152
- sample_rate: float,
44
+ chunk_it: Iterable,
45
+ kernel_size: float,
153
46
  batch_size: int,
154
- reads_per_chunk: int,
155
- chunk_length: float,
156
47
  batches_per_chunk: int,
157
- chunks_per_epoch: int,
158
48
  coincident: bool = True,
159
- num_workers: int = 0,
160
49
  device: str = "cpu",
161
- pin_memory: bool = False,
162
50
  ) -> None:
163
- if not num_workers:
164
- reads_per_worker = reads_per_chunk
165
- elif reads_per_chunk < num_workers:
166
- raise ValueError(
167
- "Too many workers {} for number of reads_per_chunk {}".format(
168
- num_workers, reads_per_chunk
169
- )
170
- )
171
- else:
172
- reads_per_worker = int(reads_per_chunk // num_workers)
173
-
174
- if kernel_length > chunk_length:
175
- raise ValueError(
176
- "Kernel length {} must be shorter than "
177
- "chunk length {}".format(kernel_length, chunk_length)
178
- )
179
- self.kernel_size = int(kernel_length * sample_rate)
180
- self.chunk_size = int(chunk_length * sample_rate)
181
-
182
- chunk_loader = ChunkLoader(
183
- fnames,
184
- channels,
185
- self.chunk_size,
186
- reads_per_worker,
187
- chunks_per_epoch,
188
- coincident=coincident,
189
- )
190
-
191
- if not num_workers:
192
- self.chunk_loader = chunk_loader
193
- else:
194
- self.chunk_loader = torch.utils.data.DataLoader(
195
- chunk_loader,
196
- batch_size=num_workers,
197
- num_workers=num_workers,
198
- pin_memory=pin_memory,
199
- collate_fn=chunk_loader.collate,
200
- )
201
-
202
- self.device = device
203
- self.num_channels = len(channels)
204
- self.coincident = coincident
205
-
51
+ self.chunk_it = chunk_it
52
+ self.kernel_size = kernel_size
206
53
  self.batch_size = batch_size
207
54
  self.batches_per_chunk = batches_per_chunk
208
- self.chunks_per_epoch = chunks_per_epoch
209
- self.num_workers = num_workers
55
+ self.coincident = coincident
56
+ self.device = device
210
57
 
211
58
  def __len__(self):
212
- if not self.num_workers:
213
- return self.chunks_per_epoch * self.batches_per_chunk
59
+ return len(self.chunk_it) * self.batches_per_chunk
214
60
 
215
- num_chunks = (self.chunks_per_epoch - 1) // self.num_workers + 1
216
- return num_chunks * self.num_workers * self.batches_per_chunk
61
+ def __iter__(self):
62
+ it = iter(self.chunk_it)
63
+ chunk = next(it)
64
+ num_chunks, num_channels, chunk_size = chunk.shape
65
+
66
+ # if we're sampling coincidentally, we only need
67
+ # to sample indices on a per-batch-element basis.
68
+ # Otherwise, we'll need indices for both each
69
+ # batch sample _and_ each channel with each sample
70
+ if self.coincident:
71
+ sample_size = (self.batch_size,)
72
+ else:
73
+ sample_size = (self.batch_size, num_channels)
217
74
 
218
- def iter_epoch(self):
219
75
  # slice kernels out a flattened chunk tensor
220
76
  # index-for-index. We'll account for batch/
221
77
  # channel indices by introducing offsets later on
222
78
  idx = torch.arange(self.kernel_size, device=self.device)
223
79
  idx = idx.view(1, 1, -1)
224
- idx = idx.repeat(self.batch_size, self.num_channels, 1)
80
+ idx = idx.repeat(self.batch_size, num_channels, 1)
225
81
 
226
82
  # this will just be a set of aranged channel indices
227
83
  # repeated to offset the kernel indices in the
228
84
  # flattened chunk tensor
229
- channel_idx = torch.arange(self.num_channels, device=self.device)
85
+ channel_idx = torch.arange(num_channels, device=self.device)
230
86
  channel_idx = channel_idx.view(1, -1, 1)
231
87
  channel_idx = channel_idx.repeat(self.batch_size, 1, self.kernel_size)
232
- idx += channel_idx * self.chunk_size
88
+ idx += channel_idx * chunk_size
233
89
 
234
- for chunk in self.chunk_loader:
90
+ while True:
235
91
  # record the number of rows in the chunk, then
236
92
  # flatten it to make it easier to slice
237
- num_chunks, _, chunk_size = chunk.shape
238
- chunk = chunk.to(self.device).reshape(-1)
93
+ if chunk_size < self.kernel_size:
94
+ raise ValueError(
95
+ "Can't sample kernels of size {} from chunk "
96
+ "with size {}".format(self.kernel_size, chunk_size)
97
+ )
98
+ chunk = chunk.reshape(-1)
239
99
 
240
100
  # generate batches from the current chunk
241
101
  for _ in range(self.batches_per_chunk):
242
- # if we're sampling coincidentally, we only need
243
- # to sample indices on a per-batch-element basis.
244
- # Otherwise, we'll need indices for both each
245
- # batch sample _and_ each channel with each sample
246
- if self.coincident:
247
- size = (self.batch_size,)
248
- else:
249
- size = (self.batch_size, self.num_channels)
250
-
251
102
  # first sample the indices of which chunk elements
252
103
  # we're going to read batch elements from
253
104
  chunk_idx = torch.randint(
254
- 0, num_chunks, size=size, device=self.device
105
+ 0, num_chunks, size=sample_size, device=self.device
255
106
  )
256
107
 
257
108
  # account for the offset this batch element
258
109
  # introduces in the flattened array
259
- chunk_idx *= self.num_channels * self.chunk_size
110
+ chunk_idx *= num_channels * chunk_size
260
111
  chunk_idx = chunk_idx.view(self.batch_size, -1, 1)
261
112
  chunk_idx = chunk_idx + idx
262
113
 
@@ -265,7 +116,7 @@ class ChunkedDataset(torch.utils.data.IterableDataset):
265
116
  time_idx = torch.randint(
266
117
  0,
267
118
  chunk_size - self.kernel_size,
268
- size=size,
119
+ size=sample_size,
269
120
  device=self.device,
270
121
  )
271
122
  time_idx = time_idx.view(self.batch_size, -1, 1)
@@ -276,5 +127,8 @@ class ChunkedDataset(torch.utils.data.IterableDataset):
276
127
  # now slice this 3D tensor from our flattened chunk
277
128
  yield chunk[chunk_idx]
278
129
 
279
- def __iter__(self):
280
- return self.iter_epoch()
130
+ try:
131
+ chunk = next(it)
132
+ except StopIteration:
133
+ break
134
+ num_chunks, num_channels, chunk_size = chunk.shape
@@ -0,0 +1,176 @@
1
+ import warnings
2
+ from typing import Sequence, Union
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ml4gw.types import WaveformTensor
9
+
10
+
11
+ class ContiguousHdf5Warning(Warning):
12
+ pass
13
+
14
+
15
+ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
16
+ """
17
+ Iterable dataset that samples and loads windows of
18
+ timeseries data uniformly from a set of HDF5 files.
19
+ It is _strongly_ recommended that these files have been
20
+ written using [chunked storage]
21
+ (https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage).
22
+ This has shown to produce increases in read-time speeds
23
+ of over an order of magnitude.
24
+
25
+ Args:
26
+ fnames:
27
+ Paths to HDF5 files from which to sample data.
28
+ channels:
29
+ Datasets to read from the indicated files, which
30
+ will be stacked along dim 1 of the generated batches
31
+ during iteration.
32
+ kernel_size:
33
+ Size of the windows to read, in number of samples.
34
+ This will be the size of the last dimension of the
35
+ generated batches.
36
+ batch_size:
37
+ Number of windows to sample at each iteration.
38
+ batches_per_epoch:
39
+ Number of batches to generate during each call
40
+ to `__iter__`.
41
+ coincident:
42
+ Whether windows for each channel in a given batch
43
+ element should be sampled coincidentally, i.e.
44
+ corresponding to the same time indices from the
45
+ same files, or should be sampled independently.
46
+ For the latter case, users can either specify
47
+ `False`, which will sample filenames independently
48
+ for each channel, or `"files"`, which will sample
49
+ windows independently within a given file for each
50
+ channel. The latter setting limits the amount of
51
+ entropy in the effective dataset, but can provide
52
+ over 2x improvement in total throughput.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ fnames: Sequence[str],
58
+ channels: Sequence[str],
59
+ kernel_size: int,
60
+ batch_size: int,
61
+ batches_per_epoch: int,
62
+ coincident: Union[bool, str],
63
+ ) -> None:
64
+ if not isinstance(coincident, bool) and coincident != "files":
65
+ raise ValueError(
66
+ "coincident must be either a boolean or 'files', "
67
+ "got unrecognized value {}".format(coincident)
68
+ )
69
+
70
+ self.fnames = fnames
71
+ self.channels = channels
72
+ self.num_channels = len(channels)
73
+ self.kernel_size = kernel_size
74
+ self.batch_size = batch_size
75
+ self.batches_per_epoch = batches_per_epoch
76
+ self.coincident = coincident
77
+
78
+ self.sizes = {}
79
+ for fname in self.fnames:
80
+ with h5py.File(fname, "r") as f:
81
+ dset = f[channels[0]]
82
+ if dset.chunks is None:
83
+ warnings.warn(
84
+ "File {} contains datasets that were generated "
85
+ "without using chunked storage. This can have "
86
+ "severe performance impacts at data loading time. "
87
+ "If you need faster loading, try re-generating "
88
+ "your datset with chunked storage turned on.".format(
89
+ fname
90
+ ),
91
+ category=ContiguousHdf5Warning,
92
+ )
93
+
94
+ self.sizes[fname] = len(dset)
95
+ total = sum(self.sizes.values())
96
+ self.probs = np.array([i / total for i in self.sizes.values()])
97
+
98
+ def __len__(self) -> int:
99
+ return self.batches_per_epoch
100
+
101
+ def sample_fnames(self, size) -> np.ndarray:
102
+ return np.random.choice(
103
+ self.fnames,
104
+ p=self.probs,
105
+ size=size,
106
+ replace=True,
107
+ )
108
+
109
+ def sample_batch(self) -> WaveformTensor:
110
+ """
111
+ Sample a single batch of multichannel timeseries
112
+ """
113
+
114
+ # allocate memory up front
115
+ x = np.zeros((self.batch_size, len(self.channels), self.kernel_size))
116
+
117
+ # sample filenames, but only loop through each unique
118
+ # filename once to avoid unnecessary I/O overhead
119
+ if self.coincident is not False:
120
+ size = (self.batch_size,)
121
+ else:
122
+ size = (self.batch_size, self.num_channels)
123
+ fnames = self.sample_fnames(size)
124
+
125
+ unique_fnames, inv, counts = np.unique(
126
+ fnames, return_inverse=True, return_counts=True
127
+ )
128
+ for i, (fname, count) in enumerate(zip(unique_fnames, counts)):
129
+ size = self.sizes[fname]
130
+ max_idx = size - self.kernel_size
131
+
132
+ # figure out which batch indices should be
133
+ # sampled from the current filename
134
+ indices = np.where(inv == i)[0]
135
+
136
+ # when sampling coincidentally either fully
137
+ # or at the file level, all channels will
138
+ # correspond to the same file
139
+ if self.coincident is not False:
140
+ batch_indices = np.repeat(indices, self.num_channels)
141
+ channel_indices = np.arange(self.num_channels)
142
+ channel_indices = np.concatenate([channel_indices] * count)
143
+ else:
144
+ batch_indices = indices // self.num_channels
145
+ channel_indices = indices % self.num_channels
146
+
147
+ # if we're sampling fully coincidentally, each
148
+ # channel will be the same in each file
149
+ if self.coincident is True:
150
+ idx = np.random.randint(max_idx, size=count)
151
+ idx = np.repeat(idx, self.num_channels)
152
+ else:
153
+ # otherwise, every channel will be different
154
+ # for the given file
155
+ idx = np.random.randint(max_idx, size=len(batch_indices))
156
+
157
+ # open the file and sample a different set of
158
+ # kernels for each batch element it occupies
159
+ with h5py.File(fname, "r") as f:
160
+ for b, c, i in zip(batch_indices, channel_indices, idx):
161
+ x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
162
+ return torch.Tensor(x)
163
+
164
+ def __iter__(self) -> torch.Tensor:
165
+ worker_info = torch.utils.data.get_worker_info()
166
+ if worker_info is None:
167
+ num_batches = self.batches_per_epoch
168
+ else:
169
+ num_batches, remainder = divmod(
170
+ self.batches_per_epoch, worker_info.num_workers
171
+ )
172
+ if worker_info.id < remainder:
173
+ num_batches += 1
174
+
175
+ for _ in range(num_batches):
176
+ yield self.sample_batch()
ml4gw/nn/__init__.py ADDED
File without changes
@@ -0,0 +1,3 @@
1
+ from .base import Autoencoder
2
+ from .convolutional import ConvolutionalAutoencoder
3
+ from .skip_connection import AddSkipConnect, ConcatSkipConnect, SkipConnection
@@ -0,0 +1,89 @@
1
+ from collections.abc import Sequence
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ from ml4gw.nn.autoencoder.skip_connection import SkipConnection
7
+
8
+
9
+ class Autoencoder(torch.nn.Module):
10
+ """
11
+ Base autoencoder class that defines some of the
12
+ basic methods and functionality. Autoencoders are
13
+ defined here as a set of sequential blocks that
14
+ have an `encode` method, which acts on the input
15
+ data to the autoencoder, and a `decode` method, which
16
+ acts on the encoded vector generated by the `encode`
17
+ method. `forward` just runs these steps one after the
18
+ other. Although it isn't explicitly enforced, a good
19
+ rule of thumb is that the ouput of a block's `decode`
20
+ method should have the same shape as the _input_ of its
21
+ `encode` method.
22
+
23
+ Accepts a `skip_connection` argument that defines how to
24
+ combine information from the input of one block's `encode`
25
+ layer with the output to its `decode`layer. See `skip_connections.py`
26
+ for more info about what these classes are expected to contain
27
+ and how they operate.
28
+ """
29
+
30
+ def __init__(self, skip_connection: Optional[SkipConnection] = None):
31
+ super().__init__()
32
+ self.skip_connection = skip_connection
33
+ self.blocks = torch.nn.ModuleList()
34
+
35
+ def encode(self, *X: torch.Tensor, return_states: bool = False):
36
+ states = []
37
+ for block in self.blocks:
38
+ if isinstance(X, tuple):
39
+ X = block.encode(*X)
40
+ else:
41
+ X = block.encode(X)
42
+ states.append(X)
43
+
44
+ # don't need to return the last
45
+ # state, since that's just equal
46
+ # to the output of this layer
47
+ if return_states:
48
+ return X, states[:-1]
49
+ return X
50
+
51
+ def decode(self, *X, states: Optional[Sequence[torch.Tensor]] = None):
52
+ if self.skip_connection is not None and states is None:
53
+ raise ValueError(
54
+ "Must pass intermediate states when autoencoder "
55
+ "has a skip connection function specified"
56
+ )
57
+ elif states is not None:
58
+ if len(states) != len(self.blocks) - 1:
59
+ raise ValueError(
60
+ "Passed {} intermediate states, expected {}".format(
61
+ len(states), len(self.blocks) - 1
62
+ )
63
+ )
64
+
65
+ # Don't skip connect the output layer
66
+ states = states[::-1] + [None]
67
+
68
+ for i, block in enumerate(self.blocks[::-1]):
69
+ if isinstance(X, tuple):
70
+ X = block.decode(*X)
71
+ else:
72
+ X = block.decode(X)
73
+
74
+ state = states[-i - 1]
75
+ if state is not None:
76
+ X = self.skip_connection(X, state)
77
+ return X
78
+
79
+ def forward(self, *X):
80
+ return_states = self.skip_connection is not None
81
+ X = self.encode(*X, return_states=return_states)
82
+ if return_states:
83
+ *X, states = X
84
+ else:
85
+ states = None
86
+
87
+ if isinstance(X, torch.Tensor):
88
+ X = (X,)
89
+ return self.decode(*X, states=states)