braindecode 1.3.0.dev175435903__tar.gz → 1.3.0.dev177628147__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {braindecode-1.3.0.dev175435903/braindecode.egg-info → braindecode-1.3.0.dev177628147}/PKG-INFO +1 -5
  2. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/serialization.py +1 -1
  3. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/atcnet.py +11 -11
  4. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/attentionbasenet.py +4 -4
  5. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/attn_sleep.py +7 -9
  6. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/bendr.py +1 -1
  7. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/biot.py +23 -25
  8. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/contrawr.py +2 -2
  9. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/ctnet.py +33 -33
  10. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/deep4.py +4 -4
  11. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/deepsleepnet.py +4 -4
  12. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegconformer.py +22 -27
  13. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eeginception_erp.py +1 -1
  14. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eeginception_mi.py +3 -3
  15. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegitnet.py +3 -3
  16. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegnet.py +1 -1
  17. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegnex.py +1 -1
  18. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegsimpleconv.py +1 -1
  19. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegtcnet.py +3 -3
  20. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/fbcnet.py +1 -1
  21. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/fblightconvnet.py +1 -1
  22. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/fbmsnet.py +1 -1
  23. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/hybrid.py +1 -1
  24. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/ifnet.py +2 -2
  25. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/labram.py +39 -39
  26. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/medformer.py +14 -14
  27. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/msvtnet.py +9 -9
  28. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/patchedtransformer.py +46 -46
  29. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sccnet.py +1 -1
  30. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/shallow_fbcsp.py +2 -4
  31. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sinc_shallow.py +2 -1
  32. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sleep_stager_blanco_2020.py +1 -1
  33. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sleep_stager_chambon_2018.py +1 -1
  34. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sparcnet.py +4 -4
  35. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sstdpn.py +1 -1
  36. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/summary.csv +1 -1
  37. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/syncnet.py +1 -1
  38. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/tcn.py +3 -3
  39. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/tidnet.py +1 -1
  40. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/tsinception.py +1 -1
  41. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/usleep.py +3 -3
  42. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/util.py +14 -154
  43. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/blocks.py +1 -3
  44. braindecode-1.3.0.dev177628147/braindecode/version.py +1 -0
  45. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147/braindecode.egg-info}/PKG-INFO +1 -5
  46. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/SOURCES.txt +0 -1
  47. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/requires.txt +0 -5
  48. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/api.rst +0 -4
  49. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/whats_new.rst +0 -5
  50. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/pyproject.toml +0 -5
  51. braindecode-1.3.0.dev175435903/braindecode/models/config.py +0 -230
  52. braindecode-1.3.0.dev175435903/braindecode/version.py +0 -1
  53. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/LICENSE.txt +0 -0
  54. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/MANIFEST.in +0 -0
  55. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/NOTICE.txt +0 -0
  56. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/README.rst +0 -0
  57. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/__init__.py +0 -0
  58. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/__init__.py +0 -0
  59. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/base.py +0 -0
  60. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/functional.py +0 -0
  61. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/transforms.py +0 -0
  62. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/classifier.py +0 -0
  63. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/__init__.py +0 -0
  64. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/base.py +0 -0
  65. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bbci.py +0 -0
  66. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bcicomp.py +0 -0
  67. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bids.py +0 -0
  68. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/experimental.py +0 -0
  69. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/hub.py +0 -0
  70. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/hub_validation.py +0 -0
  71. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/mne.py +0 -0
  72. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/moabb.py +0 -0
  73. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/nmt.py +0 -0
  74. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/registry.py +0 -0
  75. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
  76. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/sleep_physionet.py +0 -0
  77. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/tuh.py +0 -0
  78. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/xy.py +0 -0
  79. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/__init__.py +0 -0
  80. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/channel_utils.py +0 -0
  81. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/hub_formats.py +0 -0
  82. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/util.py +0 -0
  83. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/eegneuralnet.py +0 -0
  84. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/functional/__init__.py +0 -0
  85. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/functional/functions.py +0 -0
  86. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/functional/initialization.py +0 -0
  87. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/__init__.py +0 -0
  88. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/base.py +0 -0
  89. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegminer.py +0 -0
  90. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegsym.py +0 -0
  91. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/luna.py +0 -0
  92. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/signal_jepa.py +0 -0
  93. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/__init__.py +0 -0
  94. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/activation.py +0 -0
  95. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/attention.py +0 -0
  96. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/convolution.py +0 -0
  97. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/filter.py +0 -0
  98. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/layers.py +0 -0
  99. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/linear.py +0 -0
  100. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/parametrization.py +0 -0
  101. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/stats.py +0 -0
  102. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/util.py +0 -0
  103. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/wrapper.py +0 -0
  104. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/__init__.py +0 -0
  105. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/eegprep_preprocess.py +0 -0
  106. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/mne_preprocess.py +0 -0
  107. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/preprocess.py +0 -0
  108. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/util.py +0 -0
  109. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/windowers.py +0 -0
  110. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/regressor.py +0 -0
  111. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/samplers/__init__.py +0 -0
  112. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/samplers/base.py +0 -0
  113. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/samplers/ssl.py +0 -0
  114. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/training/__init__.py +0 -0
  115. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/training/callbacks.py +0 -0
  116. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/training/losses.py +0 -0
  117. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/training/scoring.py +0 -0
  118. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/util.py +0 -0
  119. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/visualization/__init__.py +0 -0
  120. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/visualization/confusion_matrices.py +0 -0
  121. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/visualization/gradients.py +0 -0
  122. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/dependency_links.txt +0 -0
  123. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/top_level.txt +0 -0
  124. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/Makefile +0 -0
  125. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/_templates/autosummary/class.rst +0 -0
  126. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/_templates/autosummary/function.rst +0 -0
  127. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/cite.rst +0 -0
  128. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/conf.py +0 -0
  129. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/help.rst +0 -0
  130. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/index.rst +0 -0
  131. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/install/install.rst +0 -0
  132. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/install/install_pip.rst +0 -0
  133. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/install/install_source.rst +0 -0
  134. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/attention.rst +0 -0
  135. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/channel.rst +0 -0
  136. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/convolution.rst +0 -0
  137. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/filterbank.rst +0 -0
  138. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/gnn.rst +0 -0
  139. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/interpretable.rst +0 -0
  140. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/lbm.rst +0 -0
  141. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/recurrent.rst +0 -0
  142. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/spd.rst +0 -0
  143. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/models.rst +0 -0
  144. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/models_categorization.rst +0 -0
  145. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/models_table.rst +0 -0
  146. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/models_visualization.rst +0 -0
  147. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/sg_execution_times.rst +0 -0
  148. {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.3.0.dev175435903
3
+ Version: 1.3.0.dev177628147
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -52,9 +52,6 @@ Requires-Dist: pytest-cov; extra == "tests"
52
52
  Requires-Dist: codecov; extra == "tests"
53
53
  Requires-Dist: pytest_cases; extra == "tests"
54
54
  Requires-Dist: mypy; extra == "tests"
55
- Provides-Extra: typing
56
- Requires-Dist: pydantic<3.0,>=2.0; extra == "typing"
57
- Requires-Dist: numpydantic>=1.7; extra == "typing"
58
55
  Provides-Extra: docs
59
56
  Requires-Dist: sphinx_gallery; extra == "docs"
60
57
  Requires-Dist: sphinx_rtd_theme; extra == "docs"
@@ -79,7 +76,6 @@ Requires-Dist: braindecode[hub]; extra == "all"
79
76
  Requires-Dist: braindecode[tests]; extra == "all"
80
77
  Requires-Dist: braindecode[docs]; extra == "all"
81
78
  Requires-Dist: braindecode[eegprep]; extra == "all"
82
- Requires-Dist: braindecode[typing]; extra == "all"
83
79
  Dynamic: license-file
84
80
 
85
81
  .. image:: https://badges.gitter.im/braindecodechat/community.svg
@@ -138,7 +138,7 @@ def _load_signals(fif_file, preload, is_raw):
138
138
  with open(pkl_file, "rb") as f:
139
139
  signals = pickle.load(f)
140
140
 
141
- if all(Path(f).exists() for f in signals.filenames):
141
+ if all(f.exists() for f in signals.filenames):
142
142
  if preload:
143
143
  signals.load_data()
144
144
  return signals
@@ -141,7 +141,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
141
141
  - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
142
142
  ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
143
143
  - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
144
- - ``num_heads``, ``head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
144
+ - ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
145
145
  - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
146
146
  longer inputs (see minimum length above). The implementation warns and *rescales*
147
147
  kernels/pools/windows if inputs are too short.
@@ -194,10 +194,10 @@ class ATCNet(EEGModuleMixin, nn.Module):
194
194
  table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
195
195
  n_windows : int
196
196
  Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
197
- head_dim : int
197
+ att_head_dim : int
198
198
  Embedding dimension used in each self-attention head, denoted dh in
199
199
  table 1 of the paper [1]_. Defaults to 8 as in [1]_.
200
- num_heads : int
200
+ att_num_heads : int
201
201
  Number of attention heads, denoted H in table 1 of the paper [1]_.
202
202
  Defaults to 2 as in [1]_.
203
203
  att_dropout : float
@@ -248,13 +248,13 @@ class ATCNet(EEGModuleMixin, nn.Module):
248
248
  conv_block_depth_mult=2,
249
249
  conv_block_dropout=0.3,
250
250
  n_windows=5,
251
- head_dim=8,
252
- num_heads=2,
251
+ att_head_dim=8,
252
+ att_num_heads=2,
253
253
  att_drop_prob=0.5,
254
254
  tcn_depth=2,
255
255
  tcn_kernel_size=4,
256
256
  tcn_drop_prob=0.3,
257
- tcn_activation: type[nn.Module] = nn.ELU,
257
+ tcn_activation: nn.Module = nn.ELU,
258
258
  concat=False,
259
259
  max_norm_const=0.25,
260
260
  chs_info=None,
@@ -316,8 +316,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
316
316
  self.conv_block_depth_mult = conv_block_depth_mult
317
317
  self.conv_block_dropout = conv_block_dropout
318
318
  self.n_windows = n_windows
319
- self.head_dim = head_dim
320
- self.num_heads = num_heads
319
+ self.att_head_dim = att_head_dim
320
+ self.att_num_heads = att_num_heads
321
321
  self.att_dropout = att_drop_prob
322
322
  self.tcn_depth = tcn_depth
323
323
  self.tcn_kernel_size = tcn_kernel_size
@@ -356,8 +356,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
356
356
  [
357
357
  _AttentionBlock(
358
358
  in_shape=self.F2,
359
- head_dim=self.head_dim,
360
- num_heads=num_heads,
359
+ head_dim=self.att_head_dim,
360
+ num_heads=att_num_heads,
361
361
  dropout=att_drop_prob,
362
362
  )
363
363
  for _ in range(self.n_windows)
@@ -656,7 +656,7 @@ class _TCNResidualBlock(nn.Module):
656
656
  kernel_size=4,
657
657
  n_filters=32,
658
658
  dropout=0.3,
659
- activation: type[nn.Module] = nn.ELU,
659
+ activation: nn.Module = nn.ELU,
660
660
  dilation=1,
661
661
  ):
662
662
  super().__init__()
@@ -235,7 +235,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
235
235
  kernel_size : int, default=9
236
236
  The kernel size used in certain types of attention mechanisms for convolution
237
237
  operations.
238
- activation: type[nn.Module] = nn.ELU,
238
+ activation: nn.Module, default=nn.ELU
239
239
  Activation function class to apply. Should be a PyTorch activation
240
240
  module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
241
241
  extra_params : bool, default=False
@@ -277,7 +277,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
277
277
  freq_idx: int = 0,
278
278
  n_codewords: int = 4,
279
279
  kernel_size: int = 9,
280
- activation: type[nn.Module] = nn.ELU,
280
+ activation: nn.Module = nn.ELU,
281
281
  extra_params: bool = False,
282
282
  ):
