braindecode 1.5.0.dev1007__py3-none-any.whl → 1.5.0.dev1010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,7 @@ import shutil
18
18
  import warnings
19
19
  from abc import abstractmethod
20
20
  from collections import Counter
21
- from collections.abc import Callable
21
+ from collections.abc import Callable, Hashable
22
22
  from glob import glob
23
23
  from pathlib import Path
24
24
  from typing import Any, Generic, Iterable, no_type_check
@@ -1340,6 +1340,124 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
1340
1340
  raise TypeError("target_transform must be a callable.")
1341
1341
  self._target_transform = fn
1342
1342
 
1343
+ def set_target(self, column: Hashable) -> "BaseConcatDataset":
1344
+ """Use ``column`` as the target ``y`` for every subdataset.
1345
+
1346
+ Dispatches on the subdataset type:
1347
+
1348
+ * For :class:`WindowsDataset` / :class:`EEGWindowsDataset`,
1349
+ ``column`` is looked up in per-window ``metadata`` first, then in
1350
+ the per-record ``description`` (broadcast to every window). The
1351
+ resolved values overwrite ``ds.metadata['target']`` and ``ds.y``.
1352
+ For :class:`WindowsDataset`, the underlying ``ds.windows.metadata``
1353
+ is kept in sync so ``get_metadata()`` and the repr reflect the
1354
+ new target.
1355
+ * For :class:`RawDataset`, ``column`` must exist on the
1356
+ ``description``. ``ds.target_name`` is set to ``column`` so
1357
+ ``__getitem__`` reads ``description[column]`` as ``y`` on every
1358
+ access — no rebuild needed.
1359
+
1360
+ Parameters
1361
+ ----------
1362
+ column : Hashable
1363
+ Name of a metadata column or description field (BIDS entity,
1364
+ participants.tsv extra, ...). Typically a string, but any
1365
+ hashable that pandas accepts as a column label is allowed.
1366
+
1367
+ Returns
1368
+ -------
1369
+ self : BaseConcatDataset
1370
+
1371
+ Raises
1372
+ ------
1373
+ TypeError
1374
+ If any subdataset is not a :class:`WindowsDataset`,
1375
+ :class:`EEGWindowsDataset`, or :class:`RawDataset`, or if a
1376
+ windowed subdataset has lazy (non-DataFrame) metadata.
1377
+ ValueError
1378
+ If ``column`` is not present on a subdataset's metadata or
1379
+ description, or if a windowed subdataset has
1380
+ ``targets_from='channels'`` (which would make this a silent
1381
+ no-op since ``__getitem__`` reads y from misc channels, not
1382
+ from ``metadata['target']``).
1383
+ """
1384
+ for i, ds in enumerate(self.datasets):
1385
+ if isinstance(ds, (WindowsDataset, EEGWindowsDataset)):
1386
+ if not isinstance(ds.metadata, pd.DataFrame):
1387
+ # _LazyDataFrame (lazy_metadata=True) does not implement
1388
+ # .copy()/__setitem__, so the in-place write below would
1389
+ # raise AttributeError. Surface the precondition cleanly.
1390
+ raise TypeError(
1391
+ "set_target requires a materialized metadata "
1392
+ f"DataFrame; datasets[{i}].metadata is "
1393
+ f"{type(ds.metadata).__name__}. Re-window with "
1394
+ "lazy_metadata=False to use set_target."
1395
+ )
1396
+ if getattr(ds, "targets_from", "metadata") != "metadata":
1397
+ # __getitem__ would read y from misc channels; writing
1398
+ # metadata['target']/ds.y would be a silent no-op.
1399
+ raise ValueError(
1400
+ f"datasets[{i}] has targets_from="
1401
+ f"{ds.targets_from!r}; set_target only applies when "
1402
+ "targets_from='metadata' (otherwise __getitem__ "
1403
+ "derives y from misc channels and would ignore the "
1404
+ "rewritten target column)."
1405
+ )
1406
+ n = len(ds)
1407
+ md = ds.metadata
1408
+ if column in md.columns:
1409
+ values = md[column].iloc[:n].to_list()
1410
+ elif (
1411
+ isinstance(ds.description, pd.Series)
1412
+ and column in ds.description.index
1413
+ ):
1414
+ values = [ds.description[column]] * n
1415
+ else:
1416
+ desc_keys = (
1417
+ list(ds.description.index)
1418
+ if isinstance(ds.description, pd.Series)
1419
+ else []
1420
+ )
1421
+ raise ValueError(
1422
+ f"Column {column!r} not found on datasets[{i}]: "
1423
+ f"metadata cols={list(md.columns)}, "
1424
+ f"description keys={desc_keys}."
1425
+ )
1426
+ # In-place write so the WindowsDataset's metadata and the
1427
+ # underlying mne.Epochs.metadata (which start as the same
1428
+ # object reference) both reflect the new target. Defensive
1429
+ # second write covers the case where they got de-aliased
1430
+ # earlier by a caller-side reassignment.
1431
+ md["target"] = values
1432
+ windows_obj = getattr(ds, "_windows", None)
1433
+ if windows_obj is not None and windows_obj.metadata is not md:
1434
+ windows_obj.metadata["target"] = values
1435
+ # values is already a fresh list (Series.to_list() / [x] * n);
1436
+ # no defensive copy needed — pandas keeps its own representation
1437
+ # for md["target"] so mutating ds.y won't reach back into it.
1438
+ ds.y = values
1439
+ elif isinstance(ds, RawDataset):
1440
+ if (
1441
+ not isinstance(ds.description, pd.Series)
1442
+ or column not in ds.description.index
1443
+ ):
1444
+ desc_keys = (
1445
+ list(ds.description.index)
1446
+ if isinstance(ds.description, pd.Series)
1447
+ else []
1448
+ )
1449
+ raise ValueError(
1450
+ f"Column {column!r} not found on datasets[{i}] "
1451
+ f"description (keys={desc_keys})."
1452
+ )
1453
+ ds.target_name = column
1454
+ else:
1455
+ raise TypeError(
1456
+ "set_target requires WindowsDataset, EEGWindowsDataset, "
1457
+ f"or RawDataset; datasets[{i}] is {type(ds).__name__}."
1458
+ )
1459
+ return self
1460
+
1343
1461
  def _outdated_save(self, path, overwrite=False):
