braindecode 1.3.0.dev173909672__tar.gz → 1.3.0.dev177628147__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. {braindecode-1.3.0.dev173909672/braindecode.egg-info → braindecode-1.3.0.dev177628147}/PKG-INFO +13 -3
  2. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/base.py +1 -1
  3. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/functional.py +255 -54
  4. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/transforms.py +76 -2
  5. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/__init__.py +12 -4
  6. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/base.py +132 -153
  7. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bcicomp.py +4 -4
  8. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bids.py +3 -3
  9. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/experimental.py +2 -2
  10. braindecode-1.3.0.dev177628147/braindecode/datasets/hub.py +962 -0
  11. braindecode-1.3.0.dev177628147/braindecode/datasets/hub_validation.py +113 -0
  12. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/mne.py +3 -5
  13. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/moabb.py +17 -7
  14. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/nmt.py +2 -2
  15. braindecode-1.3.0.dev177628147/braindecode/datasets/registry.py +120 -0
  16. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/sleep_physio_challe_18.py +2 -2
  17. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/sleep_physionet.py +2 -2
  18. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/tuh.py +2 -2
  19. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/xy.py +2 -2
  20. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datautil/__init__.py +11 -1
  21. braindecode-1.3.0.dev177628147/braindecode/datautil/channel_utils.py +114 -0
  22. braindecode-1.3.0.dev177628147/braindecode/datautil/hub_formats.py +180 -0
  23. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datautil/serialization.py +7 -8
  24. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/eegneuralnet.py +2 -0
  25. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/functional/functions.py +6 -2
  26. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/functional/initialization.py +2 -3
  27. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/__init__.py +18 -1
  28. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/atcnet.py +27 -28
  29. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/attentionbasenet.py +39 -32
  30. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/attn_sleep.py +2 -0
  31. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/base.py +280 -2
  32. braindecode-1.3.0.dev177628147/braindecode/models/bendr.py +469 -0
  33. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/biot.py +2 -0
  34. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/contrawr.py +2 -0
  35. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/ctnet.py +8 -3
  36. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/deepsleepnet.py +28 -19
  37. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eegconformer.py +2 -2
  38. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eeginception_erp.py +31 -25
  39. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eegitnet.py +2 -0
  40. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eegminer.py +2 -0
  41. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eegnet.py +1 -1
  42. braindecode-1.3.0.dev177628147/braindecode/models/eegsym.py +917 -0
  43. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eegtcnet.py +2 -0
  44. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/fbcnet.py +5 -1
  45. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/fblightconvnet.py +2 -0
  46. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/fbmsnet.py +20 -6
  47. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/ifnet.py +2 -0
  48. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/labram.py +193 -87
  49. braindecode-1.3.0.dev177628147/braindecode/models/luna.py +836 -0
  50. braindecode-1.3.0.dev177628147/braindecode/models/medformer.py +758 -0
  51. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/msvtnet.py +2 -0
  52. braindecode-1.3.0.dev177628147/braindecode/models/patchedtransformer.py +640 -0
  53. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/signal_jepa.py +111 -27
  54. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/sinc_shallow.py +12 -9
  55. braindecode-1.3.0.dev177628147/braindecode/models/sstdpn.py +869 -0
  56. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/summary.csv +6 -0
  57. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/syncnet.py +2 -0
  58. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/tcn.py +2 -0
  59. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/usleep.py +26 -21
  60. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/util.py +89 -0
  61. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/attention.py +10 -10
  62. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/blocks.py +3 -3
  63. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/filter.py +2 -9
  64. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/layers.py +18 -17
  65. braindecode-1.3.0.dev177628147/braindecode/preprocessing/__init__.py +271 -0
  66. braindecode-1.3.0.dev177628147/braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  67. braindecode-1.3.0.dev177628147/braindecode/preprocessing/mne_preprocess.py +240 -0
  68. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/preprocess.py +146 -51
  69. braindecode-1.3.0.dev177628147/braindecode/preprocessing/util.py +177 -0
  70. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/windowers.py +26 -20
  71. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/samplers/base.py +8 -8
  72. braindecode-1.3.0.dev177628147/braindecode/version.py +1 -0
  73. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147/braindecode.egg-info}/PKG-INFO +13 -3
  74. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/SOURCES.txt +13 -0
  75. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/requires.txt +13 -1
  76. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/api.rst +136 -13
  77. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/conf.py +53 -14
  78. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/index.rst +5 -5
  79. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/install/install_pip.rst +7 -1
  80. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/install/install_source.rst +1 -1
  81. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/attention.rst +2 -2
  82. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/channel.rst +2 -2
  83. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/convolution.rst +2 -2
  84. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/filterbank.rst +3 -3
  85. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/gnn.rst +3 -6
  86. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/interpretable.rst +3 -3
  87. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/lbm.rst +2 -2
  88. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/recurrent.rst +3 -3
  89. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/categorization/spd.rst +3 -3
  90. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/whats_new.rst +33 -1
  91. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/pyproject.toml +16 -2
  92. braindecode-1.3.0.dev173909672/braindecode/preprocessing/__init__.py +0 -37
  93. braindecode-1.3.0.dev173909672/braindecode/preprocessing/mne_preprocess.py +0 -77
  94. braindecode-1.3.0.dev173909672/braindecode/version.py +0 -1
  95. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/LICENSE.txt +0 -0
  96. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/MANIFEST.in +0 -0
  97. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/NOTICE.txt +0 -0
  98. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/README.rst +0 -0
  99. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/__init__.py +0 -0
  100. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/__init__.py +0 -0
  101. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/classifier.py +0 -0
  102. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bbci.py +0 -0
  103. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/datautil/util.py +0 -0
  104. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/functional/__init__.py +0 -0
  105. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/deep4.py +0 -0
  106. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eeginception_mi.py +0 -0
  107. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eegnex.py +0 -0
  108. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/eegsimpleconv.py +0 -0
  109. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/hybrid.py +0 -0
  110. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/sccnet.py +0 -0
  111. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/shallow_fbcsp.py +0 -0
  112. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/sleep_stager_blanco_2020.py +0 -0
  113. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/sleep_stager_chambon_2018.py +0 -0
  114. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/sparcnet.py +0 -0
  115. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/tidnet.py +0 -0
  116. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/models/tsinception.py +0 -0
  117. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/__init__.py +0 -0
  118. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/activation.py +0 -0
  119. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/convolution.py +0 -0
  120. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/linear.py +0 -0
  121. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/parametrization.py +0 -0
  122. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/stats.py +0 -0
  123. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/util.py +0 -0
  124. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/modules/wrapper.py +0 -0
  125. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/regressor.py +0 -0
  126. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/samplers/__init__.py +0 -0
  127. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/samplers/ssl.py +0 -0
  128. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/training/__init__.py +0 -0
  129. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/training/callbacks.py +0 -0
  130. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/training/losses.py +0 -0
  131. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/training/scoring.py +0 -0
  132. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/util.py +0 -0
  133. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/visualization/__init__.py +0 -0
  134. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/visualization/confusion_matrices.py +0 -0
  135. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode/visualization/gradients.py +0 -0
  136. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/dependency_links.txt +0 -0
  137. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/top_level.txt +0 -0
  138. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/Makefile +0 -0
  139. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/_templates/autosummary/class.rst +0 -0
  140. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/_templates/autosummary/function.rst +0 -0
  141. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/cite.rst +0 -0
  142. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/help.rst +0 -0
  143. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/install/install.rst +0 -0
  144. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/models.rst +0 -0
  145. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/models_categorization.rst +0 -0
  146. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/models_table.rst +0 -0
  147. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/models/models_visualization.rst +0 -0
  148. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/docs/sg_execution_times.rst +0 -0
  149. {braindecode-1.3.0.dev173909672 → braindecode-1.3.0.dev177628147}/setup.cfg +0 -0
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.3.0.dev173909672
3
+ Version: 1.3.0.dev177628147
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
- Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
5
+ Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
7
7
  License: BSD-3-Clause