283
283
  super(AttentionBaseNet, self).__init__()
@@ -453,7 +453,7 @@ class _FeatureExtractor(nn.Module):
453
453
  pool_length: int = 75,
454
454
  pool_stride: int = 15,
455
455
  drop_prob: float = 0.5,
456
- activation: type[nn.Module] = nn.ELU,
456
+ activation: nn.Module = nn.ELU,
457
457
  ):
458
458
  super().__init__()
459
459
 
@@ -592,7 +592,7 @@ class _ChannelAttentionBlock(nn.Module):
592
592
  n_codewords: int = 4,
593
593
  kernel_size: int = 9,
594
594
  extra_params: bool = False,
595
- activation: type[nn.Module] = nn.ELU,
595
+ activation: nn.Module = nn.ELU,
596
596
  ):
597
597
  super().__init__()
598
598
  self.conv = nn.Sequential(
@@ -90,8 +90,8 @@ class AttnSleep(EEGModuleMixin, nn.Module):
90
90
  d_ff=120,
91
91
  n_attn_heads=5,
92
92
  drop_prob=0.1,
93
- activation_mrcnn: type[nn.Module] = nn.GELU,
94
- activation: type[nn.Module] = nn.ReLU,
93
+ activation_mrcnn: nn.Module = nn.GELU,
94
+ activation: nn.Module = nn.ReLU,
95
95
  input_window_seconds=None,
96
96
  n_outputs=None,
97
97
  after_reduced_cnn_size=30,
@@ -230,7 +230,7 @@ class _SEBasicBlock(nn.Module):
230
230
  planes,
231
231
  stride=1,
232
232
  downsample=None,
233
- activation: type[nn.Module] = nn.ReLU,
233
+ activation: nn.Module = nn.ReLU,
234
234
  *,
235
235
  reduction=16,
236
236
  ):
