braindecode 1.5.0.dev179040812__tar.gz → 1.5.0.dev182195895__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. {braindecode-1.5.0.dev179040812/braindecode.egg-info → braindecode-1.5.0.dev182195895}/PKG-INFO +1 -1
  2. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/attentionbasenet.py +3 -2
  3. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/biot.py +1 -1
  4. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/syncnet.py +17 -13
  5. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/tsinception.py +1 -1
  6. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/attention.py +45 -16
  7. braindecode-1.5.0.dev182195895/braindecode/version.py +1 -0
  8. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895/braindecode.egg-info}/PKG-INFO +1 -1
  9. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/whats_new.rst +23 -1
  10. braindecode-1.5.0.dev179040812/braindecode/version.py +0 -1
  11. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/LICENSE.txt +0 -0
  12. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/MANIFEST.in +0 -0
  13. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/NOTICE.txt +0 -0
  14. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/README.rst +0 -0
  15. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/__init__.py +0 -0
  16. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/augmentation/__init__.py +0 -0
  17. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/augmentation/base.py +0 -0
  18. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/augmentation/functional.py +0 -0
  19. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/augmentation/transforms.py +0 -0
  20. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/classifier.py +0 -0
  21. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/__init__.py +0 -0
  22. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/base.py +0 -0
  23. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bbci.py +0 -0
  24. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bcicomp.py +0 -0
  25. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bids/__init__.py +0 -0
  26. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bids/datasets.py +0 -0
  27. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bids/format.py +0 -0
  28. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bids/hub.py +0 -0
  29. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bids/hub_format.py +0 -0
  30. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bids/hub_io.py +0 -0
  31. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bids/hub_validation.py +0 -0
  32. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/bids/iterable.py +0 -0
  33. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/chb_mit.py +0 -0
  34. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/mne.py +0 -0
  35. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/moabb.py +0 -0
  36. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/nmt.py +0 -0
  37. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/registry.py +0 -0
  38. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/siena.py +0 -0
  39. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
  40. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/sleep_physionet.py +0 -0
  41. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/tuh.py +0 -0
  42. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/utils.py +0 -0
  43. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datasets/xy.py +0 -0
  44. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datautil/__init__.py +0 -0
  45. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datautil/channel_utils.py +0 -0
  46. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datautil/hub_formats.py +0 -0
  47. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datautil/serialization.py +0 -0
  48. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/datautil/util.py +0 -0
  49. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/eegneuralnet.py +0 -0
  50. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/functional/__init__.py +0 -0
  51. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/functional/functions.py +0 -0
  52. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/functional/initialization.py +0 -0
  53. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/__init__.py +0 -0
  54. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/atcnet.py +0 -0
  55. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/attn_sleep.py +0 -0
  56. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/base.py +0 -0
  57. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/bendr.py +0 -0
  58. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/brainmodule.py +0 -0
  59. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/cbramod.py +0 -0
  60. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/config.py +0 -0
  61. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/contrawr.py +0 -0
  62. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/ctnet.py +0 -0
  63. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/deep4.py +0 -0
  64. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/deepsleepnet.py +0 -0
  65. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/dgcnn.py +0 -0
  66. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegconformer.py +0 -0
  67. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eeginception_erp.py +0 -0
  68. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eeginception_mi.py +0 -0
  69. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegitnet.py +0 -0
  70. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegminer.py +0 -0
  71. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegnet.py +0 -0
  72. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegnex.py +0 -0
  73. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegpt.py +0 -0
  74. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegsimpleconv.py +0 -0
  75. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegsym.py +0 -0
  76. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/eegtcnet.py +0 -0
  77. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/fbcnet.py +0 -0
  78. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/fblightconvnet.py +0 -0
  79. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/fbmsnet.py +0 -0
  80. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/hybrid.py +0 -0
  81. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/ifnet.py +0 -0
  82. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/labram.py +0 -0
  83. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/luna.py +0 -0
  84. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/medformer.py +0 -0
  85. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/msvtnet.py +0 -0
  86. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/patchedtransformer.py +0 -0
  87. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/reve.py +0 -0
  88. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/sccnet.py +0 -0
  89. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/shallow_fbcsp.py +0 -0
  90. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/signal_jepa.py +0 -0
  91. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/sinc_shallow.py +0 -0
  92. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/sleep_stager_blanco_2020.py +0 -0
  93. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/sleep_stager_chambon_2018.py +0 -0
  94. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/sparcnet.py +0 -0
  95. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/sstdpn.py +0 -0
  96. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/summary.csv +0 -0
  97. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/tcn.py +0 -0
  98. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/tidnet.py +0 -0
  99. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/usleep.py +0 -0
  100. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/models/util.py +0 -0
  101. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/__init__.py +0 -0
  102. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/activation.py +0 -0
  103. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/blocks.py +0 -0
  104. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/convolution.py +0 -0
  105. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/filter.py +0 -0
  106. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/layers.py +0 -0
  107. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/linear.py +0 -0
  108. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/parametrization.py +0 -0
  109. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/stats.py +0 -0
  110. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/util.py +0 -0
  111. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/modules/wrapper.py +0 -0
  112. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/preprocessing/__init__.py +0 -0
  113. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/preprocessing/eegprep_preprocess.py +0 -0
  114. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/preprocessing/mne_preprocess.py +0 -0
  115. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/preprocessing/preprocess.py +0 -0
  116. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/preprocessing/util.py +0 -0
  117. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/preprocessing/windowers.py +0 -0
  118. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/regressor.py +0 -0
  119. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/samplers/__init__.py +0 -0
  120. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/samplers/base.py +0 -0
  121. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/samplers/ssl.py +0 -0
  122. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/training/__init__.py +0 -0
  123. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/training/callbacks.py +0 -0
  124. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/training/losses.py +0 -0
  125. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/training/scoring.py +0 -0
  126. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/util.py +0 -0
  127. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/visualization/__init__.py +0 -0
  128. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/visualization/confusion_matrices.py +0 -0
  129. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode/visualization/gradients.py +0 -0
  130. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode.egg-info/SOURCES.txt +0 -0
  131. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode.egg-info/dependency_links.txt +0 -0
  132. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode.egg-info/requires.txt +0 -0
  133. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/braindecode.egg-info/top_level.txt +0 -0
  134. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/Makefile +0 -0
  135. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/_templates/autosummary/class.rst +0 -0
  136. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/_templates/autosummary/class_in_subdir.rst +0 -0
  137. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/_templates/autosummary/function.rst +0 -0
  138. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/_templates/autosummary/function_in_subdir.rst +0 -0
  139. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/api.rst +0 -0
  140. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/cite.rst +0 -0
  141. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/conf.py +0 -0
  142. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/help.rst +0 -0
  143. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/index.rst +0 -0
  144. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/install/install.rst +0 -0
  145. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/install/install_pip.rst +0 -0
  146. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/install/install_source.rst +0 -0
  147. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/attention.rst +0 -0
  148. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/channel.rst +0 -0
  149. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/convolution.rst +0 -0
  150. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/filterbank.rst +0 -0
  151. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/gnn.rst +0 -0
  152. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/interpretable.rst +0 -0
  153. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/lbm.rst +0 -0
  154. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/recurrent.rst +0 -0
  155. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/categorization/spd.rst +0 -0
  156. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/models.rst +0 -0
  157. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/models_categorization.rst +0 -0
  158. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/models_table.rst +0 -0
  159. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/models/models_visualization.rst +0 -0
  160. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/docs/sg_execution_times.rst +0 -0
  161. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/pyproject.toml +0 -0
  162. {braindecode-1.5.0.dev179040812 → braindecode-1.5.0.dev182195895}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.5.0.dev179040812