8
8
  Project-URL: homepage, https://braindecode.org
@@ -38,8 +38,14 @@ Requires-Dist: wfdb
38
38
  Requires-Dist: h5py
39
39
  Requires-Dist: linear_attention_transformer
40
40
  Requires-Dist: docstring_inheritance
41
+ Requires-Dist: rotary_embedding_torch
41
42
  Provides-Extra: moabb
42
43
  Requires-Dist: moabb>=1.2.0; extra == "moabb"
44
+ Provides-Extra: eegprep
45
+ Requires-Dist: eegprep[eeglabio]>=0.1.1; extra == "eegprep"
46
+ Provides-Extra: hub
47
+ Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hub"
48
+ Requires-Dist: zarr<3.0,>=2.18; extra == "hub"
43
49
  Provides-Extra: tests
44
50
  Requires-Dist: pytest; extra == "tests"
45
51
  Requires-Dist: pytest-cov; extra == "tests"
@@ -65,7 +71,11 @@ Requires-Dist: pre-commit; extra == "docs"
65
71
  Requires-Dist: openneuro-py; extra == "docs"
66
72
  Requires-Dist: plotly; extra == "docs"
67
73
  Provides-Extra: all
68
- Requires-Dist: braindecode[docs,moabb,tests]; extra == "all"
74
+ Requires-Dist: braindecode[moabb]; extra == "all"
75
+ Requires-Dist: braindecode[hub]; extra == "all"
76
+ Requires-Dist: braindecode[tests]; extra == "all"
77
+ Requires-Dist: braindecode[docs]; extra == "all"
78
+ Requires-Dist: braindecode[eegprep]; extra == "all"
69
79
  Dynamic: license-file
