braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,397 @@
1
+ from math import ceil
2
+
3
+ import torch
4
+ from einops.layers.torch import Rearrange
5
+ from torch import nn
6
+ from torch.nn import init
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import Ensure4d
11
+
12
+
13
+ class TIDNet(EEGModuleMixin, nn.Module):
14
+ r"""Thinker Invariance DenseNet model from Kostas et al (2020) [TIDNet]_.
15
+
16
+ :bdg-success:`Convolution`
17
+
18
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/17/5/056008/revision3/jneabb7a7f1_hr.jpg
19
+ :align: center
20
+ :alt: TIDNet Architecture
21
+
22
+ See [TIDNet]_ for details.
23
+
24
+ Parameters
25
+ ----------
26
+ s_growth : int
27
+ DenseNet-style growth factor (added filters per DenseFilter)
28
+ t_filters : int
29
+ Number of temporal filters.
30
+ drop_prob : float
31
+ Dropout probability
32
+ pooling : int
33
+ Max temporal pooling (width and stride)
34
+ temp_layers : int
35
+ Number of temporal layers
36
+ spat_layers : int
37
+ Number of DenseFilters
38
+ temp_span : float
39
+ Percentage of n_times that defines the temporal filter length:
40
+ temp_len = ceil(temp_span * n_times)
41
+ e.g A value of 0.05 for temp_span with 1500 n_times will yield a temporal
42
+ filter of length 75.
43
+ bottleneck : int
44
+ Bottleneck factor within Densefilter
45
+ summary : int
46
+ Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated
47
+ automatically (n_times // pooling).
48
+ in_chans :
49
+ Alias for n_chans.
50
+ n_classes:
51
+ Alias for n_outputs.
52
+ input_window_samples :
53
+ Alias for n_times.
54
+ activation: nn.Module, default=nn.LeakyReLU
55
+ Activation function class to apply. Should be a PyTorch activation
56
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.LeakyReLU``.
57
+
58
+ Notes
59
+ -----
60
+ Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/
61
+
62
+ References
63
+ ----------
64
+ .. [TIDNet] Kostas, D. & Rudzicz, F.
65
+ Thinker invariance: enabling deep neural networks for BCI across more
66
+ people.
67
+ J. Neural Eng. 17, 056008 (2020).
68
+ doi: 10.1088/1741-2552/abb7a7.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ n_chans=None,
74
+ n_outputs=None,
75
+ n_times=None,
76
+ input_window_seconds=None,
77
+ sfreq=None,
78
+ chs_info=None,
79
+ s_growth: int = 24,
80
+ t_filters: int = 32,
81
+ drop_prob: float = 0.4,
82
+ pooling: int = 15,
83
+ temp_layers: int = 2,
84
+ spat_layers: int = 2,
85
+ temp_span: float = 0.05,
86
+ bottleneck: int = 3,
87
+ summary: int = -1,
88
+ activation: type[nn.Module] = nn.LeakyReLU,
89
+ ):
90
+ super().__init__(
91
+ n_outputs=n_outputs,
92
+ n_chans=n_chans,
93
+ n_times=n_times,
94
+ input_window_seconds=input_window_seconds,
95
+ sfreq=sfreq,
96
+ chs_info=chs_info,
97
+ )
98
+ del n_outputs, n_chans, n_times, input_window_seconds, sfreq, chs_info
99
+
100
+ self.temp_len = ceil(temp_span * self.n_times)
101
+
102
+ self.dscnn = _TIDNetFeatures(
103
+ s_growth=s_growth,
104
+ t_filters=t_filters,
105
+ n_chans=self.n_chans,
106
+ n_times=self.n_times,
107
+ drop_prob=drop_prob,
108
+ pooling=pooling,
109
+ temp_layers=temp_layers,
110
+ spat_layers=spat_layers,
111
+ temp_span=temp_span,
112
+ bottleneck=bottleneck,
113
+ summary=summary,
114
+ activation=activation,
115
+ )
116
+
117
+ self._num_features = self.dscnn.num_features
118
+
119
+ self.flatten = nn.Flatten(start_dim=1)
120
+
121
+ self.final_layer = self._create_classifier(self.num_features, self.n_outputs)
122
+
123
+ def _create_classifier(self, incoming: int, n_outputs: int):
124
+ classifier = nn.Linear(incoming, n_outputs)
125
+ init.xavier_normal_(classifier.weight)
126
+ classifier.bias.data.zero_()
127
+ seq_clf = nn.Sequential(classifier, nn.Identity())
128
+
129
+ return seq_clf
130
+
131
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
132
+ """Forward pass.
133
+
134
+ Parameters
135
+ ----------
136
+ x: torch.Tensor
137
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
138
+ """
139
+
140
+ x = self.dscnn(x)
141
+ x = self.flatten(x)
142
+ return self.final_layer(x)
143
+
144
+ @property
145
+ def num_features(self):
146
+ return self._num_features
147
+
148
+
149
+ class _BatchNormZG(nn.BatchNorm2d):
150
+ def reset_parameters(self):
151
+ if self.track_running_stats:
152
+ self.running_mean.zero_()
153
+ self.running_var.fill_(1)
154
+ if self.affine:
155
+ self.weight.data.zero_()
156
+ self.bias.data.zero_()
157
+
158
+
159
+ class _ConvBlock2D(nn.Module):
160
+ r"""Implements Convolution block with order:
161
+ Convolution, dropout, activation, batch-norm
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ in_filters: int,
167
+ out_filters: int,
168
+ kernel: tuple[int, int],
169
+ stride: tuple[int, int] = (1, 1),
170
+ padding: int = 0,
171
+ dilation: int = 1,
172
+ groups: int = 1,
173
+ drop_prob: float = 0.5,
174
+ batch_norm: bool = True,
175
+ activation: type[nn.Module] = nn.LeakyReLU,
176
+ residual: bool = False,
177
+ ):
178
+ super().__init__()
179
+ self.kernel = kernel
180
+ self.activation = activation()
181
+ self.residual = residual
182
+
183
+ self.conv = nn.Conv2d(
184
+ in_filters,
185
+ out_filters,
186
+ kernel,
187
+ stride=stride,
188
+ padding=padding,
189
+ dilation=dilation,
190
+ groups=groups,
191
+ bias=not batch_norm,
192
+ )
193
+ self.dropout = nn.Dropout2d(p=float(drop_prob))
194
+ self.batch_norm = (
195
+ _BatchNormZG(out_filters)
196
+ if residual
197
+ else nn.BatchNorm2d(out_filters)
198
+ if batch_norm
199
+ else nn.Identity()
200
+ )
201
+
202
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
203
+ res = input
204
+ input = self.conv(
205
+ input,
206
+ )
207
+ input = self.dropout(input)
208
+ input = self.activation(input)
209
+ input = self.batch_norm(input)
210
+ return input + res if self.residual else input
211
+
212
+
213
+ class _DenseFilter(nn.Module):
214
+ def __init__(
215
+ self,
216
+ in_features: int,
217
+ growth_rate: int,
218
+ filter_len: int = 5,
219
+ drop_prob: float = 0.5,
220
+ bottleneck: int = 2,
221
+ activation: type[nn.Module] = nn.LeakyReLU,
222
+ dim: int = -2,
223
+ ):
224
+ super().__init__()
225
+ dim = dim if dim > 0 else dim + 4
226
+ if dim < 2 or dim > 3:
227
+ raise ValueError("Only last two dimensions supported")
228
+ kernel = (filter_len, 1) if dim == 2 else (1, filter_len)
229
+
230
+ self.net = nn.Sequential(
231
+ nn.BatchNorm2d(in_features),
232
+ activation(),
233
+ nn.Conv2d(in_features, bottleneck * growth_rate, 1),
234
+ nn.BatchNorm2d(bottleneck * growth_rate),
235
+ activation(),
236
+ nn.Conv2d(
237
+ bottleneck * growth_rate,
238
+ growth_rate,
239
+ kernel,
240
+ padding=tuple((k // 2 for k in kernel)),
241
+ ),
242
+ nn.Dropout2d(p=float(drop_prob)),
243
+ )
244
+
245
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
246
+ return torch.cat((x, self.net(x)), dim=1)
247
+
248
+
249
+ class _DenseSpatialFilter(nn.Module):
250
+ def __init__(
251
+ self,
252
+ n_chans: int,
253
+ growth: int,
254
+ depth: int,
255
+ in_ch: int = 1,
256
+ bottleneck: int = 4,
257
+ drop_prob: float = 0.0,
258
+ activation: type[nn.Module] = nn.LeakyReLU,
259
+ collapse: bool = True,
260
+ ):
261
+ super().__init__()
262
+ self.net = nn.Sequential(
263
+ *[
264
+ _DenseFilter(
265
+ in_ch + growth * d,
266
+ growth,
267
+ bottleneck=bottleneck,
268
+ drop_prob=drop_prob,
269
+ activation=activation,
270
+ )
271
+ for d in range(depth)
272
+ ]
273
+ )
274
+ n_filters = in_ch + growth * depth
275
+ self.collapse = collapse
276
+ if collapse:
277
+ self.channel_collapse = _ConvBlock2D(
278
+ n_filters, n_filters, (n_chans, 1), drop_prob=0, activation=activation
279
+ )
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ if len(x.shape) < 4:
283
+ x = x.unsqueeze(1).permute([0, 1, 3, 2])
284
+ x = self.net(x)
285
+ if self.collapse:
286
+ return self.channel_collapse(x).squeeze(-2)
287
+ return x
288
+
289
+
290
+ class _TemporalFilter(nn.Module):
291
+ def __init__(
292
+ self,
293
+ n_chans: int,
294
+ filters: int,
295
+ depth: int,
296
+ temp_len: int,
297
+ drop_prob: float = 0.0,
298
+ activation: type[nn.Module] = nn.LeakyReLU,
299
+ residual: str = "netwise",
300
+ ):
301
+ super().__init__()
302
+ temp_len = temp_len + 1 - temp_len % 2
303
+ self.residual_style = str(residual)
304
+ net = list()
305
+
306
+ for i in range(depth):
307
+ dil = depth - i
308
+ conv = weight_norm(
309
+ nn.Conv2d(
310
+ n_chans if i == 0 else filters,
311
+ filters,
312
+ kernel_size=(1, temp_len),
313
+ dilation=dil,
314
+ padding=(0, dil * (temp_len - 1) // 2),
315
+ )
316
+ )
317
+ net.append(
318
+ nn.Sequential(conv, activation(), nn.Dropout2d(p=float(drop_prob)))
319
+ )
320
+ if self.residual_style.lower() == "netwise":
321
+ self.net = nn.Sequential(*net)
322
+ self.residual = nn.Conv2d(n_chans, filters, (1, 1))
323
+ elif residual.lower() == "dense":
324
+ self.net = net
325
+
326
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
327
+ style = self.residual_style.lower()
328
+ if style == "netwise":
329
+ return self.net(x) + self.residual(x)
330
+ elif style == "dense":
331
+ for layer in self.net:
332
+ x = torch.cat((x, layer(x)), dim=1)
333
+ return x
334
+ # TorchScript now knows this path always returns or errors
335
+ else:
336
+ # Use an assertion so TorchScript can compile it
337
+ assert False, f"Unsupported residual style: {self.residual_style}"
338
+
339
+
340
+ class _TIDNetFeatures(nn.Module):
341
+ def __init__(
342
+ self,
343
+ s_growth: int,
344
+ t_filters: int,
345
+ n_chans: int,
346
+ n_times: int,
347
+ drop_prob: float,
348
+ pooling: int,
349
+ temp_layers: int,
350
+ spat_layers: int,
351
+ temp_span: float,
352
+ bottleneck: int,
353
+ summary: int,
354
+ activation: type[nn.Module] = nn.LeakyReLU,
355
+ ):
356
+ super().__init__()
357
+ self.n_chans = n_chans
358
+ self.temp_len = ceil(temp_span * n_times)
359
+
360
+ self.temporal = nn.Sequential(
361
+ Ensure4d(),
362
+ Rearrange("batch C T 1 -> batch 1 C T"),
363
+ _TemporalFilter(
364
+ 1,
365
+ t_filters,
366
+ depth=temp_layers,
367
+ temp_len=self.temp_len,
368
+ activation=activation,
369
+ ),
370
+ nn.MaxPool2d((1, pooling)),
371
+ nn.Dropout2d(p=float(drop_prob)),
372
+ )
373
+ summary = n_times // pooling if summary == -1 else summary
374
+
375
+ self.spatial = _DenseSpatialFilter(
376
+ n_chans=n_chans,
377
+ growth=s_growth,
378
+ depth=spat_layers,
379
+ in_ch=t_filters,
380
+ drop_prob=drop_prob,
381
+ bottleneck=bottleneck,
382
+ activation=activation,
383
+ )
384
+ self.extract_features = nn.Sequential(
385
+ nn.AdaptiveAvgPool1d(int(summary)), nn.Flatten(start_dim=1)
386
+ )
387
+
388
+ self._num_features = (t_filters + s_growth * spat_layers) * summary
389
+
390
+ @property
391
+ def num_features(self):
392
+ return self._num_features
393
+
394
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
395
+ x = self.temporal(x)
396
+ x = self.spatial(x)
397
+ return self.extract_features(x)
@@ -0,0 +1,295 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops.layers.torch import Rearrange
10
+ from mne.utils import deprecated, warn
11
+
12
+ from braindecode.models.base import EEGModuleMixin
13
+
14
+
15
+ class TSception(EEGModuleMixin, nn.Module):
16
+ r"""TSception model from Ding et al. (2020) from [ding2020]_.
17
+
18
+ :bdg-success:`Convolution`
19
+
20
+ TSception: A deep learning framework for emotion detection using EEG.
21
+
22
+ .. figure:: https://user-images.githubusercontent.com/58539144/74716976-80415e00-526a-11ea-9433-02ab2b753f6b.PNG
23
+ :align: center
24
+ :alt: TSception Architecture
25
+
26
+ The model consists of temporal and spatial convolutional layers
27
+ (Tception and Sception) designed to learn temporal and spatial features
28
+ from EEG data.
29
+
30
+ Parameters
31
+ ----------
32
+ number_filter_temp : int
33
+ Number of temporal convolutional filters.
34
+ number_filter_spat : int
35
+ Number of spatial convolutional filters.
36
+ hidden_size : int
37
+ Number of units in the hidden fully connected layer.
38
+ drop_prob : float
39
+ Dropout rate applied after the hidden layer.
40
+ activation : nn.Module, optional
41
+ Activation function class to apply. Should be a PyTorch activation
42
+ module like ``nn.ReLU`` or ``nn.LeakyReLU``. Default is ``nn.LeakyReLU``.
43
+ pool_size : int, optional
44
+ Pooling size for the average pooling layers. Default is 8.
45
+ inception_windows : list[float], optional
46
+ List of window sizes (in seconds) for the inception modules.
47
+ Default is [0.5, 0.25, 0.125].
48
+
49
+ Notes
50
+ -----
51
+ This implementation is not guaranteed to be correct, has not been checked
52
+ by original authors. The modifications are minimal and the model is expected
53
+ to work as intended. the original code from [code2020]_.
54
+
55
+ References
56
+ ----------
57
+ .. [ding2020] Ding, Y., Robinson, N., Zeng, Q., Chen, D., Wai, A. A. P.,
58
+ Lee, T. S., & Guan, C. (2020, July). Tsception: a deep learning framework
59
+ for emotion detection using EEG. In 2020 international joint conference
60
+ on neural networks (IJCNN) (pp. 1-7). IEEE.
61
+ .. [code2020] Ding, Y., Robinson, N., Zeng, Q., Chen, D., Wai, A. A. P.,
62
+ Lee, T. S., & Guan, C. (2020, July). Tsception: a deep learning framework
63
+ for emotion detection using EEG.
64
+ https://github.com/deepBrains/TSception/blob/master/Models.py
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ # Braindecode parameters
70
+ n_chans=None,
71
+ n_outputs=None,
72
+ input_window_seconds=None,
73
+ chs_info=None,
74
+ n_times=None,
75
+ sfreq=None,
76
+ # Model parameters
77
+ number_filter_temp: int = 9,
78
+ number_filter_spat: int = 6,
79
+ hidden_size: int = 128,
80
+ drop_prob: float = 0.5,
81
+ activation: type[nn.Module] = nn.LeakyReLU,
82
+ pool_size: int = 8,
83
+ inception_windows: tuple[float, float, float] = (0.5, 0.25, 0.125),
84
+ ):
85
+ super().__init__(
86
+ n_outputs=n_outputs,
87
+ n_chans=n_chans,
88
+ chs_info=chs_info,
89
+ n_times=n_times,
90
+ input_window_seconds=input_window_seconds,
91
+ sfreq=sfreq,
92
+ )
93
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
94
+
95
+ self.activation = activation
96
+ self.pool_size = pool_size
97
+ self.inception_windows = inception_windows
98
+ self.number_filter_spat = number_filter_spat
99
+ self.number_filter_temp = number_filter_temp
100
+ self.drop_prob = drop_prob
101
+
102
+ ### Layers
103
+ self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
104
+ if self.input_window_seconds < max(self.inception_windows):
105
+ inception_windows = (
106
+ self.input_window_seconds,
107
+ self.input_window_seconds / 2,
108
+ self.input_window_seconds / 4,
109
+ )
110
+ warning_msg = (
111
+ "Input window size is smaller than the maximum inception window size. "
112
+ "We are adjusting the input window size to match the maximum inception window size.\n"
113
+ f"Original input window size: {self.inception_windows}, \n"
114
+ f"Adjusted inception windows: {inception_windows}"
115
+ )
116
+ warn(warning_msg, UserWarning)
117
+ self.inception_windows = inception_windows
118
+ # Define temporal convolutional layers (Tception)
119
+ self.temporal_blocks = nn.ModuleList()
120
+ for window in self.inception_windows:
121
+ # 1. Calculate the temporal kernel size for this block
122
+ kernel_size_t = int(window * self.sfreq)
123
+
124
+ # 2. Calculate the output length of the convolution
125
+ conv_out_len = self.n_times - kernel_size_t + 1
126
+
127
+ # 3. Ensure the pooling size is not larger than the conv output
128
+ # and is at least 1.
129
+ dynamic_pool_size = max(1, min(self.pool_size, conv_out_len))
130
+
131
+ # 4. Create the block with the dynamic pooling size
132
+ block = self._conv_block(
133
+ in_channels=1,
134
+ out_channels=self.number_filter_temp,
135
+ kernel_size=(1, kernel_size_t),
136
+ stride=1,
137
+ pool_size=dynamic_pool_size, # Use the dynamic size
138
+ activation=self.activation,
139
+ )
140
+ self.temporal_blocks.append(block)
141
+
142
+ self.batch_temporal_lay = nn.BatchNorm2d(self.number_filter_temp)
143
+
144
+ # Define spatial convolutional layers (Sception)
145
+
146
+ pool_size_spat = self.pool_size // 4
147
+
148
+ self.spatial_block_1 = self._conv_block(
149
+ in_channels=self.number_filter_temp,
150
+ out_channels=self.number_filter_spat,
151
+ kernel_size=(self.n_chans, 1),
152
+ stride=1,
153
+ pool_size=pool_size_spat,
154
+ activation=self.activation,
155
+ )
156
+
157
+ kernel_size_spat_2 = (max(1, self.n_chans // 2), 1)
158
+
159
+ self.spatial_block_2 = self._conv_block(
160
+ in_channels=self.number_filter_temp,
161
+ out_channels=self.number_filter_spat,
162
+ kernel_size=kernel_size_spat_2,
163
+ stride=kernel_size_spat_2,
164
+ pool_size=pool_size_spat,
165
+ activation=self.activation,
166
+ )
167
+ self.batch_spatial_lay = nn.BatchNorm2d(self.number_filter_spat)
168
+
169
+ # Calculate the size of the features after convolution and pooling layers
170
+ self.feature_size = self._calculate_feature_size()
171
+ # self.feature_size = self.number_filter_spat *
172
+ # Define the final classification layers
173
+
174
+ self.dense_layer = nn.Sequential(
175
+ nn.Linear(self.feature_size, hidden_size),
176
+ self.activation(),
177
+ nn.Dropout(self.drop_prob),
178
+ )
179
+
180
+ self.final_layer = nn.Linear(hidden_size, self.n_outputs)
181
+
182
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
183
+ """
184
+ Forward pass of the TSception model.
185
+
186
+ Parameters
187
+ ----------
188
+ x : torch.Tensor
189
+ Input tensor of shape (batch_size, n_channels, n_times).
190
+
191
+ Returns
192
+ -------
193
+ torch.Tensor
194
+ Output tensor of shape (batch_size, n_classes).
195
+ """
196
+ # Temporal Convolution
197
+ # shape: (batch_size, n_channels, n_times)
198
+ x = self.ensuredim(x)
199
+ # shape: (batch_size, 1, n_channels, n_times)
200
+
201
+ t_features = [layer(x) for layer in self.temporal_blocks]
202
+ # shape: (batch_size, number_filter_temp, n_channels,
203
+ #
204
+ t_out = torch.cat(t_features, dim=-1)
205
+
206
+ t_out = self.batch_temporal_lay(t_out)
207
+
208
+ # Spatial Convolution
209
+ s_out1 = self.spatial_block_1(t_out)
210
+ s_out2 = self.spatial_block_2(t_out)
211
+ s_out = torch.cat((s_out1, s_out2), dim=2)
212
+ s_out = self.batch_spatial_lay(s_out)
213
+
214
+ # Flatten and apply final layers
215
+ s_out = s_out.view(s_out.size(0), -1)
216
+ output = self.dense_layer(s_out)
217
+ output = self.final_layer(output)
218
+ return output
219
+
220
+ def _calculate_feature_size(self) -> int:
221
+ """
222
+ Calculates the size of the features after convolution and pooling layers.
223
+
224
+ Returns
225
+ -------
226
+ int
227
+ Flattened size of the features after convolution and pooling layers.
228
+ """
229
+ with torch.no_grad():
230
+ dummy_input = torch.ones(1, 1, self.n_chans, self.n_times)
231
+ t_features = [layer(dummy_input) for layer in self.temporal_blocks]
232
+ t_out = torch.cat(t_features, dim=-1)
233
+ t_out = self.batch_temporal_lay(t_out)
234
+
235
+ s_out1 = self.spatial_block_1(t_out)
236
+ s_out2 = self.spatial_block_2(t_out)
237
+ s_out = torch.cat((s_out1, s_out2), dim=2)
238
+ s_out = self.batch_spatial_lay(s_out)
239
+
240
+ feature_size = s_out.view(1, -1).size(1)
241
+ return feature_size
242
+
243
+ @staticmethod
244
+ def _conv_block(
245
+ in_channels: int,
246
+ out_channels: int,
247
+ kernel_size: tuple,
248
+ stride: tuple[int, int] | int,
249
+ pool_size: int,
250
+ activation: nn.Module,
251
+ ) -> nn.Sequential:
252
+ """
253
+ Creates a convolutional block with Conv2d, activation, and AvgPool2d layers.
254
+
255
+ Parameters
256
+ ----------
257
+ in_channels : int
258
+ Number of input channels.
259
+ out_channels : int
260
+ Number of output channels.
261
+ kernel_size : tuple
262
+ Size of the convolutional kernel.
263
+ stride : int
264
+ Stride of the convolution.
265
+ pool_size : int
266
+ Size of the pooling kernel.
267
+ activation : nn.Module
268
+ Activation function class.
269
+
270
+ Returns
271
+ -------
272
+ nn.Sequential
273
+ A sequential container of the convolutional block.
274
+ """
275
+ return nn.Sequential(
276
+ nn.Conv2d(
277
+ in_channels=in_channels,
278
+ out_channels=out_channels,
279
+ kernel_size=kernel_size,
280
+ stride=stride,
281
+ padding=0,
282
+ ),
283
+ activation(),
284
+ nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_size)),
285
+ )
286
+
287
+
288
+ @deprecated(
289
+ "`TSceptionV1` was renamed to `TSception` in v1.12; "
290
+ "this alias will be removed in v1.14."
291
+ )
292
+ class TSceptionV1(TSception):
293
+ r"""Deprecated alias for TSception."""
294
+
295
+ pass