braindecode 1.5.0.dev1010__tar.gz → 1.5.0.dev1015__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. {braindecode-1.5.0.dev1010/braindecode.egg-info → braindecode-1.5.0.dev1015}/PKG-INFO +1 -1
  2. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/augmentation/__init__.py +2 -0
  3. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/augmentation/functional.py +127 -0
  4. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/augmentation/transforms.py +97 -0
  5. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/emg2qwerty.py +207 -7
  6. braindecode-1.5.0.dev1015/braindecode/version.py +1 -0
  7. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015/braindecode.egg-info}/PKG-INFO +1 -1
  8. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/whats_new.rst +28 -0
  9. braindecode-1.5.0.dev1010/braindecode/version.py +0 -1
  10. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/LICENSE.txt +0 -0
  11. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/MANIFEST.in +0 -0
  12. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/NOTICE.txt +0 -0
  13. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/README.rst +0 -0
  14. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/__init__.py +0 -0
  15. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/augmentation/base.py +0 -0
  16. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/classifier.py +0 -0
  17. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/__init__.py +0 -0
  18. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/base.py +0 -0
  19. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bbci.py +0 -0
  20. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bcicomp.py +0 -0
  21. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bids/__init__.py +0 -0
  22. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bids/datasets.py +0 -0
  23. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bids/format.py +0 -0
  24. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bids/hub.py +0 -0
  25. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bids/hub_format.py +0 -0
  26. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bids/hub_io.py +0 -0
  27. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bids/hub_validation.py +0 -0
  28. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/bids/iterable.py +0 -0
  29. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/chb_mit.py +0 -0
  30. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/mne.py +0 -0
  31. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/moabb.py +0 -0
  32. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/nmt.py +0 -0
  33. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/registry.py +0 -0
  34. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/siena.py +0 -0
  35. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
  36. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/sleep_physionet.py +0 -0
  37. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/tuh.py +0 -0
  38. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/utils.py +0 -0
  39. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datasets/xy.py +0 -0
  40. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datautil/__init__.py +0 -0
  41. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datautil/channel_utils.py +0 -0
  42. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datautil/hub_formats.py +0 -0
  43. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datautil/serialization.py +0 -0
  44. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/datautil/util.py +0 -0
  45. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/eegneuralnet.py +0 -0
  46. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/functional/__init__.py +0 -0
  47. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/functional/functions.py +0 -0
  48. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/functional/initialization.py +0 -0
  49. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/__init__.py +0 -0
  50. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/atcnet.py +0 -0
  51. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/attentionbasenet.py +0 -0
  52. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/attn_sleep.py +0 -0
  53. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/base.py +0 -0
  54. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/bendr.py +0 -0
  55. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/biot.py +0 -0
  56. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/brainmodule.py +0 -0
  57. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/cbramod.py +0 -0
  58. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/codebrain.py +0 -0
  59. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/config.py +0 -0
  60. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/contrawr.py +0 -0
  61. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/ctnet.py +0 -0
  62. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/deep4.py +0 -0
  63. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/deepsleepnet.py +0 -0
  64. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/dgcnn.py +0 -0
  65. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegconformer.py +0 -0
  66. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eeginception_erp.py +0 -0
  67. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eeginception_mi.py +0 -0
  68. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegitnet.py +0 -0
  69. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegminer.py +0 -0
  70. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegnet.py +0 -0
  71. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegnex.py +0 -0
  72. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegpt.py +0 -0
  73. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegsimpleconv.py +0 -0
  74. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegsym.py +0 -0
  75. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/eegtcnet.py +0 -0
  76. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/fbcnet.py +0 -0
  77. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/fblightconvnet.py +0 -0
  78. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/fbmsnet.py +0 -0
  79. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/hybrid.py +0 -0
  80. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/ifnet.py +0 -0
  81. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/interpolated.py +0 -0
  82. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/labram.py +0 -0
  83. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/luna.py +0 -0
  84. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/medformer.py +0 -0
  85. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/meta_neuromotor.py +0 -0
  86. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/msvtnet.py +0 -0
  87. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/patchedtransformer.py +0 -0
  88. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/reve.py +0 -0
  89. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/sccnet.py +0 -0
  90. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/shallow_fbcsp.py +0 -0
  91. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/signal_jepa.py +0 -0
  92. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/sinc_shallow.py +0 -0
  93. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/sleep_stager_blanco_2020.py +0 -0
  94. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/sleep_stager_chambon_2018.py +0 -0
  95. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/sparcnet.py +0 -0
  96. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/sstdpn.py +0 -0
  97. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/summary.csv +0 -0
  98. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/syncnet.py +0 -0
  99. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/tcn.py +0 -0
  100. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/tidnet.py +0 -0
  101. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/tsinception.py +0 -0
  102. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/usleep.py +0 -0
  103. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/models/util.py +0 -0
  104. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/__init__.py +0 -0
  105. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/activation.py +0 -0
  106. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/attention.py +0 -0
  107. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/blocks.py +0 -0
  108. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/convolution.py +0 -0
  109. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/filter.py +0 -0
  110. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/interpolation.py +0 -0
  111. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/layers.py +0 -0
  112. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/linear.py +0 -0
  113. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/parametrization.py +0 -0
  114. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/stats.py +0 -0
  115. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/util.py +0 -0
  116. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/modules/wrapper.py +0 -0
  117. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/preprocessing/__init__.py +0 -0
  118. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/preprocessing/eegprep_preprocess.py +0 -0
  119. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/preprocessing/mne_preprocess.py +0 -0
  120. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/preprocessing/preprocess.py +0 -0
  121. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/preprocessing/util.py +0 -0
  122. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/preprocessing/windowers.py +0 -0
  123. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/regressor.py +0 -0
  124. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/samplers/__init__.py +0 -0
  125. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/samplers/base.py +0 -0
  126. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/samplers/ssl.py +0 -0
  127. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/training/__init__.py +0 -0
  128. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/training/callbacks.py +0 -0
  129. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/training/losses.py +0 -0
  130. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/training/scoring.py +0 -0
  131. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/util.py +0 -0
  132. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/visualization/__init__.py +0 -0
  133. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/visualization/attribution.py +0 -0
  134. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/visualization/confusion_matrices.py +0 -0
  135. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/visualization/frequency.py +0 -0
  136. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/visualization/metrics.py +0 -0
  137. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/visualization/sanity.py +0 -0
  138. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode/visualization/topology.py +0 -0
  139. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode.egg-info/SOURCES.txt +0 -0
  140. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode.egg-info/dependency_links.txt +0 -0
  141. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode.egg-info/requires.txt +0 -0
  142. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/braindecode.egg-info/top_level.txt +0 -0
  143. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/Makefile +0 -0
  144. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/_templates/autosummary/class.rst +0 -0
  145. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/_templates/autosummary/class_in_subdir.rst +0 -0
  146. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/_templates/autosummary/function.rst +0 -0
  147. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/_templates/autosummary/function_in_subdir.rst +0 -0
  148. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/api.rst +0 -0
  149. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/cite.rst +0 -0
  150. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/conf.py +0 -0
  151. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/help.rst +0 -0
  152. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/index.rst +0 -0
  153. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/install/install.rst +0 -0
  154. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/install/install_pip.rst +0 -0
  155. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/install/install_source.rst +0 -0
  156. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/attention.rst +0 -0
  157. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/channel.rst +0 -0
  158. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/convolution.rst +0 -0
  159. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/filterbank.rst +0 -0
  160. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/gnn.rst +0 -0
  161. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/interpretable.rst +0 -0
  162. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/lbm.rst +0 -0
  163. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/recurrent.rst +0 -0
  164. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/categorization/spd.rst +0 -0
  165. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/models.rst +0 -0
  166. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/models_categorization.rst +0 -0
  167. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/models_table.rst +0 -0
  168. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/models/models_visualization.rst +0 -0
  169. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/docs/sg_execution_times.rst +0 -0
  170. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/pyproject.toml +0 -0
  171. {braindecode-1.5.0.dev1010 → braindecode-1.5.0.dev1015}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.5.0.dev1010