70
80
 
71
81
  .. image:: https://badges.gitter.im/braindecodechat/community.svg
@@ -189,7 +189,7 @@ class AugmentedDataLoader(DataLoader):
189
189
 
190
190
  Parameters
191
191
  ----------
192
- dataset : BaseDataset
192
+ dataset : RecordDataset
193
193
  The dataset containing the signals.
194
194
  transforms : list | Transform, optional
195
195
  Transform or sequence of Transform to be applied to each batch.
@@ -1,12 +1,17 @@
1
1
  # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
2
  # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
3
  # Gustavo Rodrigues <gustavenrique01@gmail.com>
4
+ # Bruna Lopes <brunajaflopes@gmail.com>
4
5
  #
5
6
  # License: BSD (3-clause)
6
7
 
8
+ from __future__ import annotations
9
+
7
10
  from numbers import Real
11
+ from typing import Literal
8
12
 
9
13
  import numpy as np
14
+ import numpy.typing as npt
10
15
  import torch
11
16
  from mne.filter import notch_filter
12
17
  from scipy.interpolate import Rbf
@@ -15,7 +20,7 @@ from torch.fft import fft, ifft
15
20
  from torch.nn.functional import one_hot, pad
16
21
 
17
22
 
18
- def identity(X, y):
23
+ def identity(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
19
24
  """Identity operation.
20
25
 
21
26
  Parameters
@@ -35,7 +40,7 @@ def identity(X, y):
35
40
  return X, y
36
41
 
37
42
 
38
- def time_reverse(X, y):
43
+ def time_reverse(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
39
44
  """Flip the time axis of each input.
40
45
 
41
46
  Parameters
@@ -55,7 +60,7 @@ def time_reverse(X, y):
55
60
  return torch.flip(X, [-1]), y
56
61
 
57
62
 
58
- def sign_flip(X, y):
63
+ def sign_flip(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
59
64
  """Flip the sign axis of each input.
60
65
 
61
66
  Parameters
@@ -75,7 +80,13 @@ def sign_flip(X, y):
75
80
  return -X, y
76
81
 
77
82
 
78
- def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
83
+ def _new_random_fft_phase_odd(
84
+ batch_size: int,
85
+ c: int,
86
+ n: int,
87
+ device: torch.device,
88
+ random_state: int | np.random.RandomState | None,
89
+ ) -> torch.Tensor:
79
90
  rng = check_random_state(random_state)
80
91
  random_phase = torch.from_numpy(
81
92
  2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
@@ -90,7 +101,13 @@ def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
90
101
  )
91
102
 
92
103
 
93
- def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
104
+ def _new_random_fft_phase_even(
105
+ batch_size: int,
106
+ c: int,
107
+ n: int,
108
+ device: torch.device,
109
+ random_state: int | np.random.RandomState | None,
110
+ ) -> torch.Tensor:
94
111
  rng = check_random_state(random_state)
95
112
  random_phase = torch.from_numpy(
96
113
  2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
@@ -109,7 +126,13 @@ def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
109
126
  _new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
110
127
 
111
128
 
112
- def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
129
+ def ft_surrogate(
130
+ X: torch.Tensor,
131
+ y: torch.Tensor,
132
+ phase_noise_magnitude: float,
133
+ channel_indep: bool,
134
+ random_state: int | np.random.RandomState | None = None,
135
+ ) -> tuple[torch.Tensor, torch.Tensor]:
113
136
  """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
114
137
 
115
138
  Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
@@ -175,7 +198,9 @@ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
175
198
  return transformed_X, y
176
199
 
177
200
 
178
- def _pick_channels_randomly(X, p_pick, random_state):
201
+ def _pick_channels_randomly(
202
+ X: torch.Tensor, p_pick: float, random_state: int | np.random.RandomState | None
203
+ ) -> torch.Tensor:
179
204
  rng = check_random_state(random_state)
180
205
  batch_size, n_channels, _ = X.shape
181
206
  # allows to use the same RNG
@@ -188,7 +213,12 @@ def _pick_channels_randomly(X, p_pick, random_state):
188
213
  return torch.sigmoid(1000 * (unif_samples - p_pick))
189
214
 
190
215
 
191
- def channels_dropout(X, y, p_drop, random_state=None):
216
+ def channels_dropout(
217
+ X: torch.Tensor,
218
+ y: torch.Tensor,
219
+ p_drop: float,
220
+ random_state: int | np.random.RandomState | None = None,
221
+ ) -> tuple[torch.Tensor, torch.Tensor]:
192
222
  """Randomly set channels to flat signal.
193
223
 
194
224
  Part of the CMSAugment policy proposed in [1]_
@@ -222,7 +252,9 @@ def channels_dropout(X, y, p_drop, random_state=None):
222
252
  return X * mask.unsqueeze(-1), y
223
253
 
224
254
 
225
- def _make_permutation_matrix(X, mask, random_state):
255
+ def _make_permutation_matrix(
256
+ X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None
257
+ ) -> torch.Tensor:
226
258
  rng = check_random_state(random_state)
227
259
  batch_size, n_channels, _ = X.shape
228
260
  hard_mask = mask.round()
@@ -241,7 +273,12 @@ def _make_permutation_matrix(X, mask, random_state):
241
273
  return batch_permutations
242
274
 
243
275
 
244
- def channels_shuffle(X, y, p_shuffle, random_state=None):
276
+ def channels_shuffle(
277
+ X: torch.Tensor,
278
+ y: torch.Tensor,
279
+ p_shuffle: float,
280
+ random_state: int | np.random.RandomState | None = None,
281
+ ) -> tuple[torch.Tensor, torch.Tensor]:
245
282
  """Randomly shuffle channels in EEG data matrix.
246
283
 
247
284
  Part of the CMSAugment policy proposed in [1]_
@@ -280,7 +317,12 @@ def channels_shuffle(X, y, p_shuffle, random_state=None):
280
317
  return torch.matmul(batch_permutations, X), y
281
318
 
282
319
 
283
- def gaussian_noise(X, y, std, random_state=None):
320
+ def gaussian_noise(
321
+ X: torch.Tensor,
322
+ y: torch.Tensor,
323
+ std: float,
324
+ random_state: int | np.random.RandomState | None = None,
325
+ ) -> tuple[torch.Tensor, torch.Tensor]:
284
326
  """Randomly add white Gaussian noise to all channels.
285
327
 
286
328
  Suggested e.g. in [1]_, [2]_ and [3]_
@@ -332,7 +374,9 @@ def gaussian_noise(X, y, std, random_state=None):
332
374
  return transformed_X, y
333
375
 
334
376
 
335
- def channels_permute(X, y, permutation):
377
+ def channels_permute(
378
+ X: torch.Tensor, y: torch.Tensor, permutation: list[int]
379
+ ) -> tuple[torch.Tensor, torch.Tensor]:
336
380
  """Permute EEG channels according to fixed permutation matrix.
337
381
 
338
382
  Suggested e.g. in [1]_
@@ -362,7 +406,12 @@ def channels_permute(X, y, permutation):
362
406
  return X[..., permutation, :], y
363
407
 
364
408
 
365
- def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
409
+ def smooth_time_mask(
410
+ X: torch.Tensor,
411
+ y: torch.Tensor,
412
+ mask_start_per_sample: torch.Tensor,
413
+ mask_len_samples: int,
414
+ ) -> tuple[torch.Tensor, torch.Tensor]:
366
415
  """Smoothly replace a contiguous part of all channels by zeros.
367
416
 
368
417
  Originally proposed in [1]_ and [2]_
@@ -412,7 +461,13 @@ def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
412
461
  return X * mask, y
413
462
 
414
463
 
415
- def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
464
+ def bandstop_filter(
465
+ X: torch.Tensor,
466
+ y: torch.Tensor,
467
+ sfreq: float,
468
+ bandwidth: float,
469
+ freqs_to_notch: npt.ArrayLike | None,
470
+ ) -> tuple[torch.Tensor, torch.Tensor]:
416
471
  """Apply a band-stop filter with desired bandwidth at the desired frequency
417
472
  position.
418
473
 
@@ -451,7 +506,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
451
506
  Representation Learning for Electroencephalogram Classification. In
452
507
  Machine Learning for Health (pp. 238-253). PMLR.
453
508
  """
454
- if bandwidth == 0:
509
+ if bandwidth == 0 or freqs_to_notch is None:
455
510
  return X, y
456
511
  transformed_X = X.clone()
457
512
  for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
@@ -469,7 +524,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
469
524
  return transformed_X, y
470
525
 
471
526
 
472
- def _analytic_transform(x):
527
+ def _analytic_transform(x: torch.Tensor) -> torch.Tensor:
473
528
  if torch.is_complex(x):
474
529
  raise ValueError("x must be real.")
475
530
 
@@ -486,12 +541,12 @@ def _analytic_transform(x):
486
541
  return ifft(f * h, dim=-1)
487
542
 
488
543
 
489
- def _nextpow2(n):
544
+ def _nextpow2(n: int) -> int:
490
545
  """Return the first integer N such that 2**N >= abs(n)."""
491
546
  return int(np.ceil(np.log2(np.abs(n))))
492
547
 
493
548
 
494
- def _frequency_shift(X, fs, f_shift):
549
+ def _frequency_shift(X: torch.Tensor, fs: float, f_shift: float) -> torch.Tensor:
495
550
  """
496
551
  Shift the specified signal by the specified frequency.
497
552
 
@@ -504,9 +559,13 @@ def _frequency_shift(X, fs, f_shift):
504
559
  t = torch.arange(N_padded, device=X.device) / fs
505
560
  padded = pad(X, (0, N_padded - N_orig))
506
561
  analytical = _analytic_transform(padded)
507
- if isinstance(f_shift, (float, int, np.ndarray, list)):
508
- f_shift = torch.as_tensor(f_shift).float()
509
- f_shift_stack = f_shift.repeat(N_padded, n_channels, 1)
562
+ if isinstance(f_shift, torch.Tensor):
563
+ _f_shift = f_shift
564
+ elif isinstance(f_shift, (float, int, np.ndarray, list)):
565
+ _f_shift = torch.as_tensor(f_shift).float()
566
+ else:
567
+ raise ValueError(f"Invalid f_shift type: {type(f_shift)}")
568
+ f_shift_stack = _f_shift.repeat(N_padded, n_channels, 1)
510
569
  reshaped_f_shift = f_shift_stack.permute(
511
570
  *torch.arange(f_shift_stack.ndim - 1, -1, -1)
512
571
  )
@@ -514,7 +573,9 @@ def _frequency_shift(X, fs, f_shift):
514
573
  return shifted[..., :N_orig].real.float()
515
574
 
516
575
 
517
- def frequency_shift(X, y, delta_freq, sfreq):
576
+ def frequency_shift(
577
+ X: torch.Tensor, y: torch.Tensor, delta_freq: float, sfreq: float
578
+ ) -> tuple[torch.Tensor, torch.Tensor]:
518
579
  """Adds a shift in the frequency domain to all channels.
519
580
 
520
581
  Note that here, the shift is the same for all channels of a single example.
@@ -545,7 +606,7 @@ def frequency_shift(X, y, delta_freq, sfreq):
545
606
  return transformed_X, y
546
607
 
547
608
 
548
- def _torch_normalize_vectors(rr):
609
+ def _torch_normalize_vectors(rr: torch.Tensor) -> torch.Tensor:
549
610
  """Normalize surface vertices."""
550
611
  norm = torch.linalg.norm(rr, axis=1, keepdim=True)
551
612
  mask = norm > 0
@@ -554,7 +615,9 @@ def _torch_normalize_vectors(rr):
554
615
  return new_rr
555
616
 
556
617
 
557
- def _torch_legval(x, c, tensor=True):
618
+ def _torch_legval(
619
+ x: torch.Tensor, c: torch.Tensor, tensor: bool = True
620
+ ) -> torch.Tensor:
558
621
  """
559
622
  Evaluate a Legendre series at points x.
560
623
  If `c` is of length `n + 1`, this function returns the value:
@@ -662,7 +725,9 @@ def _torch_legval(x, c, tensor=True):
662
725
  return c0 + c1 * x
663
726
 
664
727
 
665
- def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
728
+ def _torch_calc_g(
729
+ cosang: torch.Tensor, stiffness: float = 4, n_legendre_terms: int = 50
730
+ ) -> torch.Tensor:
666
731
  """Calculate spherical spline g function between points on a sphere.
667
732
 
668
733
  Parameters
@@ -718,23 +783,25 @@ def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
718
783
  return _torch_legval(cosang, [0] + factors)
719
784
 
720
785
 
721
- def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
786
+ def _torch_make_interpolation_matrix(
787
+ pos_from: torch.Tensor, pos_to: torch.Tensor, alpha: float = 1e-5
788
+ ) -> torch.Tensor:
722
789
  """Compute interpolation matrix based on spherical splines.
723
790
 
724
791
  Implementation based on [1]_
725
792
 
726
793
  Parameters
727
794
  ----------
728
- pos_from : np.ndarray of float, shape(n_good_sensors, 3)
795
+ pos_from : torch.Tensor of float, shape(n_good_sensors, 3)
729
796
  The positions to interpolate from.
730
- pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
797
+ pos_to : torch.Tensor of float, shape(n_bad_sensors, 3)
731
798
  The positions to interpolate.
732
799
  alpha : float
733
800
  Regularization parameter. Defaults to 1e-5.
734
801
 
735
802
  Returns
736
803
  -------
737
- interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to))
804
+ interpolation : torch.Tensor of float, shape(len(pos_from), len(pos_to))
738
805
  The interpolation matrix that maps good signals to the location
