braindecode 1.5.0.dev985__tar.gz → 1.5.0.dev172791986__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. {braindecode-1.5.0.dev985/braindecode.egg-info → braindecode-1.5.0.dev172791986}/PKG-INFO +1 -1
  2. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/__init__.py +10 -3
  3. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/bendr.py +92 -28
  4. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/biot.py +61 -1
  5. braindecode-1.5.0.dev172791986/braindecode/models/interpolated.py +182 -0
  6. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/labram.py +208 -266
  7. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/reve.py +1 -1
  8. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/signal_jepa.py +17 -0
  9. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/summary.csv +4 -0
  10. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/util.py +121 -20
  11. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/__init__.py +2 -0
  12. braindecode-1.5.0.dev172791986/braindecode/modules/interpolation.py +201 -0
  13. braindecode-1.5.0.dev172791986/braindecode/version.py +1 -0
  14. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986/braindecode.egg-info}/PKG-INFO +1 -1
  15. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode.egg-info/SOURCES.txt +2 -0
  16. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/api.rst +6 -0
  17. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/whats_new.rst +40 -0
  18. braindecode-1.5.0.dev985/braindecode/version.py +0 -1
  19. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/LICENSE.txt +0 -0
  20. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/MANIFEST.in +0 -0
  21. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/NOTICE.txt +0 -0
  22. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/README.rst +0 -0
  23. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/__init__.py +0 -0
  24. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/augmentation/__init__.py +0 -0
  25. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/augmentation/base.py +0 -0
  26. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/augmentation/functional.py +0 -0
  27. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/augmentation/transforms.py +0 -0
  28. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/classifier.py +0 -0
  29. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/__init__.py +0 -0
  30. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/base.py +0 -0
  31. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bbci.py +0 -0
  32. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bcicomp.py +0 -0
  33. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/__init__.py +0 -0
  34. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/datasets.py +0 -0
  35. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/format.py +0 -0
  36. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/hub.py +0 -0
  37. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/hub_format.py +0 -0
  38. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/hub_io.py +0 -0
  39. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/hub_validation.py +0 -0
  40. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/iterable.py +0 -0
  41. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/chb_mit.py +0 -0
  42. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/mne.py +0 -0
  43. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/moabb.py +0 -0
  44. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/nmt.py +0 -0
  45. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/registry.py +0 -0
  46. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/siena.py +0 -0
  47. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
  48. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/sleep_physionet.py +0 -0
  49. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/tuh.py +0 -0
  50. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/utils.py +0 -0
  51. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/xy.py +0 -0
  52. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/__init__.py +0 -0
  53. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/channel_utils.py +0 -0
  54. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/hub_formats.py +0 -0
  55. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/serialization.py +0 -0
  56. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/util.py +0 -0
  57. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/eegneuralnet.py +0 -0
  58. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/functional/__init__.py +0 -0
  59. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/functional/functions.py +0 -0
  60. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/functional/initialization.py +0 -0
  61. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/atcnet.py +0 -0
  62. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/attentionbasenet.py +0 -0
  63. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/attn_sleep.py +0 -0
  64. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/base.py +0 -0
  65. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/brainmodule.py +0 -0
  66. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/cbramod.py +0 -0
  67. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/codebrain.py +0 -0
  68. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/config.py +0 -0
  69. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/contrawr.py +0 -0
  70. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/ctnet.py +0 -0
  71. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/deep4.py +0 -0
  72. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/deepsleepnet.py +0 -0
  73. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/dgcnn.py +0 -0
  74. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegconformer.py +0 -0
  75. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eeginception_erp.py +0 -0
  76. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eeginception_mi.py +0 -0
  77. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegitnet.py +0 -0
  78. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegminer.py +0 -0
  79. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegnet.py +0 -0
  80. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegnex.py +0 -0
  81. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegpt.py +0 -0
  82. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegsimpleconv.py +0 -0
  83. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegsym.py +0 -0
  84. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegtcnet.py +0 -0
  85. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/fbcnet.py +0 -0
  86. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/fblightconvnet.py +0 -0
  87. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/fbmsnet.py +0 -0
  88. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/hybrid.py +0 -0
  89. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/ifnet.py +0 -0
  90. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/luna.py +0 -0
  91. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/medformer.py +0 -0
  92. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/msvtnet.py +0 -0
  93. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/patchedtransformer.py +0 -0
  94. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sccnet.py +0 -0
  95. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/shallow_fbcsp.py +0 -0
  96. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sinc_shallow.py +0 -0
  97. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sleep_stager_blanco_2020.py +0 -0
  98. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sleep_stager_chambon_2018.py +0 -0
  99. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sparcnet.py +0 -0
  100. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sstdpn.py +0 -0
  101. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/syncnet.py +0 -0
  102. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/tcn.py +0 -0
  103. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/tidnet.py +0 -0
  104. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/tsinception.py +0 -0
  105. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/usleep.py +0 -0
  106. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/activation.py +0 -0
  107. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/attention.py +0 -0
  108. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/blocks.py +0 -0
  109. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/convolution.py +0 -0
  110. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/filter.py +0 -0
  111. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/layers.py +0 -0
  112. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/linear.py +0 -0
  113. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/parametrization.py +0 -0
  114. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/stats.py +0 -0
  115. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/util.py +0 -0
  116. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/wrapper.py +0 -0
  117. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/__init__.py +0 -0
  118. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/eegprep_preprocess.py +0 -0
  119. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/mne_preprocess.py +0 -0
  120. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/preprocess.py +0 -0
  121. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/util.py +0 -0
  122. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/windowers.py +0 -0
  123. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/regressor.py +0 -0
  124. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/samplers/__init__.py +0 -0
  125. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/samplers/base.py +0 -0
  126. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/samplers/ssl.py +0 -0
  127. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/training/__init__.py +0 -0
  128. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/training/callbacks.py +0 -0
  129. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/training/losses.py +0 -0
  130. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/training/scoring.py +0 -0
  131. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/util.py +0 -0
  132. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/visualization/__init__.py +0 -0
  133. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/visualization/confusion_matrices.py +0 -0
  134. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/visualization/gradients.py +0 -0
  135. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode.egg-info/dependency_links.txt +0 -0
  136. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode.egg-info/requires.txt +0 -0
  137. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode.egg-info/top_level.txt +0 -0
  138. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/Makefile +0 -0
  139. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/_templates/autosummary/class.rst +0 -0
  140. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/_templates/autosummary/class_in_subdir.rst +0 -0
  141. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/_templates/autosummary/function.rst +0 -0
  142. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/_templates/autosummary/function_in_subdir.rst +0 -0
  143. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/cite.rst +0 -0
  144. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/conf.py +0 -0
  145. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/help.rst +0 -0
  146. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/index.rst +0 -0
  147. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/install/install.rst +0 -0
  148. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/install/install_pip.rst +0 -0
  149. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/install/install_source.rst +0 -0
  150. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/attention.rst +0 -0
  151. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/channel.rst +0 -0
  152. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/convolution.rst +0 -0
  153. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/filterbank.rst +0 -0
  154. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/gnn.rst +0 -0
  155. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/interpretable.rst +0 -0
  156. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/lbm.rst +0 -0
  157. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/recurrent.rst +0 -0
  158. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/spd.rst +0 -0
  159. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/models.rst +0 -0
  160. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/models_categorization.rst +0 -0
  161. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/models_table.rst +0 -0
  162. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/models_visualization.rst +0 -0
  163. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/sg_execution_times.rst +0 -0
  164. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/pyproject.toml +0 -0
  165. {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.5.0.dev985
3
+ Version: 1.5.0.dev172791986
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -4,8 +4,8 @@ from .atcnet import ATCNet
4
4
  from .attentionbasenet import AttentionBaseNet
5
5
  from .attn_sleep import AttnSleep
6
6
  from .base import EEGModuleMixin
7
- from .bendr import BENDR
8
- from .biot import BIOT
7
+ from .bendr import BENDR, InterpolatedBENDR
8
+ from .biot import BIOT, InterpolatedBIOT
9
9
  from .brainmodule import BrainModule
10
10
  from .cbramod import CBraMod
11
11
  from .codebrain import CodeBrain
@@ -30,7 +30,8 @@ from .fblightconvnet import FBLightConvNet
30
30
  from .fbmsnet import FBMSNet
31
31
  from .hybrid import HybridNet
32
32
  from .ifnet import IFNet
33
- from .labram import Labram
33
+ from .interpolated import InterpolatedModel
34
+ from .labram import InterpolatedLaBraM, Labram
34
35
  from .luna import LUNA
35
36
  from .medformer import MEDFormer
36
37
  from .msvtnet import MSVTNet
@@ -39,6 +40,7 @@ from .reve import REVE
39
40
  from .sccnet import SCCNet
40
41
  from .shallow_fbcsp import ShallowFBCSPNet
41
42
  from .signal_jepa import (
43
+ InterpolatedSignalJEPA,
42
44
  SignalJEPA,
43
45
  SignalJEPA_Contextual,
44
46
  SignalJEPA_PostLocal,
@@ -97,6 +99,11 @@ __all__ = [
97
99
  "FBMSNet",
98
100
  "HybridNet",
99
101
  "IFNet",
102
+ "InterpolatedBENDR",
103
+ "InterpolatedBIOT",
104
+ "InterpolatedLaBraM",
105
+ "InterpolatedModel",
106
+ "InterpolatedSignalJEPA",
100
107
  "Labram",
101
108
  "LUNA",
102
109
  "extract_channel_locations_from_chs_info",
@@ -1,12 +1,60 @@
1
1
  import copy
2
2
 
3
+ import numpy as np
3
4
  import torch
4
5
  from einops.layers.torch import Rearrange
5
6
  from torch import nn
6
- from torch.nn.utils.parametrize import register_parametrization
7
7
 
8
8
  from braindecode.models.base import EEGModuleMixin
9
- from braindecode.modules.parametrization import MaxNormParametrize
9
+
10
+ # The 20 channels used to pre-train BENDR, in the order expected by the
11
+ # `braindecode/braindecode-bendr` checkpoint. The first 19 entries are the
12
+ # EEG channels taken verbatim from `dn3.transforms.instance.To1020.EEG_20_div`
13
+ # (https://github.com/SPOClab-ca/dn3/blob/master/dn3/transforms/instance.py).
14
+ # Their positions come from MNE's ``standard_1005`` montage (T5/T6 are
15
+ # legacy names that share positions with P7/P8 there).
16
+ #
17
+ # The 20th entry is ``SCALE``, a relative-amplitude statistic (not an
18
+ # electrode) appended by ``To1020(include_scale_ch=True)`` during
19
+ # pre-training. Since it has no physical position, the ``loc`` below is
20
+ # the centroid of the 19 EEG positions — purely a placeholder so that
21
+ # :class:`~braindecode.modules.ChannelInterpolationLayer` (used by
22
+ # :class:`InterpolatedBENDR`) can build a valid spline interpolation
23
+ # matrix. It is NOT the SCALE the pre-training pipeline computes
24
+ # (which is an RMS-like amplitude via ``dn3.MappingDeep1010``); users
25
+ # who need a faithful SCALE must compute it themselves and feed 20
26
+ # channels to :class:`BENDR` directly.
27
+ _BENDR_TARGET_CHS_TUPLES: list[tuple[str, tuple[float, float, float]]] = [
28
+ ("FP1", (-0.0294367, +0.0839171, -0.0069900)), # standard_1005
29
+ ("FP2", (+0.0298723, +0.0848959, -0.0070800)), # standard_1005
30
+ ("F7", (-0.0702629, +0.0424743, -0.0114200)), # standard_1005
31
+ ("F3", (-0.0502438, +0.0531112, +0.0421920)), # standard_1005
32
+ ("FZ", (+0.0003122, +0.0585120, +0.0664620)), # standard_1005
33
+ ("F4", (+0.0518362, +0.0543048, +0.0408140)), # standard_1005
34
+ ("F8", (+0.0730431, +0.0444217, -0.0120000)), # standard_1005
35
+ ("T7", (-0.0841611, -0.0160187, -0.0093460)), # standard_1005
36
+ ("C3", (-0.0653581, -0.0116317, +0.0643580)), # standard_1005
37
+ ("CZ", (+0.0004009, -0.0091670, +0.1002440)), # standard_1005
38
+ ("C4", (+0.0671179, -0.0109003, +0.0635800)), # standard_1005
39
+ ("T8", (+0.0850799, -0.0150203, -0.0094900)), # standard_1005
40
+ ("T5", (-0.0724343, -0.0734527, -0.0024870)), # standard_1005 (= P7)
41
+ ("P3", (-0.0530073, -0.0787878, +0.0559400)), # standard_1005
42
+ ("PZ", (+0.0003247, -0.0811150, +0.0826150)), # standard_1005
43
+ ("P4", (+0.0556667, -0.0785602, +0.0565610)), # standard_1005
44
+ ("T6", (+0.0730557, -0.0730683, -0.0025400)), # standard_1005 (= P8)
45
+ ("O1", (-0.0294134, -0.1124490, +0.0088390)), # standard_1005
46
+ ("O2", (+0.0298426, -0.1121560, +0.0088000)), # standard_1005
47
+ (
48
+ "SCALE",
49
+ (+0.0006439, -0.0131942, +0.0278448),
50
+ ), # centroid of the 19 EEG positions (placeholder; see comment above)
51
+ ]
52
+
53
+ _BENDR_TARGET_CHS_INFO: list[dict] = [
54
+ {"ch_name": ch, "kind": "eeg", "loc": np.asarray(loc, dtype=float)}
55
+ for ch, loc in _BENDR_TARGET_CHS_TUPLES
56
+ ]
57
+ BENDR_CHANNEL_ORDER: list[str] = [ch for ch, _ in _BENDR_TARGET_CHS_TUPLES]
10
58
 
11
59
 
12
60
  class BENDR(EEGModuleMixin, nn.Module):
@@ -195,15 +243,6 @@ class BENDR(EEGModuleMixin, nn.Module):
195
243
  The contextualizer is still created (to allow loading pretrained weights) but is not
196
244
  used in the forward pass. Requires input length of at least
197
245
  ``4 * product(enc_downsample)`` samples (384 with default downsampling of 96x).
198
- n_chans_pretrained : int or None, default=None
199
- Number of input channels the pretrained weights expect (20 for the official BENDR
200
- checkpoint). When set and ``n_chans != n_chans_pretrained``, a 1x1 Conv1d with
201
- max-norm constraint projects from ``n_chans`` to ``n_chans_pretrained`` before the
202
- encoder. This allows fine-tuning pretrained BENDR on datasets with arbitrary channel
203
- counts. When using ``from_pretrained``, pass ``strict=False`` since the checkpoint
204
- will not contain ``channel_projection`` weights.
205
- chan_proj_max_norm : float, default=1.0
206
- Max-norm constraint value for the channel projection weights.
207
246
  """
208
247
 
209
248
  def __init__(
@@ -231,8 +270,6 @@ class BENDR(EEGModuleMixin, nn.Module):
231
270
  start_token=-5, # Value for start token embedding
232
271
  final_layer=True, # Whether to include the final linear layer
233
272
  encoder_only=False, # If True, bypass contextualizer and use 4-chunk pooling
234
- n_chans_pretrained=None, # Expected input channels of pretrained weights
235
- chan_proj_max_norm=1.0, # Max-norm for channel projection weights
236
273
  ):
237
274
  super().__init__(
238
275
  n_outputs=n_outputs,
@@ -246,25 +283,34 @@ class BENDR(EEGModuleMixin, nn.Module):
246
283
  # Keep these parameters if needed later, otherwise they are captured by the mixin
247
284
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
248
285
 
286
+ # If the user supplies chs_info, require it to match BENDR_CHANNEL_ORDER
287
+ # exactly (case-insensitive). Arbitrary channel sets should go through
288
+ # InterpolatedBENDR — same pattern as Labram / InterpolatedLaBraM.
289
+ # When chs_info is absent (the usual n_chans=20 path, incl.
290
+ # from_pretrained), no check is performed.
291
+ try:
292
+ _chs_info = self.chs_info
293
+ except ValueError:
294
+ _chs_info = None
295
+ if _chs_info is not None:
296
+ user_names = [ch["ch_name"] for ch in _chs_info] # type: ignore[index]
297
+ canonical = BENDR_CHANNEL_ORDER
298
+ if [n.lower() for n in user_names] != [n.lower() for n in canonical]:
299
+ raise ValueError(
300
+ f"BENDR requires chs_info to match BENDR_CHANNEL_ORDER exactly "
301
+ f"({len(canonical)} channels, specific order; last is 'SCALE'). "
302
+ f"Got {len(user_names)} channel(s). For arbitrary channel sets, "
303
+ f"use InterpolatedBENDR "
304
+ f"(from braindecode.models import InterpolatedBENDR)."
305
+ )
306
+
249
307
  self.encoder_h = encoder_h
250
308
  self.contextualizer_hidden = contextualizer_hidden
251
309
  self.include_final_layer = final_layer
252
310
  self.encoder_only = encoder_only
253
311
 
254
- # Channel projection for pretrained weight compatibility
255
- encoder_n_chans = self.n_chans
256
- if n_chans_pretrained is not None and self.n_chans != n_chans_pretrained:
257
- conv = nn.Conv1d(self.n_chans, n_chans_pretrained, 1, bias=False)
258
- register_parametrization(
259
- conv, "weight", MaxNormParametrize(chan_proj_max_norm)
260
- )
261
- self.channel_projection = conv
262
- encoder_n_chans = n_chans_pretrained
263
- else:
264
- self.channel_projection = None
265
-
266
312
  self.encoder = _ConvEncoderBENDR(
267
- in_features=encoder_n_chans,
313
+ in_features=self.n_chans,
268
314
  encoder_h=encoder_h,
269
315
  dropout=drop_prob,
270
316
  projection_head=projection_head,
@@ -308,8 +354,6 @@ class BENDR(EEGModuleMixin, nn.Module):
308
354
  self._build_head(n_outputs)
309
355
 
310
356
  def forward(self, x, return_features=False):
311
- if self.channel_projection is not None:
312
- x = self.channel_projection(x)
313
357
  encoded = self.encoder(x)
314
358
  # encoded: [batch_size, encoder_h, n_encoded_times]
315
359
 
@@ -552,3 +596,23 @@ class _BENDRContextualizer(nn.Module):
552
596
  # x: [batch_size, in_features, seq_len + 1]
553
597
 
554
598
  return x
599
+
600
+
601
+ # -----------------------------------------------------------------------------
602
+ # InterpolatedBENDR — experimental channel-interpolation variant of BENDR
603
+ # -----------------------------------------------------------------------------
604
+ # Wraps :class:`BENDR` with an MNE-backed channel-interpolation layer that
605
+ # projects arbitrary user ``chs_info`` to the canonical 20-channel BENDR
606
+ # input (:data:`_BENDR_TARGET_CHS_INFO` — the 19 pre-training EEG channels
607
+ # plus a ``SCALE`` placeholder at the centroid of those 19 positions).
608
+ # Frozen by default; set ``trainable=True`` to fine-tune the projection.
609
+ #
610
+ # NOTE: the ``SCALE`` target has no physical position, so the row of the
611
+ # interpolation matrix that produces it is a spatial spline of the user's
612
+ # EEG channels — *not* the dn3 ``MappingDeep1010`` RMS statistic the
613
+ # checkpoint saw during pre-training. Expect degraded zero-shot transfer
614
+ # from the SCALE channel; downstream fine-tuning should still work.
615
+
616
+ from braindecode.models.interpolated import InterpolatedModel # noqa: E402
617
+
618
+ InterpolatedBENDR = InterpolatedModel(BENDR, _BENDR_TARGET_CHS_INFO)
@@ -1,12 +1,57 @@
1
1
  import math
2
2
  from warnings import warn
3
3
 
4
+ import numpy as np
4
5
  import torch
5
6
  import torch.nn as nn
6
7
  from linear_attention_transformer import LinearAttentionTransformer
7
8
 
8
9
  from braindecode.models.base import EEGModuleMixin
9
10
 
11
+ # -----------------------------------------------------------------------------
12
+ # Canonical channel order for InterpolatedBIOT — the 18-channel TCP bipolar
13
+ # montage used by BIOT's shhs-prest and six-datasets pretrained checkpoints.
14
+ # Source: https://github.com/ycq091044/BIOT (README + datasets/TUAB/process.py
15
+ # + datasets/SHHS/process.py). Indices 0-15 are the TCP 16-channel bipolar
16
+ # derivations; indices 16-17 are SHHS differential channels.
17
+ #
18
+ # The `loc` values are only used to build an MNE interpolation matrix for
19
+ # InterpolatedBIOT. All entries are bipolar / differential derivations.
20
+ # TODO: positions are stored as the midpoint of the two constituent
21
+ # electrodes. This is a simplification — a bipolar signal V(A)-V(B) cannot
22
+ # be faithfully recovered by spatial interpolation at the midpoint. Revisit
23
+ # in a follow-up PR (e.g. a dedicated BipolarDerivationLayer).
24
+ # -----------------------------------------------------------------------------
25
+
26
+ # fmt: off
27
+ _BIOT_TARGET_CHS_TUPLES: list[tuple[str, tuple[float, float, float]]] = [
28
+ ("FP1-F7", (-0.04984980, 0.06319570, -0.00920500)),
29
+ ("F7-T7", (-0.07721200, 0.01322780, -0.01038300)),
30
+ ("T7-P7", (-0.07829770, -0.04473570, -0.00591650)),
31
+ ("P7-O1", (-0.05092385, -0.09295085, 0.00317600)),
32
+ ("FP2-F8", (0.05145770, 0.06465880, -0.00954000)),
33
+ ("F8-T8", (0.07906150, 0.01470070, -0.01074500)),
34
+ ("T8-P8", (0.07906780, -0.04404430, -0.00601500)),
35
+ ("P8-O2", (0.05144915, -0.09261215, 0.00313000)),
36
+ ("FP1-F3", (-0.03984025, 0.06851415, 0.01760100)),
37
+ ("F3-C3", (-0.05780095, 0.02073975, 0.05327500)),
38
+ ("C3-P3", (-0.05918270, -0.04520975, 0.06014900)),
39
+ ("P3-O1", (-0.04121035, -0.09561840, 0.03238950)),
40
+ ("FP2-F4", (0.04085425, 0.06960035, 0.01686700)),
41
+ ("F4-C4", (0.05947705, 0.02170225, 0.05219700)),
42
+ ("C4-P4", (0.06139230, -0.04473025, 0.06007050)),
43
+ ("P4-O2", (0.04275465, -0.09535810, 0.03268050)),
44
+ ("C3-A2", (0.01021790, -0.01832050, -0.00183650)),
45
+ ("C4-A1", (-0.00947910, -0.01794500, -0.00220300)),
46
+ ]
47
+ # fmt: on
48
+
49
+ _BIOT_TARGET_CHS_INFO = [
50
+ {"ch_name": ch, "kind": "eeg", "loc": np.asarray(loc, dtype=float)}
51
+ for ch, loc in _BIOT_TARGET_CHS_TUPLES
52
+ ]
53
+ BIOT_CHANNEL_ORDER = [ch for ch, _ in _BIOT_TARGET_CHS_TUPLES]
54
+
10
55
 
11
56
  class BIOT(EEGModuleMixin, nn.Module):
12
57
  r"""BIOT from Yang et al (2023) [Yang2023]_
@@ -439,7 +484,9 @@ class _BIOTEncoder(nn.Module):
439
484
  self.channel_tokens = nn.Embedding(
440
485
  num_embeddings=n_chans, embedding_dim=emb_size
441
486
  )
442
- self.register_buffer("index", torch.arange(n_chans, dtype=torch.long))
487
+ self.register_buffer(
488
+ "index", torch.arange(n_chans, dtype=torch.long), persistent=False
489
+ )
443
490
 
444
491
  def stft(self, sample):
445
492
  """
@@ -553,3 +600,16 @@ class _BIOTEncoder(nn.Module):
553
600
  # (batch_size, emb)
554
601
  emb = self.transformer(emb).mean(dim=1)
555
602
  return emb
603
+
604
+
605
+ # -----------------------------------------------------------------------------
606
+ # InterpolatedBIOT — experimental channel-interpolation variant of BIOT
607
+ # -----------------------------------------------------------------------------
608
+ # Wraps :class:`BIOT` with an MNE-backed channel-interpolation layer that
609
+ # projects arbitrary user ``chs_info`` to the canonical 18-channel BIOT
610
+ # montage (:data:`_BIOT_TARGET_CHS_INFO`). Frozen by default; set
611
+ # ``trainable=True`` to fine-tune the projection matrix.
612
+
613
+ from braindecode.models.interpolated import InterpolatedModel # noqa: E402
614
+
615
+ InterpolatedBIOT = InterpolatedModel(BIOT, _BIOT_TARGET_CHS_INFO)
@@ -0,0 +1,182 @@
1
+ # Authors: Pierre Guetschel
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ from typing import Literal, Optional, Type
7
+
8
+ import numpy as np
9
+ import torch.nn as nn
10
+
11
+ from braindecode.modules.interpolation import ChannelInterpolationLayer
12
+
13
+
14
+ def InterpolatedModel(
15
+ model_cls: Type,
16
+ target_chs_info: list[dict],
17
+ name: Optional[str] = None,
18
+ ) -> Type:
19
+ """Return a subclass of ``model_cls`` that interpolates channels to ``target_chs_info``.
20
+
21
+ .. warning:: Experimental. Public API may change without a deprecation cycle.
22
+
23
+ Parameters
24
+ ----------
25
+ model_cls : type
26
+ A braindecode model class (subclass of
27
+ :class:`~braindecode.models.base.EEGModuleMixin`).
28
+ target_chs_info : list of dict
29
+ The canonical channel set the backbone expects internally. Every
30
+ instance of the returned class projects its input channels to
31
+ this set via :class:`~braindecode.modules.ChannelInterpolationLayer`.
32
+ name : str, optional
33
+ ``__name__`` to assign to the returned class. Defaults to
34
+ ``f"Interpolated{model_cls.__name__}"``.
35
+
36
+ Returns
37
+ -------
38
+ type
39
+ A new subclass of ``model_cls`` whose ``__init__`` accepts
40
+ arbitrary user ``chs_info`` and automatically inserts a frozen
41
+ (by default) channel-interpolation layer before the backbone.
42
+ """
43
+ _is_sequential = issubclass(model_cls, nn.Sequential)
44
+
45
+ class _Interpolated(model_cls):
46
+ _TARGET_CHS_INFO = target_chs_info
47
+
48
+ def __init__(
49
+ self,
50
+ chs_info,
51
+ n_outputs=None,
52
+ n_times=None,
53
+ input_window_seconds=None,
54
+ sfreq=None,
55
+ n_chans=None,
56
+ interpolation_method: str = "spline",
57
+ interpolation_mode: Literal["always", "name_match"] = "name_match",
58
+ trainable: bool = False,
59
+ **kwargs,
60
+ ):
61
+ # Signal-related params are declared EXPLICITLY here so that
62
+ # skorch's ``EEGClassifier._set_signal_args`` (which inspects
63
+ # the ``__init__`` signature via ``inspect.signature``) can
64
+ # auto-forward them from the training dataset. They would not
65
+ # be discoverable if they were collapsed into ``**kwargs``.
66
+ # ``n_chans`` is declared for the same discoverability reason
67
+ # but intentionally ignored: the backbone must see
68
+ # ``len(target_chs_info)`` derived from ``chs_info``.
69
+ del n_chans
70
+ # Backbone init uses the target channels. During this call,
71
+ # some backbones run a dummy forward (e.g. to size the head);
72
+ # ``self.interpolation_layer`` does not exist yet — the
73
+ # ``forward`` override below falls back to pass-through when
74
+ # the attribute is absent. Assigning an ``nn.Identity()``
75
+ # before ``super()`` is impossible: ``nn.Module.__setattr__``
76
+ # requires ``nn.Module.__init__`` to have run, and the chain
77
+ # would wipe ``self._modules`` when it reaches it again.
78
+ super().__init__(
79
+ chs_info=target_chs_info,
80
+ n_outputs=n_outputs,
81
+ n_times=n_times,
82
+ input_window_seconds=input_window_seconds,
83
+ sfreq=sfreq,
84
+ **kwargs,
85
+ )
86
+
87
+ layer = ChannelInterpolationLayer(
88
+ src_chs_info=chs_info,
89
+ tgt_chs_info=target_chs_info,
90
+ mode=interpolation_mode,
91
+ method=interpolation_method,
92
+ trainable=trainable,
93
+ )
94
+ if _is_sequential:
95
+ # For nn.Sequential subclasses, prepend the interpolation
96
+ # layer so that nn.Sequential.forward runs it first.
97
+ # Registering via attribute assignment appends to _modules;
98
+ # instead we rebuild _modules with the layer first.
99
+ old_modules = list(self._modules.items()) # type: ignore[has-type]
100
+ self._modules.clear() # type: ignore[has-type]
101
+ self._modules["interpolation_layer"] = layer # type: ignore[index]
102
+ for k, v in old_modules:
103
+ self._modules[k] = v # type: ignore[index]
104
+ else:
105
+ self.interpolation_layer = layer
106
+
107
+ # Rebind private attrs so the user-facing view (.chs_info,
108
+ # .n_chans, .input_shape, build_model_config) reflects the
109
+ # user's channels. Properties are NOT overridden — we mutate
110
+ # the private attrs the base-class properties read from.
111
+ self._chs_info = chs_info
112
+ self._n_chans = len(chs_info)
113
+
114
+ if not _is_sequential:
115
+
116
+ def forward(self, x, *args, **kwargs):
117
+ # During super().__init__() the interpolation_layer attr
118
+ # does not exist yet; any dummy forward call (e.g. from
119
+ # get_output_shape) must pass through unchanged so the
120
+ # backbone sees its expected target-shape input.
121
+ # Forward *args / **kwargs so backbone-specific flags
122
+ # like ``return_features`` keep working through the
123
+ # wrapper.
124
+ interp = getattr(self, "interpolation_layer", None)
125
+ if interp is not None:
126
+ x = interp(x)
127
+ return super().forward(x, *args, **kwargs)
128
+
129
+ _Interpolated.__name__ = name or f"Interpolated{model_cls.__name__}"
130
+ _Interpolated.__qualname__ = _Interpolated.__name__
131
+ # Propagate the backbone docstring so Sphinx and the categorization tests
132
+ # can read the class badges. Prepend a short header so the class shows
133
+ # up clearly in documentation as distinct from the backbone.
134
+ backbone_doc = model_cls.__doc__ or ""
135
+ _Interpolated.__doc__ = (
136
+ f"Channel-interpolating wrapper around :class:`{model_cls.__name__}`.\n\n"
137
+ ":bdg-dark-line:`Channel`\n\n"
138
+ f"Accepts arbitrary user ``chs_info`` and projects them to the\n"
139
+ f"backbone's canonical channel set via\n"
140
+ f":class:`~braindecode.modules.ChannelInterpolationLayer`.\n\n"
141
+ f"For all other parameters and behavior see the backbone\n"
142
+ f"documentation reproduced below.\n\n" + backbone_doc
143
+ )
144
+ return _Interpolated
145
+
146
+
147
+ def _build_chs_info_from_montage(names: list[str], montage: str) -> list[dict]:
148
+ """Build a ``list[dict]`` ``chs_info`` from channel names + an MNE montage.
149
+
150
+ Each returned dict has ``ch_name``, ``kind="eeg"``, and ``loc`` (shape
151
+ ``(3,)``). Used by braindecode's shipped ``Interpolated*`` variants to
152
+ turn a bare list of canonical channel names into the dict form
153
+ ``ChannelInterpolationLayer`` expects.
154
+
155
+ Parameters
156
+ ----------
157
+ names : list of str
158
+ Channel names in the desired order.
159
+ montage : str
160
+ Name of an MNE standard montage (e.g. ``"standard_1005"``).
161
+
162
+ Returns
163
+ -------
164
+ list of dict
165
+
166
+ Raises
167
+ ------
168
+ ValueError
169
+ If a name is not found in the montage.
170
+ """
171
+ import mne
172
+
173
+ mtg = mne.channels.make_standard_montage(montage)
174
+ ch_pos = mtg.get_positions()["ch_pos"]
175
+ out = []
176
+ for n in names:
177
+ if n not in ch_pos:
178
+ raise ValueError(f"Channel {n!r} not found in montage {montage!r}.")
179
+ out.append(
180
+ {"ch_name": n, "kind": "eeg", "loc": np.asarray(ch_pos[n], dtype=float)}
181
+ )
182
+ return out