braindecode 1.0.0__py3-none-any.whl → 1.1.1.dev174934380__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (42) hide show
  1. braindecode/augmentation/transforms.py +0 -1
  2. braindecode/datautil/__init__.py +3 -0
  3. braindecode/datautil/serialization.py +13 -2
  4. braindecode/functional/__init__.py +12 -0
  5. braindecode/functional/functions.py +0 -1
  6. braindecode/models/__init__.py +48 -0
  7. braindecode/models/atcnet.py +46 -11
  8. braindecode/models/attentionbasenet.py +49 -0
  9. braindecode/models/biot.py +29 -8
  10. braindecode/models/contrawr.py +29 -8
  11. braindecode/models/ctnet.py +99 -13
  12. braindecode/models/deep4.py +52 -2
  13. braindecode/models/eegconformer.py +2 -3
  14. braindecode/models/eeginception_mi.py +9 -3
  15. braindecode/models/eegitnet.py +0 -1
  16. braindecode/models/eegminer.py +0 -1
  17. braindecode/models/eegnet.py +0 -1
  18. braindecode/models/fbcnet.py +1 -1
  19. braindecode/models/fbmsnet.py +0 -1
  20. braindecode/models/labram.py +23 -3
  21. braindecode/models/msvtnet.py +1 -1
  22. braindecode/models/sccnet.py +29 -4
  23. braindecode/models/signal_jepa.py +0 -1
  24. braindecode/models/sleep_stager_eldele_2021.py +0 -1
  25. braindecode/models/sparcnet.py +62 -16
  26. braindecode/models/tcn.py +1 -1
  27. braindecode/models/tsinception.py +38 -13
  28. braindecode/models/util.py +2 -6
  29. braindecode/modules/__init__.py +46 -0
  30. braindecode/modules/filter.py +0 -4
  31. braindecode/modules/layers.py +3 -5
  32. braindecode/modules/linear.py +1 -2
  33. braindecode/modules/util.py +0 -1
  34. braindecode/modules/wrapper.py +0 -2
  35. braindecode/samplers/base.py +0 -2
  36. braindecode/version.py +1 -1
  37. {braindecode-1.0.0.dist-info → braindecode-1.1.1.dev174934380.dist-info}/METADATA +8 -7
  38. {braindecode-1.0.0.dist-info → braindecode-1.1.1.dev174934380.dist-info}/RECORD +42 -42
  39. {braindecode-1.0.0.dist-info → braindecode-1.1.1.dev174934380.dist-info}/WHEEL +0 -0
  40. {braindecode-1.0.0.dist-info → braindecode-1.1.1.dev174934380.dist-info}/licenses/LICENSE.txt +0 -0
  41. {braindecode-1.0.0.dist-info → braindecode-1.1.1.dev174934380.dist-info}/licenses/NOTICE.txt +0 -0
  42. {braindecode-1.0.0.dist-info → braindecode-1.1.1.dev174934380.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  # License: BSD (3-clause)
4
4
 
5
5
  from einops.layers.torch import Rearrange
6
+ from mne.utils import warn
6
7
  from torch import nn
7
8
  from torch.nn import init
8
9
 
@@ -115,6 +116,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
115
116
  batch_norm=True,
116
117
  batch_norm_alpha=0.1,
117
118
  stride_before_pool=False,
119
+ # Braindecode EEGModuleMixin parameters
118
120
  chs_info=None,
119
121
  input_window_seconds=None,
120
122
  sfreq=None,
@@ -155,6 +157,27 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
155
157
  self.batch_norm_alpha = batch_norm_alpha
156
158
  self.stride_before_pool = stride_before_pool
157
159
 
160
+ min_n_times = self._get_min_n_times()
161
+ if self.n_times < min_n_times:
162
+ scaling_factor = self.n_times / min_n_times
163
+ warn(
164
+ f"n_times ({self.n_times}) is smaller than the minimum required "
165
+ f"({min_n_times}) for the current model parameters configuration. "
166
+ "Adjusting parameters to ensure compatibility."
167
+ "Reducing the kernel, pooling, and stride sizes accordingly."
168
+ "Scaling factor: {:.2f}".format(scaling_factor),
169
+ UserWarning,
170
+ )
171
+ # Calculate a scaling factor to adjust temporal parameters
172
+ # Apply the scaling factor to all temporal kernel and pooling sizes
173
+ self.filter_time_length = max(
174
+ 1, int(self.filter_time_length * scaling_factor)
175
+ )
176
+ self.pool_time_length = max(1, int(self.pool_time_length * scaling_factor))
177
+ self.pool_time_stride = max(1, int(self.pool_time_stride * scaling_factor))
178
+ self.filter_length_2 = max(1, int(self.filter_length_2 * scaling_factor))
179
+ self.filter_length_3 = max(1, int(self.filter_length_3 * scaling_factor))
180
+ self.filter_length_4 = max(1, int(self.filter_length_4 * scaling_factor))
158
181
  # For the load_state_dict
159
182
  # When padronize all layers,
160
183
  # add the old's parameters here
@@ -268,7 +291,6 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
268
291
  self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4
269
292
  )
