braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,372 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
+ # Pierre Guetschel <pierre.guetschel@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+
7
+ import abc
8
+ import inspect
9
+ import logging
10
+ from typing import Literal
11
+
12
+ import mne
13
+ import numpy as np
14
+ import torch
15
+ from sklearn.metrics import get_scorer
16
+ from skorch import NeuralNet
17
+ from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
18
+ from skorch.utils import noop, to_numpy, train_loss_score, valid_loss_score
19
+
20
+ from braindecode.datautil import infer_signal_properties
21
+
22
+ from .models.util import models_dict
23
+ from .training.scoring import (
24
+ CroppedTimeSeriesEpochScoring,
25
+ CroppedTrialEpochScoring,
26
+ PostEpochTrainScoring,
27
+ )
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+
32
+ def _get_model(model: str):
33
+ """Returns the corresponding class in case the model passed is a string."""
34
+ if isinstance(model, str):
35
+ if model in models_dict:
36
+ model = models_dict[model]
37
+ else:
38
+ raise ValueError(f"Unknown model name {model!r}.")
39
+ return model
40
+
41
+
42
+ class _EEGNeuralNet(NeuralNet, abc.ABC):
43
+ signal_args_set_ = False
44
+
45
+ @property
46
+ def log(self):
47
+ return log.getChild(self.__class__.__name__)
48
+
49
+ def initialize_module(self):
50
+ """Initializes the module.
51
+
52
+ A Braindecode model name can also be passed as module argument.
53
+
54
+ If the module is already initialized and no parameter was changed, it
55
+ will be left as is.
56
+ """
57
+ kwargs = self.get_params_for("module")
58
+ module = _get_model(self.module)
59
+ module = self.initialized_instance(module, kwargs)
60
+ # pylint: disable=attribute-defined-outside-init
61
+ self.module_ = module
62
+ return self
63
+
64
+ def _yield_callbacks(self):
65
+ # Here we parse the callbacks supplied as strings,
66
+ # e.g. 'accuracy', to the callbacks skorch expects
67
+ for name, cb, named_by_user in super()._yield_callbacks():
68
+ if name == "str":
69
+ train_cb, valid_cb = self._parse_str_callback(cb)
70
+ yield train_cb
71
+ if self.train_split is not None:
72
+ yield valid_cb
73
+ else:
74
+ yield name, cb, named_by_user
75
+
76
+ def _parse_str_callback(self, cb_supplied_name):
77
+ scoring = get_scorer(cb_supplied_name)
78
+ scoring_name = scoring._score_func.__name__
79
+ assert scoring_name.endswith(("_score", "_error", "_deviance", "_loss"))
80
+ if scoring_name.endswith("_score") or cb_supplied_name.startswith("neg_"):
81
+ lower_is_better = False
82
+ else:
83
+ lower_is_better = True
84
+ train_name = f"train_{cb_supplied_name}"
85
+ valid_name = f"valid_{cb_supplied_name}"
86
+ if self.cropped:
87
+ train_scoring = CroppedTrialEpochScoring(
88
+ cb_supplied_name, lower_is_better, on_train=True, name=train_name
89
+ )
90
+ valid_scoring = CroppedTrialEpochScoring(
91
+ cb_supplied_name, lower_is_better, on_train=False, name=valid_name
92
+ )
93
+ else:
94
+ train_scoring = PostEpochTrainScoring(
95
+ cb_supplied_name, lower_is_better, name=train_name
96
+ )
97
+ valid_scoring = EpochScoring(
98
+ cb_supplied_name, lower_is_better, on_train=False, name=valid_name
99
+ )
100
+ named_by_user = True
101
+ train_valid_callbacks = [
102
+ (train_name, train_scoring, named_by_user),
103
+ (valid_name, valid_scoring, named_by_user),
104
+ ]
105
+ return train_valid_callbacks
106
+
107
+ def on_batch_end(self, net, *batch, training=False, **kwargs):
108
+ # If training is false, assume that our loader has indices for this
109
+ # batch
110
+ if not training:
111
+ epoch_cbs = []
112
+ for name, cb in self.callbacks_:
113
+ if (
114
+ isinstance(
115
+ cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)
116
+ )
117
+ and (hasattr(cb, "window_inds_"))
118
+ and (not cb.on_train)
119
+ ):
120
+ epoch_cbs.append(cb)
121
+ # for trialwise decoding stuffs it might also be we don't have
122
+ # cropped loader, so no indices there
123
+ if len(epoch_cbs) > 0:
124
+ assert self._last_window_inds_ is not None
125
+ for cb in epoch_cbs:
126
+ cb.window_inds_.append(self._last_window_inds_)
127
+ self._last_window_inds_ = None
128
+
129
+ def predict_with_window_inds_and_ys(self, dataset):
130
+ self.module.eval()
131
+ preds = []
132
+ i_window_in_trials = []
133
+ i_window_stops = []
134
+ window_ys = []
135
+ for X, y, i in self.get_iterator(dataset, drop_index=False):
136
+ i_window_in_trials.append(i[0].cpu().numpy())
137
+ i_window_stops.append(i[2].cpu().numpy())
138
+ with torch.no_grad():
139
+ preds.append(to_numpy(self.module.forward(X.to(self.device))))
140
+ window_ys.append(y.cpu().numpy())
141
+ preds = np.concatenate(preds)
142
+ i_window_in_trials = np.concatenate(i_window_in_trials)
143
+ i_window_stops = np.concatenate(i_window_stops)
144
+ window_ys = np.concatenate(window_ys)
145
+ return dict(
146
+ preds=preds,
147
+ i_window_in_trials=i_window_in_trials,
148
+ i_window_stops=i_window_stops,
149
+ window_ys=window_ys,
150
+ )
151
+
152
+ # Changes the default target extractor to noop
153
+ @property
154
+ def _default_callbacks(self):
155
+ return [
156
+ ("epoch_timer", EpochTimer()),
157
+ (
158
+ "train_loss",
159
+ BatchScoring(
160
+ train_loss_score,
161
+ name="train_loss",
162
+ on_train=True,
163
+ target_extractor=noop,
164
+ ),
165
+ ),
166
+ (
167
+ "valid_loss",
168
+ BatchScoring(
169
+ valid_loss_score,
170
+ name="valid_loss",
171
+ target_extractor=noop,
172
+ ),
173
+ ),
174
+ ("print_log", PrintLog()),
175
+ ]
176
+
177
+ @property
178
+ @abc.abstractmethod
179
+ def mode(self) -> Literal["classification", "regression"]:
180
+ pass
181
+
182
+ def _set_signal_args(self, X, y, classes):
183
+ is_init = isinstance(self.module, torch.nn.Module)
184
+ if is_init:
185
+ self.log.info(
186
+ "The module passed is already initialized which is not recommended. "
187
+ "Instead, you can pass the module class and its parameters separately.\n"
188
+ "For more details, see "
189
+ "https://skorch.readthedocs.io/en/stable/user/neuralnet.html#module \n"
190
+ "Skipping setting signal-related parameters from data."
191
+ )
192
+ return
193
+ if classes is None:
194
+ classes = getattr(self, "classes", None)
195
+ signal_kwargs = infer_signal_properties(X, y, mode=self.mode, classes=classes)
196
+ if not signal_kwargs:
197
+ return
198
+
199
+ # kick out missing kwargs:
200
+ module_kwargs = dict()
201
+ module = _get_model(self.module)
202
+ all_module_kwargs = inspect.signature(module.__init__).parameters.keys()
203
+ for k, v in signal_kwargs.items():
204
+ if v is None:
205
+ continue
206
+ if k in all_module_kwargs:
207
+ module_kwargs[k] = v
208
+ else:
209
+ self.log.warning(f"Module {self.module!r} is missing parameter {k!r}.")
210
+
211
+ # kick out inferred signal kwargs if user specifies kwargs:
212
+ user_specified_kwargs = self.get_params_for("module").items()
213
+ if len(user_specified_kwargs) > 0:
214
+ self.log.info(
215
+ f"Overriding inferred parameters with user "
216
+ f"specified parameters{user_specified_kwargs!r}."
217
+ )
218
+ for k, v in self.get_params_for("module").items():
219
+ if k in module_kwargs:
220
+ module_kwargs.pop(k)
221
+ module_kwargs[k] = v
222
+
223
+ # save kwargs to self:
224
+ self.log.info(
225
+ f"Passing additional parameters {module_kwargs!r} "
226
+ f"to module {self.module!r}."
227
+ )
228
+ module_kwargs = {f"module__{k}": v for k, v in module_kwargs.items()}
229
+ self.set_params(**module_kwargs)
230
+
231
+ def get_dataset(self, X, y=None):
232
+ """Get a dataset that contains the input data and is passed to.
233
+
234
+ the iterator.
235
+
236
+ Override this if you want to initialize your dataset
237
+ differently.
238
+
239
+ Parameters
240
+ ----------
241
+ X : input data, compatible with skorch.dataset.Dataset
242
+ By default, you should be able to pass:
243
+
244
+ * mne.Epochs
245
+ * numpy arrays
246
+ * torch tensors
247
+ * pandas DataFrame or Series
248
+ * scipy sparse CSR matrices
249
+ * a dictionary of the former three
250
+ * a list/tuple of the former three
251
+ * a Dataset
252
+
253
+ If this doesn't work with your data, you have to pass a
254
+ ``Dataset`` that can deal with the data.
255
+
256
+ y : target data, compatible with skorch.dataset.Dataset
257
+ The same data types as for ``X`` are supported. If your X is
258
+ a Dataset that contains the target, ``y`` may be set to
259
+ None.
260
+
261
+ Returns
262
+ -------
263
+ dataset
264
+ The initialized dataset.
265
+ """
266
+ if isinstance(X, mne.BaseEpochs):
267
+ X = X.get_data(units="uV")
268
+ return super().get_dataset(X, y)
269
+
270
+ def partial_fit(self, X, y=None, classes=None, **fit_params):
271
+ """Fit the module.
272
+
273
+ If the module is initialized, it is not re-initialized, which
274
+ means that this method should be used if you want to continue
275
+ training a model (warm start).
276
+ If possible, signal-related parameters are inferred from the
277
+ data and passed to the module at initialisation.
278
+ Depending on the type of input passed, the following parameters
279
+ are inferred:
280
+
281
+ * mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
282
+ ``sfreq``, ``input_window_seconds``
283
+ * array-like: ``n_times``, ``n_chans``, ``n_outputs``
284
+ * WindowsDataset with ``targets_from='metadata'``
285
+ (or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
286
+ * other Dataset: ``n_times``, ``n_chans``
287
+ * other types: no parameters are inferred.
288
+
289
+ Parameters
290
+ ----------
291
+ X : input data, compatible with skorch.dataset.Dataset
292
+ By default, you should be able to pass:
293
+
294
+ * mne.Epochs
295
+ * numpy arrays
296
+ * torch tensors
297
+ * pandas DataFrame or Series
298
+ * scipy sparse CSR matrices
299
+ * a dictionary of the former three
300
+ * a list/tuple of the former three
301
+ * a Dataset
302
+
303
+ If this doesn't work with your data, you have to pass a
304
+ ``Dataset`` that can deal with the data.
305
+
306
+ y : target data, compatible with skorch.dataset.Dataset
307
+ The same data types as for ``X`` are supported. If your X is
308
+ a Dataset that contains the target, ``y`` may be set to
309
+ None.
310
+
311
+ classes : array, sahpe (n_classes,)
312
+ Solely for sklearn compatibility, currently unused.
313
+
314
+ **fit_params : dict
315
+ Additional parameters passed to the ``forward`` method of
316
+ the module and to the ``self.train_split`` call.
317
+ """
318
+ # this needs to be executed before the net is initialized:
319
+ if not self.signal_args_set_:
320
+ self._set_signal_args(X, y, classes)
321
+ self.signal_args_set_ = True
322
+ return super().partial_fit(X=X, y=y, classes=classes, **fit_params)
323
+
324
+ def fit(self, X, y=None, **fit_params):
325
+ """Initialize and fit the module.
326
+
327
+ If the module was already initialized, by calling fit, the
328
+ module will be re-initialized (unless ``warm_start`` is True).
329
+ If possible, signal-related parameters are inferred from the
330
+ data and passed to the module at initialisation.
331
+ Depending on the type of input passed, the following parameters
332
+ are inferred:
333
+
334
+ * mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
335
+ ``sfreq``, ``input_window_seconds``
336
+ * array-like: ``n_times``, ``n_chans``, ``n_outputs``
337
+ * WindowsDataset with ``targets_from='metadata'``
338
+ (or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
339
+ * other Dataset: ``n_times``, ``n_chans``
340
+ * other types: no parameters are inferred.
341
+
342
+ Parameters
343
+ ----------
344
+ X : input data, compatible with skorch.dataset.Dataset
345
+ By default, you should be able to pass:
346
+
347
+ * mne.Epochs
348
+ * numpy arrays
349
+ * torch tensors
350
+ * pandas DataFrame or Series
351
+ * scipy sparse CSR matrices
352
+ * a dictionary of the former three
353
+ * a list/tuple of the former three
354
+ * a Dataset
355
+
356
+ If this doesn't work with your data, you have to pass a
357
+ ``Dataset`` that can deal with the data.
358
+
359
+ y : target data, compatible with skorch.dataset.Dataset
360
+ The same data types as for ``X`` are supported. If your X is
361
+ a Dataset that contains the target, ``y`` may be set to
362
+ None.
363
+
364
+ **fit_params : dict
365
+ Additional parameters passed to the ``forward`` method of
366
+ the module and to the ``self.train_split`` call.
367
+ """
368
+ # this needs to be executed before the net is initialized:
369
+ if not self.signal_args_set_:
370
+ self._set_signal_args(X, y, classes=None)
371
+ self.signal_args_set_ = True
372
+ return super().fit(X=X, y=y, **fit_params)
@@ -0,0 +1,22 @@
1
+ from .functions import (
2
+ _get_gaussian_kernel1d,
3
+ drop_path,
4
+ hilbert_freq,
5
+ identity,
6
+ plv_time,
7
+ safe_log,
8
+ square,
9
+ )
10
+ from .initialization import glorot_weight_zero_bias, rescale_parameter
11
+
12
+ __all__ = [
13
+ "_get_gaussian_kernel1d",
14
+ "drop_path",
15
+ "hilbert_freq",
16
+ "identity",
17
+ "plv_time",
18
+ "safe_log",
19
+ "square",
20
+ "glorot_weight_zero_bias",
21
+ "rescale_parameter",
22
+ ]
@@ -0,0 +1,251 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def square(x):
10
+ return x * x
11
+
12
+
13
+ def safe_log(x, eps: float = 1e-6) -> torch.Tensor:
14
+ """Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
15
+ return torch.log(torch.clamp(x, min=eps))
16
+
17
+
18
+ def identity(x):
19
+ return x
20
+
21
+
22
+ def drop_path(
23
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
24
+ ):
25
+ """Drop paths (Stochastic Depth) per sample.
26
+
27
+ Notes: This implementation is taken from timm library.
28
+
29
+ All credit goes to Ross Wightman.
30
+
31
+ Parameters
32
+ ----------
33
+ x : torch.Tensor
34
+ input tensor
35
+ drop_prob : float, optional
36
+ survival rate (i.e. probability of being kept), by default 0.0
37
+ training : bool, optional
38
+ whether the model is in training mode, by default False
39
+ scale_by_keep : bool, optional
40
+ whether to scale output by (1/keep_prob) during training, by default True
41
+
42
+ Returns
43
+ -------
44
+ torch.Tensor
45
+ output tensor
46
+
47
+ Notes from Ross Wightman:
48
+ (when applied in main path of residual blocks)
49
+ This is the same as the DropConnect impl I created for EfficientNet,
50
+ etc. networks, however,
51
+ the original name is misleading as 'Drop Connect' is a different form
52
+ of dropout in a separate paper...
53
+ See discussion : https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
54
+ ... I've opted for changing the layer and argument names to 'drop path'
55
+ rather than mix DropConnect as a layer name and use
56
+ 'survival rate' as the argument.
57
+ """
58
+ if drop_prob == 0.0 or not training:
59
+ return x
60
+ keep_prob = 1 - drop_prob
61
+ shape = (x.shape[0],) + (1,) * (
62
+ x.ndim - 1
63
+ ) # work with diff dim tensors, not just 2D ConvNets
64
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
65
+ if keep_prob > 0.0 and scale_by_keep:
66
+ random_tensor.div_(keep_prob)
67
+ return x * random_tensor
68
+
69
+
70
+ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
71
+ """
72
+ Generates a 1-dimensional Gaussian kernel based on the specified kernel.
73
+
74
+ size and standard deviation (sigma).
75
+ This kernel is useful for Gaussian smoothing or filtering operations in
76
+ image processing. The function calculates a range limit to ensure the kernel
77
+ effectively covers the Gaussian distribution. It generates a tensor of
78
+ specified size and type, filled with values distributed according to a
79
+ Gaussian curve, normalized using a softmax function
80
+ to ensure all weights sum to 1.
81
+
82
+ Parameters
83
+ ----------
84
+ kernel_size : int
85
+ sigma : float
86
+
87
+ Returns
88
+ -------
89
+ kernel1d : torch.Tensor
90
+
91
+ Notes
92
+ -----
93
+ Code copied and modified from TorchVision:
94
+ https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py#L725-L732
95
+ All rights reserved.
96
+
97
+ LICENSE in https://github.com/pytorch/vision/blob/main/LICENSE
98
+ """
99
+ ksize_half = (kernel_size - 1) * 0.5
100
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
101
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
102
+ kernel1d = pdf / pdf.sum()
103
+ return kernel1d
104
+
105
+
106
+ def hilbert_freq(x, forward_fourier=True):
107
+ r"""
108
+ Compute the Hilbert transform using PyTorch, separating the real and
109
+ imaginary parts.
110
+
111
+ The analytic signal :math:`x_a(t)` of a real-valued signal :math:`x(t)`
112
+ is defined as:
113
+
114
+ .. math::
115
+
116
+ x_a(t) = x(t) + i y(t) = \mathcal{F}^{-1} \{ U(f) \mathcal{F}\{x(t)\} \}
117
+
118
+ where:
119
+ - :math:`\mathcal{F}` is the Fourier transform,
120
+ - :math:`U(f)` is the unit step function,
121
+ - :math:`y(t)` is the Hilbert transform of :math:`x(t)`.
122
+
123
+
124
+ Parameters
125
+ ----------
126
+ input : torch.Tensor
127
+ Input tensor. The expected shape depends on the `forward_fourier` parameter:
128
+
129
+ - If `forward_fourier` is True:
130
+ (..., seq_len)
131
+ - If `forward_fourier` is False:
132
+ (..., seq_len / 2 + 1, 2)
133
+
134
+ forward_fourier : bool, optional
135
+ Determines the format of the input tensor.
136
+ - If True, the input is in the forward Fourier domain.
137
+ - If False, the input contains separate real and imaginary parts.
138
+ Default is True.
139
+
140
+ Returns
141
+ -------
142
+ torch.Tensor
143
+ Output tensor with shape (..., seq_len, 2), where the last dimension represents
144
+ the real and imaginary parts of the Hilbert transform.
145
+
146
+ Examples
147
+ --------
148
+ >>> import torch
149
+ >>> input = torch.randn(10, 100) # Example input tensor
150
+ >>> output = hilbert_transform(input)
151
+ >>> print(output.shape)
152
+ torch.Size([10, 100, 2])
153
+
154
+ Notes
155
+ -----
156
+ The implementation is matching scipy implementation, but using torch.
157
+ https://github.com/scipy/scipy/blob/v1.14.1/scipy/signal/_signaltools.py#L2287-L2394
158
+
159
+ """
160
+ if forward_fourier:
161
+ x = torch.fft.rfft(x, norm=None, dim=-1)
162
+ x = torch.view_as_real(x)
163
+ x = x * 2.0
164
+ x[..., 0, :] = x[..., 0, :] / 2.0 # Don't multiply the DC-term by 2
165
+ x = F.pad(
166
+ x, [0, 0, 0, x.shape[-2] - 2]
167
+ ) # Fill Fourier coefficients to retain shape
168
+ x = torch.view_as_complex(x)
169
+ x = torch.fft.ifft(x, norm=None, dim=-1) # returns complex signal
170
+ x = torch.view_as_real(x)
171
+
172
+ return x
173
+
174
+
175
+ def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
176
+ """Compute the Phase Locking Value (PLV) metric in the time domain.
177
+
178
+ The Phase Locking Value (PLV) is a measure of the synchronization between
179
+ different channels by evaluating the consistency of phase differences
180
+ over time. It ranges from 0 (no synchronization) to 1 (perfect
181
+ synchronization) [Lachaux1999]_.
182
+
183
+ Parameters
184
+ ----------
185
+ x : torch.Tensor
186
+ Input tensor containing the signal data.
187
+
188
+ - If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
189
+ - If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
190
+ where the last dimension represents the real and imaginary parts.
191
+
192
+ forward_fourier : bool, optional
193
+ Specifies the format of the input tensor `x`.
194
+
195
+ - If `True`, `x` is assumed to be in the time domain.
196
+ - If `False`, `x` is assumed to be in the Fourier domain with separate real and
197
+ imaginary components.
198
+
199
+ Default is `True`.
200
+ epsilon : float, default 1e-6
201
+ Small numerical value to ensure positivity constraint on the complex part
202
+
203
+ Returns
204
+ -------
205
+ plv : torch.Tensor
206
+ The Phase Locking Value matrix with shape `(..., channels, channels)`. Each
207
+ element `[i, j]` represents the PLV between channel `i` and channel `j`.
208
+
209
+ References
210
+ ----------
211
+ .. [Lachaux1999] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
212
+ Measuring phase synchrony in brain signals. Human brain mapping,
213
+ 8(4), 194-208.
214
+ """
215
+ # Compute the analytic signal using the Hilbert transform.
216
+ # x_a has separate real and imaginary parts.
217
+ analytic_signal = hilbert_freq(x, forward_fourier)
218
+ # Calculate the amplitude (magnitude) of the analytic signal.
219
+ # Adding a small epsilon (1e-6) to avoid division by zero.
220
+ amplitude = torch.sqrt(
221
+ analytic_signal[..., 0] ** 2 + analytic_signal[..., 1] ** 2 + 1e-6
222
+ )
223
+ # Normalize the analytic signal to obtain unit vectors (phasors).
224
+ unit_phasor = analytic_signal / amplitude.unsqueeze(-1)
225
+
226
+ # Compute the real part of the outer product between phasors of
227
+ # different channels.
228
+ real_real = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 0].transpose(-2, -1))
229
+
230
+ # Compute the imaginary part of the outer product between phasors of
231
+ # different channels.
232
+ imag_imag = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 1].transpose(-2, -1))
233
+
234
+ # Compute the cross-terms for the real and imaginary parts.
235
+ real_imag = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 1].transpose(-2, -1))
236
+ imag_real = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 0].transpose(-2, -1))
237
+
238
+ # Combine the real and imaginary parts to form the complex correlation.
239
+ correlation_real = real_real + imag_imag
240
+ correlation_imag = real_imag - imag_real
241
+
242
+ # Determine the number of time points (or frequency bins if in Fourier domain).
243
+ time = amplitude.shape[-1]
244
+
245
+ # Calculate the PLV by averaging the magnitude of the complex correlation over time.
246
+ # epsilon is small numerical value to ensure positivity constraint on the complex part
247
+ plv_matrix = (
248
+ 1 / time * torch.sqrt(correlation_real**2 + correlation_imag**2 + epsilon)
249
+ )
250
+
251
+ return plv_matrix
@@ -0,0 +1,47 @@
1
+ import math
2
+
3
+ from torch import nn
4
+
5
+
6
+ def glorot_weight_zero_bias(model):
7
+ """Initialize parameters of all modules by initializing weights with.
8
+
9
+ glorot uniform/xavier initialization, and setting biases to zero. Weights from
10
+ batch norm layers are set to 1.
11
+
12
+ Parameters
13
+ ----------
14
+ model : Module
15
+ """
16
+ for module in model.modules():
17
+ if hasattr(module, "weight"):
18
+ if "BatchNorm" in module.__class__.__name__:
19
+ nn.init.constant_(module.weight, 1)
20
+ if hasattr(module, "bias"):
21
+ if module.bias is not None:
22
+ nn.init.constant_(module.bias, 0)
23
+
24
+
25
+ def rescale_parameter(param, layer_id):
26
+ r"""Recaling the l-th transformer layer.
27
+
28
+ Rescales the parameter tensor by the inverse square root of the layer id.
29
+ Made inplace. :math:`\frac{1}{\sqrt{2 \cdot \text{layer\_id}}}` [Beit2022]
30
+
31
+ In the labram, this is used to rescale the output matrices
32
+ (i.e., the last linear projection within each sub-layer) of the
33
+ self-attention module.
34
+
35
+ Parameters
36
+ ----------
37
+ param: :class:`torch.Tensor`
38
+ tensor to be rescaled
39
+ layer_id: int
40
+ layer id in the neural network
41
+
42
+ References
43
+ ----------
44
+ [Beit2022] Hangbo Bao, Li Dong, Songhao Piao, Furu We (2022). BEIT: BERT
45
+ Pre-Training of Image Transformers.
46
+ """
47
+ param.div_(math.sqrt(2.0 * layer_id))