1344
1462
  """This is a copy of the old saving function, that had inconsistent.
1345
1463
 
@@ -33,13 +33,25 @@ huggingface_hub = _soft_import(
33
33
  HAS_HF_HUB = huggingface_hub is not False
34
34
 
35
35
 
36
- class _BaseHubMixin:
37
- pass
36
+ _HF_INSTALL_HINT = (
37
+ "requires the `huggingface_hub` package. "
38
+ "Install with: pip install 'braindecode[hub]'"
39
+ )
40
+
41
+
42
+ class _BaseHubMixinStub:
43
+ @classmethod
44
+ def from_pretrained(cls, *args, **kwargs):
45
+ raise ImportError(f"{cls.__name__}.from_pretrained() {_HF_INSTALL_HINT}")
46
+
47
+ def push_to_hub(self, *args, **kwargs):
48
+ raise ImportError(f"{type(self).__name__}.push_to_hub() {_HF_INSTALL_HINT}")
38
49
 
39
50
 
40
51
  # Define base class for hub mixin
41
- if HAS_HF_HUB:
42
- _BaseHubMixin: Type = huggingface_hub.PyTorchModelHubMixin # type: ignore
52
+ _BaseHubMixin: Type = (
53
+ huggingface_hub.PyTorchModelHubMixin if HAS_HF_HUB else _BaseHubMixinStub
54
+ )
43
55
 
44
56
 
45
57
  def deprecated_args(obj, *old_new_args):
@@ -6,6 +6,7 @@
6
6
  # David Sabbagh <dav.sabbagh@gmail.com>
7
7
  # Bruno Aristimunha <b.aristimunha@gmail.com>
8
8
  # Léo Burgund <leo.burgund@gmail.com>
9
+ # Sarthak Tayal <sarthaktayal2@gmail.com>
9
10
  #
10
11
  # License: BSD (3-clause)
11
12
 
@@ -217,6 +218,7 @@ def preprocess(
217
218
  offset: int = 0,
218
219
  copy_data: bool | None = None,
219
220
  parallel_kwargs: dict | None = None,
221
+ max_nbytes: int | str | None = "1M",
220
222
  ):
221
223
  """Apply preprocessors to a concat dataset.
222
224
 