270
293
 
271
- # self.add_module('drop_classifier', nn.Dropout(p=self.drop_prob))
272
294
  self.eval()
273
295
  if self.final_conv_length == "auto":
274
296
  self.final_conv_length = self.get_output_shape()[2]
@@ -299,7 +321,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
299
321
  if self.split_first_layer:
300
322
  init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
301
323
  if not self.batch_norm:
302
- init.constant_(self.conv_spat.bias, 0)
324
+ init.constant_(self.conv_time_spat.conv_spat.bias, 0)
303
325
  if self.batch_norm:
304
326
  init.constant_(self.bnorm.weight, 1)
305
327
  init.constant_(self.bnorm.bias, 0)
@@ -320,3 +342,31 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
320
342
  init.constant_(self.final_layer.conv_classifier.bias, 0)
321
343
 
322
344
  self.train()
345
+
346
+ def _get_min_n_times(self) -> int:
347
+ """
348
+ Calculate the minimum number of time samples required for the model
349
+ to work with the given temporal parameters.
350
+ """
351
+ # Start with the minimum valid output length of the network (1)
352
+ min_len = 1
353
+
354
+ # List of conv kernel sizes and pool parameters for the 4 blocks, in reverse order
355
+ # Each tuple: (filter_length, pool_length, pool_stride)
356
+ block_params = [
357
+ (self.filter_length_4, self.pool_time_length, self.pool_time_stride),
358
+ (self.filter_length_3, self.pool_time_length, self.pool_time_stride),
359
+ (self.filter_length_2, self.pool_time_length, self.pool_time_stride),
360
+ (self.filter_time_length, self.pool_time_length, self.pool_time_stride),
361
+ ]
362
+
363
+ # Work backward from the last layer to the input
364
+ for filter_len, pool_len, pool_stride in block_params:
365
+ # Reverse the pooling operation
366
+ # L_in = stride * (L_out - 1) + kernel_size
367
+ min_len = pool_stride * (min_len - 1) + pool_len
368
+ # Reverse the convolution operation (assuming stride=1)
369
+ # L_in = L_out + kernel_size - 1
370
+ min_len = min_len + filter_len - 1
371
+
372
+ return min_len
@@ -2,11 +2,8 @@
2
2
  #
3
3
  # License: BSD (3-clause)
4
4
  import warnings
5
- from typing import Optional
6
5
 
7
6
  import torch
8
- import torch.nn.functional as F
9
- from einops import rearrange
10
7
  from einops.layers.torch import Rearrange
11
8
  from torch import Tensor, nn
12
9
 
@@ -150,6 +147,8 @@ class EEGConformer(EEGModuleMixin, nn.Module):
150
147
  if final_fc_length == "auto":
