braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,233 +0,0 @@
1
- from collections.abc import Callable
2
- from inspect import signature
3
- from types import UnionType
4
- from typing import Annotated, Any, Literal, Union, get_args, get_origin
5
-
6
- import numpy as np
7
- from mne.utils import _soft_import
8
- from typing_extensions import TypedDict
9
-
10
- from braindecode.models.base import EEGModuleMixin
11
- from braindecode.models.util import SigArgName, models_dict, models_mandatory_parameters
12
-
13
- pydantic = _soft_import(name="pydantic", purpose="model configuration", strict=False)
14
-
15
- try:
16
- from numpydantic import NDArray, Shape
17
- except ImportError:
18
- # we can't use soft import for numpydantic because numpydantic does not define its version in __init__
19
- NDArray = Any # type: ignore
20
- Shape = Any # type: ignore
21
-
22
-
23
- class ChsInfoType(TypedDict, total=False, closed=True): # type: ignore[call-arg]
24
- cal: float
25
- ch_name: str
26
- coil_type: int
27
- coord_frame: int
28
- kind: str
29
- loc: NDArray[Shape["12"], np.float64] # type: ignore[misc]
30
- logno: int
31
- range: float
32
- scanno: int
33
- unit: int
34
- unit_mul: int
35
-
36
-
37
- def _replace_type_hints(type_hint: Any) -> Any:
38
- origin = get_origin(type_hint)
39
- args = get_args(type_hint)
40
- if origin is type or origin is Callable or type_hint is Callable:
41
- return pydantic.ImportString
42
- if origin is None:
43
- return type_hint
44
- replaced_args = tuple(_replace_type_hints(arg) for arg in args)
45
- if origin is UnionType:
46
- origin = Union
47
- return origin[replaced_args]
48
-
49
-
50
- SIGNAL_ARGS_TYPES = {
51
- "n_chans": int,
52
- "n_times": int,
53
- "sfreq": float,
54
- "input_window_seconds": float,
55
- "n_outputs": int,
56
- "chs_info": list[ChsInfoType],
57
- }
58
-
59
-
60
- class BaseBraindecodeModelConfig(pydantic.BaseModel): # type: ignore
61
- def create_instance(self) -> EEGModuleMixin:
62
- model_cls = models_dict[self.model_name_]
63
- kwargs = self.model_dump(mode="python", exclude={"model_name_"})
64
- if kwargs.get("n_chans") is not None and kwargs.get("chs_info") is not None:
65
- kwargs.pop("n_chans")
66
- if (
67
- kwargs.get("n_times") is not None
68
- and kwargs.get("input_window_seconds") is not None
69
- and kwargs.get("sfreq") is not None
70
- ):
71
- kwargs.pop("n_times")
72
- return model_cls(**kwargs)
73
-
74
-
75
- def make_model_config(
76
- model_class: type[EEGModuleMixin],
77
- required: list[SigArgName],
78
- ) -> type[BaseBraindecodeModelConfig]:
79
- """Create a pydantic model config for a given model class.
80
-
81
- Parameters
82
- ----------
83
- model_class : type[EEGModuleMixin]
84
- The model class for which to create the config.
85
- required : list of SigArgName
86
- The required signal arguments for the model.
87
-
88
- Returns
89
- -------
90
- type
91
- A pydantic BaseModel subclass representing the model config.
92
- """
93
- if not pydantic:
94
- raise ImportError(
95
- "pydantic is required to use make_model_config. "
96
- "Please install braindecode[typing]."
97
- )
98
-
99
- # ironically, we need to ignore the type here to have the soft dependency.
100
-
101
- @pydantic.model_validator(mode="before")
102
- def validate_signal_params(cls, data: Any):
103
- n_outputs = data.get("n_outputs")
104
- n_chans = data.get("n_chans")
105
- chs_info = data.get("chs_info")
106
- n_times = data.get("n_times")
107
- input_window_seconds = data.get("input_window_seconds")
108
- sfreq = data.get("sfreq")
109
-
110
- # Check that required parameters are provided or can be inferred
111
- if "n_outputs" in required and n_outputs is None:
112
- raise ValueError("n_outputs is a required parameter but was not provided.")
113
- if "n_chans" in required and n_chans is None and chs_info is None:
114
- raise ValueError(
115
- "n_chans is required and could not be inferred. Either specify n_chans or chs_info."
116
- )
117
- if "chs_info" in required and chs_info is None:
118
- raise ValueError("chs_info is a required parameter but was not provided.")
119
- if "n_times" in required and (
120
- n_times is None and (sfreq is None or input_window_seconds is None)
121
- ):
122
- raise ValueError(
123
- "n_times is required and could not be inferred."
124
- "Either specify n_times or input_window_seconds and sfreq."
125
- )
126
- if "sfreq" in required and (
127
- sfreq is None and (n_times is None or input_window_seconds is None)
128
- ):
129
- raise ValueError(
130
- "sfreq is required and could not be inferred."
131
- "Either specify sfreq or input_window_seconds and n_times."
132
- )
133
- if "input_window_seconds" in required and (
134
- input_window_seconds is None and (n_times is None or sfreq is None)
135
- ):
136
- raise ValueError(
137
- "input_window_seconds is required and could not be inferred."
138
- "Either specify input_window_seconds or n_times and sfreq."
139
- )
140
-
141
- # Infer missing parameters if possible, and check consistency
142
- if chs_info is not None:
143
- if n_chans is None:
144
- data["n_chans"] = len(chs_info)
145
- elif n_chans != len(chs_info):
146
- raise ValueError(
147
- f"Provided {n_chans=} does not match length of chs_info: {len(chs_info)}."
148
- )
149
- if (
150
- n_times is not None
151
- and sfreq is not None
152
- and input_window_seconds is not None
153
- ):
154
- if n_times != round(input_window_seconds * sfreq):
155
- raise ValueError(
156
- f"Provided {n_times=} does not match {input_window_seconds=} * {sfreq=}."
157
- )
158
- elif n_times is None and sfreq is not None and input_window_seconds is not None:
159
- data["n_times"] = round(input_window_seconds * sfreq)
160
- elif sfreq is None and n_times is not None and input_window_seconds is not None:
161
- data["sfreq"] = n_times / input_window_seconds
162
- elif input_window_seconds is None and n_times is not None and sfreq is not None:
163
- data["input_window_seconds"] = n_times / sfreq
164
- return data
165
-
166
- signature_params = signature(model_class.__init__, eval_str=True).parameters
167
- has_args = any(p.kind == p.VAR_POSITIONAL for p in signature_params.values())
168
- has_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature_params.values())
169
- if has_args:
170
- raise ValueError("Model __init__ methods cannot have *args")
171
-
172
- extra = "allow" if has_kwargs else "forbid"
173
- fields = {}
174
- for name, p in signature_params.items():
175
- if name == "self" or p.kind == p.VAR_KEYWORD:
176
- continue
177
-
178
- annot = p.annotation
179
- if annot is p.empty:
180
- annot = Any
181
- # case with type[nn.Module] or callable
182
- else:
183
- annot = _replace_type_hints(annot)
184
- # Most models did not specify types for signal args, so we add them here
185
- if name in SIGNAL_ARGS_TYPES:
186
- annot = SIGNAL_ARGS_TYPES[name] | None
187
-
188
- fields[name] = (annot, p.default) if p.default is not p.empty else annot
189
-
190
- name = model_class.__name__
191
- model_config = pydantic.create_model(
192
- f"{name}Config",
193
- model_name_=(Literal[name], name),
194
- __config__=pydantic.ConfigDict(
195
- arbitrary_types_allowed=True, extra=extra, validate_default=True
196
- ),
197
- __doc__=f"Pydantic config of model {model_class.__name__}\n\n{model_class.__doc__}",
198
- __base__=BaseBraindecodeModelConfig,
199
- __module__="braindecode.models.config",
200
- __validators__={"validate_signal_params": validate_signal_params},
201
- **fields,
202
- )
203
- return model_config
204
-
205
-
206
- # Automatically generate and add classes to the global namespace
207
- # and define __all__ based on generated classes
208
- __all__ = ["make_model_config"]
209
-
210
- if not pydantic:
211
- pass
212
- else:
213
- models_configs: list[type[BaseBraindecodeModelConfig]] = []
214
- for model_name, req, _ in models_mandatory_parameters:
215
- model_cls = models_dict[model_name]
216
- model_cfg = make_model_config(model_cls, req)
217
- globals()[model_cfg.__name__] = model_cfg
218
- __all__.append(model_cfg.__name__)
219
- models_configs.append(model_cfg)
220
-
221
- BraindecodeModelConfig = Annotated[ # type: ignore
222
- Union[tuple(models_configs)],
223
- pydantic.Field(
224
- discriminator="model_name_", description="Braindecode model configuration"
225
- ),
226
- ]
227
-
228
- # # Example usage:
229
- #
230
- # class DummyConfigWithModel(pydantic.BaseModel):
231
- # model: BraindecodeModelConfig
232
- #
233
- # DummyConfigWithModel.model_validate({'model': dict(model_name_='EEGNet', n_chans=16, n_outputs=1, n_times=200)})