braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,216 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import torch
7
+ from einops.layers.torch import Rearrange
8
+ from torch import nn
9
+
10
+ from braindecode.functional import drop_path
11
+
12
+
13
+ class Ensure4d(nn.Module):
14
+ """Ensure the input tensor has 4 dimensions.
15
+
16
+ This is a small utility layer that repeatedly adds a singleton dimension at
17
+ the end until the input has shape ``(batch, channels, time, 1)``.
18
+
19
+ Examples
20
+ --------
21
+ >>> import torch
22
+ >>> from braindecode.modules import Ensure4d
23
+ >>> module = Ensure4d()
24
+ >>> outputs = module(torch.randn(2, 3, 10))
25
+ >>> outputs.shape
26
+ torch.Size([2, 3, 10, 1])
27
+ """
28
+
29
+ def forward(self, x):
30
+ while len(x.shape) < 4:
31
+ x = x.unsqueeze(-1)
32
+ return x
33
+
34
+
35
+ class Chomp1d(nn.Module):
36
+ """Remove samples from the end of a sequence.
37
+
38
+ Examples
39
+ --------
40
+ >>> import torch
41
+ >>> from braindecode.modules import Chomp1d
42
+ >>> module = Chomp1d(chomp_size=5)
43
+ >>> inputs = torch.randn(2, 3, 20)
44
+ >>> outputs = module(inputs)
45
+ >>> outputs.shape
46
+ torch.Size([2, 3, 15])
47
+ """
48
+
49
+ def __init__(self, chomp_size):
50
+ super().__init__()
51
+ self.chomp_size = chomp_size
52
+
53
+ def extra_repr(self):
54
+ return "chomp_size={}".format(self.chomp_size)
55
+
56
+ def forward(self, x):
57
+ return x[:, :, : -self.chomp_size].contiguous()
58
+
59
+
60
+ class TimeDistributed(nn.Module):
61
+ """Apply module on multiple windows.
62
+
63
+ Apply the provided module on a sequence of windows and return their
64
+ concatenation.
65
+ Useful with sequence-to-prediction models (e.g. sleep stager which must map
66
+ a sequence of consecutive windows to the label of the middle window in the
67
+ sequence).
68
+
69
+ Parameters
70
+ ----------
71
+ module : nn.Module
72
+ Module to be applied to the input windows. Must accept an input of
73
+ shape (batch_size, n_channels, n_times).
74
+
75
+ Examples
76
+ --------
77
+ >>> import torch
78
+ >>> from torch import nn
79
+ >>> from braindecode.modules import TimeDistributed
80
+ >>> module = TimeDistributed(nn.Conv1d(3, 4, kernel_size=3, padding=1))
81
+ >>> inputs = torch.randn(2, 5, 3, 20)
82
+ >>> outputs = module(inputs)
83
+ >>> outputs.shape
84
+ torch.Size([2, 5, 4])
85
+ """
86
+
87
+ def __init__(self, module):
88
+ super().__init__()
89
+ self.module = module
90
+
91
+ def forward(self, x):
92
+ """
93
+ Parameters
94
+ ----------
95
+ x : torch.Tensor
96
+ Sequence of windows, of shape (batch_size, seq_len, n_channels,
97
+ n_times).
98
+
99
+ Returns
100
+ -------
101
+ torch.Tensor
102
+ Shape (batch_size, seq_len, output_size).
103
+ """
104
+ b, s, c, t = x.shape
105
+ out = self.module(x.view(b * s, c, t))
106
+ return out.view(b, s, -1)
107
+
108
+
109
+ class DropPath(nn.Module):
110
+ """Drop paths, also known as Stochastic Depth, per sample.
111
+
112
+ When applied in main path of residual blocks.
113
+
114
+ Parameters
115
+ ----------
116
+ drop_prob: float (default=None)
117
+ Drop path probability (should be in range 0-1).
118
+
119
+ Notes
120
+ -----
121
+ Code copied and modified from VISSL facebookresearch:
122
+ https://github.com/facebookresearch/vissl/blob/0b5d6a94437bc00baed112ca90c9d78c6ccfbafb/vissl/models/model_helpers.py#L676
123
+
124
+ All rights reserved.
125
+
126
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
127
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
128
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
129
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
130
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
131
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
132
+ SOFTWARE.
133
+
134
+ Examples
135
+ --------
136
+ >>> import torch
137
+ >>> from braindecode.modules import DropPath
138
+ >>> module = DropPath(drop_prob=0.2)
139
+ >>> module.train()
140
+ >>> inputs = torch.randn(2, 3, 10)
141
+ >>> outputs = module(inputs)
142
+ >>> outputs.shape
143
+ torch.Size([2, 3, 10])
144
+ """
145
+
146
+ def __init__(self, drop_prob=None):
147
+ super(DropPath, self).__init__()
148
+ self.drop_prob = drop_prob
149
+
150
+ def forward(self, x):
151
+ return drop_path(x, self.drop_prob, self.training)
152
+
153
+ # Utility function to print DropPath module
154
+ def extra_repr(self) -> str:
155
+ return f"p={self.drop_prob}"
156
+
157
+
158
+ class SqueezeFinalOutput(nn.Module):
159
+ """
160
+
161
+ Removes empty dimension at end and potentially removes empty time
162
+ dimension. It does not just use squeeze as we never want to remove
163
+ first dimension.
164
+
165
+ Returns
166
+ -------
167
+ x: torch.Tensor
168
+ squeezed tensor
169
+ """
170
+
171
+ def __init__(self):
172
+ super().__init__()
173
+
174
+ self.squeeze = Rearrange("b c t 1 -> b c t")
175
+
176
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
177
+ # 1) drop feature dim
178
+ x = self.squeeze(x)
179
+ # 2) drop time dim if singleton
180
+ if x.shape[-1] == 1:
181
+ x = x.squeeze(-1)
182
+ return x
183
+
184
+
185
+ class SubjectLayers(nn.Module):
186
+ """Per-subject linear transformation layer.
187
+
188
+ Applies subject-specific linear transformations to the input. Each subject
189
+ owns an independent weight matrix, enabling personalized feature
190
+ processing.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ in_channels: int,
196
+ out_channels: int,
197
+ n_subjects: int,
198
+ init_id: bool = False,
199
+ ):
200
+ super().__init__()
201
+ self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels))
202
+ if init_id:
203
+ if in_channels != out_channels:
204
+ raise AssertionError("init_id requires in_channels == out_channels")
205
+ self.weights.data[:] = torch.eye(in_channels)[None]
206
+ self.weights.data *= 1 / (in_channels**0.5)
207
+
208
+ def forward(self, x: torch.Tensor, subjects: torch.Tensor) -> torch.Tensor:
209
+ """Apply the subject-specific linear transforms."""
210
+ _, C, D = self.weights.shape
211
+ weights = self.weights.gather(0, subjects.view(-1, 1, 1).expand(-1, C, D))
212
+ return torch.einsum("bct,bcd->bdt", x, weights)
213
+
214
+ def __repr__(self) -> str:
215
+ S, C, D = self.weights.shape
216
+ return f"SubjectLayers({C}, {D}, {S})"
@@ -0,0 +1,70 @@
1
+ from torch import nn
2
+ from torch.nn.utils.parametrize import register_parametrization
3
+
4
+ from braindecode.modules.parametrization import MaxNorm, MaxNormParametrize
5
+
6
+
7
+ class MaxNormLinear(nn.Linear):
8
+ """Linear layer with MaxNorm constraining on weights.
9
+
10
+ Equivalent of Keras tf.keras.Dense(..., kernel_constraint=max_norm())
11
+ [1]_ and [2]_. Implemented as advised in [3]_.
12
+
13
+ Parameters
14
+ ----------
15
+ in_features: int
16
+ Size of each input sample.
17
+ out_features: int
18
+ Size of each output sample.
19
+ bias: bool, optional
20
+ If set to ``False``, the layer will not learn an additive bias.
21
+ Default: ``True``.
22
+
23
+ Examples
24
+ --------
25
+ >>> import torch
26
+ >>> from braindecode.modules import MaxNormLinear
27
+ >>> module = MaxNormLinear(10, 5, max_norm_val=2)
28
+ >>> inputs = torch.randn(2, 10)
29
+ >>> outputs = module(inputs)
30
+ >>> outputs.shape
31
+ torch.Size([2, 5])
32
+
33
+ References
34
+ ----------
35
+ .. [1] https://keras.io/api/layers/core_layers/dense/#dense-class
36
+ .. [2] https://www.tensorflow.org/api_docs/python/tf/keras/constraints/
37
+ MaxNorm
38
+ .. [3] https://discuss.pytorch.org/t/how-to-correctly-implement-in-place-
39
+ max-norm-constraint/96769
40
+ """
41
+
42
+ def __init__(
43
+ self, in_features, out_features, bias=True, max_norm_val=2, eps=1e-5, **kwargs
44
+ ):
45
+ super().__init__(
46
+ in_features=in_features, out_features=out_features, bias=bias, **kwargs
47
+ )
48
+ self._max_norm_val = max_norm_val
49
+ self._eps = eps
50
+ register_parametrization(self, "weight", MaxNorm(self._max_norm_val, self._eps))
51
+
52
+
53
+ class LinearWithConstraint(nn.Linear):
54
+ """Linear layer with max-norm constraint on the weights.
55
+
56
+ Examples
57
+ --------
58
+ >>> import torch
59
+ >>> from braindecode.modules import LinearWithConstraint
60
+ >>> module = LinearWithConstraint(10, 5, max_norm=1.0)
61
+ >>> inputs = torch.randn(2, 10)
62
+ >>> outputs = module(inputs)
63
+ >>> outputs.shape
64
+ torch.Size([2, 5])
65
+ """
66
+
67
+ def __init__(self, *args, max_norm=1.0, **kwargs):
68
+ super(LinearWithConstraint, self).__init__(*args, **kwargs)
69
+ self.max_norm = max_norm
70
+ register_parametrization(self, "weight", MaxNormParametrize(self.max_norm))
@@ -0,0 +1,38 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class MaxNorm(nn.Module):
6
+ def __init__(self, max_norm_val=2.0, eps=1e-5):
7
+ super().__init__()
8
+ self.max_norm_val = max_norm_val
9
+ self.eps = eps
10
+
11
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
12
+ norm = X.norm(2, dim=0, keepdim=True)
13
+ denom = norm.clamp(min=self.max_norm_val / 2)
14
+ number = denom.clamp(max=self.max_norm_val)
15
+ return X * (number / (denom + self.eps))
16
+
17
+ def right_inverse(self, X: torch.Tensor) -> torch.Tensor:
18
+ # Assuming the forward scales X by a factor s,
19
+ # the right inverse would scale it back by 1/s.
20
+ norm = X.norm(2, dim=0, keepdim=True)
21
+ denom = norm.clamp(min=self.max_norm_val / 2)
22
+ number = denom.clamp(max=self.max_norm_val)
23
+ scale = number / (denom + self.eps)
24
+ return X / scale
25
+
26
+
27
+ class MaxNormParametrize(nn.Module):
28
+ """
29
+ Enforce a max‑norm constraint on the rows of a weight tensor via parametrization.
30
+ """
31
+
32
+ def __init__(self, max_norm: float = 1.0):
33
+ super().__init__()
34
+ self.max_norm = max_norm
35
+
36
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
37
+ # Renormalize each "row" (dim=0 slice) to have at most self.max_norm L2-norm
38
+ return X.renorm(p=2, dim=0, maxnorm=self.max_norm)
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class StatLayer(nn.Module):
11
+ """
12
+ Generic layer to compute a statistical function along a specified dimension.
13
+ Parameters
14
+ ----------
15
+ stat_fn : Callable
16
+ A function like torch.mean, torch.std, etc.
17
+ dim : int
18
+ Dimension along which to apply the function.
19
+ keepdim : bool, default=True
20
+ Whether to keep the reduced dimension.
21
+ clamp_range : tuple(float, float), optional
22
+ Used only for functions requiring clamping (e.g., log variance).
23
+ apply_log : bool, default=False
24
+ Whether to apply log after computation (used for LogVarLayer).
25
+
26
+ Examples
27
+ --------
28
+ >>> import torch
29
+ >>> from braindecode.modules import StatLayer
30
+ >>> module = StatLayer(stat_fn=torch.mean, dim=-1, keepdim=True)
31
+ >>> inputs = torch.randn(2, 3, 10)
32
+ >>> outputs = module(inputs)
33
+ >>> outputs.shape
34
+ torch.Size([2, 3, 1])
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ stat_fn: Callable[..., torch.Tensor],
40
+ dim: int,
41
+ keepdim: bool = True,
42
+ clamp_range: Optional[tuple[float, float]] = None,
43
+ apply_log: bool = False,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.stat_fn = stat_fn
47
+ self.dim = dim
48
+ self.keepdim = keepdim
49
+ self.clamp_range = clamp_range
50
+ self.apply_log = apply_log
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ out = self.stat_fn(x, dim=self.dim, keepdim=self.keepdim)
54
+ if self.clamp_range is not None:
55
+ out = torch.clamp(out, min=self.clamp_range[0], max=self.clamp_range[1])
56
+ if self.apply_log:
57
+ out = torch.log(out)
58
+ return out
59
+
60
+
61
+ # make things more simple
62
+ def _max_fn(x: torch.Tensor, dim: int, keepdim: bool) -> torch.Tensor:
63
+ return x.max(dim=dim, keepdim=keepdim)[0]
64
+
65
+
66
+ def _power_fn(x: torch.Tensor, dim: int, keepdim: bool) -> torch.Tensor:
67
+ # compute mean of squared values along `dim`
68
+ return torch.mean(x**2, dim=dim, keepdim=keepdim)
69
+
70
+
71
+ MeanLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.mean)
72
+ MaxLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, _max_fn)
73
+ VarLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.var)
74
+ StdLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.std)
75
+ LogVarLayer: Callable[[int, bool], StatLayer] = partial(
76
+ StatLayer,
77
+ torch.var,
78
+ clamp_range=(1e-6, 1e6),
79
+ apply_log=True,
80
+ )
81
+
82
+ LogPowerLayer: Callable[[int, bool], StatLayer] = partial(
83
+ StatLayer,
84
+ _power_fn,
85
+ clamp_range=(1e-4, 1e4),
86
+ apply_log=True,
87
+ )
@@ -0,0 +1,85 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ # Hubert Banville <hubert.jbanville@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ import numpy as np
7
+ from scipy.special import log_softmax
8
+
9
+
10
+ def _pad_shift_array(x, stride=1):
11
+ """Zero-pad and shift rows of a 3D array.
12
+
13
+ E.g., used to align predictions of corresponding windows in
14
+ sequence-to-sequence models.
15
+
16
+ Parameters
17
+ ----------
18
+ x : np.ndarray
19
+ Array of shape (n_rows, n_classes, n_windows).
20
+ stride : int
21
+ Number of non-overlapping elements between two consecutive sequences.
22
+
23
+ Returns
24
+ -------
25
+ np.ndarray :
26
+ Array of shape (n_rows, n_classes, (n_rows - 1) * stride + n_windows)
27
+ where each row is obtained by zero-padding the corresponding row in
28
+ ``x`` before and after in the last dimension.
29
+ """
30
+ if x.ndim != 3:
31
+ raise NotImplementedError(
32
+ f"x must be of shape (n_rows, n_classes, n_windows), got {x.shape}"
33
+ )
34
+ x_padded = np.pad(x, ((0, 0), (0, 0), (0, (x.shape[0] - 1) * stride)))
35
+ orig_strides = x_padded.strides
36
+ new_strides = (
37
+ orig_strides[0] - stride * orig_strides[2],
38
+ orig_strides[1],
39
+ orig_strides[2],
40
+ )
41
+ return np.lib.stride_tricks.as_strided(x_padded, strides=new_strides)
42
+
43
+
44
+ def aggregate_probas(logits, n_windows_stride=1):
45
+ """Aggregate predicted probabilities with self-ensembling.
46
+
47
+ Aggregate window-wise predicted probabilities obtained on overlapping
48
+ sequences of windows using multiplicative voting as described in
49
+ [Phan2018]_.
50
+
51
+ Parameters
52
+ ----------
53
+ logits : np.ndarray
54
+ Array of shape (n_sequences, n_classes, n_windows) containing the
55
+ logits (i.e. the raw unnormalized scores for each class) for each
56
+ window of each sequence.
57
+ n_windows_stride : int
58
+ Number of windows between two consecutive sequences. Default is 1
59
+ (maximally overlapping sequences).
60
+
61
+ Returns
62
+ -------
63
+ np.ndarray :
64
+ Array of shape ((n_rows - 1) * stride + n_windows, n_classes)
65
+ containing the aggregated predicted probabilities for each window
66
+ contained in the input sequences.
67
+
68
+ References
69
+ ----------
70
+ .. [Phan2018] Phan, H., Andreotti, F., Cooray, N., Chén, O. Y., &
71
+ De Vos, M. (2018). Joint classification and prediction CNN framework
72
+ for automatic sleep stage classification. IEEE Transactions on
73
+ Biomedical Engineering, 66(5), 1285-1296.
74
+
75
+ Examples
76
+ --------
77
+ >>> import numpy as np
78
+ >>> from braindecode.modules import aggregate_probas
79
+ >>> logits = np.random.randn(3, 4, 5) # (n_sequences, n_classes, n_windows)
80
+ >>> probas = aggregate_probas(logits, n_windows_stride=1)
81
+ >>> probas.shape
82
+ (7, 4)
83
+ """
84
+ log_probas = log_softmax(logits, axis=1)
85
+ return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T
@@ -0,0 +1,90 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Expression(nn.Module):
6
+ """Compute given expression on forward pass.
7
+
8
+ Parameters
9
+ ----------
10
+ expression_fn : callable
11
+ Should accept variable number of objects of type
12
+ `torch.autograd.Variable` to compute its output.
13
+
14
+ Examples
15
+ --------
16
+ >>> import torch
17
+ >>> from braindecode.modules import Expression
18
+ >>> module = Expression(lambda x: x**2)
19
+ >>> inputs = torch.randn(2, 3)
20
+ >>> outputs = module(inputs)
21
+ >>> outputs.shape
22
+ torch.Size([2, 3])
23
+ """
24
+
25
+ def __init__(self, expression_fn):
26
+ super().__init__()
27
+ self.expression_fn = expression_fn
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ return self.expression_fn(x)
31
+
32
+ def __repr__(self):
33
+ if hasattr(self.expression_fn, "func") and hasattr(
34
+ self.expression_fn, "kwargs"
35
+ ):
36
+ expression_str = "{:s} {:s}".format(
37
+ self.expression_fn.func.__name__, str(self.expression_fn.kwargs)
38
+ )
39
+ elif hasattr(self.expression_fn, "__name__"):
40
+ expression_str = self.expression_fn.__name__
41
+ else:
42
+ expression_str = repr(self.expression_fn)
43
+ return self.__class__.__name__ + "(expression=%s) " % expression_str
44
+
45
+
46
+ class IntermediateOutputWrapper(nn.Module):
47
+ """Wraps network model such that outputs of intermediate layers can be returned.
48
+ forward() returns list of intermediate activations in a network during forward pass.
49
+
50
+ Parameters
51
+ ----------
52
+ to_select : list
53
+ list of module names for which activation should be returned
54
+ model : model object
55
+ network model
56
+
57
+ Examples
58
+ --------
59
+ >>> model = Deep4Net()
60
+ >>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
61
+ >>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
62
+
63
+ >>> import torch
64
+ >>> base = torch.nn.Sequential(torch.nn.Linear(10, 8), torch.nn.ReLU(), torch.nn.Linear(8, 2))
65
+ >>> wrapped = IntermediateOutputWrapper(to_select=["0", "2"], model=base)
66
+ >>> outputs = wrapped(torch.randn(4, 10))
67
+ >>> len(outputs)
68
+ 2
69
+ """
70
+
71
+ def __init__(self, to_select, model):
72
+ if not len(list(model.children())) == len(list(model.named_children())):
73
+ raise Exception("All modules in model need to have names!")
74
+
75
+ super().__init__()
76
+
77
+ modules_list = model.named_children()
78
+ for key, module in modules_list:
79
+ self.add_module(key, module)
80
+ self._modules[key].load_state_dict(module.state_dict())
81
+ self._to_select = to_select
82
+
83
+ def forward(self, x):
84
+ # Call modules individually and append activation to output if module is in to_select
85
+ o = []
86
+ for name, module in self._modules.items():
87
+ x = module(x)
88
+ if name in self._to_select:
89
+ o.append(x)
90
+ return o