3
+ Version: 1.5.0.dev1015
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -4,6 +4,7 @@ from . import functional
4
4
  from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
5
5
  from .transforms import (
6
6
  AmplitudeScale,
7
+ BandRotation,
7
8
  BandstopFilter,
8
9
  ChannelsDropout,
9
10
  ChannelsReref,
@@ -47,6 +48,7 @@ __all__ = [
47
48
  "SegmentationReconstruction",
48
49
  "MaskEncoding",
49
50
  "AmplitudeScale",
51
+ "BandRotation",
50
52
  "ChannelsReref",
51
53
  "functional",
52
54
  ]
@@ -1298,3 +1298,130 @@ def amplitude_scale(
1298
1298
  X = s * X
1299
1299
 
1300
1300
  return X, y
1301
+
1302
+
1303
+ def band_rotation(
1304
+ X: torch.Tensor,
1305
+ y: torch.Tensor,
1306
+ num_bands: int = 2,
1307
+ electrodes_per_band: int = 16,
1308
+ band_offsets: tuple[int, ...] = (-1, 0, 1),
1309
+ max_temporal_jitter: int = 0,
1310
+ circular_jitter: bool = True,
1311
+ random_state: int | np.random.RandomState | None = None,
1312
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1313
+ """Per-band electrode rotation + inter-band temporal jitter.
1314
+
1315
+ Models small wristband rotation between sessions and relative timing
1316
+ noise between two arms. Introduced in [Sivakumar2024]_ for the
1317
+ emg2qwerty CTC keystroke decoding task: each electrode band gets its
1318
+ own circular roll along the channel axis (``Uniform(band_offsets)``
1319
+ positions), and band 1 also gets a sample-level temporal shift
1320
+ (``Uniform(-max_temporal_jitter, +max_temporal_jitter)``) along the
1321
+ time axis.
1322
+
1323
+ Channel layout assumes ``(B, num_bands * electrodes_per_band, T)`` with
1324
+ bands contiguous along the channel axis. Same offset / shift is
1325
+ applied to every sample in the batch (one set of parameters per call).
1326
+
1327
+ Parameters
1328
+ ----------
1329
+ X : torch.Tensor
1330
+ EMG input batch of shape ``(B, C, T)`` with
1331
+ ``C == num_bands * electrodes_per_band``.
1332
+ y : torch.Tensor
1333
+ Labels (returned unchanged).
1334
+ num_bands : int, optional
1335
+ Number of electrode bands (e.g. ``2`` for left + right wristband).
1336
+ Must be ``>= 1``. Defaults to 2.
1337
+ electrodes_per_band : int, optional
1338
+ Electrodes per band (e.g. ``16``). Must be ``>= 1``. Defaults
1339
+ to 16.
1340
+ band_offsets : tuple of int, optional
1341
+ Per-band roll values to sample from uniformly. ``(-1, 0, 1)``
1342
+ covers ±1-electrode misalignment. Must be non-empty. Defaults
1343
+ to ``(-1, 0, 1)``.
1344
+ max_temporal_jitter : int, optional
1345
+ Max ±-sample temporal shift applied to band 1 only when
1346
+ ``num_bands >= 2``. Defaults to 0 (disabled). Must be ``>= 0``.
1347
+ circular_jitter : bool, optional
1348
+ If True (the default, paper-faithful), the temporal jitter is a
1349
+ circular ``torch.roll`` — samples shifted off one edge wrap to
1350
+ the other. If False, the gap left by the shift is zero-padded
1351
+ and the shifted-off samples are dropped, avoiding wrap-around
1352
+ discontinuity at the cost of a small zeroed margin. Has no
1353
+ effect when ``max_temporal_jitter == 0``.
1354
+ random_state : int | numpy.random.RandomState, optional
1355
+ Seed / generator for sampling rotation + jitter values.
1356
+
1357
+ Returns
1358
+ -------
1359
+ torch.Tensor
1360
+ Transformed inputs.
1361
+ torch.Tensor
1362
+ Labels (unchanged).
1363
+
1364
+ References
1365
+ ----------
1366
+ .. [Sivakumar2024] Sivakumar, V., Seely, J., Du, A., Bittner, S. R.,
1367
+ Berenzweig, A., Bolarinwa, A., Gramfort, A., & Mandel, M. I. (2024).
1368
+ "emg2qwerty: A Large Dataset with Baselines for Touch Typing using
1369
+ Surface Electromyography." *NeurIPS Datasets and Benchmarks Track*.
1370
+ """
1371
+ if num_bands < 1:
1372
+ raise ValueError(f"num_bands must be >= 1, got {num_bands}")
1373
+ if electrodes_per_band < 1:
1374
+ raise ValueError(f"electrodes_per_band must be >= 1, got {electrodes_per_band}")
1375
+ # Normalise to a tuple before truth-testing so callers can pass any
1376
+ # sequence-like (incl. ``np.ndarray``) without hitting numpy's
1377
+ # ambiguous-truth-value error on ``if not band_offsets``.
1378
+ band_offsets = tuple(band_offsets)
1379
+ if not band_offsets:
1380
+ raise ValueError("band_offsets must be non-empty")
1381
+ if not all(isinstance(o, (int, np.integer)) for o in band_offsets):
1382
+ raise ValueError(f"band_offsets must contain integers, got {band_offsets!r}")
1383
+ if max_temporal_jitter < 0:
1384
+ raise ValueError(f"max_temporal_jitter must be >= 0, got {max_temporal_jitter}")
1385
+ expected_channels = num_bands * electrodes_per_band
1386
+ if X.shape[1] != expected_channels:
1387
+ raise ValueError(
1388
+ f"X.shape[1]={X.shape[1]} != num_bands * electrodes_per_band="
1389
+ f"{expected_channels}"
1390
+ )
1391
+
1392
+ rng = check_random_state(random_state)
1393
+ band_offsets_arr = np.asarray(band_offsets)
1394
+ out = X.clone()
1395
+
1396
+ # Per-band channel-axis rolls. A vectorized ``torch.gather`` was
1397
+ # benchmarked and is ~16 % slower for the typical ``num_bands == 2``
1398
+ # case on CPU (the index tensor is larger than what two contiguous
1399
+ # rolls touch); the gather only wins past ``num_bands >= 8``.
1400
+ for b in range(num_bands):
1401
+ offset = int(rng.choice(band_offsets_arr))
1402
+ if offset:
1403
+ sl = slice(b * electrodes_per_band, (b + 1) * electrodes_per_band)
1404
+ out[:, sl, :] = torch.roll(out[:, sl, :], offset, dims=1)
1405
+
1406
+ # Inter-band temporal jitter — paper recipe applies it to band 1 only.
1407
+ if max_temporal_jitter > 0 and num_bands >= 2:
1408
+ shift = int(rng.randint(-max_temporal_jitter, max_temporal_jitter + 1))
1409
+ if shift:
1410
+ sl = slice(electrodes_per_band, 2 * electrodes_per_band)
1411
+ band1 = out[:, sl, :]
1412
+ if circular_jitter:
1413
+ # Paper-faithful circular shift; wraps end-of-window
1414
+ # samples to the start (and vice versa).
1415
+ out[:, sl, :] = torch.roll(band1, shift, dims=2)
1416
+ else:
1417
+ # Crop-and-pad shift: drop samples that fall off one end,
1418
+ # zero-pad the gap on the other. Avoids the wrap-around
1419
+ # discontinuity at the cost of a ``|shift|``-sample margin.
1420
+ shifted = torch.zeros_like(band1)
1421
+ if shift > 0:
1422
+ shifted[:, :, shift:] = band1[:, :, :-shift]
1423
+ else: # shift < 0
1424
+ shifted[:, :, :shift] = band1[:, :, -shift:]
1425
+ out[:, sl, :] = shifted
1426
+
1427
+ return out, y
@@ -16,6 +16,7 @@ from mne.channels import make_standard_montage
16
16
  from .base import Transform
17
17
  from .functional import (
18
18
  amplitude_scale,
19
+ band_rotation,
19
20
  bandstop_filter,
20
21
  channels_dropout,
21
22
  channels_permute,
@@ -1356,3 +1357,99 @@ class AmplitudeScale(Transform):
1356
1357
  def get_augmentation_params(self, *batch):
1357
1358
  """Return transform parameters."""
1358
1359
  return {"random_state": self.rng, "scale": self.scale}
1360
+
1361
+
1362
+ class BandRotation(Transform):
1363
+ """Per-band electrode rotation + inter-band temporal jitter.
1364
+
1365
+ Models small wristband rotation between sessions and relative timing
1366
+ noise between two arms. Introduced in [Sivakumar2024]_ for the
1367
+ emg2qwerty surface-EMG keystroke decoding task: the channel axis is
1368
+ laid out as ``(B, num_bands * electrodes_per_band, T)`` with bands
1369
+ contiguous, each band gets a uniform circular roll along the channel
1370
+ axis, and when ``num_bands >= 2``, band 1 also gets a sample-level
1371
+ temporal shift. The same offset / shift is applied to every sample
1372
+ in a transformed sub-batch (one set of parameters per call).
1373
+
1374
+ Parameters
1375
+ ----------
1376
+ probability : float
1377
+ Float setting the probability of applying the operation.
1378
+ num_bands : int, optional
1379
+ Number of electrode bands (e.g. ``2`` for left + right wristband).
1380
+ Must be ``>= 1``. Defaults to 2.
1381
+ electrodes_per_band : int, optional
1382
+ Electrodes per band (e.g. ``16``). Must be ``>= 1``. Defaults
1383
+ to 16.
1384
+ band_offsets : tuple of int, optional
1385
+ Per-band roll values to sample from uniformly. ``(-1, 0, 1)``
1386
+ covers ±1-electrode misalignment. Must be non-empty. Defaults
1387
+ to ``(-1, 0, 1)``.
1388
+ max_temporal_jitter : int, optional
1389
+ Max ±-sample temporal shift applied to band 1. Defaults to 0
1390
+ (jitter disabled). Must be ``>= 0``. The emg2qwerty paper uses
1391
+ 120 samples (60 ms at 2 kHz).
1392
+ circular_jitter : bool, optional
1393
+ If True (default, paper-faithful) the jitter is a circular roll;
1394
+ if False the gap left by the shift is zero-padded. See
1395
+ :func:`band_rotation`.
1396
+ random_state : int | numpy.random.RandomState, optional
1397
+ Seed for the rotation / jitter sampler. Defaults to None.
1398
+
1399
+ References
1400
+ ----------
1401
+ .. [Sivakumar2024] Sivakumar, V., Seely, J., Du, A., Bittner, S. R.,
1402
+ Berenzweig, A., Bolarinwa, A., Gramfort, A., & Mandel, M. I. (2024).
1403
+ "emg2qwerty: A Large Dataset with Baselines for Touch Typing using
1404
+ Surface Electromyography." *NeurIPS Datasets and Benchmarks Track*.
1405
+ """
1406
+
1407
+ operation = staticmethod(band_rotation) # type: ignore[assignment]
1408
+
1409
+ def __init__(
1410
+ self,
1411
+ probability,
1412
+ num_bands=2,
1413
+ electrodes_per_band=16,
1414
+ band_offsets=(-1, 0, 1),
1415
+ max_temporal_jitter=0,
1416
+ circular_jitter=True,
1417
+ random_state=None,
1418
+ ):
1419
+ super().__init__(probability=probability, random_state=random_state)
1420
+ # Up-front parameter validation; the underlying ``band_rotation``
1421
+ # also re-checks at call time, but raising here surfaces config
1422
+ # mistakes when the Transform is built rather than on the first
1423
+ # batch.
1424
+ if num_bands < 1:
1425
+ raise ValueError(f"num_bands must be >= 1, got {num_bands}")
1426
+ if electrodes_per_band < 1:
1427
+ raise ValueError(
1428
+ f"electrodes_per_band must be >= 1, got {electrodes_per_band}"
1429
+ )
1430
+ band_offsets = tuple(band_offsets)
1431
+ if not band_offsets:
1432
+ raise ValueError("band_offsets must be non-empty")
1433
+ if not all(isinstance(o, (int, np.integer)) for o in band_offsets):
1434
+ raise ValueError(
1435
+ f"band_offsets must contain integers, got {band_offsets!r}"
1436
+ )
1437
+ if max_temporal_jitter < 0:
1438
+ raise ValueError(
1439
+ f"max_temporal_jitter must be >= 0, got {max_temporal_jitter}"
1440
+ )
1441
+ self.num_bands = num_bands
1442
+ self.electrodes_per_band = electrodes_per_band
1443
+ self.band_offsets = band_offsets
1444
+ self.max_temporal_jitter = max_temporal_jitter
1445
+ self.circular_jitter = circular_jitter
1446
+
1447
+ def get_augmentation_params(self, *batch):
1448
+ return {
1449
+ "num_bands": self.num_bands,
1450
+ "electrodes_per_band": self.electrodes_per_band,
1451
+ "band_offsets": self.band_offsets,
1452
+ "max_temporal_jitter": self.max_temporal_jitter,
1453
+ "circular_jitter": self.circular_jitter,
1454
+ "random_state": self.rng,
1455
+ }
@@ -56,7 +56,11 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
56
56
  Returns ``(batch, T_out, n_outputs)``. With ``n_times=8000`` and
57
57
  defaults, ``T_out=373``. For :class:`~torch.nn.CTCLoss`, transpose
58
58
  to ``(T_out, batch, n_outputs)``; use :meth:`compute_output_lengths`
59
- for emission lengths.
59
+ for emission lengths. Pass ``return_features=True`` to return the
60
+ pre-classifier encoder representation as a
61
+ ``{"features": (batch, T_out, num_features), "cls_token": None}``
62
+ dict, matching the BIOT / signal-JEPA convention used by downstream
63
+ wrappers (e.g. neuroai's ``DownstreamWrapperModel``).
60
64
 
61
65
  .. rubric:: Paper training recipe
62
66
 
@@ -69,7 +73,9 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
69
73
  local minimum).
70
74
  - **Augmentation**: per-band electrode rotations by -1/0/+1 positions,
71
75
  ±60-sample temporal jitter, and SpecAugment [park2019specaug]_ on
72
- the log-spectrogram.
76
+ the log-spectrogram. SpecAugment is built into the model
77
+ (``spec_augment=True``) and only fires in training mode; the
78
+ time/frequency-jitter pieces are dataset-side augmentations.
73
79
  - **Decoding**: greedy CTC. Upstream also reports a 6-gram KenLM
74
80
  beam decoder, not ported here.
75
81
 
@@ -145,6 +151,37 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
145
151
  layers and again after the second :class:`~torch.nn.Linear`.
146
152
  Default ``0.0`` matches the upstream paper recipe (no dropout).
147
153
  Set ``> 0`` for regularized training.
154
+ spec_augment : bool
155
+ If ``True``, apply SpecAugment [park2019specaug]_ time/frequency
156
+ masking on the log-spectrogram during training only. Disabled in
157
+ ``eval`` mode and absent from the parameter / state-dict count.
158
+ Defaults to ``False``; set to ``True`` to match the upstream
159
+ emg2qwerty paper recipe.
160
+ n_time_masks : int
161
+ Maximum number of time masks applied per call. Each forward pass
162
+ samples a uniform integer in ``[0, n_time_masks]``. Defaults to
163
+ ``3`` (Sivakumar et al. Sec 5.2).
164
+ time_mask_param : int
165
+ Maximum time-mask width in spectrogram frames. Defaults to ``25``.
166
+ n_freq_masks : int
167
+ Maximum number of frequency masks applied per call. Each forward
168
+ pass samples a uniform integer in ``[0, n_freq_masks]``. Defaults
169
+ to ``2``.
170
+ freq_mask_param : int
171
+ Maximum frequency-mask width in STFT bins. Defaults to ``4``.
172
+ spec_augment_prob : float
173
+ Probability of running SpecAugment on a given training batch
174
+ (Bernoulli gate before sampling mask counts). Defaults to ``1.0``.
175
+ return_feature : bool
176
+ If ``True``, ``forward`` returns a tuple
177
+ ``(emissions, features)`` instead of just ``emissions`` —
178
+ :class:`braindecode.models.BIOT`-style legacy feature path. Lets
179
+ configuration-driven downstream wrappers (e.g. neuroai's
180
+ ``DownstreamWrapperModel`` with ``model_output_key=1``) pick up
181
+ the encoder representation without passing a runtime kwarg.
182
+ Defaults to ``False``. Mutually compatible with the runtime
183
+ ``return_features`` (plural) flag, which still wins when set
184
+ to ``True``.
148
185
 
149
186
  Examples
150
187
  --------
@@ -216,6 +253,13 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
216
253
  log_softmax: bool = False,
217
254
  activation: type[nn.Module] = nn.ReLU,
218
255
  drop_prob: float = 0.0,
256
+ spec_augment: bool = False,
257
+ n_time_masks: int = 3,
258
+ time_mask_param: int = 25,
259
+ n_freq_masks: int = 2,
260
+ freq_mask_param: int = 4,
261
+ spec_augment_prob: float = 1.0,
262
+ return_feature: bool = False,
219
263
  # Standard braindecode args
220
264
  n_times: int | None = None,
221
265
  input_window_seconds: float | None = None,
@@ -256,6 +300,7 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
256
300
  self.hop_length = hop_length
257
301
  self.kernel_width = kernel_width
258
302
  self.log_softmax = log_softmax
303
+ self.return_feature = return_feature
259
304
 
260
305
  n_freq_bins = n_fft // 2 + 1
261
306
  in_features = electrodes_per_band * n_freq_bins
@@ -269,6 +314,23 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
269
314
  log_eps=log_eps,
270
315
  )
271
316
 
317
+ # Built-in SpecAugment lives between the spectrogram and the BatchNorm
318
+ # so it operates on the log-power tensor (matches upstream emg2qwerty
319
+ # and the previous neuralbench callback). ``nn.Identity`` keeps the
320
+ # forward path symmetrical without contributing parameters or
321
+ # state-dict keys when SpecAugment is disabled.
322
+ self.spec_augment: nn.Module
323
+ if spec_augment:
324
+ self.spec_augment = _SpecAugment(
325
+ n_time_masks=n_time_masks,
326
+ time_mask_param=time_mask_param,
327
+ n_freq_masks=n_freq_masks,
328
+ freq_mask_param=freq_mask_param,
329
+ prob=spec_augment_prob,
330
+ )
331
+ else:
332
+ self.spec_augment = nn.Identity()
333
+
272
334
  # Indices 0/1/3 match upstream's ``TDSConvCTCModule.model``;
273
335
  # index 2 is a parameter-free Flatten; upstream's index 4 (head)
274
336
  # is broken out as ``self.final_layer`` and remapped via :attr:`mapping`.
@@ -298,7 +360,13 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
298
360
  isinstance(m, _TDSConv2dBlock) for m in self.model[3].tds_conv_blocks
299
361
  )
300
362
 
301
- def forward(self, x: torch.Tensor) -> torch.Tensor:
363
+ def forward(
364
+ self, x: torch.Tensor, return_features: bool = False
365
+ ) -> (
366
+ torch.Tensor
367
+ | dict[str, torch.Tensor | None]
368
+ | tuple[torch.Tensor, torch.Tensor]
369
+ ):
302
370
  """Run the full pipeline.
303
371
 
304
372
  Parameters
@@ -307,12 +375,37 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
307
375
  Raw EMG of shape ``(batch, n_chans=32, n_times)``. ``n_times``
308
376
  must be at least the encoder's receptive field, ``n_fft +
309
377
  n_conv_blocks * (kernel_width - 1) * hop_length``.
378
+ return_features : bool
379
+ If ``True``, return a ``dict`` with the encoder representation
380
+ instead of the classification emissions. The encoder is the
381
+ full TDS-Conv stack up to (but not including)
382
+ ``self.final_layer`` — i.e. what downstream wrappers want
383
+ when they apply their own probe/aggregation. Matches the
384
+ BIOT / signal-JEPA convention so the same neuroai
385
+ ``DownstreamWrapperModel(model_output_key="features")``
386
+ can consume it. Wins over the constructor-time
387
+ ``return_feature`` flag when set.
310
388
 
311
389
  Returns
312
390
  -------
313
- emissions : torch.Tensor
314
- Shape ``(batch, T_out, n_outputs)``. Log-probabilities if
391
+ torch.Tensor or dict or tuple
392
+ Default (``return_features=False``, init
393
+ ``return_feature=False``): ``torch.Tensor`` of shape
394
+ ``(batch, T_out, n_outputs)``. Log-probabilities if
315
395
  ``log_softmax=True``, otherwise logits.
396
+
397
+ If runtime ``return_features=True``: ``dict`` with
398
+ ``"features"`` (shape ``(batch, T_out, num_features)``,
399
+ where ``num_features = num_bands * mlp_features[-1]``) and
400
+ ``"cls_token"`` (always ``None`` — TDS-Conv has no
401
+ ``[CLS]``).
402
+
403
+ If init ``return_feature=True`` and runtime
404
+ ``return_features=False``: tuple ``(emissions, features)``
405
+ where ``features`` has shape ``(batch, T_out,
406
+ num_features)``. Same layout BIOT exposes for
407
+ configuration-driven feature extraction (e.g. neuroai's
408
+ ``model_output_key=1``).
316
409
  """
317
410
  if x.ndim != 3 or x.shape[-2] != self.n_chans:
318
411
  raise ValueError(
@@ -331,11 +424,24 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
331
424
  f"kernel_width={self.kernel_width})."
332
425
  )
333
426
  spectrogram = self.spectrogram(x)
427
+ spectrogram = self.spec_augment(spectrogram)
334
428
  encoded = self.model(spectrogram)
429
+ # ``encoded`` is (T_out, B, num_features); only materialise the
430
+ # batch-first features tensor in the branches that actually return
431
+ # it, so the default emissions-only path skips the extra transpose
432
+ # + contiguous copy on every forward.
433
+ if return_features:
434
+ return {
435
+ "features": encoded.transpose(0, 1).contiguous(),
436
+ "cls_token": None,
437
+ }
335
438
  emissions = self.final_layer(encoded)
336
439
  if self.log_softmax:
337
440
  emissions = F.log_softmax(emissions, dim=-1)
338
- return emissions.transpose(0, 1).contiguous()
441
+ emissions = emissions.transpose(0, 1).contiguous()
442
+ if self.return_feature:
443
+ return emissions, encoded.transpose(0, 1).contiguous()
444
+ return emissions
339
445
 
340
446
  def reset_head(self, n_outputs: int) -> None:
341
447
  """Replace the classification head for a new vocabulary size.
@@ -411,7 +517,13 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
411
517
  dummy_input = torch.zeros(
412
518
  1, self.n_chans, n_times, dtype=dtype, device=device
413
519
  )
414
- return tuple(self.forward(dummy_input).shape)
520
+ # ``return_features=False`` keeps the dict path off; the init
521
+ # ``return_feature`` flag may still produce a tuple, so unpack
522
+ # the emissions explicitly to report the public output shape.
523
+ out = self.forward(dummy_input, return_features=False)
524
+ emissions = out[0] if isinstance(out, tuple) else out
525
+ assert isinstance(emissions, torch.Tensor)
526
+ return tuple(emissions.shape)
415
527
 
416
528
 
417
529
  class _LogSpectrogram(nn.Module):
@@ -483,6 +595,94 @@ class _LogSpectrogram(nn.Module):
483
595
  ).movedim(-1, 0)
484
596
 
485
597
 
598
+ class _SpecAugment(nn.Module):
599
+ r"""SpecAugment masking on the log-spectrogram during training.
600
+
601
+ Applies up to ``n_time_masks`` × ``time_mask_param``-frame time
602
+ bands and ``n_freq_masks`` × ``freq_mask_param``-bin frequency
603
+ bands. Masks are independent per ``(sample × band × electrode)``
604
+ triple — same recipe as the upstream emg2qwerty
605
+ :class:`emg2qwerty.transforms.SpecAugment` dataset transform
606
+ (Sivakumar et al. Sec 5.2 / NeurIPS 2024), which is
607
+ :func:`torchaudio.functional.mask_along_axis_iid`-style masking
608
+ sampled per leading dim of a spectrogram with shape
609
+ ``(..., freq, time)``. No-op outside ``training``.
610
+
611
+ The mask fill value is the on-device mean of the spectrogram —
612
+ ``log(power=1)=0`` would sit well above the typical log-power
613
+ distribution and inject artificial spikes — and stays a 0-D
614
+ tensor so the forward pass adds no host round-trip on GPU.
615
+ """
616
+
617
+ def __init__(
618
+ self,
619
+ n_time_masks: int = 3,
620
+ time_mask_param: int = 25,
621
+ n_freq_masks: int = 2,
622
+ freq_mask_param: int = 4,
623
+ prob: float = 1.0,
624
+ ) -> None:
625
+ super().__init__()
626
+ if n_time_masks < 0 or n_freq_masks < 0:
627
+ raise ValueError(
628
+ f"n_time_masks and n_freq_masks must be >= 0; got "
629
+ f"n_time_masks={n_time_masks}, n_freq_masks={n_freq_masks}."
630
+ )
631
+ if time_mask_param < 0 or freq_mask_param < 0:
632
+ raise ValueError(
633
+ f"time_mask_param and freq_mask_param must be >= 0; got "
634
+ f"time_mask_param={time_mask_param}, "
635
+ f"freq_mask_param={freq_mask_param}."
636
+ )
637
+ if not 0.0 <= prob <= 1.0:
638
+ raise ValueError(f"prob must be in [0, 1]; got {prob}.")
639
+ self.n_time_masks = n_time_masks
640
+ self.time_mask_param = time_mask_param
641
+ self.n_freq_masks = n_freq_masks
642
+ self.freq_mask_param = freq_mask_param
643
+ self.prob = prob
644
+ # ``iid_masks=True`` so masking is sampled over every leading dim
645
+ # except the trailing ``(freq, time)`` pair — i.e. one mask per
646
+ # ``(sample × band × electrode)`` on a 5-D
647
+ # ``(B, num_bands, electrodes, freq, T)`` input. Matches upstream
648
+ # emg2qwerty's per-``(band × electrode)`` dataset-time recipe.
649
+ self.time_mask = ta_transforms.TimeMasking(time_mask_param, iid_masks=True)
650
+ self.freq_mask = ta_transforms.FrequencyMasking(freq_mask_param, iid_masks=True)
651
+
652
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
653
+ # ``x``: (T_spec, B, num_bands, electrodes, freq).
654
+ if (
655
+ not self.training
656
+ or self.prob <= 0.0
657
+ or (self.n_time_masks == 0 and self.n_freq_masks == 0)
658
+ ):
659
+ return x
660
+ # All RNG draws use ``x.device`` so reproducibility seeds the same
661
+ # stream regardless of whether the user calls ``torch.manual_seed``
662
+ # or ``torch.cuda.manual_seed`` — and so torchaudio's internal
663
+ # device-side RNG and our Python-level gate stay in sync. ``.item()``
664
+ # still forces a host sync for the Python ``if``/loop bound, but
665
+ # that is unavoidable for control flow.
666
+ if self.prob < 1.0 and torch.rand((), device=x.device).item() >= self.prob:
667
+ return x
668
+ # ``torchaudio`` masking expects ``(..., freq, time)``; here that means
669
+ # ``(B, num_bands, electrodes, freq, T_spec)``. Move time to the end
670
+ # rather than reshaping into 4D, because ``mask_along_axis_iid`` draws
671
+ # one mask per leading-axis index, so the 5-D layout already gives the
672
+ # desired per-``(B × num_bands × electrodes)`` independence.
673
+ spec = x.movedim(0, -1).contiguous()
674
+ # 0-D on-device tensor — ``masked_fill`` / ``torch.where`` accept it
675
+ # without a host sync.
676
+ mask_value = spec.mean()
677
+ n_t = int(torch.randint(self.n_time_masks + 1, (), device=x.device).item())
678
+ for _ in range(n_t):
679
+ spec = self.time_mask(spec, mask_value=mask_value)
680
+ n_f = int(torch.randint(self.n_freq_masks + 1, (), device=x.device).item())
681
+ for _ in range(n_f):
682
+ spec = self.freq_mask(spec, mask_value=mask_value)
683
+ return spec.movedim(-1, 0)
684
+
685
+
486
686
  class _SpectrogramNorm(nn.Module):
487
687
  r""":class:`~torch.nn.BatchNorm2d` over (band × electrode) channels.
488
688
 
@@ -0,0 +1 @@
1
+ __version__ = "1.5.0.dev1015"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.5.0.dev1010
3
+ Version: 1.5.0.dev1015
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -28,6 +28,34 @@ Current 1.5.0 (stable)
28
28
  Enhancements
29
29
  ============
30
30
 
31
+ - Add :class:`braindecode.augmentation.BandRotation` and
32
+ :func:`braindecode.augmentation.functional.band_rotation`: per-band
33
+ circular roll along the channel axis plus inter-band temporal jitter,
34
+ for surface-EMG inputs shaped ``(B, num_bands * electrodes_per_band, T)``.
35
+ Models small wristband rotation between sessions and relative timing
36
+ noise between two arms, from the emg2qwerty paper (Sivakumar et al.,
37
+ NeurIPS 2024). By `Bruno Aristimunha`_.
38
+
39
+ - Build SpecAugment (Park et al., Interspeech 2019) into
40
+ :class:`braindecode.models.EMG2QwertyNet` as a parameter-free submodule
41
+ gated by a new ``spec_augment`` constructor flag (default ``False``).
42
+ When enabled, applies up to ``n_time_masks`` × ``time_mask_param``
43
+ time bands and ``n_freq_masks`` × ``freq_mask_param`` frequency bands
44
+ on the log-spectrogram during ``train()`` only, with masks sampled
45
+ IID per ``(sample × band × electrode)`` triple — same recipe as the
46
+ upstream ``emg2qwerty.transforms.SpecAugment`` dataset transform. The
47
+ mask fill value stays a 0-D on-device tensor (``spec.mean()``) so the
48
+ forward pass adds no host round-trip on GPU. Also adds a
49
+ ``return_features`` runtime flag to
50
+ :meth:`braindecode.models.EMG2QwertyNet.forward` (returns
51
+ ``{"features": (B, T_out, num_features), "cls_token": None}``,
52
+ BIOT / signal-JEPA convention) and a matching ``return_feature``
53
+ constructor flag (returns ``(emissions, features)`` tuple, BIOT-style
54
+ legacy path) so downstream wrappers — such as neuroai's
55
+ ``DownstreamWrapperModel`` — can pick up the encoder representation
56
+ via ``model_output_key="features"`` (dict) or ``model_output_key=1``
57
+ (tuple) without changes to their call site. By `Bruno Aristimunha`_.
58
+
31
59
  - Add :meth:`braindecode.datasets.BaseConcatDataset.set_target` to swap
32
60
  any per-window metadata column or per-record description field
33
61
  (e.g. a BIDS entity, a participants.tsv extra) into the dataset's
@@ -1 +0,0 @@
1
- __version__ = "1.5.0.dev1010"