braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -35,16 +35,6 @@ class SqueezeAndExcitation(nn.Module):
35
35
  bias: bool, default=False
36
36
  if True, adds a learnable bias will be used in the convolution.
37
37
 
38
- Examples
39
- --------
40
- >>> import torch
41
- >>> from braindecode.modules import SqueezeAndExcitation
42
- >>> module = SqueezeAndExcitation(in_channels=16, reduction_rate=4)
43
- >>> inputs = torch.randn(2, 16, 1, 64)
44
- >>> outputs = module(inputs)
45
- >>> outputs.shape
46
- torch.Size([2, 16, 1, 64])
47
-
48
38
  References
49
39
  ----------
50
40
  .. [Hu2018] Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
@@ -100,16 +90,6 @@ class GSoP(nn.Module):
100
90
  bias: bool, default=False
101
91
  if True, adds a learnable bias will be used in the convolution.
102
92
 
103
- Examples
104
- --------
105
- >>> import torch
106
- >>> from braindecode.modules import GSoP
107
- >>> module = GSoP(in_channels=16, reduction_rate=4)
108
- >>> inputs = torch.randn(2, 16, 1, 64)
109
- >>> outputs = module(inputs)
110
- >>> outputs.shape
111
- torch.Size([2, 16, 1, 64])
112
-
113
93
  References
114
94
  ----------
115
95
  .. [Gao2018] Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
@@ -166,16 +146,6 @@ class FCA(nn.Module):
166
146
  reduction_rate : int, default=4
167
147
  Reduction ratio of the fully-connected layers.
168
148
 
169
- Examples
170
- --------
171
- >>> import torch
172
- >>> from braindecode.modules import FCA
173
- >>> module = FCA(in_channels=16, seq_len=64, reduction_rate=4, freq_idx=0)
174
- >>> inputs = torch.randn(2, 16, 1, 64)
175
- >>> outputs = module(inputs)
176
- >>> outputs.shape
177
- torch.Size([2, 16, 1, 64])
178
-
179
149
  References
180
150
  ----------
181
151
  .. [Qin2021] Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
@@ -260,16 +230,6 @@ class EncNet(nn.Module):
260
230
  n_codewords : int
261
231
  number of codewords
262
232
 
263
- Examples
264
- --------
265
- >>> import torch
266
- >>> from braindecode.modules import EncNet
267
- >>> module = EncNet(in_channels=16, n_codewords=8)
268
- >>> inputs = torch.randn(2, 16, 1, 64)
269
- >>> outputs = module(inputs)
270
- >>> outputs.shape
271
- torch.Size([2, 16, 1, 64])
272
-
273
233
  References
274
234
  ----------
275
235
  .. [Zhang2018] Zhang, H. et al. 2018.
@@ -327,16 +287,6 @@ class ECA(nn.Module):
327
287
  kernel size of convolutional layer, determines degree of channel
328
288
  interaction, must be odd.
329
289
 
330
- Examples
331
- --------
332
- >>> import torch
333
- >>> from braindecode.modules import ECA
334
- >>> module = ECA(in_channels=16, kernel_size=3)
335
- >>> inputs = torch.randn(2, 16, 1, 64)
336
- >>> outputs = module(inputs)
337
- >>> outputs.shape
338
- torch.Size([2, 16, 1, 64])
339
-
340
290
  References
341
291
  ----------
342
292
  .. [Wang2021] Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
@@ -388,16 +338,6 @@ class GatherExcite(nn.Module):
388
338
  reduction_rate : int, default=4
389
339
  reduction ratio of the excite block (if used)
390
340
 
391
- Examples
392
- --------
393
- >>> import torch
394
- >>> from braindecode.modules import GatherExcite
395
- >>> module = GatherExcite(in_channels=16, seq_len=64, extra_params=False, use_mlp=True)
396
- >>> inputs = torch.randn(2, 16, 1, 64)
397
- >>> outputs = module(inputs)
398
- >>> outputs.shape
399
- torch.Size([2, 16, 1, 64])
400
-
401
341
  References
402
342
  ----------
403
343
  .. [Hu2018b] Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
@@ -467,16 +407,6 @@ class GCT(nn.Module):
467
407
  in_channels : int
468
408
  number of input feature channels
469
409
 