151
148
  assert self.n_times is not None
152
149
  self.final_fc_length = self.get_fc_size()
150
+ else:
151
+ self.final_fc_length = final_fc_length
153
152
 
154
153
  self.transformer = _TransformerEncoder(
155
154
  att_depth=att_depth,
@@ -66,9 +66,9 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
66
66
  n_outputs=None,
67
67
  input_window_seconds=None,
68
68
  sfreq=250,
69
- n_convs=5,
70
- n_filters=48,
71
- kernel_unit_s=0.1,
69
+ n_convs: int = 5,
70
+ n_filters: int = 48,
71
+ kernel_unit_s: float = 0.1,
72
72
  activation: nn.Module = nn.ReLU,
73
73
  chs_info=None,
74
74
  n_times=None,
@@ -307,6 +307,12 @@ class _InceptionModuleMI(nn.Module):
307
307
 
308
308
  X2 = self.pooling(X)
309
309
  X2 = self.pooling_conv(X2)
310
+ # Get the target length from one of the conv branches
311
+ target_len = X1[0].shape[-1]
312
+
313
+ # Crop the pooling output if its length does not match
314
+ if X2.shape[-1] != target_len:
315
+ X2 = X2[..., :target_len]
310
316
 
311
317
  out = torch.cat(X1 + [X2], 1)
312
318
 
@@ -1,7 +1,6 @@
1
1
  # Authors: Ghaith Bouallegue <ghaithbouallegue@gmail.com>
2
2
  #
3
3
  # License: BSD-3
4
- import torch
5
4
  from einops.layers.torch import Rearrange
6
5
  from torch import nn
7
6
 
@@ -10,7 +10,6 @@ from functools import partial
10
10
  import torch
11
11
  from einops.layers.torch import Rearrange
12
12
  from torch import nn
13
- from torch.fft import fftfreq
14
13
 
15
14
  import braindecode.functional as F
16
15
  from braindecode.models.base import EEGModuleMixin
@@ -14,7 +14,6 @@ from braindecode.models.base import EEGModuleMixin
14
14
  from braindecode.modules import (
15
15
  Conv2dWithConstraint,
16
16
  Ensure4d,
17
- Expression,
18
17
  LinearWithConstraint,
19
18
  SqueezeFinalOutput,
20
19
  )
@@ -5,7 +5,7 @@ from typing import Any
5
5
  import torch
6
6
  from einops.layers.torch import Rearrange
7
7
  from mne.utils import warn
8
- from torch import Tensor, nn
8
+ from torch import nn
9
9
 
10
10
  from braindecode.models.base import EEGModuleMixin
11
11
  from braindecode.modules import (
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial
4
3
  from typing import Optional, Sequence
5
4
 
6
5
  import torch
@@ -166,11 +166,23 @@ class Labram(EEGModuleMixin, nn.Module):
166
166
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
167
167
 
168
168
  self.patch_size = patch_size
169
- self.n_path = self.n_times // patch_size
170
169
  self.num_features = self.emb_size = emb_size
171
170
  self.neural_tokenizer = neural_tokenizer
172
171
  self.init_scale = init_scale
173
172
 
173
+ if patch_size > self.n_times:
174
+ warn(
175
+ f"patch_size ({patch_size}) > n_times ({self.n_times}); "
176
+ f"setting patch_size = {self.n_times}.",
177
+ UserWarning,
178
+ )
179
+ self.patch_size = self.n_times
180
+ self.num_features = None
181
+ self.emb_size = None
182
+ else:
183
+ self.patch_size = patch_size
184
+ self.n_path = self.n_times // self.patch_size
185
+
174
186
  if neural_tokenizer and in_channels != 1:
175
187
  warn(
176
188
  "The model is in Neural Tokenizer mode, but the variable "
@@ -220,8 +232,17 @@ class Labram(EEGModuleMixin, nn.Module):
220
232
  emb_dim=self.emb_size,
221
233
  ),
222
234
  )
235
+
236
+ with torch.no_grad():
237
+ dummy = torch.zeros(1, self.n_chans, self.n_times)
238
+ out = self.patch_embed(dummy)
239
+ # out.shape for tokenizer: (1, n_chans, emb_dim)
240
+ # for decoder: (1, n_patch, patch_size, emb_dim), but we want last dim
241
+ self.emb_size = out.shape[-1]
242
+ self.num_features = self.emb_size
243
+
223
244
  # Defining the parameters
224
- # Creating a parameter list with cls token
245
+ # Creating a parameter list with cls token]
225
246
  self.cls_token = nn.Parameter(torch.zeros(1, 1, self.emb_size))
226
247
  # Positional embedding and time embedding are complementary
227
248
  # one is for the spatial information and the other is for the temporal
@@ -366,7 +387,6 @@ class Labram(EEGModuleMixin, nn.Module):
366
387
  x : torch.Tensor
367
388
  The output of the model.
368
389
  """
369
-
370
390
  if self.neural_tokenizer:
371
391
  batch_size, nch, n_patch, temporal = self.patch_embed.segment_patch(x).shape
372
392
  else:
@@ -1,7 +1,7 @@
1
1
  # Authors: Tao Yang <sheeptao@outlook.com>
2
2
  # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
3
3
  #
4
- from typing import Dict, Optional, Type, Union
4
+ from typing import Type, Union
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
@@ -4,6 +4,7 @@
4
4
  # License: BSD (3-clause)
5
5
 
6
6
  import math
7
+ from warnings import warn
7
8
 
8
9
  import torch
9
10
  from einops.layers.torch import Rearrange
@@ -98,9 +99,33 @@ class SCCNet(EEGModuleMixin, nn.Module):
98
99
  self.n_spatial_filters_smooth = n_spatial_filters_smooth
99
100
  self.drop_prob = drop_prob
100
101
 
101
- self.samples_100ms = int(math.floor(self.sfreq * 0.1))
102
- self.kernel_size_pool = int(self.sfreq * 0.5)
103
- # Equivalent to 0.5 seconds
102
+ # Original logical for SCCNet
103
+ conv_kernel_time = 0.1 # 100ms
104
+ pool_kernel_time = 0.5 # 500ms
105
+
106
+ # Calculate sample-based sizes from time durations
107
+ conv_kernel_samples = int(math.floor(self.sfreq * conv_kernel_time))
108
+ pool_kernel_samples = int(math.floor(self.sfreq * pool_kernel_time))
109
+
110
+ # If the input window is too short for the default kernel sizes,
111
+ # scale them down proportionally.
112
+ total_kernel_samples = conv_kernel_samples + pool_kernel_samples
113
+
114
+ if self.n_times < total_kernel_samples:
115
+ warning_msg = (
116
+ f"Input window seconds ({self.input_window_seconds:.2f}s) is smaller than the "
117
+ f"model's combined kernel sizes ({(total_kernel_samples / self.sfreq):.2f}s). "
118
+ "Scaling temporal parameters down proportionally."
119
+ )
120
+ warn(warning_msg, UserWarning, stacklevel=2)
121
+
122
+ scaling_factor = self.n_times / total_kernel_samples
123
+ conv_kernel_samples = int(math.floor(conv_kernel_samples * scaling_factor))
124
+ pool_kernel_samples = int(math.floor(pool_kernel_samples * scaling_factor))
125
+
126
+ # Ensure kernels are at least 1 sample wide
127
+ self.samples_100ms = max(1, conv_kernel_samples)
128
+ self.kernel_size_pool = max(1, pool_kernel_samples)
104
129
 
105
130
  num_features = self._calc_num_features()
106
131
 
@@ -135,7 +160,7 @@ class SCCNet(EEGModuleMixin, nn.Module):
135
160
 
136
161
  self.dropout = nn.Dropout(self.drop_prob)
137
162
  self.temporal_smoothing = nn.AvgPool2d(
138
- kernel_size=(1, int(self.sfreq / 2)),
163
+ kernel_size=(1, self.kernel_size_pool),
139
164
  stride=(1, self.samples_100ms),
140
165
  )
141
166
 
@@ -8,7 +8,6 @@ from copy import deepcopy
8
8
  from typing import Any, Sequence
9
9
 
10
10
  import torch
11
- from einops import parse_shape, rearrange, repeat
12
11
  from einops.layers.torch import Rearrange
13
12
  from torch import nn
14
13
 
@@ -5,7 +5,6 @@
5
5
  import math
6
6
  import warnings
7
7
  from copy import deepcopy
8
- from typing import Callable, Optional
9
8
 
10
9
  import torch
11
10
  import torch.nn.functional as F
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from collections import OrderedDict
4
4
  from math import floor, log2
5
- from typing import Any
6
5
 
7
6
  import torch
8
7
  import torch.nn as nn
@@ -69,6 +68,19 @@ class SPARCNet(EEGModuleMixin, nn.Module):
69
68
  conv_bias: bool = True,
70
69
  batch_norm: bool = True,
71
70
  activation: nn.Module = nn.ELU,
71
+ kernel_size_conv0: int = 7,
72
+ kernel_size_conv1: int = 1,
73
+ kernel_size_conv2: int = 3,
74
+ kernel_size_pool: int = 3,
75
+ stride_pool: int = 2,
76
+ stride_conv0: int = 2,
77
+ stride_conv1: int = 1,
78
+ stride_conv2: int = 1,
79
+ padding_pool: int = 1,
80
+ padding_conv0: int = 3,
81
+ padding_conv2: int = 1,
82
+ kernel_size_trans: int = 2,
83
+ stride_trans: int = 2,
72
84
  # EEGModuleMixin parameters
73
85
  # (another way to present the same parameters)
74
86
  chs_info=None,
@@ -96,9 +108,9 @@ class SPARCNet(EEGModuleMixin, nn.Module):
96
108
  nn.Conv1d(
97
109
  in_channels=self.n_chans,
98
110
  out_channels=out_channels,
99
- kernel_size=7,
100
- stride=2,
101
- padding=3,
111
+ kernel_size=kernel_size_conv0,
112
+ stride=stride_conv0,
113
+ padding=padding_conv0,
102
114
  bias=conv_bias,
103
115
  ),
104
116
  )
@@ -106,7 +118,11 @@ class SPARCNet(EEGModuleMixin, nn.Module):
106
118
  )
107
119
  first_conv["norm0"] = nn.BatchNorm1d(out_channels)
108
120
  first_conv["act_layer"] = activation()
109
- first_conv["pool0"] = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
121
+ first_conv["pool0"] = nn.MaxPool1d(
122
+ kernel_size=kernel_size_pool,
123
+ stride=stride_pool,
124
+ padding=padding_pool,
125
+ )
110
126
 
111
127
  self.encoder = nn.Sequential(first_conv)
112
128
 
@@ -123,6 +139,11 @@ class SPARCNet(EEGModuleMixin, nn.Module):
123
139
  conv_bias=conv_bias,
124
140
  batch_norm=batch_norm,
125
141
  activation=activation,
142
+ kernel_size_conv1=kernel_size_conv1,
143
+ kernel_size_conv2=kernel_size_conv2,
144
+ stride_conv1=stride_conv1,
145
+ stride_conv2=stride_conv2,
146
+ padding_conv2=padding_conv2,
126
147
  )
127
148
  self.encoder.add_module("denseblock%d" % (n_layer + 1), block)
128
149
  # update the number of channels after each dense block
@@ -134,16 +155,19 @@ class SPARCNet(EEGModuleMixin, nn.Module):
134
155
  conv_bias=conv_bias,
135
156
  batch_norm=batch_norm,
136
157
  activation=activation,
158
+ kernel_size_trans=kernel_size_trans,
159
+ stride_trans=stride_trans,
137
160
  )
138
161
  self.encoder.add_module("transition%d" % (n_layer + 1), trans)
139
162
  # update the number of channels after each transition layer
140
163
  n_channels = n_channels // 2
141
164
 
165
+ self.adaptative_pool = nn.AdaptiveAvgPool1d(1)
166
+ self.activation_layer = activation()
167
+ self.flatten_layer = nn.Flatten()
168
+
142
169
  # add final convolutional layer
143
- self.final_layer = nn.Sequential(
144
- activation(),
145
- nn.Linear(n_channels, self.n_outputs),
146
- )
170
+ self.final_layer = nn.Linear(n_channels, self.n_outputs)
147
171
 
148
172
  self._init_weights()
149
173
 
@@ -178,7 +202,10 @@ class SPARCNet(EEGModuleMixin, nn.Module):
178
202
  torch.Tensor
179
203
  The output tensor of the model with shape (batch_size, n_outputs)
180
204
  """
181
- emb = self.encoder(X).squeeze(-1)
205
+ emb = self.encoder(X)
206
+ emb = self.adaptative_pool(emb)
207
+ emb = self.activation_layer(emb)
208
+ emb = self.flatten_layer(emb)
182
209
  out = self.final_layer(emb)
183
210
  return out
184
211
 
@@ -224,6 +251,11 @@ class _DenseLayer(nn.Sequential):
224
251
  conv_bias: bool = True,
225
252
  batch_norm: bool = True,
226
253
  activation: nn.Module = nn.ELU,
254
+ kernel_size_conv1: int = 1,
255
+ kernel_size_conv2: int = 3,
256
+ stride_conv1: int = 1,
257
+ stride_conv2: int = 1,
258
+ padding_conv2: int = 1,
227
259
  ):