739
806
  of bad signals.
740
807
 
@@ -822,7 +889,12 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
822
889
  return interpolation
823
890
 
824
891
 
825
- def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
892
+ def _rotate_signals(
893
+ X: torch.Tensor,
894
+ rotations: list[torch.Tensor],
895
+ sensors_positions_matrix: torch.Tensor,
896
+ spherical: bool = True,
897
+ ) -> torch.Tensor:
826
898
  sensors_positions_matrix = sensors_positions_matrix.to(X.device)
827
899
  rot_sensors_matrices = [
828
900
  rotation.matmul(sensors_positions_matrix) for rotation in rotations
@@ -853,22 +925,29 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
853
925
  return transformed_X
854
926
 
855
927
 
856
- def _make_rotation_matrix(axis, angle, degrees=True):
928
+ def _make_rotation_matrix(
929
+ axis: Literal["x", "y", "z"],
930
+ angle: float | int | np.ndarray | list | torch.Tensor,
931
+ degrees: bool = True,
932
+ ) -> torch.Tensor:
857
933
  assert axis in ["x", "y", "z"], "axis should be either x, y or z."
858
-
859
- if isinstance(angle, (float, int, np.ndarray, list)):
860
- angle = torch.as_tensor(angle)
934
+ if isinstance(angle, torch.Tensor):
935
+ _angle = angle
936
+ elif isinstance(angle, (float, int, np.ndarray, list)):
937
+ _angle = torch.as_tensor(angle)
938
+ else:
939
+ raise ValueError(f"Invalid angle type: {type(angle)}")
861
940
 
862
941
  if degrees:
863
- angle = angle * np.pi / 180
942
+ _angle = _angle * np.pi / 180
864
943
 
865
- device = angle.device
944
+ device = _angle.device
866
945
  zero = torch.zeros(1, device=device)
867
946
  rot = torch.stack(
868
947
  [
869
948
  torch.as_tensor([1, 0, 0], device=device),
870
- torch.hstack([zero, torch.cos(angle), -torch.sin(angle)]),
871
- torch.hstack([zero, torch.sin(angle), torch.cos(angle)]),
949
+ torch.hstack([zero, torch.cos(_angle), -torch.sin(_angle)]),
950
+ torch.hstack([zero, torch.sin(_angle), torch.cos(_angle)]),
872
951
  ]
873
952
  )
874
953
  if axis == "x":
@@ -881,7 +960,14 @@ def _make_rotation_matrix(axis, angle, degrees=True):
881
960
  return rot[:, [1, 2, 0]]
882
961
 
883
962
 
884
- def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_splines):
963
+ def sensors_rotation(
964
+ X: torch.Tensor,
965
+ y: torch.Tensor,
966
+ sensors_positions_matrix: torch.Tensor,
967
+ axis: Literal["x", "y", "z"],
968
+ angles: npt.ArrayLike,
969
+ spherical_splines: bool,
970
+ ) -> tuple[torch.Tensor, torch.Tensor]:
885
971
  """Interpolates EEG signals over sensors rotated around the desired axis
886
972
  with the desired angle.
887
973
 
@@ -893,7 +979,7 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
893
979
  EEG input example or batch.
894
980
  y : torch.Tensor
895
981
  EEG labels for the example or batch.
896
- sensors_positions_matrix : numpy.ndarray
982
+ sensors_positions_matrix : torch.Tensor
897
983
  Matrix giving the positions of each sensor in a 3D cartesian coordinate
898
984
  system. Should have shape (3, n_channels), where n_channels is the
899
985
  number of channels. Standard 10-20 positions can be obtained from
@@ -924,7 +1010,9 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
924
1010
  return rotated_X, y
925
1011
 
926
1012
 
927
- def mixup(X, y, lam, idx_perm):
1013
+ def mixup(
1014
+ X: torch.Tensor, y: torch.Tensor, lam: torch.Tensor, idx_perm: torch.Tensor
1015
+ ) -> tuple[torch.Tensor, torch.Tensor]:
928
1016
  """Mixes two channels of EEG data.