@@ -278,8 +278,8 @@ class _MRCNN(nn.Module):
278
278
  self,
279
279
  after_reduced_cnn_size,
280
280
  kernel_size=7,
281
- activation: type[nn.Module] = nn.GELU,
282
- activation_se: type[nn.Module] = nn.ReLU,
281
+ activation: nn.Module = nn.GELU,
282
+ activation_se: nn.Module = nn.ReLU,
283
283
  ):
284
284
  super(_MRCNN, self).__init__()
285
285
  drate = 0.5
@@ -325,7 +325,7 @@ class _MRCNN(nn.Module):
325
325
  )
326
326
 
327
327
  def _make_layer(
328
- self, block, planes, blocks, stride=1, activate: type[nn.Module] = nn.ReLU
328
+ self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
329
329
  ): # makes residual SE block
330
330
  downsample = None
331
331
  if stride != 1 or self.inplanes != planes * block.expansion:
@@ -526,9 +526,7 @@ class _EncoderLayer(nn.Module):
526
526
  class _PositionwiseFeedForward(nn.Module):
527
527
  """Positionwise feed-forward network."""
528
528
 
529
- def __init__(
530
- self, d_model, d_ff, dropout=0.1, activation: type[nn.Module] = nn.ReLU
531
- ):
529
+ def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
532
530
  super().__init__()
