braindecode 1.3.2.dev168310904__tar.gz → 1.3.2.dev168517820__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. {braindecode-1.3.2.dev168310904/braindecode.egg-info → braindecode-1.3.2.dev168517820}/PKG-INFO +1 -1
  2. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/augmentation/transforms.py +6 -4
  3. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/base.py +153 -20
  4. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/bendr.py +17 -6
  5. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/biot.py +23 -5
  6. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/cbramod.py +19 -11
  7. braindecode-1.3.2.dev168517820/braindecode/models/deepsleepnet.py +477 -0
  8. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eeginception_mi.py +2 -1
  9. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegitnet.py +4 -3
  10. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegpt.py +15 -2
  11. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/labram.py +15 -1
  12. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/luna.py +5 -16
  13. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/reve.py +36 -20
  14. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/signal_jepa.py +63 -13
  15. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/util.py +1 -1
  16. braindecode-1.3.2.dev168517820/braindecode/version.py +1 -0
  17. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820/braindecode.egg-info}/PKG-INFO +1 -1
  18. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/whats_new.rst +44 -0
  19. braindecode-1.3.2.dev168310904/braindecode/models/deepsleepnet.py +0 -417
  20. braindecode-1.3.2.dev168310904/braindecode/version.py +0 -1
  21. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/LICENSE.txt +0 -0
  22. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/MANIFEST.in +0 -0
  23. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/NOTICE.txt +0 -0
  24. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/README.rst +0 -0
  25. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/__init__.py +0 -0
  26. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/augmentation/__init__.py +0 -0
  27. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/augmentation/base.py +0 -0
  28. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/augmentation/functional.py +0 -0
  29. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/classifier.py +0 -0
  30. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/__init__.py +0 -0
  31. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/base.py +0 -0
  32. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bbci.py +0 -0
  33. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bcicomp.py +0 -0
  34. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/__init__.py +0 -0
  35. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/datasets.py +0 -0
  36. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/format.py +0 -0
  37. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/hub.py +0 -0
  38. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/hub_format.py +0 -0
  39. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/hub_io.py +0 -0
  40. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/hub_validation.py +0 -0
  41. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/iterable.py +0 -0
  42. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/chb_mit.py +0 -0
  43. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/mne.py +0 -0
  44. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/moabb.py +0 -0
  45. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/nmt.py +0 -0
  46. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/registry.py +0 -0
  47. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/siena.py +0 -0
  48. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
  49. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/sleep_physionet.py +0 -0
  50. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/tuh.py +0 -0
  51. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/utils.py +0 -0
  52. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/xy.py +0 -0
  53. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/__init__.py +0 -0
  54. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/channel_utils.py +0 -0
  55. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/hub_formats.py +0 -0
  56. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/serialization.py +0 -0
  57. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/util.py +0 -0
  58. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/eegneuralnet.py +0 -0
  59. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/functional/__init__.py +0 -0
  60. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/functional/functions.py +0 -0
  61. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/functional/initialization.py +0 -0
  62. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/__init__.py +0 -0
  63. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/atcnet.py +0 -0
  64. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/attentionbasenet.py +0 -0
  65. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/attn_sleep.py +0 -0
  66. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/brainmodule.py +0 -0
  67. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/config.py +0 -0
  68. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/contrawr.py +0 -0
  69. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/ctnet.py +0 -0
  70. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/deep4.py +0 -0
  71. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/dgcnn.py +0 -0
  72. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegconformer.py +0 -0
  73. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eeginception_erp.py +0 -0
  74. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegminer.py +0 -0
  75. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegnet.py +0 -0
  76. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegnex.py +0 -0
  77. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegsimpleconv.py +0 -0
  78. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegsym.py +0 -0
  79. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegtcnet.py +0 -0
  80. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/fbcnet.py +0 -0
  81. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/fblightconvnet.py +0 -0
  82. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/fbmsnet.py +0 -0
  83. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/hybrid.py +0 -0
  84. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/ifnet.py +0 -0
  85. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/medformer.py +0 -0
  86. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/msvtnet.py +0 -0
  87. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/patchedtransformer.py +0 -0
  88. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sccnet.py +0 -0
  89. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/shallow_fbcsp.py +0 -0
  90. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sinc_shallow.py +0 -0
  91. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sleep_stager_blanco_2020.py +0 -0
  92. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sleep_stager_chambon_2018.py +0 -0
  93. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sparcnet.py +0 -0
  94. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sstdpn.py +0 -0
  95. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/summary.csv +0 -0
  96. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/syncnet.py +0 -0
  97. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/tcn.py +0 -0
  98. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/tidnet.py +0 -0
  99. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/tsinception.py +0 -0
  100. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/usleep.py +0 -0
  101. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/__init__.py +0 -0
  102. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/activation.py +0 -0
  103. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/attention.py +0 -0
  104. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/blocks.py +0 -0
  105. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/convolution.py +0 -0
  106. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/filter.py +0 -0
  107. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/layers.py +0 -0
  108. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/linear.py +0 -0
  109. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/parametrization.py +0 -0
  110. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/stats.py +0 -0
  111. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/util.py +0 -0
  112. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/wrapper.py +0 -0
  113. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/__init__.py +0 -0
  114. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/eegprep_preprocess.py +0 -0
  115. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/mne_preprocess.py +0 -0
  116. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/preprocess.py +0 -0
  117. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/util.py +0 -0
  118. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/windowers.py +0 -0
  119. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/regressor.py +0 -0
  120. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/samplers/__init__.py +0 -0
  121. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/samplers/base.py +0 -0
  122. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/samplers/ssl.py +0 -0
  123. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/training/__init__.py +0 -0
  124. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/training/callbacks.py +0 -0
  125. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/training/losses.py +0 -0
  126. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/training/scoring.py +0 -0
  127. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/util.py +0 -0
  128. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/visualization/__init__.py +0 -0
  129. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/visualization/confusion_matrices.py +0 -0
  130. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/visualization/gradients.py +0 -0
  131. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode.egg-info/SOURCES.txt +0 -0
  132. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode.egg-info/dependency_links.txt +0 -0
  133. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode.egg-info/requires.txt +0 -0
  134. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode.egg-info/top_level.txt +0 -0
  135. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/Makefile +0 -0
  136. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/_templates/autosummary/class.rst +0 -0
  137. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/_templates/autosummary/class_in_subdir.rst +0 -0
  138. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/_templates/autosummary/function.rst +0 -0
  139. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/_templates/autosummary/function_in_subdir.rst +0 -0
  140. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/api.rst +0 -0
  141. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/cite.rst +0 -0
  142. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/conf.py +0 -0
  143. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/help.rst +0 -0
  144. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/index.rst +0 -0
  145. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/install/install.rst +0 -0
  146. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/install/install_pip.rst +0 -0
  147. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/install/install_source.rst +0 -0
  148. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/attention.rst +0 -0
  149. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/channel.rst +0 -0
  150. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/convolution.rst +0 -0
  151. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/filterbank.rst +0 -0
  152. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/gnn.rst +0 -0
  153. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/interpretable.rst +0 -0
  154. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/lbm.rst +0 -0
  155. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/recurrent.rst +0 -0
  156. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/spd.rst +0 -0
  157. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/models.rst +0 -0
  158. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/models_categorization.rst +0 -0
  159. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/models_table.rst +0 -0
  160. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/models_visualization.rst +0 -0
  161. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/sg_execution_times.rst +0 -0
  162. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/pyproject.toml +0 -0
  163. {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.3.2.dev168310904
3
+ Version: 1.3.2.dev168517820
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -2,6 +2,7 @@
2
2
  # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
3
  # Gustavo Rodrigues <gustavenrique01@gmail.com>
4
4
  # Bruna Lopes <brunajaflopes@gmail.com>
5
+ # Sarthak Tayal <sarthaktayal2@gmail.com>
5
6
  #
6
7
  # License: BSD (3-clause)
7
8
 
@@ -557,8 +558,9 @@ class BandstopFilter(Transform):
557
558
  f" Nyquist frequency ({nyq} Hz)."
558
559
  f" Falling back to max_freq = {nyq}."
559
560
  )