470
- Examples
471
- --------
472
- >>> import torch
473
- >>> from braindecode.modules import GCT
474
- >>> module = GCT(in_channels=16)
475
- >>> inputs = torch.randn(2, 16, 1, 64)
476
- >>> outputs = module(inputs)
477
- >>> outputs.shape
478
- torch.Size([2, 16, 1, 64])
479
-
480
410
  References
481
411
  ----------
482
412
  .. [Yang2020] Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
@@ -522,16 +452,6 @@ class SRM(nn.Module):
522
452
  reduction_rate : int, default=4
523
453
  reduction ratio of the fully-connected layers (if used),
524
454
 
525
- Examples
526
- --------
527
- >>> import torch
528
- >>> from braindecode.modules import SRM
529
- >>> module = SRM(in_channels=16, use_mlp=False)
530
- >>> inputs = torch.randn(2, 16, 1, 64)
531
- >>> outputs = module(inputs)
532
- >>> outputs.shape
533
- torch.Size([2, 16, 1, 64])
534
-
535
455
  References
536
456
  ----------
537
457
  .. [Lee2019] Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
@@ -597,16 +517,6 @@ class CBAM(nn.Module):
597
517
  kernel_size : int
598
518
  kernel size of the convolutional layer
599
519
 
600
- Examples
601
- --------
602
- >>> import torch
603
- >>> from braindecode.modules import CBAM
604
- >>> module = CBAM(in_channels=16, reduction_rate=4, kernel_size=3)
605
- >>> inputs = torch.randn(2, 16, 1, 64)
606
- >>> outputs = module(inputs)
607
- >>> outputs.shape
608
- torch.Size([2, 16, 1, 64])
609
-
610
520
  References
611
521
  ----------
612
522
  .. [Woo2018] Woo, S., Park, J., Lee, J., Kweon, I., 2018.
@@ -665,16 +575,6 @@ class CAT(nn.Module):
665
575
  bias : bool, default=False
666
576
  if True, adds a learnable bias will be used in the convolution,
667
577
 
668
- Examples
669
- --------
670
- >>> import torch
671
- >>> from braindecode.modules import CAT
672
- >>> module = CAT(in_channels=16, reduction_rate=4, kernel_size=3)
673
- >>> inputs = torch.randn(2, 16, 1, 64)
674
- >>> outputs = module(inputs)
675
- >>> outputs.shape
676
- torch.Size([2, 16, 1, 64])
677
-
678
578
  References
679
579
  ----------
680
580
  .. [Wu2023] Wu, Z. et al., 2023
@@ -771,16 +671,6 @@ class CATLite(nn.Module):
771
671
  bias : bool, default=True
772
672
  if True, adds a learnable bias will be used in the convolution,
773
673
 
774
- Examples
775
- --------
776
- >>> import torch
777
- >>> from braindecode.modules import CATLite
778
- >>> module = CATLite(in_channels=16, reduction_rate=4)
779
- >>> inputs = torch.randn(2, 16, 1, 64)
780
- >>> outputs = module(inputs)
781
- >>> outputs.shape
782
- torch.Size([2, 16, 1, 64])
783
-
784
674
  References
785
675
  ----------
786
676
  .. [Wu2023] Wu, Z. et al., 2023 CAT: Learning to Collaborate Channel and
@@ -834,19 +724,6 @@ class CATLite(nn.Module):
834
724
 
835
725
 
836
726
  class MultiHeadAttention(nn.Module):
837
- """Multi-head self-attention block.
838
-
839
- Examples
840
- --------
841
- >>> import torch
842
- >>> from braindecode.modules import MultiHeadAttention
843
- >>> module = MultiHeadAttention(emb_size=32, num_heads=4, dropout=0.1)
844
- >>> inputs = torch.randn(2, 10, 32)
845
- >>> outputs = module(inputs)
846
- >>> outputs.shape
847
- torch.Size([2, 10, 32])
848
- """
849
-
850
727
  def __init__(self, emb_size, num_heads, dropout):
851
728
  super().__init__()
852
729
  self.emb_size = emb_size
@@ -14,22 +14,6 @@ class InceptionBlock(nn.Module):
14
14
  ----------
15
15
  branches : list of nn.Module
16
16
  List of convolutional branches to apply to the input.