533
531
  self.w_1 = nn.Linear(d_model, d_ff)
534
532
  self.w_2 = nn.Linear(d_ff, d_model)
@@ -176,7 +176,7 @@ class BENDR(EEGModuleMixin, nn.Module):
176
176
  projection_head=False, # Whether encoder should project back to input feature size (unused in original fine-tuning)
177
177
  drop_prob=0.1, # General dropout probability (paper: 0.15 for pretraining, 0.0 for fine-tuning)
178
178
  layer_drop=0.0, # Probability of dropping transformer layers during training (paper: 0.01 for pretraining)
179
- activation: type[nn.Module] = nn.GELU, # Activation function
179
+ activation=nn.GELU, # Activation function
180
180
  # Transformer specific parameters
181
181
  transformer_layers=8,
182
182
  transformer_heads=8,
@@ -45,11 +45,11 @@ class BIOT(EEGModuleMixin, nn.Module):
45
45
 
46
46
  Parameters
47
47
  ----------
48
- embed_dim : int, optional
48
+ emb_size : int, optional
49
49
  The size of the embedding layer, by default 256
50
- num_heads : int, optional
50
+ att_num_heads : int, optional
51
51
  The number of attention heads, by default 8
52
- num_layers : int, optional
52
+ n_layers : int, optional
53
53
  The number of transformer layers, by default 4