228
260
  super().__init__()
229
261
  if batch_norm:
@@ -235,8 +267,8 @@ class _DenseLayer(nn.Sequential):
235
267
  nn.Conv1d(
236
268
  in_channels=in_channels,
237
269
  out_channels=bottleneck_size * growth_rate,
238
- kernel_size=1,
239
- stride=1,
270
+ kernel_size=kernel_size_conv1,
271
+ stride=stride_conv1,
240
272
  bias=conv_bias,
241
273
  ),
242
274
  )
@@ -248,9 +280,9 @@ class _DenseLayer(nn.Sequential):
248
280
  nn.Conv1d(
249
281
  in_channels=bottleneck_size * growth_rate,
250
282
  out_channels=growth_rate,
251
- kernel_size=3,
252
- stride=1,
253
- padding=1,
283
+ kernel_size=kernel_size_conv2,
284
+ stride=stride_conv2,
285
+ padding=padding_conv2,
254
286
  bias=conv_bias,
255
287
  ),
256
288
  )
@@ -311,6 +343,11 @@ class _DenseBlock(nn.Sequential):
311
343
  conv_bias=True,
312
344
  batch_norm=True,
313
345
  activation: nn.Module = nn.ELU,
346
+ kernel_size_conv1: int = 1,
347
+ kernel_size_conv2: int = 3,
348
+ stride_conv1: int = 1,
349
+ stride_conv2: int = 1,
350
+ padding_conv2: int = 1,
314
351
  ):