929
1017
 
930
1018
  See [1]_ for details.
@@ -973,8 +1061,13 @@ def mixup(X, y, lam, idx_perm):
973
1061
 
974
1062
 
975
1063
  def segmentation_reconstruction(
976
- X, y, n_segments, data_classes, rand_indices, idx_shuffle
977
- ):
1064
+ X: torch.Tensor,
1065
+ y: torch.Tensor,
1066
+ n_segments: int,
1067
+ data_classes: list[tuple[int, torch.Tensor]],
1068
+ rand_indices: npt.ArrayLike,
1069
+ idx_shuffle: npt.ArrayLike,
1070
+ ) -> tuple[torch.Tensor, torch.Tensor]:
978
1071
  """Segment and reconstruct EEG data from [1]_.
979
1072
 
980
1073
  See [1]_ for details.
@@ -987,6 +1080,8 @@ def segmentation_reconstruction(
987
1080
  EEG labels for the example or batch.
988
1081
  n_segments : int
989
1082
  Number of segments to use in the batch.
1083
+ data_classes: list[tuple[int, torch.Tensor]]
1084
+ List of tuples. Each tuple contains the class index and the corresponding EEG data.
990
1085
  rand_indices: array-like
991
1086
  Array of indices that indicates which trial to use in each segment.
992
1087
  idx_shuffle: array-like
@@ -1005,8 +1100,8 @@ def segmentation_reconstruction(
1005
1100
  """