3
+ Version: 1.5.0.dev182195895
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -1,5 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ # Authors: Sarthak Tayal <sarthaktayal2@gmail.com>
4
+ #
5
+ # License: BSD (3-clause)
3
6
  import math
4
7
 
5
8
  from einops.layers.torch import Rearrange
@@ -275,8 +278,6 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
275
278
  activation: type[nn.Module] = nn.ELU,
276
279
  extra_params: bool = False,
277
280
  ):
278
- super(AttentionBaseNet, self).__init__()
279
-
280
281
  super().__init__(
281
282
  n_outputs=n_outputs,
282
283
  n_chans=n_chans,
@@ -439,7 +439,7 @@ class _BIOTEncoder(nn.Module):
439
439
  self.channel_tokens = nn.Embedding(
440
440
  num_embeddings=n_chans, embedding_dim=emb_size
441
441
  )
442
- self.index = nn.Parameter(torch.LongTensor(range(n_chans)), requires_grad=False)
442
+ self.register_buffer("index", torch.arange(n_chans, dtype=torch.long))
443
443
 
444
444
  def stft(self, sample):
445
445
  """
@@ -1,3 +1,7 @@
1
+ # Authors: Sarthak Tayal <sarthaktayal2@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
1
5
  import torch
2
6
  import torch.nn as nn
3
7
  import torch.nn.functional as F
@@ -51,11 +55,11 @@ class SyncNet(EEGModuleMixin, nn.Module):
51
55
  The initialization range for omega parameters using uniform
52
56
  distribution. Default is (0, 1).
53
57
  beta_init_values : tuple of float, optional
54
- The initialization range for beta parameters using uniform
55
- distribution. Default is (0, 1). Default is (0, 0.05).
58
+ The initialization range for beta (decay) parameters using uniform
59
+ distribution. Default is (0, 0.05).
56
60
  phase_init_values : tuple of float, optional
57
- The initialization range for phase parameters using `normal`
58
- distribution. Default is (0, 1). Default is (0, 0.05).
61
+ The initialization mean and standard deviation for phase
62
+ parameters using normal distribution. Default is (0, 0.05).
59
63
 
60
64
 
61
65
  Notes
@@ -146,12 +150,12 @@ class SyncNet(EEGModuleMixin, nn.Module):
146
150
  # Phase Shift
147
151
  self.phi_ini = nn.Parameter(
148
152
  torch.FloatTensor(1, 1, self.n_chans, self.num_filters).normal_(
149
- self.beta_init_values[0], self.beta_init_values[1]
153
+ self.phase_init_values[0], self.phase_init_values[1]
150
154
  )
151
155
  )
152
156
  self.beta = nn.Parameter(
153
157
  torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
154
- self.phase_init_values[0], self.phase_init_values[1]
158
+ self.beta_init_values[0], self.beta_init_values[1]
155
159
  )
156
160
  )
157
161
 
@@ -185,21 +189,21 @@ class SyncNet(EEGModuleMixin, nn.Module):
185
189
  # Output: (batch_size, n_chans, 1, n_times)
186
190
 
187
191
  # Compute the oscillatory component
192
+ # Shape: (1, filter_width, n_chans, num_filters)
188
193
  W_osc = self.amplitude * torch.cos(self.t * self.omega + self.phi_ini)
189
- # W_osc is (1, filter_width, n_chans, 1)
190
194
 
191
195
  # Compute the decay component
192
- t_squared = torch.pow(self.t, 2) # Shape: (filter_width,)
193
- t_squared_beta = t_squared * self.beta # Shape: (filter_width, num_filters)
196
+ # Shape: (1, filter_width, 1, num_filters)
197
+ t_squared = torch.pow(self.t, 2)
198
+ t_squared_beta = t_squared * self.beta
194
199
  W_decay = torch.exp(-t_squared_beta)
195
- # W_osc is (1, filter_width, 1, 1)
196
200
 
197
201
  # Combine oscillatory and decay components
198
- # W shape: (1, n_chans, num_filters, filter_width)
202
+ # Shape: (1, filter_width, n_chans, num_filters)
199
203
  W = W_osc * W_decay
200
- # W shape will be: (1, filter_width, n_chans, 1)
201
204
 
202
- W = W.view(self.num_filters, self.n_chans, 1, self.filter_width)
205
+ # Permute to conv2d weight shape (out_channels, in_channels, kH, kW)
206
+ W = W.permute(3, 2, 0, 1).contiguous()
203
207
 
204
208
  # Apply convolution
205
209
  x_padded = self.pad_input(x.float())
@@ -1,4 +1,4 @@
1
- # Authors: Bruno Aristimunha <b.aristimunha>
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
2
  #
3
3
  # License: BSD (3-clause)
4
4
 
@@ -837,6 +837,18 @@ class CATLite(nn.Module):
837
837
  class MultiHeadAttention(nn.Module):
838
838
  """Multi-head self-attention block.
839
839
 
840
+ Uses ``F.scaled_dot_product_attention`` for optimized attention
841
+ kernels (flash-attention on CUDA, memory-efficient on other devices).
842
+
843
+ Parameters
844
+ ----------
845
+ emb_size : int
846
+ The embedding dimension.
847
+ num_heads : int
848
+ Number of attention heads. Must evenly divide ``emb_size``.
849
+ dropout : float, optional
850
+ Dropout probability applied to attention weights. Default: 0.0.
851
+
840
852
  Examples
841
853
  --------
842
854
  >>> import torch
@@ -848,40 +860,57 @@ class MultiHeadAttention(nn.Module):
848
860
  torch.Size([2, 10, 32])
849
861
  """
850
862
 
851
- def __init__(self, emb_size, num_heads, dropout):
863
+ def __init__(self, emb_size, num_heads, dropout=0.0):
852
864
  super().__init__()
865
+ if emb_size % num_heads != 0:
866
+ raise ValueError(
867
+ f"emb_size ({emb_size}) must be divisible by num_heads ({num_heads})."
868
+ )
853
869
  self.emb_size = emb_size
854
870
  self.num_heads = num_heads
871
+ self.head_dim = emb_size // num_heads
855
872
  self.keys = nn.Linear(emb_size, emb_size)
856
873
  self.queries = nn.Linear(emb_size, emb_size)
857
874
  self.values = nn.Linear(emb_size, emb_size)
858
- self.att_drop = nn.Dropout(dropout)
875
+ self.att_drop = dropout
859
876
  self.projection = nn.Linear(emb_size, emb_size)
860
877
 
861
878
  self.rearrange_stack = Rearrange(
862
- "b n (h d) -> b h n d",
863
- h=num_heads,
879
+ "batch seq (heads head_dim) -> batch heads seq head_dim",
880
+ heads=num_heads,
864
881
  )
865
882
  self.rearrange_unstack = Rearrange(
866
- "b h n d -> b n (h d)",
883
+ "batch heads seq head_dim -> batch seq (heads head_dim)",
867
884
  )
868
885
 
869
886
  def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
887
+ """Forward pass.
888
+
889
+ Parameters
890
+ ----------
891
+ x : Tensor
892
+ Input tensor of shape ``(batch, seq, emb_size)``.
893
+ mask : Tensor, optional
894
+ Attention mask following PyTorch SDPA convention: for boolean
895
+ masks ``True`` means **ignore** that position; for float
896
+ masks the values are **added** to attention scores before
897
+ softmax.
898
+ """
870
899
  queries = self.rearrange_stack(self.queries(x))
871
900
  keys = self.rearrange_stack(self.keys(x))
872
901
  values = self.rearrange_stack(self.values(x))
873
- energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
874
- if mask is not None:
875
- fill_value = float("-inf")
876
- energy = energy.masked_fill(~mask, fill_value)
877
-
878
- scaling = self.emb_size ** (1 / 2)
879
- att = F.softmax(energy / scaling, dim=-1)
880
- att = self.att_drop(att)
881
- out = torch.einsum("bhal, bhlv -> bhav ", att, values)
902
+
903
+ dp = self.att_drop if self.training else 0.0
904
+ out = F.scaled_dot_product_attention(
905
+ queries,
906
+ keys,
907
+ values,
908
+ attn_mask=mask,
909
+ dropout_p=dp,
910
+ )
911
+
882
912
  out = self.rearrange_unstack(out)
883
- out = self.projection(out)
884
- return out
913
+ return self.projection(out)
885
914
 
886
915
 
887
916
  class CrissCrossTransformerEncoderLayer(nn.Module):
@@ -0,0 +1 @@
1
+ __version__ = "1.5.0.dev182195895"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.5.0.dev179040812
3
+ Version: 1.5.0.dev182195895
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -28,10 +28,19 @@ Current 1.5.0 (GitHub)
28
28
  Enhancements
29
29
  ============
30
30
 
31
+ - Use ``F.scaled_dot_product_attention`` in :class:`braindecode.modules.MultiHeadAttention`,
32
+ enabling optimized attention kernels (flash-attention on CUDA,
33
+ memory-efficient backends on other devices).
34
+ By `Léo Burgund`_ and `Bruno Aristimunha`_.
35
+ (:gh:`902`)
36
+
31
37
  API and behavior changes
32
38
  ========================
33
39
 
34
- - None yet
40
+ - :class:`braindecode.modules.MultiHeadAttention` now follows PyTorch's SDPA mask
41
+ convention: boolean masks use ``True`` to **ignore** a position (previously
42
+ ``True`` meant keep). The scaling factor is now ``1/sqrt(head_dim)`` instead of
43
+ ``1/sqrt(emb_size)``. (:gh:`902`)
35
44
 
36
45
  Requirements
37
46
  ============
@@ -41,6 +50,16 @@ Requirements
41
50
  Bug fixes
42
51
  ==========
43
52
 
53
+ - Fix :class:`braindecode.models.SyncNet` swapped parameter initialization where
54
+ ``phi_ini`` (phase shift) was using ``beta_init_values`` and ``beta`` (decay) was
55
+ using ``phase_init_values``, replaced incorrect ``.view()`` reshape with ``.permute()``
56
+ for proper conv2d filter weight layout, and fixed duplicate default values in docstring
57
+ (by `Sarthak Tayal`_)
58
+ - Fix :class:`braindecode.models.AttentionBaseNet` redundant
59
+ ``super().__init__()`` call that ran the parent ``nn.Module.__init__`` twice
60
+ (by `Sarthak Tayal`_)
61
+ - Fix incomplete author email in :class:`braindecode.models.TSception` header
62
+ (by `Sarthak Tayal`_)
44
63
  - Fix a time-of-check-time-of-use race in
45
64
  :func:`braindecode.datasets.base._zarr_to_memmap` that caused
46
65
  concurrent workers to repeatedly ``rename``-replace the published
@@ -50,6 +69,9 @@ Bug fixes
50
69
  is never replaced, making the cache safe under arbitrary
51
70
  concurrent access on local POSIX, NFSv3, Lustre and SMB
52
71
  (:gh:`986` by `Pierre Guetschel`_)
72
+ - Register :class:`braindecode.models.BIOT` encoder ``index`` as a non-trainable
73
+ buffer instead of a parameter (``torch.long``), so it is treated as module
74
+ state rather than trainable weights (:gh:`988` by `Pierre Guetschel`_)
53
75
 
54
76
  Code health
55
77
  ============
@@ -1 +0,0 @@
1
- __version__ = "1.5.0.dev179040812"