315
352
  super(_DenseBlock, self).__init__()
316
353
  for idx_layer in range(num_layers):
@@ -322,6 +359,11 @@ class _DenseBlock(nn.Sequential):
322
359
  conv_bias=conv_bias,
323
360
  batch_norm=batch_norm,
324
361
  activation=activation,
362
+ kernel_size_conv1=kernel_size_conv1,
363
+ kernel_size_conv2=kernel_size_conv2,
364
+ stride_conv1=stride_conv1,
365
+ stride_conv2=stride_conv2,
366
+ padding_conv2=padding_conv2,
325
367
  )
326
368
  self.add_module(f"denselayer{idx_layer + 1}", layer)
327
369
 
@@ -360,6 +402,8 @@ class _TransitionLayer(nn.Sequential):
360
402
  conv_bias=True,
361
403
  batch_norm=True,
362
404
  activation: nn.Module = nn.ELU,
405
+ kernel_size_trans: int = 2,
406
+ stride_trans: int = 2,
363
407
  ):
364
408
  super(_TransitionLayer, self).__init__()
365
409
  if batch_norm:
@@ -375,4 +419,6 @@ class _TransitionLayer(nn.Sequential):
375
419
  bias=conv_bias,
376
420
  ),
377
421
  )
378
- self.add_module("pool", nn.AvgPool1d(kernel_size=2, stride=2))
422
+ self.add_module(
423
+ "pool", nn.AvgPool1d(kernel_size=kernel_size_trans, stride=stride_trans)
424
+ )
braindecode/models/tcn.py CHANGED
@@ -8,7 +8,7 @@ from torch.nn import init
8
8
  from torch.nn.utils.parametrizations import weight_norm
