braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,417 @@
1
+ # Authors: Théo Gnassounou <theo.gnassounou@inria.fr>
2
+ #
3
+ # License: BSD (3-clause)
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class DeepSleepNet(EEGModuleMixin, nn.Module):
11
+ r"""DeepSleepNet from Supratak et al (2017) [Supratak2017]_.
12
+
13
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
14
+
15
+ .. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/master/img/deepsleepnet.png
16
+ :align: center
17
+ :alt: DeepSleepNet Architecture
18
+ :width: 700px
19
+
20
+ .. rubric:: Architectural Overview
21
+
22
+ DeepSleepNet couples **dual-path convolution neural network representation learning** with
23
+ **sequence residual learning** via bidirectional LSTMs.
24
+
25
+ The network have:
26
+
27
+ - (i) learns complementary, time-frequency features from each
28
+ 30-s epoch using **two parallel CNNs** (small vs. large first-layer filters), then
29
+ - (ii) models **temporal dependencies across epochs** using **two-layer BiLSTMs**
30
+ with a **residual shortcut** from the CNN features, and finally
31
+ - (iii) outputs per-epoch sleep stages. This design encodes both
32
+ epoch-local patterns and longer-range transition rules used by human scorers.
33
+
34
+ In term of implementation:
35
+
36
+ - (i) :class:`_RepresentationLearning` two CNNs extract epoch-wise features
37
+ (small-filter path for temporal precision; large-filter path for frequency precision);
38
+ - (ii) :class:`_SequenceResidualLearning` stacked BiLSTMs with peepholes + residual shortcut
39
+ inject temporal context while preserving CNN evidence;
40
+ - (iii) :class:`_Classifier` linear readout (softmax) for the five sleep stages.
41
+
42
+ .. rubric:: Macro Components
43
+
44
+ - :class:`_RepresentationLearning` **(dual-path CNN → epoch feature)**
45
+
46
+ - *Operations.*
47
+ - **Small-filter CNN** 4 times:
48
+
49
+ - :class:`~torch.nn.Conv1d`
50
+ - :class:`~torch.nn.BatchNorm1d`
51
+ - :class:`~torch.nn.ReLU`
52
+ - :class:`~torch.nn.MaxPool1d` after.
53
+
54
+ - First conv uses **filter length ≈ Fs/2** and **stride ≈ Fs/16** to emphasize *timing* of graphoelements.
55
+
56
+ - **Large-filter CNN**:
57
+
58
+ - Same stack but first conv uses **filter length ≈ 4·Fs** and
59
+ - **stride ≈ Fs/2** to emphasize *frequency* content.
60
+
61
+ - Outputs from both paths are **concatenated** into the epoch embedding ``a_t``.
62
+
63
+ - *Rationale.*
64
+ Two first-layer scales provide a **learned, dual-scale filter bank** that trades
65
+ temporal vs. frequency precision without hand-crafted features.
66
+
67
+ - :class:`_SequenceResidualLearning` (:class:`~torch.nn.BiLSTM` **context + residual fusion)**
68
+
69
+ - *Operations.*
70
+ - **Two-layer BiLSTM** with **peephole connections** processes the sequence of epoch embeddings
71
+ ``{a_t}`` forward and backward; hidden states from both directions are **concatenated**.
72
+ - A **shortcut MLP** (fully connected + :class:`~torch.nn.BatchNorm1d` + :class:`~torch.nn.ReLU`) projects ``a_t`` to the BiLSTM output
73
+ dimension and is **added** (residual) to the :class:`~torch.nn.BiLSTM` output at each time step.
74
+ - *Role.* Encodes **stage-transition rules** and smooths predictions over time while preserving
75
+ salient CNN features via the residual path.
76
+
77
+ - :class:`_Classifier` **(epoch-wise prediction)**
78
+
79
+ - *Operations.*
80
+ - :class:`~torch.nn.Linear` to produce per-epoch class probabilities.
81
+
82
+ Original training uses two-step optimization: CNN pretraining on class-balanced data,
83
+ then end-to-end fine-tuning with sequential batches.
84
+
85
+ .. rubric:: Convolutional Details
86
+
87
+ - **Temporal (where time-domain patterns are learned).**
88
+
89
+ Both CNN paths use **1-D temporal convolutions**. The *small-filter* path (first kernel ≈ Fs/2,
90
+ stride ≈ Fs/16) captures *when* characteristic transients occur; the *large-filter* path
91
+ (first kernel ≈ 4·Fs, stride ≈ Fs/2) captures *which* frequency components dominate over the
92
+ epoch. Deeper layers use **small kernels** to refine features with fewer parameters, interleaved
93
+ with **max pooling** for downsampling.
94
+
95
+ - **Spatial (how channels are processed).**
96
+
97
+ The original model operates on **single-channel** raw EEG; convolutions therefore mix only
98
+ along time (no spatial convolution across electrodes).
99
+
100
+ - **Spectral (how frequency information emerges).**
101
+
102
+ No explicit Fourier/wavelet transform is used. The **large-filter path** serves as a
103
+ *frequency-sensitive* analyzer, while the **small-filter path** remains *time-sensitive*,
104
+ together functioning as a **two-band learned filter bank** at the first layer.
105
+
106
+ .. rubric:: Attention / Sequential Modules
107
+
108
+ - **Type.** **Bidirectional LSTM** (two layers) with **peephole connections**; forward and
109
+ backward streams are independent and concatenated.
110
+ - **Shapes.** For a sequence of ``N`` epochs, the CNN produces ``{a_t} ∈ R^{D}``;
111
+ BiLSTM outputs ``h_t ∈ R^{2H}``; the shortcut MLP maps ``a_t → R^{2H}`` to enable
112
+ **element-wise residual addition**.
113
+ - **Role.** Models **long-range temporal dependencies** (e.g., persisting N2 without visible
114
+ K-complex/spindles), stabilizing per-epoch predictions.
115
+
116
+
117
+ .. rubric:: Additional Mechanisms
118
+
119
+ - **Residual shortcut over sequence encoder.** Adds projected CNN features to BiLSTM outputs,
120
+ improving gradient flow and retaining discriminative content from representation learning.
121
+ - **Two-step training.**
122
+
123
+ - (i) **Pretrain** the CNN paths with class-balanced sampling;
124
+ - (ii) **fine-tune** the full network with sequential batches, using **lower LR** for CNNs and **higher LR** for the
125
+ sequence encoder.
126
+
127
+ - **State handling.** BiLSTM states are **reinitialized per subject** so that temporal context
128
+ does not leak across recordings.
129
+
130
+
131
+ .. rubric:: Usage and Configuration
132
+
133
+ - **Epoch pipeline.** Use **two parallel CNNs** with the first conv sized to **Fs/2** (small path)
134
+ and **4·Fs** (large path), with strides **Fs/16** and **Fs/2**, respectively; stack three more
135
+ conv blocks with small kernels, plus **max pooling** in each path. Concatenate path outputs
136
+ to form epoch embeddings.
137
+ - **Sequence encoder.** Apply **two-layer BiLSTM (peepholes)** over the sequence of embeddings;
138
+ add a **projection MLP** on the CNN features and **sum** with BiLSTM outputs (residual).
139
+ Finish with :class:`~torch.nn.Linear` per epoch.
140
+ - **Reference implementation.** See the official repository for a faithful implementation and
141
+ training scripts.
142
+
143
+ Parameters
144
+ ----------
145
+ activation_large: nn.Module, default=nn.ELU
146
+ Activation function class to apply. Should be a PyTorch activation
147
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
148
+ activation_small: nn.Module, default=nn.ReLU
149
+ Activation function class to apply. Should be a PyTorch activation
150
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
151
+ return_feats : bool
152
+ If True, return the features, i.e. the output of the feature extractor
153
+ (before the final linear layer). If False, pass the features through
154
+ the final linear layer.
155
+ drop_prob : float, default=0.5
156
+ The dropout rate for regularization. Values should be between 0 and 1.
157
+
158
+ References
159
+ ----------
160
+ .. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
161
+ DeepSleepNet: A model for automatic sleep stage scoring based
162
+ on raw single-channel EEG. IEEE Transactions on Neural Systems
163
+ and Rehabilitation Engineering, 25(11), 1998-2008.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ n_outputs=5,
169
+ return_feats=False,
170
+ n_chans=None,
171
+ chs_info=None,
172
+ n_times=None,
173
+ input_window_seconds=None,
174
+ sfreq=None,
175
+ activation_large: type[nn.Module] = nn.ELU,
176
+ activation_small: type[nn.Module] = nn.ReLU,
177
+ drop_prob: float = 0.5,
178
+ ):
179
+ super().__init__(
180
+ n_outputs=n_outputs,
181
+ n_chans=n_chans,
182
+ chs_info=chs_info,
183
+ n_times=n_times,
184
+ input_window_seconds=input_window_seconds,
185
+ sfreq=sfreq,
186
+ )
187
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
188
+ self.cnn1 = _SmallCNN(activation=activation_small, drop_prob=drop_prob)
189
+ self.cnn2 = _LargeCNN(activation=activation_large, drop_prob=drop_prob)
190
+ self.dropout = nn.Dropout(0.5)
191
+ self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
192
+ self.fc = nn.Sequential(
193
+ nn.Linear(3072, 1024, bias=False), nn.BatchNorm1d(num_features=1024)
194
+ )
195
+
196
+ self.features_extractor = nn.Identity()
197
+ self.len_last_layer = 1024
198
+ self.return_feats = return_feats
199
+
200
+ # TODO: Add new way to handle return_features == True
201
+ if not return_feats:
202
+ self.final_layer = nn.Linear(1024, self.n_outputs)
203
+ else:
204
+ self.final_layer = nn.Identity()
205
+
206
+ def forward(self, x):
207
+ """Forward pass.
208
+
209
+ Parameters
210
+ ----------
211
+ x: torch.Tensor
212
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
213
+ """
214
+
215
+ if x.ndim == 3:
216
+ x = x.unsqueeze(1)
217
+
218
+ x1 = self.cnn1(x)
219
+ x1 = x1.flatten(start_dim=1)
220
+
221
+ x2 = self.cnn2(x)
222
+ x2 = x2.flatten(start_dim=1)
223
+
224
+ x = torch.cat((x1, x2), dim=1)
225
+ x = self.dropout(x)
226
+ temp = x.clone()
227
+ temp = self.fc(temp)
228
+ x = x.unsqueeze(1)
229
+ x = self.bilstm(x)
230
+ x = x.squeeze()
231
+ x = torch.add(x, temp)
232
+ x = self.dropout(x)
233
+
234
+ feats = self.features_extractor(x)
235
+
236
+ if self.return_feats:
237
+ return feats
238
+ else:
239
+ return self.final_layer(feats)
240
+
241
+
242
+ class _SmallCNN(nn.Module):
243
+ r"""
244
+ Smaller filter sizes to learn temporal information.
245
+
246
+ Parameters
247
+ ----------
248
+ activation: nn.Module, default=nn.ReLU
249
+ Activation function class to apply. Should be a PyTorch activation
250
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
251
+ drop_prob : float, default=0.5
252
+ The dropout rate for regularization. Values should be between 0 and 1.
253
+ """
254
+
255
+ def __init__(self, activation: type[nn.Module] = nn.ReLU, drop_prob: float = 0.5):
256
+ super().__init__()
257
+ self.conv1 = nn.Sequential(
258
+ nn.Conv2d(
259
+ in_channels=1,
260
+ out_channels=64,
261
+ kernel_size=(1, 50),
262
+ stride=(1, 6),
263
+ padding=(0, 22),
264
+ bias=False,
265
+ ),
266
+ nn.BatchNorm2d(num_features=64),
267
+ activation(),
268
+ )
269
+ self.pool1 = nn.MaxPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 2))
270
+ self.dropout = nn.Dropout(p=drop_prob)
271
+ self.conv2 = nn.Sequential(
272
+ nn.Conv2d(
273
+ in_channels=64,
274
+ out_channels=128,
275
+ kernel_size=(1, 8),
276
+ stride=1,
277
+ padding="same",
278
+ bias=False,
279
+ ),
280
+ nn.BatchNorm2d(num_features=128),
281
+ activation(),
282
+ )
283
+ self.conv3 = nn.Sequential(
284
+ nn.Conv2d(
285
+ in_channels=128,
286
+ out_channels=128,
287
+ kernel_size=(1, 8),
288
+ stride=1,
289
+ padding="same",
290
+ bias=False,
291
+ ),
292
+ nn.BatchNorm2d(num_features=128),
293
+ activation(),
294
+ )
295
+ self.conv4 = nn.Sequential(
296
+ nn.Conv2d(
297
+ in_channels=128,
298
+ out_channels=128,
299
+ kernel_size=(1, 8),
300
+ stride=1,
301
+ padding="same",
302
+ bias=False,
303
+ ),
304
+ nn.BatchNorm2d(num_features=128),
305
+ activation(),
306
+ )
307
+ self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
308
+
309
+ def forward(self, x):
310
+ x = self.conv1(x)
311
+ x = self.dropout(self.pool1(x))
312
+ x = self.conv2(x)
313
+ x = self.conv3(x)
314
+ x = self.conv4(x)
315
+ x = self.pool2(x)
316
+ return x
317
+
318
+
319
+ class _LargeCNN(nn.Module):
320
+ r"""
321
+ Larger filter sizes to learn frequency information.
322
+
323
+ Parameters
324
+ ----------
325
+ activation: nn.Module, default=nn.ELU
326
+ Activation function class to apply. Should be a PyTorch activation
327
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
328
+
329
+ """
330
+
331
+ def __init__(self, activation: type[nn.Module] = nn.ELU, drop_prob: float = 0.5):
332
+ super().__init__()
333
+
334
+ self.conv1 = nn.Sequential(
335
+ nn.Conv2d(
336
+ in_channels=1,
337
+ out_channels=64,
338
+ kernel_size=(1, 400),
339
+ stride=(1, 50),
340
+ padding=(0, 175),
341
+ bias=False,
342
+ ),
343
+ nn.BatchNorm2d(num_features=64),
344
+ activation(),
345
+ )
346
+ self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
347
+ self.dropout = nn.Dropout(p=drop_prob)
348
+ self.conv2 = nn.Sequential(
349
+ nn.Conv2d(
350
+ in_channels=64,
351
+ out_channels=128,
352
+ kernel_size=(1, 6),
353
+ stride=1,
354
+ padding="same",
355
+ bias=False,
356
+ ),
357
+ nn.BatchNorm2d(num_features=128),
358
+ activation(),
359
+ )
360
+ self.conv3 = nn.Sequential(
361
+ nn.Conv2d(
362
+ in_channels=128,
363
+ out_channels=128,
364
+ kernel_size=(1, 6),
365
+ stride=1,
366
+ padding="same",
367
+ bias=False,
368
+ ),
369
+ nn.BatchNorm2d(num_features=128),
370
+ activation(),
371
+ )
372
+ self.conv4 = nn.Sequential(
373
+ nn.Conv2d(
374
+ in_channels=128,
375
+ out_channels=128,
376
+ kernel_size=(1, 6),
377
+ stride=1,
378
+ padding="same",
379
+ bias=False,
380
+ ),
381
+ nn.BatchNorm2d(num_features=128),
382
+ activation(),
383
+ )
384
+ self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
385
+
386
+ def forward(self, x):
387
+ x = self.conv1(x)
388
+ x = self.dropout(self.pool1(x))
389
+ x = self.conv2(x)
390
+ x = self.conv3(x)
391
+ x = self.conv4(x)
392
+ x = self.pool2(x)
393
+ return x
394
+
395
+
396
+ class _BiLSTM(nn.Module):
397
+ def __init__(self, input_size, hidden_size, num_layers):
398
+ super(_BiLSTM, self).__init__()
399
+ self.hidden_size = hidden_size
400
+ self.num_layers = num_layers
401
+ self.lstm = nn.LSTM(
402
+ input_size,
403
+ hidden_size,
404
+ num_layers,
405
+ batch_first=True,
406
+ dropout=0.5,
407
+ bidirectional=True,
408
+ )
409
+
410
+ def forward(self, x):
411
+ # set initial hidden and cell states
412
+ h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
413
+ c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
414
+
415
+ # forward propagate LSTM
416
+ out, _ = self.lstm(x, (h0, c0))
417
+ return out