54
54
  activation: nn.Module, default=nn.ELU
55
55
  Activation function class to apply. Should be a PyTorch activation
@@ -76,9 +76,9 @@ class BIOT(EEGModuleMixin, nn.Module):
76
76
 
77
77
  def __init__(
78
78
  self,
79
- embed_dim=256,
80
- num_heads=8,
81
- num_layers=4,
79
+ emb_size=256,
80
+ att_num_heads=8,
81
+ n_layers=4,
82
82
  sfreq=200,
83
83
  hop_length=100,
84
84
  return_feature=False,
@@ -87,12 +87,12 @@ class BIOT(EEGModuleMixin, nn.Module):
87
87
  chs_info=None,
88
88
  n_times=None,
89
89
  input_window_seconds=None,
90
- activation: type[nn.Module] = nn.ELU,
90
+ activation: nn.Module = nn.ELU,
91
91
  drop_prob: float = 0.5,
92
92
  # Parameters for the encoder
93
93
  max_seq_len: int = 1024,
94
- att_drop_prob=0.2,
95
- att_layer_drop_prob=0.2,
94
+ attn_dropout=0.2,
95
+ attn_layer_dropout=0.2,
96
96
  ):
97
97
  super().__init__(
98
98
  n_outputs=n_outputs,
@@ -103,10 +103,10 @@ class BIOT(EEGModuleMixin, nn.Module):
103
103
  sfreq=sfreq,
104
104
  )
105
105
  del n_outputs, n_chans, chs_info, n_times, sfreq
106
- self.embed_dim = embed_dim
106
+ self.emb_size = emb_size
107
107
  self.hop_length = hop_length
108
- self.num_heads = num_heads
109
- self.num_layers = num_layers
108
+ self.att_num_heads = att_num_heads
109
+ self.n_layers = n_layers
110
110
  self.return_feature = return_feature
111
111
  if (self.sfreq != 200) & (self.sfreq is not None):
112
112
  warn(
@@ -114,7 +114,7 @@ class BIOT(EEGModuleMixin, nn.Module):
114
114
  + "no guarantee to generalize well with the default parameters",
115
115
  UserWarning,
116
116
  )
