braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,717 +0,0 @@
1
- # mypy: ignore-errors
2
- """
3
- BIDS-inspired format utilities for Hub integration.
4
-
5
- This module provides BIDS-inspired structures for storing EEG data optimized
6
- for deep learning training. It leverages mne_bids for BIDS path handling and
7
- metadata generation.
8
-
9
- The data is stored in the ``sourcedata/`` directory, which according to BIDS:
10
- - Is NOT validated by BIDS validators (so .zarr files won't cause errors)
11
- - Has no naming restrictions ("BIDS does not prescribe a particular naming
12
- scheme for source data")
13
- - Is intended for data before file format conversion
14
-
15
- This approach allows us to use efficient Zarr storage while maintaining
16
- BIDS-style organization for discoverability.
17
-
18
- Structure:
19
- - sourcedata/<pipeline-name>/
20
- - dataset_description.json (BIDS-style metadata)
21
- - participants.tsv (BIDS-style metadata)
22
- - sub-<label>/
23
- - [ses-<label>/]
24
- - eeg/
25
- - *_events.tsv (BIDS-style metadata)
26
- - *_channels.tsv (BIDS-style metadata)
27
- - *_eeg.json (BIDS-style metadata)
28
- - *_eeg.zarr/ (Zarr data - efficient for training)
29
- - dataset.zarr/ (Main data store for training)
30
-
31
- References:
32
- - BIDS sourcedata: https://bids-specification.readthedocs.io/en/stable/common-principles.html#source-vs-raw-vs-derived-data
33
- - BIDS EEG: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/electroencephalography.html
34
- """
35
-
36
- # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
37
- # Kuntal Kokate
38
- #
39
- # License: BSD (3-clause)
40
-
41
- from __future__ import annotations
42
-
43
- import json
44
- from pathlib import Path
45
- from tempfile import TemporaryDirectory
46
- from typing import Any, Optional, Union
47
-
48
- import mne
49
- import mne_bids
50
- import numpy as np
51
- import pandas as pd
52
- from mne_bids.write import _channels_tsv, _sidecar_json
53
-
54
- import braindecode
55
-
56
- # Default pipeline name for braindecode derivatives
57
- DEFAULT_PIPELINE_NAME = "braindecode"
58
-
59
-
60
- def _raw_from_info(
61
- info: "mne.Info",
62
- bad_channels: Optional[list[str]] = None,
63
- ) -> "mne.io.RawArray":
64
- info = info.copy()
65
- if bad_channels is not None:
66
- info["bads"] = list(bad_channels)
67
- data = np.zeros((len(info["ch_names"]), 1), dtype=float)
68
- raw = mne.io.RawArray(data, info, verbose="error")
69
- if not raw.filenames or raw.filenames[0] is None:
70
- raw._filenames = [Path("dummy.fif")]
71
- return raw
72
-
73
-
74
- def _read_tsv(writer, *args) -> pd.DataFrame:
75
- with TemporaryDirectory() as tmpdir, mne.utils.use_log_level("WARNING"):
76
- tsv_path = Path(tmpdir) / "sidecar.tsv"
77
- writer(*args, tsv_path, overwrite=True)
78
- return pd.read_csv(tsv_path, sep="\t")
79
-
80
-
81
- def description_to_bids_path(
82
- description: pd.Series,
83
- root: Union[str, Path],
84
- datatype: str = "eeg",
85
- suffix: str = "eeg",
86
- extension: str = ".zarr",
87
- desc_label: str = "preproc",
88
- check: bool = False,
89
- ) -> mne_bids.BIDSPath:
90
- """
91
- Create a BIDSPath from a dataset description.
92
-
93
- Uses mne_bids.BIDSPath for proper BIDS path handling.
94
-
95
- Parameters
96
- ----------
97
- description : pd.Series
98
- Dataset description containing BIDS entities.
99
- root : str | Path
100
- Root directory of the BIDS dataset.
101
- datatype : str
102
- Data type (eeg, meg, etc.).
103
- suffix : str
104
- BIDS suffix.
105
- extension : str
106
- File extension.
107
- desc_label : str
108
- Description label for derivatives.
109
- check : bool
110
- Whether to enforce BIDS conformity.
111
-
112
- Returns
113
- -------
114
- mne_bids.BIDSPath
115
- BIDS path object.
116
- """
117
- # Extract BIDS entities from description
118
- entities = _extract_bids_entities(description)
119
-
120
- # Create BIDSPath using mne_bids
121
- bids_path = mne_bids.BIDSPath(
122
- root=root,
123
- subject=entities.get("subject", "unknown"),
124
- session=entities.get("session"),
125
- task=entities.get("task", "task"),
126
- acquisition=entities.get("acquisition"),
127
- run=entities.get("run"),
128
- processing=entities.get("processing"),
129
- recording=entities.get("recording"),
130
- space=entities.get("space"),
131
- description=desc_label,
132
- suffix=suffix,
133
- extension=extension,
134
- datatype=datatype,
135
- check=check,
136
- )
137
-
138
- return bids_path
139
-
140
-
141
- def _extract_bids_entities(description: pd.Series) -> dict[str, Any]:
142
- """
143
- Extract BIDS entities from a dataset description.
144
-
145
- Parameters
146
- ----------
147
- description : pd.Series
148
- Dataset description containing metadata.
149
-
150
- Returns
151
- -------
152
- dict
153
- Dictionary with BIDS entity keys.
154
- """
155
- if description is None or len(description) == 0:
156
- return {}
157
-
158
- # Common mappings from description keys to BIDS entities
159
- key_mappings = {
160
- "subject": "subject",
161
- "sub": "subject",
162
- "subject_id": "subject",
163
- "session": "session",
164
- "ses": "session",
165
- "task": "task",
166
- "run": "run",
167
- "acquisition": "acquisition",
168
- "acq": "acquisition",
169
- "processing": "processing",
170
- "proc": "processing",
171
- "recording": "recording",
172
- "rec": "recording",
173
- "space": "space",
174
- "split": "split",
175
- "description": "description",
176
- "desc": "description",
177
- }
178
-
179
- entities = {}
180
- for key in description.index:
181
- key_lower = str(key).lower()
182
- if key_lower in key_mappings:
183
- bids_key = key_mappings[key_lower]
184
- value = description[key]
185
- # Convert to string, handling None and NaN
186
- if pd.notna(value):
187
- # Clean up the value for BIDS compatibility
188
- str_value = str(value)
189
- # Remove any characters not allowed in BIDS entities
190
- str_value = "".join(c for c in str_value if c.isalnum() or c in "-_")
191
- if str_value:
192
- entities[bids_key] = str_value
193
-
194
- return entities
195
-
196
-
197
- def make_dataset_description(
198
- path: Union[str, Path],
199
- name: str = "Braindecode Dataset",
200
- pipeline_name: str = DEFAULT_PIPELINE_NAME,
201
- source_datasets: Optional[list[dict]] = None,
202
- overwrite: bool = True,
203
- ) -> Path:
204
- """
205
- Create a BIDS-compliant dataset_description.json for derivatives.
206
-
207
- Uses mne_bids.make_dataset_description for proper BIDS compliance.
208
-
209
- Parameters
210
- ----------
211
- path : str | Path
212
- Path to the derivatives directory.
213
- name : str
214
- Name of the dataset.
215
- pipeline_name : str
216
- Name of the pipeline that generated the derivatives.
217
- source_datasets : list of dict | None
218
- List of source dataset references.
219
- overwrite : bool
220
- Whether to overwrite existing file.
221
-
222
- Returns
223
- -------
224
- Path
225
- Path to the created dataset_description.json.
226
- """
227
- path = Path(path)
228
- path.mkdir(parents=True, exist_ok=True)
229
-
230
- # Use mne_bids to create the dataset description
231
- mne_bids.make_dataset_description(
232
- path=path,
233
- name=name,
234
- dataset_type="derivative",
235
- generated_by=[
236
- {
237
- "Name": "braindecode",
238
- "Version": braindecode.__version__,
239
- "CodeURL": "https://github.com/braindecode/braindecode",
240
- }
241
- ],
242
- source_datasets=source_datasets,
243
- overwrite=overwrite,
244
- )
245
-
246
- return path / "dataset_description.json"
247
-
248
-
249
- def create_events_tsv(
250
- metadata: pd.DataFrame,
251
- sfreq: float,
252
- target_column: str = "target",
253
- extra_columns: Optional[list[str]] = None,
254
- ) -> pd.DataFrame:
255
- """
256
- Create a BIDS-compliant events.tsv DataFrame from window metadata.
257
-
258
- Parameters
259
- ----------
260
- metadata : pd.DataFrame
261
- Window metadata containing i_start_in_trial, i_stop_in_trial, target columns.
262
- sfreq : float
263
- Sampling frequency in Hz.
264
- target_column : str
265
- Name of the column containing trial type/target information.
266
- extra_columns : list of str | None
267
- Additional columns from metadata to include in events.
268
-
269
- Returns
270
- -------
271
- pd.DataFrame
272
- BIDS-compliant events DataFrame.
273
-
274
- Notes
275
- -----
276
- The events.tsv file follows BIDS format with columns:
277
- - onset: Time of event onset in seconds
278
- - duration: Duration of event in seconds
279
- - trial_type: Name/label of the event type
280
- - sample: Sample index of event onset
281
- - value: Numeric value/target
282
- """
283
- events_data: dict[str, list[Any]] = {
284
- "onset": [],
285
- "duration": [],
286
- "trial_type": [],
287
- "sample": [],
288
- "value": [],
289
- }
290
-
291
- # Add extra columns
292
- extra_data: dict[str, list[Any]] = {col: [] for col in (extra_columns or [])}
293
-
294
- for idx, row in metadata.iterrows():
295
- # Calculate onset and duration from sample indices
296
- i_start = row.get("i_start_in_trial", 0)
297
- i_stop = row.get("i_stop_in_trial", i_start + 1)
298
-
299
- onset = i_start / sfreq
300
- duration = (i_stop - i_start) / sfreq
301
-
302
- # Get target/trial_type
303
- target = row.get(target_column, "n/a")
304
- trial_type = str(target) if pd.notna(target) else "n/a"
305
-
306
- events_data["onset"].append(onset)
307
- events_data["duration"].append(duration)
308
- events_data["trial_type"].append(trial_type)
309
- events_data["sample"].append(int(i_start))
310
- events_data["value"].append(target if pd.notna(target) else "n/a")
311
-
312
- # Add extra columns
313
- for col in extra_columns or []:
314
- extra_data[col].append(row.get(col, "n/a"))
315
-
316
- # Combine all data
317
- events_data.update(extra_data)
318
-
319
- return pd.DataFrame(events_data)
320
-
321
-
322
- def create_participants_tsv(
323
- descriptions: list[pd.Series],
324
- extra_columns: Optional[list[str]] = None,
325
- ) -> pd.DataFrame:
326
- """
327
- Create a BIDS-compliant participants.tsv from dataset descriptions.
328
-
329
- Parameters
330
- ----------
331
- descriptions : list of pd.Series
332
- List of dataset descriptions.
333
- extra_columns : list of str | None
334
- Additional columns to include from descriptions.
335
-
336
- Returns
337
- -------
338
- pd.DataFrame
339
- BIDS-compliant participants DataFrame.
340
- """
341
- participants_data: dict[str, list[Any]] = {
342
- "participant_id": [],
343
- "age": [],
344
- "sex": [],
345
- "hand": [],
346
- }
347
-
348
- # Add extra columns
349
- extra_data: dict[str, list[Any]] = {col: [] for col in (extra_columns or [])}
350
-
351
- seen_subjects = set()
352
-
353
- for desc in descriptions:
354
- if desc is None:
355
- continue
356
-
357
- # Get subject ID
358
- subject = None
359
- for key in ["subject", "sub", "subject_id"]:
360
- if key in desc.index and pd.notna(desc[key]):
361
- subject = str(desc[key])
362
- break
363
-
364
- if subject is None:
365
- continue
366
-
367
- # Skip duplicates
368
- if subject in seen_subjects:
369
- continue
370
- seen_subjects.add(subject)
371
-
372
- # Format as BIDS participant_id
373
- participant_id = f"sub-{subject}" if not subject.startswith("sub-") else subject
374
-
375
- # Get other info
376
- age = desc.get("age", "n/a")
377
- sex = desc.get("sex", desc.get("gender", "n/a"))
378
- hand = desc.get("hand", desc.get("handedness", "n/a"))
379
-
380
- participants_data["participant_id"].append(participant_id)
381
- participants_data["age"].append(age if pd.notna(age) else "n/a")
382
- participants_data["sex"].append(sex if pd.notna(sex) else "n/a")
383
- participants_data["hand"].append(hand if pd.notna(hand) else "n/a")
384
-
385
- # Add extra columns
386
- for col in extra_columns or []:
387
- extra_data[col].append(desc.get(col, "n/a"))
388
-
389
- # Combine all data
390
- participants_data.update(extra_data)
391
-
392
- return pd.DataFrame(participants_data)
393
-
394
-
395
- def create_channels_tsv(
396
- info: "mne.Info",
397
- bad_channels: Optional[list[str]] = None,
398
- ) -> pd.DataFrame:
399
- """
400
- Create a BIDS-compliant channels.tsv from MNE Info.
401
-
402
- Delegates channel formatting to mne_bids.write._channels_tsv.
403
-
404
- Parameters
405
- ----------
406
- info : mne.Info
407
- MNE Info object containing channel information.
408
- bad_channels : list of str | None
409
- List of bad channel names.
410
-
411
- Returns
412
- -------
413
- pd.DataFrame
414
- BIDS-compliant channels DataFrame.
415
- """
416
- bad_channels = bad_channels or info.get("bads", [])
417
- raw = _raw_from_info(info, bad_channels)
418
- return _read_tsv(_channels_tsv, raw)
419
-
420
-
421
- def create_eeg_json_sidecar(
422
- info: "mne.Info",
423
- task_name: str = "unknown",
424
- task_description: Optional[str] = None,
425
- instructions: Optional[str] = None,
426
- institution_name: Optional[str] = None,
427
- manufacturer: Optional[str] = None,
428
- recording_duration: Optional[float] = None,
429
- recording_type: Optional[str] = None,
430
- epoch_length: Optional[float] = None,
431
- extra_metadata: Optional[dict] = None,
432
- ) -> dict:
433
- """
434
- Create a BIDS-compliant EEG sidecar JSON.
435
-
436
- Delegates base JSON creation to mne_bids.write._sidecar_json.
437
-
438
- Parameters
439
- ----------
440
- info : mne.Info
441
- MNE Info object.
442
- task_name : str
443
- Name of the task.
444
- task_description : str | None
445
- Description of the task.
446
- instructions : str | None
447
- Instructions given to the participant.
448
- institution_name : str | None
449
- Name of the institution.
450
- manufacturer : str | None
451
- Manufacturer of the EEG equipment.
452
- recording_duration : float | None
453
- Length of the recording in seconds (BIDS RECOMMENDED).
454
- recording_type : str | None
455
- Type of recording: "continuous", "epoched", or "discontinuous"
456
- (BIDS RECOMMENDED).
457
- epoch_length : float | None
458
- Duration of individual epochs in seconds. RECOMMENDED if
459
- recording_type is "epoched".
460
- extra_metadata : dict | None
461
- Additional metadata.
462
-
463
- Returns
464
- -------
465
- dict
466
- Sidecar JSON content.
467
- """
468
- raw = _raw_from_info(info)
469
- manufacturer = manufacturer or "n/a"
470
- with TemporaryDirectory() as tmpdir, mne.utils.use_log_level("WARNING"):
471
- sidecar_path = Path(tmpdir) / "eeg.json"
472
- _sidecar_json(
473
- raw,
474
- task_name,
475
- manufacturer,
476
- sidecar_path,
477
- "eeg",
478
- overwrite=True,
479
- )
480
- sidecar = json.loads(sidecar_path.read_text(encoding="utf-8"))
481
-
482
- if task_description:
483
- sidecar["TaskDescription"] = task_description
484
- if instructions:
485
- sidecar["Instructions"] = instructions
486
- if institution_name:
487
- sidecar["InstitutionName"] = institution_name
488
- if recording_duration is not None:
489
- sidecar["RecordingDuration"] = recording_duration
490
- if recording_type is not None:
491
- sidecar["RecordingType"] = recording_type
492
- if epoch_length is not None:
493
- sidecar["EpochLength"] = epoch_length
494
-
495
- if extra_metadata:
496
- sidecar.update(extra_metadata)
497
-
498
- return sidecar
499
-
500
-
501
- def save_bids_sidecar_files(
502
- bids_path: mne_bids.BIDSPath,
503
- info: "mne.Info",
504
- metadata: Optional[pd.DataFrame] = None,
505
- sfreq: Optional[float] = None,
506
- task_name: str = "unknown",
507
- recording_duration: Optional[float] = None,
508
- recording_type: Optional[str] = None,
509
- epoch_length: Optional[float] = None,
510
- ) -> dict[str, Path]:
511
- """
512
- Save BIDS sidecar files for a recording using mne_bids BIDSPath.
513
-
514
- Parameters
515
- ----------
516
- bids_path : mne_bids.BIDSPath
517
- BIDS path object for the recording.
518
- info : mne.Info
519
- MNE Info object.
520
- metadata : pd.DataFrame | None
521
- Window metadata for events.tsv.
522
- sfreq : float | None
523
- Sampling frequency (if not in info).
524
- task_name : str
525
- Task name for sidecar JSON.
526
- recording_duration : float | None
527
- Length of the recording in seconds (BIDS RECOMMENDED).
528
- recording_type : str | None
529
- Type of recording: "continuous", "epoched", or "discontinuous"
530
- (BIDS RECOMMENDED).
531
- epoch_length : float | None
532
- Duration of individual epochs in seconds. RECOMMENDED if
533
- recording_type is "epoched".
534
-
535
- Returns
536
- -------
537
- dict
538
- Dictionary mapping file types to their paths.
539
- """
540
- # Ensure directory exists
541
- bids_path.mkdir(exist_ok=True)
542
-
543
- saved_files = {}
544
- sfreq = sfreq or info["sfreq"]
545
-
546
- # Get the base path for sidecar files
547
- base_path = bids_path.copy()
548
-
549
- # Save events.tsv if metadata is available
550
- if metadata is not None and len(metadata) > 0:
551
- events_df = create_events_tsv(metadata, sfreq)
552
- events_path = base_path.copy().update(suffix="events", extension=".tsv")
553
- events_df.to_csv(
554
- events_path.fpath, sep="\t", index=False, na_rep="n/a", encoding="utf-8"
555
- )
556
- saved_files["events"] = events_path.fpath
557
-
558
- # Save channels.tsv
559
- channels_df = create_channels_tsv(info)
560
- channels_path = base_path.copy().update(suffix="channels", extension=".tsv")
561
- channels_df.to_csv(
562
- channels_path.fpath, sep="\t", index=False, na_rep="n/a", encoding="utf-8"
563
- )
564
- saved_files["channels"] = channels_path.fpath
565
-
566
- # Save EEG sidecar JSON
567
- sidecar = create_eeg_json_sidecar(
568
- info,
569
- task_name=task_name,
570
- recording_duration=recording_duration,
571
- recording_type=recording_type,
572
- epoch_length=epoch_length,
573
- )
574
- sidecar_path = base_path.copy().update(suffix="eeg", extension=".json")
575
- with open(sidecar_path.fpath, "w", encoding="utf-8") as f:
576
- json.dump(sidecar, f, indent=2)
577
- saved_files["sidecar"] = sidecar_path.fpath
578
-
579
- return saved_files
580
-
581
-
582
- class BIDSSourcedataLayout:
583
- """
584
- Helper class for creating BIDS sourcedata folder structure.
585
-
586
- This creates a structure using the BIDS ``sourcedata/`` directory,
587
- which is not validated by BIDS validators, allowing us to store
588
- data in Zarr format for efficient training.
589
-
590
- Structure:
591
- sourcedata/<pipeline>/
592
- ├── dataset_description.json
593
- ├── participants.tsv
594
- ├── sub-<label>/
595
- │ └── [ses-<label>/]
596
- │ └── eeg/
597
- │ ├── sub-<label>_task-<label>_desc-preproc_events.tsv
598
- │ ├── sub-<label>_task-<label>_desc-preproc_channels.tsv
599
- │ ├── sub-<label>_task-<label>_desc-preproc_eeg.json
600
- │ └── sub-<label>_task-<label>_desc-preproc_eeg.zarr/
601
- └── dataset.zarr (main data file for efficient loading)
602
- """
603
-
604
- def __init__(
605
- self,
606
- root: Union[str, Path],
607
- pipeline_name: str = DEFAULT_PIPELINE_NAME,
608
- ):
609
- """
610
- Initialize BIDS sourcedata layout.
611
-
612
- Parameters
613
- ----------
614
- root : str | Path
615
- Root directory for sourcedata.
616
- pipeline_name : str
617
- Name of the processing pipeline.
618
- """
619
- self.root = Path(root)
620
- self.pipeline_name = pipeline_name
621
- self.sourcedata_dir = self.root / "sourcedata" / pipeline_name
622
-
623
- def create_structure(self) -> Path:
624
- """Create the basic sourcedata directory structure."""
625
- self.sourcedata_dir.mkdir(parents=True, exist_ok=True)
626
- return self.sourcedata_dir
627
-
628
- def get_bids_path(
629
- self,
630
- description: pd.Series,
631
- suffix: str = "eeg",
632
- extension: str = ".zarr",
633
- desc_label: str = "preproc",
634
- ) -> mne_bids.BIDSPath:
635
- """
636
- Get a BIDSPath for a recording based on its description.
637
-
638
- Parameters
639
- ----------
640
- description : pd.Series
641
- Dataset description.
642
- suffix : str
643
- BIDS suffix.
644
- extension : str
645
- File extension.
646
- desc_label : str
647
- Description label.
648
-
649
- Returns
650
- -------
651
- mne_bids.BIDSPath
652
- BIDS path for the recording.
653
- """
654
- return description_to_bids_path(
655
- description=description,
656
- root=self.sourcedata_dir,
657
- datatype="eeg",
658
- suffix=suffix,
659
- extension=extension,
660
- desc_label=desc_label,
661
- check=False,
662
- )
663
-
664
- def save_dataset_description(
665
- self,
666
- name: str = "Braindecode Dataset",
667
- source_datasets: Optional[list[dict]] = None,
668
- ) -> Path:
669
- """
670
- Save dataset_description.json for sourcedata.
671
-
672
- Parameters
673
- ----------
674
- name : str
675
- Name of the dataset.
676
- source_datasets : list of dict | None
677
- Source dataset references.
678
-
679
- Returns
680
- -------
681
- Path
682
- Path to saved file.
683
- """
684
- return make_dataset_description(
685
- path=self.sourcedata_dir,
686
- name=name,
687
- pipeline_name=self.pipeline_name,
688
- source_datasets=source_datasets,
689
- overwrite=True,
690
- )
691
-
692
- def save_participants(
693
- self,
694
- descriptions: list[pd.Series],
695
- extra_columns: Optional[list[str]] = None,
696
- ) -> Path:
697
- """
698
- Save participants.tsv file.
699
-
700
- Parameters
701
- ----------
702
- descriptions : list of pd.Series
703
- List of dataset descriptions.
704
- extra_columns : list of str | None
705
- Additional columns to include.
706
-
707
- Returns
708
- -------
709
- Path
710
- Path to saved file.
711
- """
712
- participants_df = create_participants_tsv(descriptions, extra_columns)
713
- participants_path = self.sourcedata_dir / "participants.tsv"
714
- participants_df.to_csv(
715
- participants_path, sep="\t", index=False, na_rep="n/a", encoding="utf-8"
716
- )
717
- return participants_path