braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ # Authors: Patryk Chrabaszcz
2
+ # Lukas Gemein <l.gemein@gmail.com>
3
+ #
4
+ # License: BSD-3
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import init
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import Chomp1d, Ensure4d, Expression, SqueezeFinalOutput
12
+
13
+
14
+ class BDTCN(EEGModuleMixin, nn.Module):
15
+ """Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
16
+
17
+ .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S1053811920305073-gr3_lrg.jpg
18
+ :align: center
19
+ :alt: Braindecode TCN Architecture
20
+
21
+ See [gemein2020]_ for details.
22
+
23
+ Parameters
24
+ ----------
25
+ n_filters: int
26
+ number of output filters of each convolution
27
+ n_blocks: int
28
+ number of temporal blocks in the network
29
+ kernel_size: int
30
+ kernel size of the convolutions
31
+ drop_prob: float
32
+ dropout probability
33
+ activation: nn.Module, default=nn.ReLU
34
+ Activation function class to apply. Should be a PyTorch activation
35
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
36
+
37
+ References
38
+ ----------
39
+ .. [gemein2020] Gemein, L. A., Schirrmeister, R. T., Chrabąszcz, P., Wilson, D.,
40
+ Boedecker, J., Schulze-Bonhage, A., ... & Ball, T. (2020). Machine-learning-based
41
+ diagnostics of EEG pathology. NeuroImage, 220, 117021.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ # Braindecode parameters
47
+ n_chans=None,
48
+ n_outputs=None,
49
+ chs_info=None,
50
+ n_times=None,
51
+ sfreq=None,
52
+ input_window_seconds=None,
53
+ # Model's parameters
54
+ n_blocks=3,
55
+ n_filters=30,
56
+ kernel_size=5,
57
+ drop_prob=0.5,
58
+ activation: nn.Module = nn.ReLU,
59
+ ):
60
+ super().__init__(
61
+ n_outputs=n_outputs,
62
+ n_chans=n_chans,
63
+ chs_info=chs_info,
64
+ n_times=n_times,
65
+ input_window_seconds=input_window_seconds,
66
+ sfreq=sfreq,
67
+ )
68
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
69
+
70
+ self.base_tcn = TCN(
71
+ n_chans=self.n_chans,
72
+ n_outputs=self.n_outputs,
73
+ n_blocks=n_blocks,
74
+ n_filters=n_filters,
75
+ kernel_size=kernel_size,
76
+ drop_prob=drop_prob,
77
+ activation=activation,
78
+ )
79
+
80
+ self.final_layer = torch.nn.Sequential(
81
+ torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
82
+ )
83
+
84
+ def forward(self, x):
85
+ x = self.base_tcn(x)
86
+ x = self.final_layer(x)
87
+ return x
88
+
89
+
90
+ class TCN(nn.Module):
91
+ """Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
92
+
93
+ See [Bai2018]_ for details.
94
+
95
+ Code adapted from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py
96
+
97
+ Parameters
98
+ ----------
99
+ n_filters: int
100
+ number of output filters of each convolution
101
+ n_blocks: int
102
+ number of temporal blocks in the network
103
+ kernel_size: int
104
+ kernel size of the convolutions
105
+ drop_prob: float
106
+ dropout probability
107
+ activation: nn.Module, default=nn.ReLU
108
+ Activation function class to apply. Should be a PyTorch activation
109
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
110
+
111
+ References
112
+ ----------
113
+ .. [Bai2018] Bai, S., Kolter, J. Z., & Koltun, V. (2018).
114
+ An empirical evaluation of generic convolutional and recurrent networks
115
+ for sequence modeling.
116
+ arXiv preprint arXiv:1803.01271.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ n_chans=None,
122
+ n_outputs=None,
123
+ n_blocks=3,
124
+ n_filters=30,
125
+ kernel_size=5,
126
+ drop_prob=0.5,
127
+ activation: nn.Module = nn.ReLU,
128
+ ):
129
+ super().__init__()
130
+ self.mapping = {
131
+ "fc.weight": "final_layer.fc.weight",
132
+ "fc.bias": "final_layer.fc.bias",
133
+ }
134
+ self.ensuredims = Ensure4d()
135
+ t_blocks = nn.Sequential()
136
+ for i in range(n_blocks):
137
+ n_inputs = n_chans if i == 0 else n_filters
138
+ dilation_size = 2**i
139
+ t_blocks.add_module(
140
+ "temporal_block_{:d}".format(i),
141
+ _TemporalBlock(
142
+ n_inputs=n_inputs,
143
+ n_outputs=n_filters,
144
+ kernel_size=kernel_size,
145
+ stride=1,
146
+ dilation=dilation_size,
147
+ padding=(kernel_size - 1) * dilation_size,
148
+ drop_prob=drop_prob,
149
+ activation=activation,
150
+ ),
151
+ )
152
+ self.temporal_blocks = t_blocks
153
+
154
+ self.final_layer = _FinalLayer(
155
+ in_features=n_filters,
156
+ out_features=n_outputs,
157
+ )
158
+ self.min_len = 1
159
+ for i in range(n_blocks):
160
+ dilation = 2**i
161
+ self.min_len += 2 * (kernel_size - 1) * dilation
162
+
163
+ # start in eval mode
164
+ self.train()
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ """Forward pass.
168
+
169
+ Parameters
170
+ ----------
171
+ x: torch.Tensor
172
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
173
+ """
174
+ x = self.ensuredims(x)
175
+ # x is in format: B x C x T x 1
176
+ (batch_size, _, time_size, _) = x.size()
177
+ assert time_size >= self.min_len
178
+ # remove empty trailing dimension
179
+ x = x.squeeze(3)
180
+ x = self.temporal_blocks(x)
181
+ # Convert to: B x T x C
182
+ x = x.transpose(1, 2).contiguous()
183
+
184
+ out = self.final_layer(x, batch_size, time_size, self.min_len)
185
+
186
+ return out
187
+
188
+
189
+ class _FinalLayer(nn.Module):
190
+ def __init__(self, in_features, out_features):
191
+ super().__init__()
192
+
193
+ self.fc = nn.Linear(in_features=in_features, out_features=out_features)
194
+
195
+ self.out_fun = nn.Identity()
196
+
197
+ self.squeeze = SqueezeFinalOutput()
198
+
199
+ def forward(
200
+ self, x: torch.Tensor, batch_size: int, time_size: int, min_len: int
201
+ ) -> torch.Tensor:
202
+ fc_out = self.fc(x.view(batch_size * time_size, x.size(2)))
203
+ fc_out = self.out_fun(fc_out)
204
+ fc_out = fc_out.view(batch_size, time_size, fc_out.size(1))
205
+
206
+ out_size = 1 + max(0, time_size - min_len)
207
+ out = fc_out[:, -out_size:, :].transpose(1, 2)
208
+ # re-add 4th dimension for compatibility with braindecode
209
+ return self.squeeze(out[:, :, :, None])
210
+
211
+
212
+ class _TemporalBlock(nn.Module):
213
+ def __init__(
214
+ self,
215
+ n_inputs,
216
+ n_outputs,
217
+ kernel_size,
218
+ stride,
219
+ dilation,
220
+ padding,
221
+ drop_prob,
222
+ activation: nn.Module = nn.ReLU,
223
+ ):
224
+ super().__init__()
225
+ self.conv1 = weight_norm(
226
+ nn.Conv1d(
227
+ n_inputs,
228
+ n_outputs,
229
+ kernel_size,
230
+ stride=stride,
231
+ padding=padding,
232
+ dilation=dilation,
233
+ )
234
+ )
235
+ self.chomp1 = Chomp1d(padding)
236
+ self.relu1 = activation()
237
+ self.dropout1 = nn.Dropout2d(drop_prob)
238
+
239
+ self.conv2 = weight_norm(
240
+ nn.Conv1d(
241
+ n_outputs,
242
+ n_outputs,
243
+ kernel_size,
244
+ stride=stride,
245
+ padding=padding,
246
+ dilation=dilation,
247
+ )
248
+ )
249
+ self.chomp2 = Chomp1d(padding)
250
+ self.relu2 = activation()
251
+ self.dropout2 = nn.Dropout2d(drop_prob)
252
+
253
+ self.downsample = (
254
+ nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
255
+ )
256
+ self.relu = activation()
257
+
258
+ init.normal_(self.conv1.weight, 0, 0.01)
259
+ init.normal_(self.conv2.weight, 0, 0.01)
260
+ if self.downsample is not None:
261
+ init.normal_(self.downsample.weight, 0, 0.01)
262
+
263
+ def forward(self, x):
264
+ out = self.conv1(x)
265
+ out = self.chomp1(out)
266
+ out = self.relu1(out)
267
+ out = self.dropout1(out)
268
+ out = self.conv2(out)
269
+ out = self.chomp2(out)
270
+ out = self.relu2(out)
271
+ out = self.dropout2(out)
272
+ res = x if self.downsample is None else self.downsample(x)
273
+ return self.relu(out + res)
@@ -0,0 +1,395 @@
1
+ from math import ceil
2
+
3
+ import torch
4
+ from einops.layers.torch import Rearrange
5
+ from torch import nn
6
+ from torch.nn import init
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import Ensure4d
11
+
12
+
13
+ class TIDNet(EEGModuleMixin, nn.Module):
14
+ """Thinker Invariance DenseNet model from Kostas et al. (2020) [TIDNet]_.
15
+
16
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/17/5/056008/revision3/jneabb7a7f1_hr.jpg
17
+ :align: center
18
+ :alt: TIDNet Architecture
19
+
20
+ See [TIDNet]_ for details.
21
+
22
+ Parameters
23
+ ----------
24
+ s_growth : int
25
+ DenseNet-style growth factor (added filters per DenseFilter)
26
+ t_filters : int
27
+ Number of temporal filters.
28
+ drop_prob : float
29
+ Dropout probability
30
+ pooling : int
31
+ Max temporal pooling (width and stride)
32
+ temp_layers : int
33
+ Number of temporal layers
34
+ spat_layers : int
35
+ Number of DenseFilters
36
+ temp_span : float
37
+ Percentage of n_times that defines the temporal filter length:
38
+ temp_len = ceil(temp_span * n_times)
39
+ e.g A value of 0.05 for temp_span with 1500 n_times will yield a temporal
40
+ filter of length 75.
41
+ bottleneck : int
42
+ Bottleneck factor within Densefilter
43
+ summary : int
44
+ Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated
45
+ automatically (n_times // pooling).
46
+ in_chans :
47
+ Alias for n_chans.
48
+ n_classes:
49
+ Alias for n_outputs.
50
+ input_window_samples :
51
+ Alias for n_times.
52
+ activation: nn.Module, default=nn.LeakyReLU
53
+ Activation function class to apply. Should be a PyTorch activation
54
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.LeakyReLU``.
55
+
56
+ Notes
57
+ -----
58
+ Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/
59
+
60
+ References
61
+ ----------
62
+ .. [TIDNet] Kostas, D. & Rudzicz, F.
63
+ Thinker invariance: enabling deep neural networks for BCI across more
64
+ people.
65
+ J. Neural Eng. 17, 056008 (2020).
66
+ doi: 10.1088/1741-2552/abb7a7.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ n_chans=None,
72
+ n_outputs=None,
73
+ n_times=None,
74
+ input_window_seconds=None,
75
+ sfreq=None,
76
+ chs_info=None,
77
+ s_growth: int = 24,
78
+ t_filters: int = 32,
79
+ drop_prob: float = 0.4,
80
+ pooling: int = 15,
81
+ temp_layers: int = 2,
82
+ spat_layers: int = 2,
83
+ temp_span: float = 0.05,
84
+ bottleneck: int = 3,
85
+ summary: int = -1,
86
+ activation: nn.Module = nn.LeakyReLU,
87
+ ):
88
+ super().__init__(
89
+ n_outputs=n_outputs,
90
+ n_chans=n_chans,
91
+ n_times=n_times,
92
+ input_window_seconds=input_window_seconds,
93
+ sfreq=sfreq,
94
+ chs_info=chs_info,
95
+ )
96
+ del n_outputs, n_chans, n_times, input_window_seconds, sfreq, chs_info
97
+
98
+ self.temp_len = ceil(temp_span * self.n_times)
99
+
100
+ self.dscnn = _TIDNetFeatures(
101
+ s_growth=s_growth,
102
+ t_filters=t_filters,
103
+ n_chans=self.n_chans,
104
+ n_times=self.n_times,
105
+ drop_prob=drop_prob,
106
+ pooling=pooling,
107
+ temp_layers=temp_layers,
108
+ spat_layers=spat_layers,
109
+ temp_span=temp_span,
110
+ bottleneck=bottleneck,
111
+ summary=summary,
112
+ activation=activation,
113
+ )
114
+
115
+ self._num_features = self.dscnn.num_features
116
+
117
+ self.flatten = nn.Flatten(start_dim=1)
118
+
119
+ self.final_layer = self._create_classifier(self.num_features, self.n_outputs)
120
+
121
+ def _create_classifier(self, incoming: int, n_outputs: int):
122
+ classifier = nn.Linear(incoming, n_outputs)
123
+ init.xavier_normal_(classifier.weight)
124
+ classifier.bias.data.zero_()
125
+ seq_clf = nn.Sequential(classifier, nn.Identity())
126
+
127
+ return seq_clf
128
+
129
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
130
+ """Forward pass.
131
+
132
+ Parameters
133
+ ----------
134
+ x: torch.Tensor
135
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
136
+ """
137
+
138
+ x = self.dscnn(x)
139
+ x = self.flatten(x)
140
+ return self.final_layer(x)
141
+
142
+ @property
143
+ def num_features(self):
144
+ return self._num_features
145
+
146
+
147
+ class _BatchNormZG(nn.BatchNorm2d):
148
+ def reset_parameters(self):
149
+ if self.track_running_stats:
150
+ self.running_mean.zero_()
151
+ self.running_var.fill_(1)
152
+ if self.affine:
153
+ self.weight.data.zero_()
154
+ self.bias.data.zero_()
155
+
156
+
157
+ class _ConvBlock2D(nn.Module):
158
+ """Implements Convolution block with order:
159
+ Convolution, dropout, activation, batch-norm
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ in_filters: int,
165
+ out_filters: int,
166
+ kernel: tuple[int, int],
167
+ stride: tuple[int, int] = (1, 1),
168
+ padding: int = 0,
169
+ dilation: int = 1,
170
+ groups: int = 1,
171
+ drop_prob: float = 0.5,
172
+ batch_norm: bool = True,
173
+ activation: type[nn.Module] = nn.LeakyReLU,
174
+ residual: bool = False,
175
+ ):
176
+ super().__init__()
177
+ self.kernel = kernel
178
+ self.activation = activation()
179
+ self.residual = residual
180
+
181
+ self.conv = nn.Conv2d(
182
+ in_filters,
183
+ out_filters,
184
+ kernel,
185
+ stride=stride,
186
+ padding=padding,
187
+ dilation=dilation,
188
+ groups=groups,
189
+ bias=not batch_norm,
190
+ )
191
+ self.dropout = nn.Dropout2d(p=float(drop_prob))
192
+ self.batch_norm = (
193
+ _BatchNormZG(out_filters)
194
+ if residual
195
+ else nn.BatchNorm2d(out_filters)
196
+ if batch_norm
197
+ else nn.Identity()
198
+ )
199
+
200
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
201
+ res = input
202
+ input = self.conv(
203
+ input,
204
+ )
205
+ input = self.dropout(input)
206
+ input = self.activation(input)
207
+ input = self.batch_norm(input)
208
+ return input + res if self.residual else input
209
+
210
+
211
+ class _DenseFilter(nn.Module):
212
+ def __init__(
213
+ self,
214
+ in_features: int,
215
+ growth_rate: int,
216
+ filter_len: int = 5,
217
+ drop_prob: float = 0.5,
218
+ bottleneck: int = 2,
219
+ activation: type[nn.Module] = nn.LeakyReLU,
220
+ dim: int = -2,
221
+ ):
222
+ super().__init__()
223
+ dim = dim if dim > 0 else dim + 4
224
+ if dim < 2 or dim > 3:
225
+ raise ValueError("Only last two dimensions supported")
226
+ kernel = (filter_len, 1) if dim == 2 else (1, filter_len)
227
+
228
+ self.net = nn.Sequential(
229
+ nn.BatchNorm2d(in_features),
230
+ activation(),
231
+ nn.Conv2d(in_features, bottleneck * growth_rate, 1),
232
+ nn.BatchNorm2d(bottleneck * growth_rate),
233
+ activation(),
234
+ nn.Conv2d(
235
+ bottleneck * growth_rate,
236
+ growth_rate,
237
+ kernel,
238
+ padding=tuple((k // 2 for k in kernel)),
239
+ ),
240
+ nn.Dropout2d(p=float(drop_prob)),
241
+ )
242
+
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
244
+ return torch.cat((x, self.net(x)), dim=1)
245
+
246
+
247
+ class _DenseSpatialFilter(nn.Module):
248
+ def __init__(
249
+ self,
250
+ n_chans: int,
251
+ growth: int,
252
+ depth: int,
253
+ in_ch: int = 1,
254
+ bottleneck: int = 4,
255
+ drop_prob: float = 0.0,
256
+ activation: type[nn.Module] = nn.LeakyReLU,
257
+ collapse: bool = True,
258
+ ):
259
+ super().__init__()
260
+ self.net = nn.Sequential(
261
+ *[
262
+ _DenseFilter(
263
+ in_ch + growth * d,
264
+ growth,
265
+ bottleneck=bottleneck,
266
+ drop_prob=drop_prob,
267
+ activation=activation,
268
+ )
269
+ for d in range(depth)
270
+ ]
271
+ )
272
+ n_filters = in_ch + growth * depth
273
+ self.collapse = collapse
274
+ if collapse:
275
+ self.channel_collapse = _ConvBlock2D(
276
+ n_filters, n_filters, (n_chans, 1), drop_prob=0, activation=activation
277
+ )
278
+
279
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
280
+ if len(x.shape) < 4:
281
+ x = x.unsqueeze(1).permute([0, 1, 3, 2])
282
+ x = self.net(x)
283
+ if self.collapse:
284
+ return self.channel_collapse(x).squeeze(-2)
285
+ return x
286
+
287
+
288
+ class _TemporalFilter(nn.Module):
289
+ def __init__(
290
+ self,
291
+ n_chans: int,
292
+ filters: int,
293
+ depth: int,
294
+ temp_len: int,
295
+ drop_prob: float = 0.0,
296
+ activation: type[nn.Module] = nn.LeakyReLU,
297
+ residual: str = "netwise",
298
+ ):
299
+ super().__init__()
300
+ temp_len = temp_len + 1 - temp_len % 2
301
+ self.residual_style = str(residual)
302
+ net = list()
303
+
304
+ for i in range(depth):
305
+ dil = depth - i
306
+ conv = weight_norm(
307
+ nn.Conv2d(
308
+ n_chans if i == 0 else filters,
309
+ filters,
310
+ kernel_size=(1, temp_len),
311
+ dilation=dil,
312
+ padding=(0, dil * (temp_len - 1) // 2),
313
+ )
314
+ )
315
+ net.append(
316
+ nn.Sequential(conv, activation(), nn.Dropout2d(p=float(drop_prob)))
317
+ )
318
+ if self.residual_style.lower() == "netwise":
319
+ self.net = nn.Sequential(*net)
320
+ self.residual = nn.Conv2d(n_chans, filters, (1, 1))
321
+ elif residual.lower() == "dense":
322
+ self.net = net
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ style = self.residual_style.lower()
326
+ if style == "netwise":
327
+ return self.net(x) + self.residual(x)
328
+ elif style == "dense":
329
+ for layer in self.net:
330
+ x = torch.cat((x, layer(x)), dim=1)
331
+ return x
332
+ # TorchScript now knows this path always returns or errors
333
+ else:
334
+ # Use an assertion so TorchScript can compile it
335
+ assert False, f"Unsupported residual style: {self.residual_style}"
336
+
337
+
338
+ class _TIDNetFeatures(nn.Module):
339
+ def __init__(
340
+ self,
341
+ s_growth: int,
342
+ t_filters: int,
343
+ n_chans: int,
344
+ n_times: int,
345
+ drop_prob: float,
346
+ pooling: int,
347
+ temp_layers: int,
348
+ spat_layers: int,
349
+ temp_span: float,
350
+ bottleneck: int,
351
+ summary: int,
352
+ activation: type[nn.Module] = nn.LeakyReLU,
353
+ ):
354
+ super().__init__()
355
+ self.n_chans = n_chans
356
+ self.temp_len = ceil(temp_span * n_times)
357
+
358
+ self.temporal = nn.Sequential(
359
+ Ensure4d(),
360
+ Rearrange("batch C T 1 -> batch 1 C T"),
361
+ _TemporalFilter(
362
+ 1,
363
+ t_filters,
364
+ depth=temp_layers,
365
+ temp_len=self.temp_len,
366
+ activation=activation,
367
+ ),
368
+ nn.MaxPool2d((1, pooling)),
369
+ nn.Dropout2d(p=float(drop_prob)),
370
+ )
371
+ summary = n_times // pooling if summary == -1 else summary
372
+
373
+ self.spatial = _DenseSpatialFilter(
374
+ n_chans=n_chans,
375
+ growth=s_growth,
376
+ depth=spat_layers,
377
+ in_ch=t_filters,
378
+ drop_prob=drop_prob,
379
+ bottleneck=bottleneck,
380
+ activation=activation,
381
+ )
382
+ self.extract_features = nn.Sequential(
383
+ nn.AdaptiveAvgPool1d(int(summary)), nn.Flatten(start_dim=1)
384
+ )
385
+
386
+ self._num_features = (t_filters + s_growth * spat_layers) * summary
387
+
388
+ @property
389
+ def num_features(self):
390
+ return self._num_features
391
+
392
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
393
+ x = self.temporal(x)
394
+ x = self.spatial(x)
395
+ return self.extract_features(x)