117
- if self.n_chans > embed_dim:
117
+ if self.n_chans > emb_size:
118
118
  warn(
119
119
  "The number of channels is larger than the embedding size. "
120
120
  + "This may cause overfitting. Consider using a larger "
@@ -142,20 +142,20 @@ class BIOT(EEGModuleMixin, nn.Module):
142
142
  self.n_fft = int(self.sfreq)
143
143
 
144
144
  self.encoder = _BIOTEncoder(
145
- emb_size=self.embed_dim,
146
- num_heads=self.num_heads,
147
- n_layers=self.num_layers,
145
+ emb_size=emb_size,
146
+ att_num_heads=att_num_heads,
147
+ n_layers=n_layers,
148
148
  n_chans=self.n_chans,
149
149
  n_fft=self.n_fft,
150
150
  hop_length=hop_length,
151
151
  drop_prob=drop_prob,
152
152
  max_seq_len=max_seq_len,
153
- attn_dropout=att_drop_prob,
154
- attn_layer_dropout=att_layer_drop_prob,
153
+ attn_dropout=attn_dropout,
154
+ attn_layer_dropout=attn_layer_dropout,
155
155
  )
156
156
 
157
157
  self.final_layer = _ClassificationHead(
158
- emb_size=self.embed_dim,
158
+ emb_size=emb_size,
159
159
  n_outputs=self.n_outputs,
160
160
  activation=activation,
161
161
  )
@@ -250,9 +250,7 @@ class _ClassificationHead(nn.Sequential):
250
250
  (batch, n_outputs)
251
251
  """
252
252
 
253
- def __init__(
254
- self, emb_size: int, n_outputs: int, activation: type[nn.Module] = nn.ELU
255
- ):
253
+ def __init__(self, emb_size: int, n_outputs: int, activation: nn.Module = nn.ELU):
256
254
  super().__init__()
257
255
  self.activation_layer = activation()
258
256
  self.classification_head = nn.Linear(emb_size, n_outputs)
@@ -347,7 +345,7 @@ class _BIOTEncoder(nn.Module):
347
345
  The number of channels
348
346
  emb_size: int
349
347
  The size of the embedding layer
350
- num_heads: int
348
+ att_num_heads: int
351
349
  The number of attention heads
352
350
  n_layers: int
353
351
  The number of transformer layers
@@ -360,7 +358,7 @@ class _BIOTEncoder(nn.Module):
360
358
  def __init__(
361
359
  self,
362
360
  emb_size=256, # The size of the embedding layer
363
- num_heads=8, # The number of attention heads
361
+ att_num_heads=8, # The number of attention heads
364
362
  n_chans=16, # The number of channels
365
363
  n_layers=4, # The number of transformer layers
366
364
  n_fft=200, # Related with the frequency resolution
@@ -380,7 +378,7 @@ class _BIOTEncoder(nn.Module):
380
378
  )
381
379
  self.transformer = LinearAttentionTransformer(
382
380
  dim=emb_size,
383
- heads=num_heads,
381
+ heads=att_num_heads,
384
382
  depth=n_layers,
385
383
  max_seq_len=max_seq_len,
386
384
  attn_layer_dropout=attn_layer_dropout,
@@ -58,7 +58,7 @@ class ContraWR(EEGModuleMixin, nn.Module):
58
58
  emb_size: int = 256,
59
59
  res_channels: list[int] = [32, 64, 128],
60
60
  steps=20,
61
- activation: type[nn.Module] = nn.ELU,
61
+ activation: nn.Module = nn.ELU,
62
62
  drop_prob: float = 0.5,
63
63
  stride_res: int = 2,
64
64
  kernel_size_res: int = 3,
@@ -195,7 +195,7 @@ class _ResBlock(nn.Module):
195
195
  kernel_size=3,
196
196
  padding=1,
197
197
  drop_prob=0.5,
198
- activation: type[nn.Module] = nn.ReLU,
198
+ activation: nn.Module = nn.ReLU,
199
199
  ):
200
200
  super().__init__()
201
201
  self.conv1 = nn.Conv2d(
@@ -61,11 +61,11 @@ class CTNet(EEGModuleMixin, nn.Module):
61
61
  ----------
62
62
  activation : nn.Module, default=nn.GELU
63
63
  Activation function to use in the network.
64
- num_heads : int, default=4
64
+ heads : int, default=4
65
65
  Number of attention heads in the Transformer encoder.
66
- embed_dim : int or None, default=None
66
+ emb_size : int or None, default=None
67
67
  Embedding size (dimensionality) for the Transformer encoder.
68
- num_layers : int, default=6
68
+ depth : int, default=6
69
69
  Number of encoder layers in the Transformer.
70
70
  n_filters_time : int, default=20
71
71
  Number of temporal filters in the first convolutional layer.
@@ -77,11 +77,11 @@ class CTNet(EEGModuleMixin, nn.Module):
77
77
  Pooling size for the first average pooling layer.
78
78
  pool_size_2 : int, default=8
79
79
  Pooling size for the second average pooling layer.
80
- cnn_drop_prob: float, default=0.3
80
+ drop_prob_cnn : float, default=0.3
81
81
  Dropout probability after convolutional layers.
82
- att_positional_drop_prob : float, default=0.1
82
+ drop_prob_posi : float, default=0.1
83
83
  Dropout probability for the positional encoding in the Transformer.
84
- final_drop_prob : float, default=0.5
84
+ drop_prob_final : float, default=0.5
85
85
  Dropout probability before the final classification layer.
86
86
 
87
87
  Notes
@@ -109,15 +109,15 @@ class CTNet(EEGModuleMixin, nn.Module):
109
109
  n_times=None,
110
110
  input_window_seconds=None,
111
111
  # Model specific arguments
112
- activation_patch: type[nn.Module] = nn.ELU,
113
- activation_transformer: type[nn.Module] = nn.GELU,
114
- cnn_drop_prob: float = 0.3,
115
- att_positional_drop_prob: float = 0.1,
116
- final_drop_prob: float = 0.5,
112
+ activation_patch: nn.Module = nn.ELU,
113
+ activation_transformer: nn.Module = nn.GELU,
114
+ drop_prob_cnn: float = 0.3,
115
+ drop_prob_posi: float = 0.1,
116
+ drop_prob_final: float = 0.5,
117
117
  # other parameters
118
- num_heads: int = 4,
119
- embed_dim: Optional[int] = 40,
120
- num_layers: int = 6,
118
+ heads: int = 4,
119
+ emb_size: Optional[int] = 40,
120
+ depth: int = 6,
121
121
  n_filters_time: Optional[int] = None,
122
122
  kernel_size: int = 64,
123
123
  depth_multiplier: Optional[int] = 2,
@@ -136,14 +136,14 @@ class CTNet(EEGModuleMixin, nn.Module):
136
136
 
137
137
  self.activation_patch = activation_patch
138
138
  self.activation_transformer = activation_transformer
139
- self.cnn_drop_prob = cnn_drop_prob
139
+ self.drop_prob_cnn = drop_prob_cnn
140
140
  self.pool_size_1 = pool_size_1
141
141
  self.pool_size_2 = pool_size_2
142
142
  self.kernel_size = kernel_size
143
- self.att_positional_drop_prob = att_positional_drop_prob
144
- self.final_drop_prob = final_drop_prob
145
- self.num_heads = num_heads
146
- self.num_layers = num_layers
143
+ self.drop_prob_posi = drop_prob_posi
144
+ self.drop_prob_final = drop_prob_final
145
+ self.heads = heads
146
+ self.depth = depth
147
147
  # n_times - pool_size_1 / p
148
148
  self.sequence_length = math.floor(
149
149
  (
@@ -154,8 +154,8 @@ class CTNet(EEGModuleMixin, nn.Module):
154
154
  + 1
155
155
  )
156
156
 
157
- self.depth_multiplier, self.n_filters_time, self.embed_dim = self._resolve_dims(
158
- depth_multiplier, n_filters_time, embed_dim
157
+ self.depth_multiplier, self.n_filters_time, self.emb_size = self._resolve_dims(
158
+ depth_multiplier, n_filters_time, emb_size
159
159
  )
160
160
 
161
161
  # Layers
@@ -168,32 +168,32 @@ class CTNet(EEGModuleMixin, nn.Module):
168
168
  depth_multiplier=self.depth_multiplier,
169
169
  pool_size_1=self.pool_size_1,
170
170
  pool_size_2=self.pool_size_2,
171
- drop_prob=self.cnn_drop_prob,
171
+ drop_prob=self.drop_prob_cnn,
172
172
  n_chans=self.n_chans,
173
173
  activation=self.activation_patch,
174
174
  )
175
175
 
176
176
  self.position = _PositionalEncoding(
177
- emb_size=self.embed_dim,
178
- drop_prob=self.att_positional_drop_prob,
177
+ emb_size=self.emb_size,
178
+ drop_prob=self.drop_prob_posi,
179
179
  n_times=self.n_times,
180
180
  pool_size=self.pool_size_1,
181
181
  )
182
182
 
183
183
  self.trans = _TransformerEncoder(
184
- self.num_heads,
185
- self.num_layers,
186
- self.embed_dim,
184
+ self.heads,
185
+ self.depth,
186
+ self.emb_size,
187
187
  activation=self.activation_transformer,
188
188
  )
189
189
 
190
190
  self.flatten_drop_layer = nn.Sequential(
191
191
  nn.Flatten(),
192
- nn.Dropout(p=self.final_drop_prob),
192
+ nn.Dropout(p=self.drop_prob_final),
193
193
  )
194
194
 
195
195
  self.final_layer = nn.Linear(
196
- in_features=int(self.embed_dim * self.sequence_length),
196
+ in_features=int(self.emb_size * self.sequence_length),
197
197
  out_features=self.n_outputs,
198
198
  )
199
199
 
@@ -213,7 +213,7 @@ class CTNet(EEGModuleMixin, nn.Module):
213
213
  """
214
214
  x = self.ensuredim(x)
215
215
  cnn = self.cnn(x)
216
- cnn = cnn * math.sqrt(self.embed_dim)
216
+ cnn = cnn * math.sqrt(self.emb_size)
217
217
  cnn = self.position(cnn)
218
218
  trans = self.trans(cnn)
219
219
  features = cnn + trans
@@ -312,7 +312,7 @@ class _PatchEmbeddingEEGNet(nn.Module):
312
312
  pool_size_2: int = 8,
313
313
  drop_prob: float = 0.3,
314
314
  n_chans: int = 22,
315
- activation: type[nn.Module] = nn.ELU,
315
+ activation: nn.Module = nn.ELU,
316
316
  ):