560
- assert bandwidth < max_freq, (
561
- f"`bandwidth` needs to be smaller than max_freq={max_freq}"
561
+ assert bandwidth < max_freq - 2, (
562
+ f"`bandwidth` needs to be smaller than max_freq - 2={max_freq - 2} "
563
+ f"to allow valid notch frequency sampling with 1 Hz transition bands."
562
564
  )
563
565
 
564
566
  # override bandwidth value when a magnitude is passed
@@ -600,8 +602,8 @@ class BandstopFilter(Transform):
600
602
 
601
603
  # Prevents transitions from going below 0 and above max_freq
602
604
  notched_freqs = self.rng.uniform(
603
- low=1 + 2 * self.bandwidth,
604
- high=self.max_freq - 1 - 2 * self.bandwidth,
605
+ low=1 + self.bandwidth / 2,
606
+ high=self.max_freq - 1 - self.bandwidth / 2,
605
607
  size=X.shape[0],
606
608
  )
607
609
  return {
@@ -59,7 +59,25 @@ def deprecated_args(obj, *old_new_args):
59
59
  return out_args
60
60
 
61
61
 
62
- class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta):
62
+ class _BraindecodeDocstringMeta(NumpyDocstringInheritanceInitMeta):
63
+ """Defer ``__init__`` wrapping until after docstring inheritance.
64
+
65
+ ``NumpyDocstringInheritanceInitMeta`` uses ``inspect.unwrap()``
66
+ internally, which bypasses ``@wraps`` wrappers. By wrapping
67
+ ``__init__`` *after* docstring processing, the metaclass sees the
68
+ unwrapped function and correctly inherits ``cls.__doc__``.
69
+ """
70
+
71
+ def __init__(cls, class_name, class_bases, class_dict):
72
+ super().__init__(class_name, class_bases, class_dict)
73
+ # Only wrap subclass __init__s, not EEGModuleMixin itself.
74
+ # Wrapping the mixin would cause super().__init__() calls to
75
+ # overwrite _braindecode_init_kwargs captured by the subclass.
76
+ if any(isinstance(b, _BraindecodeDocstringMeta) for b in class_bases):
77
+ track_model_init_kwargs(cls)
78
+
79
+
80
+ class EEGModuleMixin(_BaseHubMixin, metaclass=_BraindecodeDocstringMeta):
63
81
  """
64
82
  Mixin class for all EEG models in braindecode.
65
83
 
@@ -132,20 +150,46 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
132
150
  # Load pretrained model
133
151
  model = {name}.from_pretrained("username/my-{name_lower}-model")
134
152
 
135
- The integration automatically handles EEG-specific parameters (n_chans,
136
- n_times, sfreq, chs_info, etc.) by saving them in a config file alongside
137
- the model weights. This ensures that loaded models are correctly configured
138
- for their original data specifications.
153
+ # Load with a different number of outputs (head is rebuilt automatically)
154
+ model = {name}.from_pretrained("username/my-{name_lower}-model", n_outputs=4)
155
+
156
+ **Extracting features and replacing the head:**
157
+
158
+ .. code-block::
159
+
160
+ import torch
161
+
162
+ x = torch.randn(1, model.n_chans, model.n_times)
163
+ # Extract encoder features (consistent dict across all models)
164
+ out = model(x, return_features=True)
165
+ features = out["features"]
166
+
167
+ # Replace the classification head
168
+ model.reset_head(n_outputs=10)
169
+
170
+ **Saving and restoring full configuration:**
171
+
172
+ .. code-block::
173
+
174
+ import json
175
+
176
+ config = model.get_config() # all __init__ params
177
+ with open("config.json", "w") as f:
178
+ json.dump(config, f)
179
+
180
+ model2 = {name}.from_config(config) # reconstruct (no weights)
139
181
 
140
182
  All model parameters (both EEG-specific and model-specific such as
141
183
  dropout rates, activation functions, number of filters) are automatically
142
184
  saved to the Hub and restored when loading.
185
+
186
+ See :ref:`load-pretrained-models` for a complete tutorial.
143
187
  """
144
188
 
145
189
  def __init_subclass__(cls, **kwargs):
146
190
  # Append model-specific Hub integration notes to the docstring.
147
- # This runs after the metaclass, so we concatenate rather than
148
- # override any existing Notes section in the subclass.
191
+ # This runs before the metaclass __init__, so the Hub notes will
192
+ # be included in the docstring that the metaclass processes.
149
193
  if cls.__doc__ is not None:
150
194
  hub_notes = cls._HUB_NOTES_TEMPLATE.format(
151
195
  name=cls.__name__,
@@ -155,7 +199,6 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
155
199
 
156
200
  if not HAS_HF_HUB:
157
201
  super().__init_subclass__(**kwargs)
158
- track_model_init_kwargs(cls)
159
202
  return
160
203
 
161
204
  base_tags = ["braindecode", cls.__name__]
@@ -195,7 +238,6 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
195
238
  coders=coders,
196
239
  **kwargs,
197
240
  )
198
- track_model_init_kwargs(cls)
199
241
 
200
242
  def __init__(
201
243
  self,
@@ -426,6 +468,34 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
426
468
  resolve_type_kwargs(cls, config)
427
469
  return cls(**config)
428
470
 
471
+ def reset_head(self, n_outputs):
472
+ """Replace the classification head for a new number of outputs.
473
+
474
+ This is called automatically by :meth:`from_pretrained` when the
475
+ user passes an ``n_outputs`` that differs from the saved config.
476
+ Override in subclasses that need a model-specific head structure.
477
+
478
+ Parameters
479
+ ----------
480
+ n_outputs : int
481
+ New number of output classes.
482
+
483
+ Examples
484
+ --------
485
+ >>> from braindecode.models import BENDR
486
+ >>> model = BENDR(n_chans=22, n_times=1000, n_outputs=4)
487
+ >>> model.reset_head(10)
488
+ >>> model.n_outputs
489
+ 10
490
+
491
+ .. versionadded:: 1.4
492
+ """
493
+ raise NotImplementedError(
494
+ f"{type(self).__name__} does not implement reset_head(). "
495
+ "Override this method to support changing n_outputs after "
496
+ "loading pretrained weights."
497
+ )
498
+
429
499
  mapping: Optional[Dict[str, str]] = None
430
500
 
431
501
  def load_state_dict(self, state_dict, *args, **kwargs):
@@ -625,15 +695,78 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
625
695
  **model_kwargs,
626
696
  ):
627
697
  model_kwargs.pop("braindecode_version", None)
698
+ filename = model_kwargs.pop("filename", None)
628
699
  resolve_type_kwargs(cls, model_kwargs)
629
- return super()._from_pretrained( # type: ignore
630
- model_id=model_id,
631
- revision=revision,
632
- cache_dir=cache_dir,
633
- force_download=force_download,
634
- local_files_only=local_files_only,
635
- token=token,
636
- map_location=map_location,
637
- strict=strict,
638
- **model_kwargs,
639
- )
700
+
701
+ # Read saved n_outputs from config.json to detect when the
702
+ # user wants a different number of outputs. Works for both
703
+ # local directories and Hub repo IDs.
704
+ saved_n_outputs = None
705
+ try:
706
+ if Path(model_id).is_dir():
707
+ config_file = Path(model_id) / "config.json"
708
+ else:
709
+ config_file = huggingface_hub.hf_hub_download(
710
+ repo_id=model_id,
711
+ filename="config.json",
712
+ revision=revision,
713
+ cache_dir=cache_dir,
714
+ force_download=force_download,
715
+ token=token,
716
+ local_files_only=local_files_only,
717
+ )
718
+ with open(config_file, "r") as f:
719
+ saved_n_outputs = json.load(f).get("n_outputs")
720
+ except (OSError, json.JSONDecodeError, KeyError):
721
+ pass # config unavailable; skip reset_head logic
722
+
723
+ requested_n_outputs = model_kwargs.get("n_outputs")
724
+
725
+ # If the user requests different n_outputs, load with the
726
+ # saved value first (so weights match), then swap the head.
727
+ if (
728
+ saved_n_outputs is not None
729
+ and requested_n_outputs is not None
730
+ and requested_n_outputs != saved_n_outputs
731
+ ):
732
+ model_kwargs["n_outputs"] = saved_n_outputs
733
+
734
+ # If a custom filename is provided, temporarily override the
735
+ # HuggingFace constant so the parent class downloads the
736
+ # correct file (e.g. "LUNA_base.safetensors" instead of
737
+ # "model.safetensors").
738
+ hf_constants = huggingface_hub.constants
739
+ _orig_safetensors = hf_constants.SAFETENSORS_SINGLE_FILE
740
+ if filename is not None:
741
+ hf_constants.SAFETENSORS_SINGLE_FILE = filename
742
+ try:
743
+ model = super()._from_pretrained( # type: ignore
744
+ model_id=model_id,
745
+ revision=revision,
746
+ cache_dir=cache_dir,
747
+ force_download=force_download,
748
+ local_files_only=local_files_only,
749
+ token=token,
750
+ map_location=map_location,
751
+ strict=strict,
752
+ **model_kwargs,
753
+ )
754
+ finally:
755
+ hf_constants.SAFETENSORS_SINGLE_FILE = _orig_safetensors
756
+
757
+ if (
758
+ saved_n_outputs is not None
759
+ and requested_n_outputs is not None
760
+ and requested_n_outputs != saved_n_outputs
761
+ ):
762
+ try:
763
+ model.reset_head(requested_n_outputs)
764
+ except NotImplementedError:
765
+ raise ValueError(
766
+ f"{type(model).__name__} does not support changing "
767
+ f"n_outputs after loading. Saved model has "
768
+ f"n_outputs={saved_n_outputs}, but "
769
+ f"n_outputs={requested_n_outputs} was requested."
770
+ ) from None
771
+
772
+ return model
@@ -291,15 +291,23 @@ class BENDR(EEGModuleMixin, nn.Module):
291
291
  else:
292
292
  in_features = encoder_h
293
293
 
294
+ self._head_in_features = in_features
294
295
  self.final_layer = None
295
296
  if self.include_final_layer:
296
- # in_features: encoder_h*4 (encoder_only) or encoder_h (full model)
297
- linear = nn.Linear(in_features=in_features, out_features=self.n_outputs)
298
- self.final_layer = nn.utils.parametrizations.weight_norm(
299
- linear, name="weight", dim=1
300
- )
297
+ self._build_head(self.n_outputs)
301
298
 
302
- def forward(self, x):
299
+ def _build_head(self, n_outputs):
300
+ linear = nn.Linear(in_features=self._head_in_features, out_features=n_outputs)
301
+ self.final_layer = nn.utils.parametrizations.weight_norm(
302
+ linear, name="weight", dim=1
303
+ )
304
+
305
+ def reset_head(self, n_outputs):
306
+ self._n_outputs = n_outputs
307
+ self.include_final_layer = True
308
+ self._build_head(n_outputs)
309
+
310
+ def forward(self, x, return_features=False):
303
311
  if self.channel_projection is not None:
304
312
  x = self.channel_projection(x)
305
313
  encoded = self.encoder(x)
@@ -328,6 +336,9 @@ class BENDR(EEGModuleMixin, nn.Module):
328
336
  feature = context[:, :, 0]
329
337
  # feature: [batch_size, encoder_h]
330
338
 
339
+ if return_features:
340
+ return {"features": feature, "cls_token": None}
341
+
331
342
  if self.final_layer is not None:
332
343
  feature = self.final_layer(feature)
333
344
  # feature: [batch_size, n_outputs]
@@ -183,13 +183,22 @@ class BIOT(EEGModuleMixin, nn.Module):
183
183
  attn_layer_dropout=att_layer_drop_prob,
184
184
  )