9
9
 
10
10
  from braindecode.models.base import EEGModuleMixin
11
- from braindecode.modules import Chomp1d, Ensure4d, Expression, SqueezeFinalOutput
11
+ from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
12
12
 
13
13
 
14
14
  class BDTCN(EEGModuleMixin, nn.Module):
@@ -7,6 +7,7 @@ from __future__ import annotations
7
7
  import torch
8
8
  import torch.nn as nn
9
9
  from einops.layers.torch import Rearrange
10
+ from mne.utils import warn
10
11
 
11
12
  from braindecode.models.base import EEGModuleMixin
12
13
 
@@ -98,20 +99,44 @@ class TSceptionV1(EEGModuleMixin, nn.Module):
98
99
 
99
100
  ### Layers
100
101
  self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
102
+ if self.input_window_seconds < max(self.inception_windows):
103
+ inception_windows = (
104
+ self.input_window_seconds,
105
+ self.input_window_seconds / 2,
106
+ self.input_window_seconds / 4,
107
+ )
108
+ warning_msg = (
109
+ "Input window size is smaller than the maximum inception window size. "
110
+ "We are adjusting the input window size to match the maximum inception window size.\n"
111
+ f"Original input window size: {self.inception_windows}, \n"
112
+ f"Adjusted inception windows: {inception_windows}"
113
+ )
114
+ warn(warning_msg, UserWarning)
115
+ self.inception_windows = inception_windows
101
116
  # Define temporal convolutional layers (Tception)