17
-
18
- Examples
19
- --------
20
- >>> import torch
21
- >>> from torch import nn
22
- >>> from braindecode.modules import InceptionBlock
23
- >>> block = InceptionBlock(
24
- ... [
25
- ... nn.Conv1d(3, 4, kernel_size=1),
26
- ... nn.Conv1d(3, 4, kernel_size=3, padding=1),
27
- ... ]
28
- ... )
29
- >>> inputs = torch.randn(2, 3, 100)
30
- >>> outputs = block(inputs)
31
- >>> outputs.shape
32
- torch.Size([2, 8, 100])
33
17
  """
34
18
 
35
19
  def __init__(self, branches):
@@ -71,16 +55,6 @@ class MLP(nn.Sequential):
71
55
  Dropout rate.
72
56
  normalize: bool (default=False)
73
57
  Whether to apply layer normalization.
74
-
75
- Examples
76
- --------
77
- >>> import torch
78
- >>> from braindecode.modules import MLP
79
- >>> module = MLP(in_features=32, hidden_features=(64,), out_features=16)
80
- >>> inputs = torch.randn(2, 10, 32)
81
- >>> outputs = module(inputs)
82
- >>> outputs.shape
83
- torch.Size([2, 10, 16])
84
58
  """
85
59
 
86
60
  def __init__(
@@ -125,33 +99,7 @@ class MLP(nn.Sequential):
125
99
 
126
100
 
127
101
  class FeedForwardBlock(nn.Sequential):
128
- """Feedforward network block.
129
-
130
- Parameters
131
- ----------
132
- emb_size : int
133
- Embedding dimension.
134
- expansion : int
135
- Expansion factor for the hidden layer size.
136
- drop_p : float
137
- Dropout probability.
138
- activation : type[nn.Module], default=nn.GELU
139
- Activation function constructor.
140
-
141
- Examples
142
- --------
143
- >>> import torch
144
- >>> from braindecode.modules import FeedForwardBlock
145
- >>> module = FeedForwardBlock(emb_size=32, expansion=2, drop_p=0.1)
146
- >>> inputs = torch.randn(2, 10, 32)
147
- >>> outputs = module(inputs)
148
- >>> outputs.shape
149
- torch.Size([2, 10, 32])
150
- """
151
-
152
- def __init__(
153
- self, emb_size, expansion, drop_p, activation: type[nn.Module] = nn.GELU
154
- ):
102
+ def __init__(self, emb_size, expansion, drop_p, activation: nn.Module = nn.GELU):
155
103
  super().__init__(
156
104
  nn.Linear(emb_size, expansion * emb_size),
157
105
  activation(),
@@ -25,16 +25,6 @@ class AvgPool2dWithConv(nn.Module):
25
25
  Dilation applied to the pooling filter.
26
26
  padding: int or (int,int)
27
27
  Padding applied before the pooling operation.
28
-
29
- Examples
30
- --------
31
- >>> import torch
32
- >>> from braindecode.modules import AvgPool2dWithConv
33
- >>> module = AvgPool2dWithConv(kernel_size=(1, 4), stride=(1, 4))
34
- >>> inputs = torch.randn(2, 4, 1, 16)
35
- >>> outputs = module(inputs)
36
- >>> outputs.shape
37
- torch.Size([2, 4, 1, 4])
38
28
  """
39
29
 
40
30
  def __init__(self, kernel_size, stride, dilation=1, padding=0):
@@ -83,19 +73,6 @@ class AvgPool2dWithConv(nn.Module):
83
73
 
84
74
 
85
75
  class Conv2dWithConstraint(nn.Conv2d):
86
- """2D convolution with max-norm constraint on the weights.
87
-
88
- Examples
89
- --------
90
- >>> import torch
91
- >>> from braindecode.modules import Conv2dWithConstraint
92
- >>> module = Conv2dWithConstraint(4, 8, kernel_size=(1, 3), padding=(0, 1), bias=False)
93
- >>> inputs = torch.randn(2, 4, 1, 64)
94
- >>> outputs = module(inputs)
95
- >>> outputs.shape
96
- torch.Size([2, 8, 1, 64])
97
- """
98
-
99
76
  def __init__(self, *args, max_norm=1, **kwargs):
100
77
  super().__init__(*args, **kwargs)
101
78
  self.max_norm = max_norm
@@ -124,16 +101,6 @@ class CombinedConv(nn.Module):
124
101
  bias_spat: bool
125
102
  Whether to use bias in the spatial conv
126
103
 
127
- Examples
128
- --------
129
- >>> import torch
130
- >>> from braindecode.modules import CombinedConv
131
- >>> module = CombinedConv(in_chans=8, n_filters_time=4, n_filters_spat=4, filter_time_length=5)
132
- >>> inputs = torch.randn(2, 1, 100, 8)
133
- >>> outputs = module(inputs)
134
- >>> outputs.shape
135
- torch.Size([2, 4, 96, 1])
136
-
137
104
  """
138
105
 
139
106
  def __init__(
@@ -215,16 +182,6 @@ class CausalConv1d(nn.Conv1d):
215
182
  ----------
216
183
  .. [1] https://discuss.pytorch.org/t/causal-convolution/3456/4
217
184
  .. [2] https://gist.github.com/paultsw/7a9d6e3ce7b70e9e2c61bc9287addefc
218
-
219
- Examples
220
- --------
221
- >>> import torch
222
- >>> from braindecode.modules import CausalConv1d
223
- >>> module = CausalConv1d(in_channels=4, out_channels=8, kernel_size=5, dilation=2)
224
- >>> inputs = torch.randn(2, 4, 128)
225
- >>> outputs = module(inputs)
226
- >>> outputs.shape
227
- torch.Size([2, 8, 128])
228
185
  """
229
186
 
230
187
  def __init__(
@@ -293,16 +250,6 @@ class DepthwiseConv2d(torch.nn.Conv2d):
293
250
  Padding mode to use. Options are 'zeros', 'reflect', 'replicate', or
294
251
  'circular'.
295
252
  Default is 'zeros'.
296
-
297
- Examples
298
- --------
299
- >>> import torch
300
- >>> from braindecode.modules import DepthwiseConv2d
301
- >>> module = DepthwiseConv2d(in_channels=4, depth_multiplier=2, kernel_size=3, padding=1)
302
- >>> inputs = torch.randn(2, 4, 1, 64)
303
- >>> outputs = module(inputs)
304
- >>> outputs.shape
305
- torch.Size([2, 8, 1, 64])
306
253
  """
307
254
 
308
255
  def __init__(
@@ -113,22 +113,6 @@ class FilterBankLayer(nn.Module):
113
113
  Control verbosity of the logging output. If ``None``, use the default
114
114
  verbosity level. See the func:`mne.verbose` for details.
115
115
  Should only be passed as a keyword argument.
116
-
117
- Examples
118
- --------
119
- >>> import torch
120
- >>> from braindecode.modules import FilterBankLayer
121
- >>> module = FilterBankLayer(
122
- ... n_chans=2,
123
- ... sfreq=128,
124
- ... band_filters=[(4.0, 8.0), (8.0, 12.0)],
125
- ... method="fir",
126
- ... verbose=False,
127
- ... )
128
- >>> inputs = torch.randn(2, 2, 256)
129
- >>> outputs = module(inputs)
130
- >>> outputs.shape
131
- torch.Size([2, 2, 2, 256])
132
116
  """
133
117
 
134
118
  def __init__(
@@ -417,21 +401,6 @@ class GeneralizedGaussianFilter(nn.Module):
417
401
  Minimum and maximum allowable values for the center frequency `f_mean` in Hz.
418
402
  Specified as (min_f_mean, max_f_mean). Default is (1.0, 45.0).
419
403
 
420
- Examples
421
- --------
422
- >>> import torch
423
- >>> from braindecode.modules import GeneralizedGaussianFilter
424
- >>> module = GeneralizedGaussianFilter(
425
- ... in_channels=2,
426
- ... out_channels=2,
427
- ... sequence_length=256,
428
- ... sample_rate=128,
429
- ... inverse_fourier=True,
430
- ... )
431
- >>> inputs = torch.randn(3, 2, 256)
432
- >>> outputs = module(inputs)
433
- >>> outputs.shape
434
- torch.Size([3, 2, 256])
435
404
 
436
405
  Notes
437
406
  -----
@@ -11,21 +11,6 @@ from braindecode.functional import drop_path
11
11
 
12
12
 
13
13
  class Ensure4d(nn.Module):
14
- """Ensure the input tensor has 4 dimensions.
15
-
16
- This is a small utility layer that repeatedly adds a singleton dimension at
17
- the end until the input has shape ``(batch, channels, time, 1)``.
18
-
19
- Examples
20
- --------
21
- >>> import torch
22
- >>> from braindecode.modules import Ensure4d
23
- >>> module = Ensure4d()
24
- >>> outputs = module(torch.randn(2, 3, 10))
25
- >>> outputs.shape
26
- torch.Size([2, 3, 10, 1])
27
- """
28
-
29
14
  def forward(self, x):
30
15
  while len(x.shape) < 4:
31
16
  x = x.unsqueeze(-1)
@@ -33,19 +18,6 @@ class Ensure4d(nn.Module):
33
18
 
34
19
 
35
20
  class Chomp1d(nn.Module):
36
- """Remove samples from the end of a sequence.
37
-
38
- Examples
39
- --------
40
- >>> import torch
41
- >>> from braindecode.modules import Chomp1d
42
- >>> module = Chomp1d(chomp_size=5)
43
- >>> inputs = torch.randn(2, 3, 20)
44
- >>> outputs = module(inputs)
45
- >>> outputs.shape
46
- torch.Size([2, 3, 15])
47
- """
48
-
49
21
  def __init__(self, chomp_size):
50
22
  super().__init__()
51
23
  self.chomp_size = chomp_size
@@ -71,17 +43,6 @@ class TimeDistributed(nn.Module):
71
43
  module : nn.Module
72
44
  Module to be applied to the input windows. Must accept an input of
73
45
  shape (batch_size, n_channels, n_times).
74
-
75
- Examples
76
- --------
77
- >>> import torch
78
- >>> from torch import nn
79
- >>> from braindecode.modules import TimeDistributed
80
- >>> module = TimeDistributed(nn.Conv1d(3, 4, kernel_size=3, padding=1))
81
- >>> inputs = torch.randn(2, 5, 3, 20)
82
- >>> outputs = module(inputs)
83
- >>> outputs.shape
84
- torch.Size([2, 5, 4])
85
46
  """
86
47
 
87
48
  def __init__(self, module):
@@ -130,17 +91,6 @@ class DropPath(nn.Module):
130
91
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
131
92
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
132
93
  SOFTWARE.
133
-
134
- Examples
135
- --------
136
- >>> import torch
137
- >>> from braindecode.modules import DropPath
138
- >>> module = DropPath(drop_prob=0.2)
139
- >>> module.train()
140
- >>> inputs = torch.randn(2, 3, 10)
141
- >>> outputs = module(inputs)
142
- >>> outputs.shape
143
- torch.Size([2, 3, 10])
144
94
  """
145
95
 
146
96
  def __init__(self, drop_prob=None):
@@ -180,37 +130,3 @@ class SqueezeFinalOutput(nn.Module):
180
130
  if x.shape[-1] == 1:
181
131
  x = x.squeeze(-1)
182
132
  return x
183
-
184
-
185
- class SubjectLayers(nn.Module):
186
- """Per-subject linear transformation layer.
187
-
188
- Applies subject-specific linear transformations to the input. Each subject
189
- owns an independent weight matrix, enabling personalized feature
190
- processing.
191
- """
192
-
193
- def __init__(
194
- self,
195
- in_channels: int,
196
- out_channels: int,
197
- n_subjects: int,
198
- init_id: bool = False,
199
- ):
200
- super().__init__()
201
- self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels))
202
- if init_id:
203
- if in_channels != out_channels:
204
- raise AssertionError("init_id requires in_channels == out_channels")
205
- self.weights.data[:] = torch.eye(in_channels)[None]
206
- self.weights.data *= 1 / (in_channels**0.5)
207
-
208
- def forward(self, x: torch.Tensor, subjects: torch.Tensor) -> torch.Tensor:
209
- """Apply the subject-specific linear transforms."""
210
- _, C, D = self.weights.shape
211
- weights = self.weights.gather(0, subjects.view(-1, 1, 1).expand(-1, C, D))
212
- return torch.einsum("bct,bcd->bdt", x, weights)
213
-
214
- def __repr__(self) -> str:
215
- S, C, D = self.weights.shape
216
- return f"SubjectLayers({C}, {D}, {S})"
@@ -20,16 +20,6 @@ class MaxNormLinear(nn.Linear):
20
20
  If set to ``False``, the layer will not learn an additive bias.
21
21
  Default: ``True``.
22
22
 
23
- Examples
24
- --------
25
- >>> import torch
26
- >>> from braindecode.modules import MaxNormLinear
27
- >>> module = MaxNormLinear(10, 5, max_norm_val=2)
28
- >>> inputs = torch.randn(2, 10)
29
- >>> outputs = module(inputs)
30
- >>> outputs.shape
31
- torch.Size([2, 5])
32
-
33
23
  References
34
24
  ----------
35
25
  .. [1] https://keras.io/api/layers/core_layers/dense/#dense-class
@@ -51,18 +41,7 @@ class MaxNormLinear(nn.Linear):
51
41
 
52
42
 
53
43
  class LinearWithConstraint(nn.Linear):
54
- """Linear layer with max-norm constraint on the weights.
55
-
56
- Examples
57
- --------
58
- >>> import torch
59
- >>> from braindecode.modules import LinearWithConstraint
60
- >>> module = LinearWithConstraint(10, 5, max_norm=1.0)
61
- >>> inputs = torch.randn(2, 10)
62
- >>> outputs = module(inputs)
63
- >>> outputs.shape
64
- torch.Size([2, 5])
65
- """
44
+ """Linear layer with max-norm constraint on the weights."""
66
45
 
67
46
  def __init__(self, *args, max_norm=1.0, **kwargs):
68
47
  super(LinearWithConstraint, self).__init__(*args, **kwargs)
@@ -22,16 +22,6 @@ class StatLayer(nn.Module):
22
22
  Used only for functions requiring clamping (e.g., log variance).
23
23
  apply_log : bool, default=False
24
24
  Whether to apply log after computation (used for LogVarLayer).
25
-
26
- Examples
27
- --------
28
- >>> import torch
29
- >>> from braindecode.modules import StatLayer
30
- >>> module = StatLayer(stat_fn=torch.mean, dim=-1, keepdim=True)
31
- >>> inputs = torch.randn(2, 3, 10)
32
- >>> outputs = module(inputs)
33
- >>> outputs.shape
34
- torch.Size([2, 3, 1])
35
25
  """
36
26
 
37
27
  def __init__(
@@ -71,15 +71,6 @@ def aggregate_probas(logits, n_windows_stride=1):
71
71
  De Vos, M. (2018). Joint classification and prediction CNN framework
72
72
  for automatic sleep stage classification. IEEE Transactions on
73
73
  Biomedical Engineering, 66(5), 1285-1296.
74
-
75
- Examples
76
- --------
77
- >>> import numpy as np
78
- >>> from braindecode.modules import aggregate_probas
79
- >>> logits = np.random.randn(3, 4, 5) # (n_sequences, n_classes, n_windows)
80
- >>> probas = aggregate_probas(logits, n_windows_stride=1)
81
- >>> probas.shape
82
- (7, 4)
83
74
  """
84
75
  log_probas = log_softmax(logits, axis=1)
85
76
  return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T
@@ -10,16 +10,6 @@ class Expression(nn.Module):
10
10
  expression_fn : callable
11
11
  Should accept variable number of objects of type
12
12
  `torch.autograd.Variable` to compute its output.
13
-
14
- Examples
15
- --------
16
- >>> import torch
17
- >>> from braindecode.modules import Expression
18
- >>> module = Expression(lambda x: x**2)
19
- >>> inputs = torch.randn(2, 3)
20
- >>> outputs = module(inputs)
21
- >>> outputs.shape
22
- torch.Size([2, 3])
23
13
  """
24
14
 
25
15
  def __init__(self, expression_fn):
@@ -59,13 +49,6 @@ class IntermediateOutputWrapper(nn.Module):
59
49
  >>> model = Deep4Net()
60
50
  >>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
61
51
  >>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
62
-
63
- >>> import torch
64
- >>> base = torch.nn.Sequential(torch.nn.Linear(10, 8), torch.nn.ReLU(), torch.nn.Linear(8, 2))
65
- >>> wrapped = IntermediateOutputWrapper(to_select=["0", "2"], model=base)
66
- >>> outputs = wrapped(torch.randn(4, 10))
67
- >>> len(outputs)
68
- 2
69
52
  """
70
53
 
71
54
  def __init__(self, to_select, model):
@@ -324,9 +324,6 @@ def _replace_inplace(concat_ds, new_concat_ds):
324
324
  concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
325
325
  )
326
326
 
327
- # Recompute cumulative_sizes after replacing datasets
328
- concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
329
-
330
327
 
331
328
  def _preprocess(
332
329
  ds: RecordDataset,