185
185
 
186
+ self._head_activation = activation
186
187
  self.final_layer = _ClassificationHead(
187
188
  emb_size=self.embed_dim,
188
189
  n_outputs=self.n_outputs,
189
190
  activation=activation,
190
191
  )
191
192
 
192
- def forward(self, x):
193
+ def reset_head(self, n_outputs):
194
+ self._n_outputs = n_outputs
195
+ self.final_layer = _ClassificationHead(
196
+ emb_size=self.embed_dim,
197
+ n_outputs=n_outputs,
198
+ activation=self._head_activation,
199
+ )
200
+
201
+ def forward(self, x, return_features=False):
193
202
  """
194
203
  Pass the input through the BIOT encoder, and then through the
195
204
  classification head.
@@ -198,15 +207,24 @@ class BIOT(EEGModuleMixin, nn.Module):
198
207
  ----------
199
208
  x: Tensor
200
209
  (batch_size, n_channels, n_times)
210
+ return_features : bool
211
+ If True, return a dict with ``"features"`` and ``"cls_token"``
212
+ instead of the classification output.
201
213
 
202
214
  Returns
203
215
  -------
204
- out: Tensor
205
- (batch_size, n_outputs)
206
- (out, emb): tuple Tensor
207
- (batch_size, n_outputs), (batch_size, emb_size)
216
+ torch.Tensor or tuple or dict
217
+ Default: ``torch.Tensor`` of shape ``(batch_size, n_outputs)``.
218
+ If ``return_features=True``: ``dict`` with ``"features"``
219
+ ``(batch_size, emb_size)`` and ``"cls_token"`` (``None``).
220
+ If legacy ``return_feature=True`` (init param):
221
+ ``(out, emb)`` tuple (ignored when ``return_features=True``).
208
222
  """