102
- self.temporal_blocks = nn.ModuleList(
103
- [
104
- self._conv_block(
105
- in_channels=1,
106
- out_channels=number_filter_temp,
107
- kernel_size=(1, int(window * self.sfreq)),
108
- stride=1,
109
- pool_size=self.pool_size,
110
- activation=self.activation,
111
- )
112
- for window in self.inception_windows
113
- ]
114
- )
117
+ self.temporal_blocks = nn.ModuleList()
118
+ for window in self.inception_windows:
119
+ # 1. Calculate the temporal kernel size for this block
120
+ kernel_size_t = int(window * self.sfreq)
121
+
122
+ # 2. Calculate the output length of the convolution
123
+ conv_out_len = self.n_times - kernel_size_t + 1
124
+
125
+ # 3. Ensure the pooling size is not larger than the conv output
126
+ # and is at least 1.
127
+ dynamic_pool_size = max(1, min(self.pool_size, conv_out_len))
128
+
129
+ # 4. Create the block with the dynamic pooling size
130
+ block = self._conv_block(
131
+ in_channels=1,
132
+ out_channels=self.number_filter_temp,
133
+ kernel_size=(1, kernel_size_t),
134
+ stride=1,
135
+ pool_size=dynamic_pool_size, # Use the dynamic size
136
+ activation=self.activation,
137
+ )
138
+ self.temporal_blocks.append(block)
139
+
115
140
  self.batch_temporal_lay = nn.BatchNorm2d(self.number_filter_temp)