@@ -244,14 +246,20 @@ def preprocess(
244
246
  Additional keyword arguments forwarded to ``joblib.Parallel``.
245
247
  Defaults to None (equivalent to ``{}``).
246
248
  See https://joblib.readthedocs.io/en/stable/generated/joblib.Parallel.html for details.
249
+ max_nbytes : int, str, or None
250
+ Threshold (in bytes; or e.g. ``"1M"``) above which joblib memory-maps
251
+ preloaded arrays as read-only when dispatching to worker processes.
252
+ Effective only when ``n_jobs != 1``. Pass ``None`` to disable memory
253
+ mapping when a preprocessor resizes the underlying data (for example
254
+ ``filterbank``), which would otherwise fail with an ``mmap can't
255
+ resize a readonly`` error. ``parallel_kwargs['max_nbytes']`` takes
256
+ precedence if both are provided.
247
257
 
248
258
  Returns
249
259
  -------
250
260
  BaseConcatDataset
251
261
  Preprocessed dataset.
252
262
  """
253
- # In case of serialization, make sure directory is available before
254
- # preprocessing
255
263
  if save_dir is not None and not overwrite:
256
264
  _check_save_dir_empty(save_dir)
257
265
 
@@ -266,22 +274,35 @@ def preprocess(
266
274
  parallel_params.setdefault(
267
275
  "prefer", "threads" if platform.system() == "Windows" else None
268
276
  )
269
-
270
- list_of_ds = Parallel(n_jobs=n_jobs, **parallel_params)(
271
- delayed(_preprocess)(
272
- ds,
273
- i + offset,
274
- preprocessors,
275
- save_dir,
276
- overwrite,
277
- copy_data=(
278
- (parallel_processing and (save_dir is None))
279
- if copy_data is None
280
- else copy_data
281
- ),
277
+ parallel_params.setdefault("max_nbytes", max_nbytes)
278
+
279
+ try:
280
+ list_of_ds = Parallel(n_jobs=n_jobs, **parallel_params)(
281
+ delayed(_preprocess)(
282
+ ds,
283
+ i + offset,
284
+ preprocessors,
285
+ save_dir,
286
+ overwrite,
287
+ copy_data=(
288
+ (parallel_processing and (save_dir is None))
289
+ if copy_data is None
290
+ else copy_data
291
+ ),
292
+ )
293
+ for i, ds in enumerate(concat_ds.datasets)
282
294
  )
283
- for i, ds in enumerate(concat_ds.datasets)
284
- )
295
+ except (BufferError, ValueError, OSError) as exc:
296
+ msg = str(exc).lower().replace("-", "")
297
+ if "mmap" in msg and "readonly" in msg:
298
+ raise RuntimeError(
299
+ "Parallel preprocessing failed because joblib memory-mapped "
300
+ "a preloaded array that a preprocessor then attempted to "
301
+ "resize (e.g. ``filterbank``). Pass ``max_nbytes=None`` to "
302
+ "``preprocess`` to disable memory mapping, or supply a "
303
+ "``save_dir`` so the data is reloaded with ``preload=False``."
304
+ ) from exc
305
+ raise
285
306
 
286
307
  if save_dir is not None: # Reload datasets and replace in concat_ds
287
308
  ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
@@ -498,7 +498,7 @@ def create_fixed_length_windows(
498
498
  EEGWindowsDataset objects with the extracted windows, depending on
499
499
  the value of ``use_mne_epochs``.
500
500
  """
501
- stop_offset_samples, drop_last_window = (
501
+ stop_offset_samples, window_stride_samples, drop_last_window = (
502
502
  _check_and_set_fixed_length_window_arguments(
503
503
  start_offset_samples,
504
504
  stop_offset_samples,
@@ -886,7 +886,10 @@ def _create_fixed_length_windows(
886
886
  if mapping is not None:
887
887
  # in case of multiple targets
888
888
  if isinstance(target, pd.Series):
889
- target = target.replace(mapping).to_list()
889
+ # Plain comprehension instead of Series.replace(mapping):
890
+ # replace() emits a pandas FutureWarning about silent downcasting
891
+ # and the result is immediately list-ified anyway.
892
+ target = [mapping.get(v, v) for v in target]
890
893
  # in case of single value target
891
894
  else:
892
895
  target = mapping[target]
@@ -1245,8 +1248,14 @@ def _check_and_set_fixed_length_window_arguments(
1245
1248
  lazy_metadata,
1246
1249
  ):
1247
1250
  """Raises warnings for incorrect input arguments and will set correct default values for
1248
- stop_offset_samples & drop_last_window, if necessary.
1251
+ stop_offset_samples, window_stride_samples & drop_last_window, if necessary.
1249
1252
  """
1253
+ # default stride to window size for non-overlapping windows
1254
+ if window_size_samples is not None and window_stride_samples is None:
1255
+ window_stride_samples = window_size_samples
1256
+ if drop_last_window is None:
1257
+ drop_last_window = True
1258
+
1250
1259
  _check_windowing_arguments(
1251
1260
  start_offset_samples,
1252
1261
  stop_offset_samples,
@@ -1295,7 +1304,7 @@ def _check_and_set_fixed_length_window_arguments(
1295
1304
  raise ValueError(
1296
1305
  "Cannot have drop_last_window=False and lazy_metadata=True at the same time."
1297
1306
  )
1298
- return stop_offset_samples, drop_last_window
1307
+ return stop_offset_samples, window_stride_samples, drop_last_window
1299
1308
 
1300
1309
 
1301
1310
  def _get_windowing_kwargs(windowing_func_locals):
braindecode/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.5.0.dev1007"
1
+ __version__ = "1.5.0.dev1010"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.5.0.dev1007
3
+ Version: 1.5.0.dev1010
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -3,13 +3,13 @@ braindecode/classifier.py,sha256=7kC_oY_UzHEes_WWdCvEpiA1ZKxMeuLL5tIPp5rfcpg,962
3
3
  braindecode/eegneuralnet.py,sha256=xjE6aPZdCQPs29NIpy_m1GLMMC2WZ3Db0Fuh1-xE1h4,13827
4
4
  braindecode/regressor.py,sha256=KiMJpqCUPWA2k2JWk9HGYTzeoBqJ4gAKEudeUVcFZY4,9266
5
5
  braindecode/util.py,sha256=f8bNIwt-SwsHqheH_BADQxTtA9oPt3Lb7GFnoI-Huwc,14101
6
- braindecode/version.py,sha256=D3ZA75b03PZQ80vQOvTo6EESMVaElofg2dXgccZLJV8,30
6
+ braindecode/version.py,sha256=dhwXuCj7AR3Ot0DD1Q7xVlLGgTrkwdesqS9QyInsZcM,30
7
7
  braindecode/augmentation/__init__.py,sha256=4xune2QUK6KHMKsAqijF7I9eeiVbP0wEoQJjCNLNcKM,1081
8
8
  braindecode/augmentation/base.py,sha256=OJ1shOljI1yTY9zh2qWxQwivlY43sfx9Q-MAyMhxtPs,7338
9
9
  braindecode/augmentation/functional.py,sha256=q2k6mAXrujYlOZUndcjZN8e8b-6oJF1gGsORAI23hyE,43998
10
10
  braindecode/augmentation/transforms.py,sha256=x-3pwX0PtMHfSnPLGKNXbpTSk7j17Ci2FG_-646scg4,47268
11
11
  braindecode/datasets/__init__.py,sha256=rVOBadwqYBiMz5kl7nGiBOmMgr11xvjS4nuzzZTOn1U,1102
12
- braindecode/datasets/base.py,sha256=7qFEJNOAe_05evfqxx-yrlUfG0eWaCuKkfOlj32ShwU,64701
12
+ braindecode/datasets/base.py,sha256=3lKLZQO4hfA-dv_JJEfPwyZ5nzRkLTu4qiRAqFVZUUQ,70508
13
13
  braindecode/datasets/bbci.py,sha256=SCm7OnCObotILQ0B1EdmZPoyJtzsRXpeU_gNKtqQLSc,19288
14
14
  braindecode/datasets/bcicomp.py,sha256=YWIsRYFvwBFHnd9CRxo_BBQWFg-rg0UirRTsd0Ml2Oc,7550
15
15
  braindecode/datasets/chb_mit.py,sha256=XGUVtADLHLTa5Ldanyn8msAJY0SQh-eV9keHnWO6n3A,7231
@@ -43,7 +43,7 @@ braindecode/models/__init__.py,sha256=SwnwJ-nRF0nWq2YYvw6G4F_zMrw0QCQW2VoYE2KOpS
43
43
  braindecode/models/atcnet.py,sha256=JzLV92WYWlhQol8AqE9xeWvlnWN_uswYhVOkMnuBCXo,32231
44
44
  braindecode/models/attentionbasenet.py,sha256=k7ar7aEjANudPu7krAZsRx-ag61ugirS1Xn7qFKNfWg,30483
45
45
  braindecode/models/attn_sleep.py,sha256=F9x4spTtzfiCC1h9UYITmIDQeJW6_2CXTZktZX9R0RE,17950
46
- braindecode/models/base.py,sha256=yGJgr0f5rD-gJZ5Msw9FzGzVO-x_re-6jOSK8Iht6x4,27923
46
+ braindecode/models/base.py,sha256=O3pkmdqaBgLWHpWkVF7ld0A95T43RFyno7mtkyi08as,28327
47
47
  braindecode/models/bendr.py,sha256=RkPeHoFF0vcDBW2jMo_9oS1Pxuq3wcI4ErDeL7L8QGs,29157
48
48
  braindecode/models/biot.py,sha256=UPpT1Gv7siBCMcNDNx4DS4hgs-YlmWwU5Nmikaj9WpY,22309
49
49
  braindecode/models/brainmodule.py,sha256=idyQVTp3VBJXKF1YjMx8o1kUKrcL_E_AJsioPjcuqV8,33282
@@ -111,9 +111,9 @@ braindecode/modules/wrapper.py,sha256=3lgNjcwJ1Kre2TCUouKpEp6cidFgp6LBDDWQNLOvRQ
111
111
  braindecode/preprocessing/__init__.py,sha256=NCKhT4-_-DMBFpppEg2VguCe4x7ZW-gInWn60G6hOSs,9643
112
112
  braindecode/preprocessing/eegprep_preprocess.py,sha256=wGKKo-JvG90n7eB7Y_p-bSkf8LAyIxTU8N01qPjAYQI,59989
113
113
  braindecode/preprocessing/mne_preprocess.py,sha256=taW2H-k3yreGdItCW7ldZB8jnzxBAMxX1ykgpvj7QvI,8246
114
- braindecode/preprocessing/preprocess.py,sha256=T-MA0frxrCb61U30F9_LoljyXD5oUHgnI3uWiggfgH8,22515
114
+ braindecode/preprocessing/preprocess.py,sha256=6qm0tU5JZjCLEl125cflaDk-B54rMsN8JbUUQG6cWW0,23774
115
115
  braindecode/preprocessing/util.py,sha256=ivODABSuy-SKvPMa2U6W3uWM4cwmSg-7jSKqIRxBBw4,5925
116
- braindecode/preprocessing/windowers.py,sha256=NTplG_8pOUrE1R-w41vevuBK_4_RHEozy1iR3RyoNCU,48949
116
+ braindecode/preprocessing/windowers.py,sha256=s7wAFGAdC_mMxkxuGTho5tlXy0KuxKydrOwb4-plvXA,49495
117
117
  braindecode/samplers/__init__.py,sha256=TLuO6gXv2WioJdX671MI_CHVSsOfbjnly1Xv9K3_WdA,452
118
118
  braindecode/samplers/base.py,sha256=PTa4gGAKXH1Tnx4vBXBAb43x7wQKVvqK1mlM_zE3yY4,15133
119
119
  braindecode/samplers/ssl.py,sha256=GusCFpjOk8w57Br2JdqOLm7vbEQWDj6oWHqgUM7JrF0,9146
@@ -128,9 +128,9 @@ braindecode/visualization/frequency.py,sha256=gNwkn9yIik5SUp7d9HE9J_vPVGyzNsxxCO
128
128
  braindecode/visualization/metrics.py,sha256=j01kc04P9uEkQ2g2Tt2C76yr6soIj31PAuBMflrmODg,13615
129
129
  braindecode/visualization/sanity.py,sha256=nNClauUC8dCj_KCy_1RmaPDQAqExLczfPtUeQ7k9-Q0,4812
130
130
  braindecode/visualization/topology.py,sha256=mXxUfCCUJqa_cMF4y6GC3_A-qBCcS4uTc0EzBolkytE,2274
131
- braindecode-1.5.0.dev1007.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
132
- braindecode-1.5.0.dev1007.dist-info/licenses/NOTICE.txt,sha256=ZFFhigxIaKgDcMjCzPyAVSFV42ztU0kLOENt_kvherw,857
133
- braindecode-1.5.0.dev1007.dist-info/METADATA,sha256=fxnOI00CSFVaUb9jrwzxuNGBpOsqlvIiwIMNg1-gf2I,10275
134
- braindecode-1.5.0.dev1007.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
135
- braindecode-1.5.0.dev1007.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
136
- braindecode-1.5.0.dev1007.dist-info/RECORD,,
131
+ braindecode-1.5.0.dev1010.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
132
+ braindecode-1.5.0.dev1010.dist-info/licenses/NOTICE.txt,sha256=ZFFhigxIaKgDcMjCzPyAVSFV42ztU0kLOENt_kvherw,857
133
+ braindecode-1.5.0.dev1010.dist-info/METADATA,sha256=4sQBBGOi3h1EE-DCZxdlh3Hswxjd4jQJ_fTdFBUSsyc,10275
134
+ braindecode-1.5.0.dev1010.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
135
+ braindecode-1.5.0.dev1010.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
136
+ braindecode-1.5.0.dev1010.dist-info/RECORD,,