209
223
  emb = self.encoder(x)
224
+
225
+ if return_features:
226
+ return {"features": emb, "cls_token": None}
227
+
210
228
  x = self.final_layer(emb)
211
229
 
212
230
  if self.return_feature:
@@ -205,20 +205,15 @@ class CBraMod(EEGModuleMixin, nn.Module):
205
205
  self.encoder = _TransformerEncoder(encoder_layer, num_layers=n_layer)
206
206
  self.proj_out = nn.Sequential(nn.Linear(d_model, emb_dim))
207
207
 
208
+ self._emb_dim = emb_dim
209
+ self._patch_size = patch_size
208
210
  self._weights_init()
209
211
 
210
- try:
211
- n_times = self.n_times
212
- n_chans = self.n_chans
213
- except ValueError:
214
- n_times = None
215
- n_chans = None
216
-
217
212
  if return_encoder_output:
218
213
  self.final_layer = nn.Identity()
219
- elif n_times is not None and n_chans is not None:
220
- n_patch = n_times // patch_size
221
- flat_dim = n_chans * n_patch * emb_dim
214
+ elif self._n_times is not None and self._n_chans is not None:
215
+ n_patch = self._n_times // patch_size
216
+ flat_dim = self._n_chans * n_patch * emb_dim
222
217
  self.final_layer = nn.Sequential(
223
218
  nn.Flatten(), nn.Linear(flat_dim, self.n_outputs)
224
219
  )