116
141
 
117
142
  # Define spatial convolutional layers (Sception)
@@ -5,11 +5,7 @@
5
5
  import inspect
6
6
  from pathlib import Path
7
7
 
8
- import numpy as np
9
8
  import pandas as pd
10
- import torch
11
- from scipy.special import log_softmax
12
- from sklearn.utils import deprecated
13
9
 
14
10
  import braindecode.models as models
15
11
 
@@ -76,12 +72,12 @@ models_mandatory_parameters = [
76
72
  ), # 1 channel
77
73
  ("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
78
74
  ("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=128.0)),
79
- ("BIOT", ["n_chans", "n_outputs", "sfreq"], None),
75
+ ("BIOT", ["n_chans", "n_outputs", "sfreq", "n_times"], None),
80
76
  ("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
81
77
  ("Labram", ["n_chans", "n_outputs", "n_times"], None),
82
78
  ("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
83
79
  ("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
84
- ("ContraWR", ["n_chans", "n_outputs", "sfreq"], dict(sfreq=200.0)),
80
+ ("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"], dict(sfreq=200.0)),
85
81
  ("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
86
82
  ("TSceptionV1", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
87
83
  ("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
@@ -36,3 +36,49 @@ from .stats import (
36
36
  )
37
37
  from .util import aggregate_probas
38
38
  from .wrapper import Expression, IntermediateOutputWrapper
39
+
40
+ __all__ = [
41
+ "LogActivation",
42
+ "SafeLog",
43
+ "CAT",
44
+ "CBAM",
45
+ "ECA",
46
+ "FCA",
47
+ "GCT",
48
+ "SRM",
49
+ "CATLite",
50
+ "EncNet",
51
+ "GatherExcite",
52
+ "GSoP",
53
+ "MultiHeadAttention",
54
+ "SqueezeAndExcitation",
55
+ "MLP",
56
+ "FeedForwardBlock",
57
+ "InceptionBlock",
58
+ "AvgPool2dWithConv",
59
+ "CausalConv1d",
60
+ "CombinedConv",
61
+ "Conv2dWithConstraint",
62
+ "DepthwiseConv2d",
63
+ "FilterBankLayer",
64
+ "GeneralizedGaussianFilter",
65
+ "Chomp1d",
66
+ "DropPath",
67
+ "Ensure4d",
68
+ "SqueezeFinalOutput",
69
+ "TimeDistributed",
70
+ "LinearWithConstraint",
71
+ "MaxNormLinear",
72
+ "MaxNorm",
73
+ "MaxNormParametrize",
74
+ "LogPowerLayer",
75
+ "LogVarLayer",
76
+ "MaxLayer",
77
+ "MeanLayer",
78
+ "StatLayer",
79
+ "StdLayer",
80
+ "VarLayer",
81
+ "aggregate_probas",
82
+ "Expression",
83
+ "IntermediateOutputWrapper",
84
+ ]
@@ -1,18 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial
4
3
  from typing import Optional
5
4
 
6
5
  import torch
7
- from einops.layers.torch import Rearrange
8
6
  from mne.filter import _check_coefficients, create_filter
9
7
  from mne.utils import warn
10
8
  from torch import Tensor, from_numpy, nn
11
9
  from torch.fft import fftfreq
12
10
  from torchaudio.functional import fftconvolve, filtfilt
13
11
 
14
- import braindecode.functional as F
15
-
16
12
 
17
13
  class FilterBankLayer(nn.Module):
18
14
  """Apply multiple band-pass filters to generate multiview signal representation.