317
317
  super().__init__()
318
318
  n_filters_out = depth_multiplier * n_filters_time
@@ -416,7 +416,7 @@ class _TransformerEncoderBlock(nn.Module):
416
416
  drop_prob: float = 0.5,
417
417
  forward_expansion: int = 4,
418
418
  forward_drop_p: float = 0.5,
419
- activation: type[nn.Module] = nn.GELU,
419
+ activation: nn.Module = nn.GELU,
420
420
  ):
421
421
  super().__init__()
422
422
  self.attention = _ResidualAdd(
@@ -466,7 +466,7 @@ class _TransformerEncoder(nn.Module):
466
466
  nheads: int,
467
467
  depth: int,
468
468
  dim_feedforward: int,
469
- activation: type[nn.Module] = nn.GELU,
469
+ activation: nn.Module = nn.GELU,
470
470
  ):
471
471
  super().__init__()
472
472
  self.layers = nn.Sequential(
@@ -109,12 +109,12 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
109
109
  filter_length_3=10,
110
110
  n_filters_4=200,
111
111
  filter_length_4=10,
112
- activation_first_conv_nonlin: type[nn.Module] = nn.ELU,
112
+ activation_first_conv_nonlin: nn.Module = nn.ELU,
113
113
  first_pool_mode="max",
114
- first_pool_nonlin: type[nn.Module] = nn.Identity,
115
- activation_later_conv_nonlin: type[nn.Module] = nn.ELU,
114
+ first_pool_nonlin: nn.Module = nn.Identity,
115
+ activation_later_conv_nonlin: nn.Module = nn.ELU,
116
116
  later_pool_mode="max",
117
- later_pool_nonlin: type[nn.Module] = nn.Identity,
117
+ later_pool_nonlin: nn.Module = nn.Identity,
118
118
  drop_prob=0.5,
119
119
  split_first_layer=True,
120
120
  batch_norm=True,
@@ -172,8 +172,8 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
172
172
  n_times=None,
173
173
  input_window_seconds=None,
174
174
  sfreq=None,
175
- activation_large: type[nn.Module] = nn.ELU,
176
- activation_small: type[nn.Module] = nn.ReLU,
175
+ activation_large: nn.Module = nn.ELU,
176
+ activation_small: nn.Module = nn.ReLU,
177
177
  drop_prob: float = 0.5,
178
178
  ):
179
179
  super().__init__(
@@ -252,7 +252,7 @@ class _SmallCNN(nn.Module):
252
252
  The dropout rate for regularization. Values should be between 0 and 1.
253
253
  """
254
254
 
255
- def __init__(self, activation: type[nn.Module] = nn.ReLU, drop_prob: float = 0.5):
255
+ def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
256
256
  super().__init__()
257
257
  self.conv1 = nn.Sequential(
258
258
  nn.Conv2d(
@@ -328,7 +328,7 @@ class _LargeCNN(nn.Module):
328
328
 
329
329
  """
330
330
 
331
- def __init__(self, activation: type[nn.Module] = nn.ELU, drop_prob: float = 0.5):
331
+ def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
332
332
  super().__init__()
333
333
 
334
334
  self.conv1 = nn.Sequential(