@@ -227,6 +222,17 @@ class CBraMod(EEGModuleMixin, nn.Module):
227
222
  nn.Flatten(), nn.LazyLinear(self.n_outputs)
228
223
  )
229
224
 
225
+ def reset_head(self, n_outputs):
226
+ self._n_outputs = n_outputs
227
+ if self._n_times is not None and self._n_chans is not None:
228
+ n_patch = self._n_times // self._patch_size
229
+ flat_dim = self._n_chans * n_patch * self._emb_dim
230
+ self.final_layer = nn.Sequential(
231
+ nn.Flatten(), nn.Linear(flat_dim, n_outputs)
232
+ )
233
+ else:
234
+ self.final_layer = nn.Sequential(nn.Flatten(), nn.LazyLinear(n_outputs))
235
+
230
236
  def _weights_init(self):
231
237
  for m in self.modules():
232
238
  if isinstance(m, nn.Linear):
@@ -237,11 +243,13 @@ class CBraMod(EEGModuleMixin, nn.Module):
237
243
  nn.init.constant_(m.weight, 1)
238
244
  nn.init.constant_(m.bias, 0)
239
245
 
240
- def forward(self, x, mask=None):
246
+ def forward(self, x, mask=None, return_features=False):
241
247
  x = self.rearrange(x)
242
248
  patch_emb = self.patch_embedding(x, mask)
243
249
  feats = self.encoder(patch_emb)
244
250
  out = self.proj_out(feats)
251
+ if return_features:
252
+ return {"features": out, "cls_token": None}
245
253
  return self.final_layer(out)
246
254
 
247
255