1006
1101
 
1007
1102
  # Initialize lists to store augmented data and corresponding labels
1008
- aug_data = []
1009
- aug_label = []
1103
+ aug_data: list[torch.Tensor] = []
1104
+ aug_label: list[torch.Tensor] = []
1010
1105
 
1011
1106
  # Iterate through each class to separate and augment data
1012
1107
  for class_index, X_class in data_classes:
@@ -1030,20 +1125,26 @@ def segmentation_reconstruction(
1030
1125
  aug_data.append(X_aug)
1031
1126
  aug_label.append(torch.full((n_trials,), class_index))
1032
1127
  # Concatenate the augmented data and labels
1033
- aug_data = torch.cat(aug_data, dim=0)
1034
- aug_data = aug_data.to(dtype=X.dtype, device=X.device)
1035
- aug_data = aug_data[idx_shuffle]
1128
+ concat_aug_data = torch.cat(aug_data, dim=0)
1129
+ concat_aug_data = concat_aug_data.to(dtype=X.dtype, device=X.device)
1130
+ concat_aug_data = concat_aug_data[idx_shuffle]
1036
1131
 
1037
1132
  if y is not None:
1038
- aug_label = torch.cat(aug_label, dim=0)
1039
- aug_label = aug_label.to(dtype=y.dtype, device=y.device)
1040
- aug_label = aug_label[idx_shuffle]
1041
- return aug_data, aug_label
1133
+ concat_label = torch.cat(aug_label, dim=0)
1134
+ concat_label = concat_label.to(dtype=y.dtype, device=y.device)
1135
+ concat_label = concat_label[idx_shuffle]
1136
+ return concat_aug_data, concat_label
1042
1137
 
1043
- return aug_data, y
1138
+ return concat_aug_data, None
1044
1139
 
1045
1140
 
1046
- def mask_encoding(X, y, time_start, segment_length, n_segments):
1141
+ def mask_encoding(
1142
+ X: torch.Tensor,
1143
+ y: torch.Tensor,
1144
+ time_start: torch.Tensor,
1145
+ segment_length: int,
1146
+ n_segments: int,
1147
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1047
1148
  """Mark encoding from Ding et al. (2024) from [ding2024]_.
1048
1149
 
1049
1150
  Replaces a contiguous part (or parts) of all channels by zeros
@@ -1094,3 +1195,103 @@ def mask_encoding(X, y, time_start, segment_length, n_segments):
1094
1195
  X[mask] = 0
1095
1196
 
1096
1197
  return X, y # Return the masked tensor and labels
1198
+
1199
+
1200
+ def channels_rereference(
1201
+ X: torch.Tensor,
1202
+ y: torch.Tensor,
1203
+ random_state: int | np.random.RandomState | None = None,
1204
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1205
+ """Randomly re-reference channels in EEG data matrix.
1206
+
1207
+ Part of the augmentations proposed in [1]_
1208
+
1209
+ Parameters
1210
+ ----------
1211
+ X : torch.Tensor
1212
+ EEG input example or batch.
1213
+ y : torch.Tensor
1214
+ EEG labels for the example or batch.
1215
+ random_state: int | numpy.random.Generator, optional
1216
+ Seed to be used to instantiate numpy random number generator instance.
1217
+ Defaults to None.
1218
+
1219
+ Returns
1220
+ -------
1221
+ torch.Tensor
1222
+ Transformed inputs.
1223
+ torch.Tensor
1224
+ Transformed labels.
1225
+
1226
+ References
1227
+ ----------
1228
+ .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1229
+ Representation Learning for Electroencephalogram Classification. Proceedings
1230
+ of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1231
+ Learning Research 136:238-253
1232
+
1233
+ """
1234
+
1235
+ rng = check_random_state(random_state)
1236
+ batch_size, n_channels, _ = X.shape
1237
+
1238
+ ch = rng.randint(0, n_channels, size=batch_size)
1239
+
1240
+ X_ch = X[torch.arange(batch_size), ch, :]
1241
+ X = X - X_ch.unsqueeze(1)
1242
+ X[torch.arange(batch_size), ch, :] = -X_ch
1243
+
1244
+ return X, y
1245
+
1246
+
1247
+ def amplitude_scale(
1248
+ X: torch.Tensor,
1249
+ y: torch.Tensor,
1250
+ scale: tuple,
1251
+ random_state: int | np.random.RandomState | None = None,
1252
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1253
+ """Rescale amplitude of each channel based on a random sampled scaling value.
1254
+
1255
+ Part of the augmentations proposed in [1]_
1256
+
1257
+ Parameters
1258
+ ----------
1259
+ X : torch.Tensor
1260
+ EEG input example or batch.
1261
+ y : torch.Tensor
1262
+ EEG labels for the example or batch.
1263
+ scale : tuple of floats
1264
+ Interval from which ypu sample the scaling value
1265
+ random_state: int | numpy.random.Generator, optional
1266
+ Seed to be used to instantiate numpy random number generator instance.
1267
+ Defaults to None.
1268
+
1269
+ Returns
1270
+ -------
1271
+ torch.Tensor
1272
+ Transformed inputs.
1273
+ torch.Tensor
1274
+ Transformed labels.
1275
+
1276
+ References
1277
+ ----------
1278
+ .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1279
+ Representation Learning for Electroencephalogram Classification. Proceedings
1280
+ of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1281
+ Learning Research 136:238-253
1282
+
1283
+ """
1284
+
1285
+ rng = torch.Generator()
1286
+ rng.manual_seed(random_state)
1287
+ batch_size, n_channels, _ = X.shape
1288
+
1289
+ # Parameter for scaling amplitude / channel / trial
1290
+ l, h = scale
1291
+ s = l + (h - l) * torch.rand(
1292
+ batch_size, n_channels, 1, generator=rng, device=X.device, dtype=X.dtype
1293
+ )
1294
+
1295
+ X = s * X
1296
+
1297
+ return X, y