eegdash 0.3.3.dev178374711__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

Files changed (96) hide show
  1. eegdash-0.3.4/MANIFEST.in +7 -0
  2. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/PKG-INFO +1 -1
  3. eegdash-0.3.4/docs/Makefile +31 -0
  4. eegdash-0.3.4/docs/build/html/_downloads/22c048359758b424393a09689a41275e/challenge_1.ipynb +140 -0
  5. eegdash-0.3.4/docs/build/html/_downloads/2c592649a2079630923cb072bc1beaf3/tutorial_eoec.ipynb +176 -0
  6. eegdash-0.3.4/docs/build/html/_downloads/893ab57ca8de4ec74e7c17907d3e8a27/challenge_2.ipynb +151 -0
  7. eegdash-0.3.4/docs/build/html/_downloads/befd83bab11618bf4f65555df6b7d484/challenge_2_machine_learning.ipynb +224 -0
  8. eegdash-0.3.4/docs/build/html/_downloads/f3cf56a30a7c06a2eccae3b5b3d28e35/tutorial_feature_extractor_open_close_eye.ipynb +260 -0
  9. eegdash-0.3.4/docs/source/api/eegdash.api.rst +7 -0
  10. eegdash-0.3.4/docs/source/api/eegdash.data_config.rst +7 -0
  11. eegdash-0.3.4/docs/source/api/eegdash.data_utils.rst +7 -0
  12. eegdash-0.3.4/docs/source/api/eegdash.dataset.rst +7 -0
  13. eegdash-0.3.4/docs/source/api/eegdash.features.datasets.rst +7 -0
  14. eegdash-0.3.4/docs/source/api/eegdash.features.decorators.rst +7 -0
  15. eegdash-0.3.4/docs/source/api/eegdash.features.extractors.rst +7 -0
  16. eegdash-0.3.4/docs/source/api/eegdash.features.feature_bank.complexity.rst +7 -0
  17. eegdash-0.3.4/docs/source/api/eegdash.features.feature_bank.connectivity.rst +7 -0
  18. eegdash-0.3.4/docs/source/api/eegdash.features.feature_bank.csp.rst +7 -0
  19. eegdash-0.3.4/docs/source/api/eegdash.features.feature_bank.dimensionality.rst +7 -0
  20. eegdash-0.3.4/docs/source/api/eegdash.features.feature_bank.rst +21 -0
  21. eegdash-0.3.4/docs/source/api/eegdash.features.feature_bank.signal.rst +7 -0
  22. eegdash-0.3.4/docs/source/api/eegdash.features.feature_bank.spectral.rst +7 -0
  23. eegdash-0.3.4/docs/source/api/eegdash.features.feature_bank.utils.rst +7 -0
  24. eegdash-0.3.4/docs/source/api/eegdash.features.inspect.rst +7 -0
  25. eegdash-0.3.4/docs/source/api/eegdash.features.rst +28 -0
  26. eegdash-0.3.4/docs/source/api/eegdash.features.serialization.rst +7 -0
  27. eegdash-0.3.4/docs/source/api/eegdash.features.utils.rst +7 -0
  28. eegdash-0.3.4/docs/source/api/eegdash.mongodb.rst +7 -0
  29. eegdash-0.3.4/docs/source/api/eegdash.preprocessing.rst +7 -0
  30. eegdash-0.3.4/docs/source/api/eegdash.registry.rst +7 -0
  31. eegdash-0.3.4/docs/source/api/eegdash.rst +30 -0
  32. eegdash-0.3.4/docs/source/api/eegdash.utils.rst +7 -0
  33. eegdash-0.3.4/docs/source/api/modules.rst +7 -0
  34. eegdash-0.3.4/docs/source/conf.py +142 -0
  35. eegdash-0.3.4/docs/source/dataset_summary.rst +80 -0
  36. eegdash-0.3.4/docs/source/generated/auto_examples/core/sg_execution_times.rst +40 -0
  37. eegdash-0.3.4/docs/source/generated/auto_examples/core/tutorial_eoec.ipynb +176 -0
  38. eegdash-0.3.4/docs/source/generated/auto_examples/core/tutorial_eoec.rst +388 -0
  39. eegdash-0.3.4/docs/source/generated/auto_examples/core/tutorial_feature_extractor_open_close_eye.ipynb +260 -0
  40. eegdash-0.3.4/docs/source/generated/auto_examples/core/tutorial_feature_extractor_open_close_eye.rst +510 -0
  41. eegdash-0.3.4/docs/source/generated/auto_examples/eeg2025/challenge_1.ipynb +140 -0
  42. eegdash-0.3.4/docs/source/generated/auto_examples/eeg2025/challenge_1.rst +381 -0
  43. eegdash-0.3.4/docs/source/generated/auto_examples/eeg2025/challenge_2.ipynb +151 -0
  44. eegdash-0.3.4/docs/source/generated/auto_examples/eeg2025/challenge_2.rst +311 -0
  45. eegdash-0.3.4/docs/source/generated/auto_examples/eeg2025/challenge_2_machine_learning.ipynb +224 -0
  46. eegdash-0.3.4/docs/source/generated/auto_examples/eeg2025/challenge_2_machine_learning.rst +390 -0
  47. eegdash-0.3.4/docs/source/generated/auto_examples/eeg2025/sg_execution_times.rst +43 -0
  48. eegdash-0.3.4/docs/source/generated/auto_examples/index.rst +177 -0
  49. eegdash-0.3.4/docs/source/generated/auto_examples/sg_execution_times.rst +37 -0
  50. eegdash-0.3.4/docs/source/index.rst +56 -0
  51. eegdash-0.3.4/docs/source/install/install.rst +77 -0
  52. eegdash-0.3.4/docs/source/install/install_pip.rst +19 -0
  53. eegdash-0.3.4/docs/source/install/install_source.rst +80 -0
  54. eegdash-0.3.4/docs/source/sg_execution_times.rst +49 -0
  55. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/__init__.py +1 -1
  56. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/data_utils.py +12 -1
  57. eegdash-0.3.4/eegdash/dataset_summary.csv +255 -0
  58. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash.egg-info/PKG-INFO +1 -1
  59. eegdash-0.3.4/eegdash.egg-info/SOURCES.txt +93 -0
  60. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/pyproject.toml +4 -0
  61. eegdash-0.3.3.dev178374711/eegdash.egg-info/SOURCES.txt +0 -39
  62. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/LICENSE +0 -0
  63. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/README.md +0 -0
  64. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/api.py +0 -0
  65. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/data_config.py +0 -0
  66. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/dataset.py +0 -0
  67. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/__init__.py +0 -0
  68. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/datasets.py +0 -0
  69. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/decorators.py +0 -0
  70. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/extractors.py +0 -0
  71. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/feature_bank/__init__.py +0 -0
  72. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/feature_bank/complexity.py +0 -0
  73. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/feature_bank/connectivity.py +0 -0
  74. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/feature_bank/csp.py +0 -0
  75. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/feature_bank/dimensionality.py +0 -0
  76. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/feature_bank/signal.py +0 -0
  77. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/feature_bank/spectral.py +0 -0
  78. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/feature_bank/utils.py +0 -0
  79. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/inspect.py +0 -0
  80. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/serialization.py +0 -0
  81. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/features/utils.py +0 -0
  82. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/mongodb.py +0 -0
  83. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/preprocessing.py +0 -0
  84. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/registry.py +0 -0
  85. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash/utils.py +0 -0
  86. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash.egg-info/dependency_links.txt +0 -0
  87. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash.egg-info/requires.txt +0 -0
  88. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/eegdash.egg-info/top_level.txt +0 -0
  89. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/setup.cfg +0 -0
  90. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/tests/test_correctness.py +0 -0
  91. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/tests/test_dataset.py +0 -0
  92. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/tests/test_dataset_registration.py +0 -0
  93. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/tests/test_eegdash.py +0 -0
  94. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/tests/test_functional.py +0 -0
  95. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/tests/test_init.py +0 -0
  96. {eegdash-0.3.3.dev178374711 → eegdash-0.3.4}/tests/test_mongo_connection.py +0 -0
@@ -0,0 +1,7 @@
1
+ include README.md
2
+ include LICENSE
3
+
4
+ include eegdash/dataset_summary.csv
5
+
6
+ recursive-include docs *.ipynb *.rst conf.py Makefile
7
+ recursive-exclude docs *checkpoint.ipynb
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: eegdash
3
- Version: 0.3.3.dev178374711
3
+ Version: 0.3.4
4
4
  Summary: EEG data for machine learning
5
5
  Author-email: Young Truong <dt.young112@gmail.com>, Arnaud Delorme <adelorme@gmail.com>, Aviv Dotan <avivd220@gmail.com>, Oren Shriki <oren70@gmail.com>, Bruno Aristimunha <b.aristimunha@gmail.com>
6
6
  License-Expression: GPL-3.0-only
@@ -0,0 +1,31 @@
1
+ # Minimal makefile for Sphinx documentation
2
+ SPHINXOPTS ?=
3
+ SPHINXBUILD ?= sphinx-build
4
+ SOURCEDIR = source
5
+ BUILDDIR = build
6
+ PKG ?= eegdash
7
+ APIDIR := $(SOURCEDIR)/api
8
+
9
+ help:
10
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
11
+
12
+ .PHONY: apidoc
13
+ apidoc:
14
+ @sphinx-apidoc -o "$(APIDIR)" "../$(PKG)" -f -e -M
15
+
16
+ # Standard build runs examples
17
+ html: apidoc
18
+
19
+ # Fast build: do NOT execute examples (sphinx-gallery)
20
+ .PHONY: html-noplot
21
+ html-noplot: apidoc
22
+ @python prepare_summary_tables.py ../eegdash/ $(BUILDDIR)
23
+ @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" \
24
+ $(SPHINXOPTS) -D sphinx_gallery_conf.plot_gallery=0 $(O)
25
+
26
+ .PHONY: help apidoc
27
+ Makefile: ;
28
+
29
+ %: Makefile
30
+ @python prepare_summary_tables.py ../eegdash/ $(BUILDDIR)
31
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@@ -0,0 +1,140 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "collapsed": false
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "# For tips on running notebooks in Google Colab:\n# `pip install eegdash`\n%matplotlib inline"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "\n# Challenge 1: Transfer Learning (Contrast\u2011Change Detection, CCD)\n\nThis tutorial walks you through preparing **Challenge\u00a01** data for the EEG 2025 competition.\nYou will load a CCD recording from OpenNeuro, extract trial\u2011wise behavioral metadata\n(**stimulus side, correctness, response time**), epoch around contrast\u2011change onsets,\nand produce a :class:`braindecode.datasets.WindowsDataset` ready for training.\n\n## Why this matters\nChallenge\u00a01 evaluates representations that **transfer across subjects, sessions, and sites**.\nYour pipeline should emphasize **robust, interpretable features** over brittle task\u2011specific hacks.\n\n## What you\u2019ll do\n- Load subject ``NDARAG340ERT`` from OpenNeuro ``ds005507`` (also in the R3 minsets).\n- Read the BIDS ``events.tsv`` to access **stimulus**, **button press**, and **feedback** rows.\n- Compute **response time** and **correctness** per trial.\n- Epoch around **contrast\u2011change** onsets and attach the behavioral metadata.\n- Build a :class:`braindecode.datasets.WindowsDataset` for model training/evaluation.\n\n## Prerequisites\n- Packages: :mod:`eegdash`, :mod:`braindecode`, :mod:`mne`, :mod:`numpy`, :mod:`pandas`.\n- Data: a BIDS cache managed by EEGDash (it will download on first use).\n- Hardware: any modern CPU; a GPU is optional for later modeling steps.\n\n## Notes\n- This tutorial **only** covers the CCD task for Challenge\u00a01. The **SuS** task is not covered here.\n- Large models are allowed, but the emphasis is on **features that generalize** across cohorts.\n\n## References\n- [OpenNeuro ds005507 (CCD)](https://openneuro.org/datasets/ds005507)\n- [EEGDash documentation](https://github.com/eegdash/eegdash)\n- [Braindecode documentation](https://braindecode.org/)\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {
25
+ "collapsed": false
26
+ },
27
+ "outputs": [],
28
+ "source": [
29
+ "import os\nfrom pathlib import Path\nimport pandas as pd\nimport numpy as np\nimport mne\nfrom eegdash import EEGDashDataset\nfrom braindecode.preprocessing import create_windows_from_events\nimport warnings\n\n# Suppress warnings for cleaner output\nwarnings.filterwarnings(\"ignore\")"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {},
35
+ "source": [
36
+ "## 1. Loading the Data\n\nWe'll load the data for subject `NDARAG340ERT` from the `ds005507` dataset. `EEGDashDataset` will handle the download and preprocessing automatically.\n\n\n"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {},
42
+ "source": [
43
+ "Load the dataset\n\n"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {
50
+ "collapsed": false
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "cache_dir = (Path.home() / \"mne_data\" / \"eeg_challenge_cache\").resolve()\n\ndataset_name = \"ds005507\"\ndataset = EEGDashDataset(\n {\n \"dataset\": dataset_name,\n \"subject\": \"NDARAG340ERT\",\n \"task\": \"contrastChangeDetection\",\n \"run\": 1,\n },\n cache_dir=cache_dir,\n)\n\n# Get the raw EEG data\nraw = dataset.datasets[0].raw\nprint(\"Dataset loaded successfully!\")\nprint(f\"Sampling frequency: {raw.info['sfreq']} Hz\")\nprint(f\"Duration: {raw.times[-1]:.1f} seconds\")\nprint(f\"Number of channels: {len(raw.ch_names)}\")\nprint(f\"Channel names: {raw.ch_names[:10]}...\") # Show first 10 channels"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {},
60
+ "source": [
61
+ "## 2. Reading BIDS Events File with Additional Columns\n\nThe power of BIDS-formatted datasets is that they include rich metadata in standardized formats. The events.tsv file contains additional columns like `feedback` that aren't available through MNE's annotation system. Let's read the BIDS events file directly using pandas to access ALL the columns:\n\n\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "metadata": {},
67
+ "source": [
68
+ "The key insight: We can read the BIDS events.tsv file directly using pandas!\nThis gives us access to ALL columns including the crucial 'feedback' column\n\n"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "metadata": {
75
+ "collapsed": false
76
+ },
77
+ "outputs": [],
78
+ "source": [
79
+ "# Get the events file path from the EEGDashDataset\nbids_args = dataset.datasets[0].get_raw_bids_args()\nevents_file = os.path.join(\n cache_dir,\n dataset_name,\n f\"sub-{bids_args['subject']}/eeg/sub-{bids_args['subject']}_task-{bids_args['task']}_run-{bids_args['run']}_events.tsv\",\n)\n\n# Read the events.tsv file using pandas\nevents_df = pd.read_csv(events_file, sep=\"\\t\")\n\nprint(\"BIDS Events File Structure:\")\nprint(f\"Shape: {events_df.shape}\")\nprint(f\"Columns: {list(events_df.columns)}\")\nprint(\"\\nFirst 10 rows:\")\nprint(events_df.head(10))\n\nprint(\"\\nFeedback column unique values:\")\nprint(events_df[\"feedback\"].value_counts())"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "metadata": {},
85
+ "source": [
86
+ "## 3. Calculate Response Times and Correctness from BIDS Events\n\nNow we'll calculate response times and correctness by matching stimulus events with their corresponding button presses and feedback. This approach uses the temporal sequence of events in the BIDS file.\n\n\n"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "metadata": {
93
+ "collapsed": false
94
+ },
95
+ "outputs": [],
96
+ "source": [
97
+ "def calculate_behavioral_metrics_from_bids(events_df):\n \"\"\"Calculate response times and correctness from BIDS events DataFrame.\n\n This function matches stimulus events with subsequent button presses and feedback.\n \"\"\"\n # Get stimulus events\n stimuli = events_df[events_df[\"value\"].isin([\"left_target\", \"right_target\"])].copy()\n\n # Get button press events\n responses = events_df[\n events_df[\"value\"].isin([\"left_buttonPress\", \"right_buttonPress\"])\n ]\n\n # Get contrast trial start events\n contrast_trials = events_df[events_df[\"value\"] == \"contrastTrial_start\"]\n\n # Initialize columns\n stimuli[\"response_time\"] = np.nan\n stimuli[\"correct\"] = None\n stimuli[\"response_type\"] = None\n stimuli[\"contrast_trial_start\"] = None\n\n for idx, stimulus in stimuli.iterrows():\n # Find the next button press after this stimulus, but make sure it is before next 'contrastTrial_start'\n next_contrast_start = contrast_trials[\n contrast_trials[\"onset\"] > stimulus[\"onset\"]\n ].iloc[0][\"onset\"]\n future_responses = responses[\n (responses[\"onset\"] > stimulus[\"onset\"])\n & (responses[\"onset\"] < next_contrast_start)\n ]\n stimuli.loc[idx, \"contrast_trial_start\"] = contrast_trials[\n contrast_trials[\"onset\"] < stimulus[\"onset\"]\n ].iloc[-1][\"onset\"]\n if len(future_responses) > 0:\n # Get the first (closest) response\n next_response = future_responses.iloc[0]\n # Calculate response time\n response_time = next_response[\"onset\"] - stimulus[\"onset\"]\n stimuli.loc[idx, \"response_time\"] = response_time\n stimuli.loc[idx, \"response_type\"] = next_response[\"value\"]\n # We can use the feedback column directly!\n # Find feedback that corresponds to the button press\n if len(next_response[\"feedback\"]) > 0:\n feedback = next_response[\"feedback\"]\n # Map feedback to correctness\n if feedback == \"smiley_face\":\n stimuli.loc[idx, \"correct\"] = True\n elif feedback == \"sad_face\":\n stimuli.loc[idx, \"correct\"] = False\n # Note: 'non_target' feedback might indicate a different type of trial\n return stimuli\n\n\n# Calculate behavioral metrics\nstimulus_metadata = calculate_behavioral_metrics_from_bids(events_df)\nprint(\"Behavioral Analysis Results:\")\nprint(f\"Total stimulus events: {len(stimulus_metadata)}\")\nprint(f\"Events with responses: {stimulus_metadata['response_time'].notna().sum()}\")\nprint(f\"Correct responses: {stimulus_metadata['correct'].sum()}\")\nprint(\n f\"Incorrect responses: {stimulus_metadata['response_time'].notna().sum() - stimulus_metadata['correct'].sum()}\"\n)\nprint(\"Response time statistics:\")\nprint(stimulus_metadata[\"response_time\"].describe())\nprint(\"First few trials with calculated metrics:\")\nprint(\n stimulus_metadata[\n [\n \"onset\",\n \"value\",\n \"response_time\",\n \"correct\",\n \"response_type\",\n \"contrast_trial_start\",\n ]\n ].head(8)\n)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "## 4. Creating Epochs with Braindecode and BIDS Metadata\nNow we'll create epochs using `braindecode`'s `create_windows_from_events`. According to the EEG 2025 challenge requirements, epochs should start from **contrast trial starts** and be **2 seconds long**. This epoching approach ensures we capture:\n\n- The entire trial from contrast trial start (t=0)\n- The stimulus presentation (usually ~2.8 seconds after trial start)\n- The response window (usually within 2 seconds of stimulus)\n- Full behavioral context for each trial\n\nWe'll use our enhanced metadata that includes the behavioral information extracted from the BIDS events file.\n\n\n"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {
111
+ "collapsed": false
112
+ },
113
+ "outputs": [],
114
+ "source": [
115
+ "# Create epochs from contrast trial starts with 2-second duration as per EEG 2025 challenge\n# IMPORTANT: Only epoch trials that have valid behavioral data (stimulus + response)\n\n# First, get all contrast trial start events from the BIDS events\nall_contrast_trials = events_df[events_df[\"value\"] == \"contrastTrial_start\"].copy()\nprint(f\"Found {len(all_contrast_trials)} total contrast trial start events\")\n\n# Filter to only include contrast trials that have valid behavioral data\n# Get the contrast trial start times that correspond to trials with valid stimulus/response data\nvalid_contrast_times = stimulus_metadata[\"contrast_trial_start\"].dropna().unique()\nprint(f\"Found {len(valid_contrast_times)} contrast trials with valid behavioral data\")\n\n# Filter contrast trial events to only those with valid behavioral data\nvalid_contrast_trials = all_contrast_trials[\n all_contrast_trials[\"onset\"].isin(valid_contrast_times)\n].copy()\n\nprint(\n f\"Epoching {len(valid_contrast_trials)} contrast trials (only those with behavioral data)\"\n)\nprint(\n f\"Excluded {len(all_contrast_trials) - len(valid_contrast_trials)} trials without behavioral data\"\n)\n\n# Convert valid contrast trial start onset times to samples for MNE\nvalid_contrast_trials[\"sample_mne\"] = (\n valid_contrast_trials[\"onset\"] * raw.info[\"sfreq\"]\n).astype(int)\n\n# Create new events array with valid contrast trial starts only\n# Format: [sample, previous_sample, event_id]\nnew_events = np.column_stack(\n [\n valid_contrast_trials[\"sample_mne\"].values,\n np.zeros(len(valid_contrast_trials), dtype=int),\n np.full(\n len(valid_contrast_trials), 99, dtype=int\n ), # Use event_id 99 for contrast_trial_start\n ]\n)\n\n# Create new annotations from these events to replace the original annotations\n# This is the key step - we need to replace the annotations in the raw object\nannot_from_events = mne.annotations_from_events(\n events=new_events,\n event_desc={99: \"contrast_trial_start\"},\n sfreq=raw.info[\"sfreq\"],\n orig_time=raw.info[\"meas_date\"],\n)\n\n# Replace the annotations in the raw object\nprint(f\"Original annotations: {len(raw.annotations)} events\")\nraw.set_annotations(annot_from_events)\nprint(\n f\"New annotations: {len(raw.annotations)} contrast trial start events (valid trials only)\"\n)\n\n# Verify the new annotations\nevents_check, event_id_check = mne.events_from_annotations(raw)\nprint(f\"Events from new annotations: {len(events_check)} events\")\nprint(f\"Event ID mapping: {event_id_check}\")\n\n# Now use braindecode's create_windows_from_events to create 2-second epochs\n# Calculate the window size in samples (2 seconds * sampling rate)\nwindow_size_samples = int(2.0 * raw.info[\"sfreq\"]) # 2 seconds in samples\nprint(\n f\"Window size: {window_size_samples} samples ({window_size_samples / raw.info['sfreq']:.1f} seconds)\"\n)\n\n# Create 2-second epochs from valid contrast trial starts only\nwindows_dataset = create_windows_from_events(\n dataset, # The EEGDashDataset\n trial_start_offset_samples=0, # Start from the contrast trial start (no offset)\n trial_stop_offset_samples=window_size_samples, # End 2 seconds later\n preload=True,\n)\n\nprint(f\"Created {len(windows_dataset)} epochs with behavioral data\")\nprint(\"All epochs should now have valid stimulus and response information\")\n\n\n# ## Conclusion\n# - The epoched data is now ready under `windows_dataset`.\n# - The response time is under `stimulus_metadata['response_time']`. (required for challenge 1 regression task)\n# - The correctness is under `stimulus_metadata['correct']`. (required for challenge 1 classification task)\n# - The stimulus type (left or right) is under `stimulus_metadata['value']`. (might be useful)"
116
+ ]
117
+ }
118
+ ],
119
+ "metadata": {
120
+ "kernelspec": {
121
+ "display_name": "Python 3",
122
+ "language": "python",
123
+ "name": "python3"
124
+ },
125
+ "language_info": {
126
+ "codemirror_mode": {
127
+ "name": "ipython",
128
+ "version": 3
129
+ },
130
+ "file_extension": ".py",
131
+ "mimetype": "text/x-python",
132
+ "name": "python",
133
+ "nbconvert_exporter": "python",
134
+ "pygments_lexer": "ipython3",
135
+ "version": "3.10.0"
136
+ }
137
+ },
138
+ "nbformat": 4,
139
+ "nbformat_minor": 0
140
+ }
@@ -0,0 +1,176 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "collapsed": false
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "# For tips on running notebooks in Google Colab:\n# `pip install eegdash`\n%matplotlib inline"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "\n# Eyes Open vs. Closed Classification\n\nEEGDash example for eyes open vs. closed classification.\n\nThe code below provides an example of using the *EEGDash* library in combination with PyTorch to develop a deep learning model for analyzing EEG data, specifically for eyes open vs. closed classification in a single subject.\n\n1. **Data Retrieval Using EEGDash**: An instance of *EEGDashDataset* is created to search and retrieve an EEG dataset. At this step, only the metadata is transferred.\n\n2. **Data Preprocessing Using BrainDecode**: This process preprocesses EEG data using Braindecode by reannotating events, selecting specific channels, resampling, filtering, and extracting 2-second epochs, ensuring balanced eyes-open and eyes-closed data for analysis.\n\n3. **Creating train and testing sets**: The dataset is split into training (80%) and testing (20%) sets with balanced labels, converted into PyTorch tensors, and wrapped in DataLoader objects for efficient mini-batch training.\n\n4. **Model Definition**: The model is a shallow convolutional neural network (ShallowFBCSPNet) with 24 input channels (EEG channels), 2 output classes (eyes-open and eyes-closed).\n\n5. **Model Training and Evaluation Process**: This section trains the neural network, normalizes input data, computes cross-entropy loss, updates model parameters, and evaluates classification accuracy over six epochs.\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## Data Retrieval Using EEGDash\n\nFirst we find one resting state dataset. This dataset contains both eyes open and eyes closed data.\n\n"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {
32
+ "collapsed": false
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "from eegdash import EEGDashDataset\n\nds_eoec = EEGDashDataset(\n {\"dataset\": \"ds005514\", \"task\": \"RestingState\", \"subject\": \"NDARDB033FW5\"}\n)"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {},
42
+ "source": [
43
+ "## Data Preprocessing Using Braindecode\n\n[BrainDecode](https://braindecode.org/stable/install/install.html) is a specialized library for preprocessing EEG and MEG data. In this dataset, there are two key events in the continuous data: **instructed_toCloseEyes**, marking the start of a 40-second eyes-closed period, and **instructed_toOpenEyes**, indicating the start of a 20-second eyes-open period.\n\nFor the eyes-closed event, we extract 14 seconds of data from 15 to 29 seconds after the event onset. Similarly, for the eyes-open event, we extract data from 5 to 19 seconds after the event onset. This ensures an equal amount of data for both conditions. The event extraction is handled by the custom function **hbn_ec_ec_reannotation**.\n\nNext, we apply four preprocessing steps in Braindecode:\n1.\t**Reannotation** of event markers using hbn_ec_ec_reannotation().\n2.\t**Selection** of 24 specific EEG channels from the original 128.\n3.\t**Resampling** the EEG data to a frequency of 128 Hz.\n4.\t**Filtering** the EEG signals to retain frequencies between 1 Hz and 55 Hz.\n\nWhen calling the **preprocess** function, the data is retrieved from the remote repository.\n\nFinally, we use **create_windows_from_events** to extract 2-second epochs from the data. These epochs serve as the dataset samples. At this stage, each sample is automatically labeled with the corresponding event type (eyes-open or eyes-closed). windows_ds is a PyTorch dataset, and when queried, it returns labels for eyes-open and eyes-closed (assigned as labels 0 and 1, corresponding to their respective event markers).\n\n"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {
50
+ "collapsed": false
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "from braindecode.preprocessing import (\n preprocess,\n Preprocessor,\n create_windows_from_events,\n)\nimport numpy as np\nimport mne\nimport warnings\n\nwarnings.simplefilter(\"ignore\", category=RuntimeWarning)\n\n\nclass hbn_ec_ec_reannotation(Preprocessor):\n def __init__(self):\n super().__init__(\n fn=self.transform, apply_on_array=False\n ) # Pass the transform method as the function\n\n def transform(self, raw): # Changed from 'apply' to 'transform'\n # Create events array from annotations\n events, event_id = mne.events_from_annotations(raw)\n\n print(event_id)\n\n # Create new events array for 2-second segments\n new_events = []\n sfreq = raw.info[\"sfreq\"]\n for event in events[events[:, 2] == event_id[\"instructed_toCloseEyes\"]]:\n # For each original event, create events every 2 seconds from 15s to 29s after\n start_times = event[0] + np.arange(15, 29, 2) * sfreq\n new_events.extend([[int(t), 0, 1] for t in start_times])\n\n for event in events[events[:, 2] == event_id[\"instructed_toOpenEyes\"]]:\n # For each original event, create events every 2 seconds from 5s to 19s after\n start_times = event[0] + np.arange(5, 19, 2) * sfreq\n new_events.extend([[int(t), 0, 2] for t in start_times])\n\n # replace events in raw\n new_events = np.array(new_events)\n annot_from_events = mne.annotations_from_events(\n events=new_events,\n event_desc={1: \"eyes_closed\", 2: \"eyes_open\"},\n sfreq=raw.info[\"sfreq\"],\n )\n raw.set_annotations(annot_from_events)\n return raw\n\n\n# BrainDecode preprocessors\npreprocessors = [\n hbn_ec_ec_reannotation(),\n Preprocessor(\n \"pick_channels\",\n ch_names=[\n \"E22\",\n \"E9\",\n \"E33\",\n \"E24\",\n \"E11\",\n \"E124\",\n \"E122\",\n \"E29\",\n \"E6\",\n \"E111\",\n \"E45\",\n \"E36\",\n \"E104\",\n \"E108\",\n \"E42\",\n \"E55\",\n \"E93\",\n \"E58\",\n \"E52\",\n \"E62\",\n \"E92\",\n \"E96\",\n \"E70\",\n \"Cz\",\n ],\n ),\n Preprocessor(\"resample\", sfreq=128),\n Preprocessor(\"filter\", l_freq=1, h_freq=55),\n]\npreprocess(ds_eoec, preprocessors)\n\n# Extract 2-second segments\nwindows_ds = create_windows_from_events(\n ds_eoec,\n trial_start_offset_samples=0,\n trial_stop_offset_samples=256,\n preload=True,\n)"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {},
60
+ "source": [
61
+ "## Plotting a Single Channel for One Sample\n\nIt\u2019s always a good practice to verify that the data has been properly loaded and processed. Here, we plot a single channel from one sample to ensure the signal is present and looks as expected.\n\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {
68
+ "collapsed": false
69
+ },
70
+ "outputs": [],
71
+ "source": [
72
+ "import matplotlib.pyplot as plt\n\nplt.figure()\nplt.plot(windows_ds[2][0][0, :].transpose()) # first channel of first epoch\nplt.show()"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "## Creating training and test sets\n\nThe code below creates a training and test set. We first split the data into training and test sets using the **train_test_split** function from the **sklearn** library. We then create a **TensorDataset** for the training and test sets.\n\n1.\t**Set Random Seed** \u2013 The random seed is fixed using torch.manual_seed(random_state) to ensure reproducibility in dataset splitting and model training.\n2.\t**Extract Labels from the Dataset** \u2013 Labels (eye-open or eye-closed events) are extracted from windows_ds, stored as a NumPy array, and printed for verification.\n3.\t**Split Dataset into Train and Test Sets** \u2013 The dataset is split into training (80%) and testing (20%) subsets using train_test_split(), ensuring balanced stratification based on the extracted labels. Stratification means that we have as many eyes-open and eyes-closed samples in the training and testing sets.\n4.\t**Convert Data to PyTorch Tensors** \u2013 The selected training and testing samples are converted into FloatTensor for input features and LongTensor for labels, making them compatible with PyTorch models.\n5.\t**Create DataLoaders** \u2013 The datasets are wrapped in PyTorch DataLoader objects with a batch size of 10, enabling efficient mini-batch training and shuffling.\n\n\n"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {
86
+ "collapsed": false
87
+ },
88
+ "outputs": [],
89
+ "source": [
90
+ "import torch\nfrom sklearn.model_selection import train_test_split\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import TensorDataset\n\n# Set random seed for reproducibility\nrandom_state = 42\ntorch.manual_seed(random_state)\nnp.random.seed(random_state)\n\n# Extract labels from the dataset\neo_ec = np.array([ds[1] for ds in windows_ds]).transpose() # check labels\nprint(\"labels: \", eo_ec)\n\n# Get balanced indices for male and female subjects\ntrain_indices, test_indices = train_test_split(\n range(len(windows_ds)), test_size=0.2, stratify=eo_ec, random_state=random_state\n)\n\n# Convert the data to tensors\nX_train = torch.FloatTensor(\n np.array([windows_ds[i][0] for i in train_indices])\n) # Convert list of arrays to single tensor\nX_test = torch.FloatTensor(\n np.array([windows_ds[i][0] for i in test_indices])\n) # Convert list of arrays to single tensor\ny_train = torch.LongTensor(eo_ec[train_indices]) # Convert targets to tensor\ny_test = torch.LongTensor(eo_ec[test_indices]) # Convert targets to tensor\ndataset_train = TensorDataset(X_train, y_train)\ndataset_test = TensorDataset(X_test, y_test)\n\n# Create data loaders for training and testing (batch size 10)\ntrain_loader = DataLoader(dataset_train, batch_size=10, shuffle=True)\ntest_loader = DataLoader(dataset_test, batch_size=10, shuffle=True)\n\n# Print shapes and sizes to verify split\nprint(\n f\"Shape of data {X_train.shape} number of samples - Train: {len(train_loader)}, Test: {len(test_loader)}\"\n)\nprint(\n f\"Eyes-Open/Eyes-Closed balance, train: {np.mean(eo_ec[train_indices]):.2f}, test: {np.mean(eo_ec[test_indices]):.2f}\"\n)"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {},
96
+ "source": [
97
+ "# Check labels\n\nIt is good practice to verify the labels and ensure the random seed is functioning correctly. If all labels are 0s (eyes closed) or 1s (eyes open), it could indicate an issue with data loading or stratification, requiring further investigation.\n\n"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "Visualize a batch of target labels\n\n"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {
111
+ "collapsed": false
112
+ },
113
+ "outputs": [],
114
+ "source": [
115
+ "dataiter = iter(train_loader)\nfirst_item, label = dataiter.__next__()\nlabel"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "markdown",
120
+ "metadata": {},
121
+ "source": [
122
+ "# Create model\n\nThe model is a shallow convolutional neural network (ShallowFBCSPNet) with 24 input channels (EEG channels), 2 output classes (eyes-open and eyes-closed), and an input window size of 256 samples (2 seconds of EEG data).\n\n"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {
129
+ "collapsed": false
130
+ },
131
+ "outputs": [],
132
+ "source": [
133
+ "import torch\nimport numpy as np\nfrom torch.nn import functional as F\nfrom braindecode.models import ShallowFBCSPNet\nfrom torchinfo import summary\n\ntorch.manual_seed(random_state)\nmodel = ShallowFBCSPNet(24, 2, n_times=256, final_conv_length=\"auto\")\nsummary(model, input_size=(1, 24, 256))"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "metadata": {},
139
+ "source": [
140
+ "# Model Training and Evaluation Process\n\nThis section trains the neural network using the Adamax optimizer, normalizes input data, computes cross-entropy loss, updates model parameters, and tracks accuracy across six epochs.\n\n1. **Set Up Optimizer and Learning Rate Scheduler** \u2013 The `Adamax` optimizer initializes with a learning rate of 0.002 and weight decay of 0.001 for regularization. An `ExponentialLR` scheduler with a decay factor of 1 keeps the learning rate constant.\n\n2. **Allocate Model to Device** \u2013 The model moves to the specified device (CPU, GPU, or MPS for Mac silicon) to optimize computation efficiency.\n\n3. **Normalize Input Data** \u2013 The `normalize_data` function standardizes input data by subtracting the mean and dividing by the standard deviation along the time dimension before transferring it to the appropriate device.\n\n4. **Evaluates Classification Accuracy Over Six Epochs** \u2013 The training loop iterates through data batches with the model in training mode. It normalizes inputs, computes predictions, calculates cross-entropy loss, performs backpropagation, updates model parameters, and steps the learning rate scheduler. It tracks correct predictions to compute accuracy.\n\n5. **Evaluate on Test Data** \u2013 After each epoch, the model runs in evaluation mode on the test set. It computes predictions on normalized data and calculates test accuracy by comparing outputs with actual labels.\n\n"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {
147
+ "collapsed": false
148
+ },
149
+ "outputs": [],
150
+ "source": [
151
+ "optimizer = torch.optim.Adamax(model.parameters(), lr=0.002, weight_decay=0.001)\nscheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1)\n\ndevice = torch.device(\n \"cuda\"\n if torch.cuda.is_available()\n else \"mps\"\n if torch.backends.mps.is_available()\n else \"cpu\"\n)\nmodel = model.to(device=device) # move the model parameters to CPU/GPU\nepochs = 6\n\n\ndef normalize_data(x):\n mean = x.mean(dim=2, keepdim=True)\n std = x.std(dim=2, keepdim=True) + 1e-7 # add small epsilon for numerical stability\n x = (x - mean) / std\n x = x.to(device=device, dtype=torch.float32) # move to device, e.g. GPU\n return x\n\n\nfor e in range(epochs):\n # training\n correct_train = 0\n for t, (x, y) in enumerate(train_loader):\n model.train() # put model to training mode\n scores = model(normalize_data(x))\n y = y.to(device=device, dtype=torch.long)\n _, preds = scores.max(1)\n correct_train += (preds == y).sum() / len(dataset_train)\n\n loss = F.cross_entropy(scores, y)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n scheduler.step()\n\n # Validation\n correct_test = 0\n for t, (x, y) in enumerate(test_loader):\n model.eval() # put model to testing mode\n scores = model(normalize_data(x))\n y = y.to(device=device, dtype=torch.long)\n _, preds = scores.max(1)\n correct_test += (preds == y).sum() / len(dataset_test)\n\n # Reporting\n print(\n f\"Epoch {e}, Train accuracy: {correct_train:.2f}, Test accuracy: {correct_test:.2f}\"\n )"
152
+ ]
153
+ }
154
+ ],
155
+ "metadata": {
156
+ "kernelspec": {
157
+ "display_name": "Python 3",
158
+ "language": "python",
159
+ "name": "python3"
160
+ },
161
+ "language_info": {
162
+ "codemirror_mode": {
163
+ "name": "ipython",
164
+ "version": 3
165
+ },
166
+ "file_extension": ".py",
167
+ "mimetype": "text/x-python",
168
+ "name": "python",
169
+ "nbconvert_exporter": "python",
170
+ "pygments_lexer": "ipython3",
171
+ "version": "3.10.0"
172
+ }
173
+ },
174
+ "nbformat": 4,
175
+ "nbformat_minor": 0
176
+ }
@@ -0,0 +1,151 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "collapsed": false
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "# For tips on running notebooks in Google Colab:\n# `pip install eegdash`\n%matplotlib inline"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "\n# Challenge 2: Predicting the p-factor from EEG\n\nThis tutorial presents Challenge 2: regression of the p-factor (a general psychopathology factor) from EEG recordings.\nThe objective is to identify reproducible EEG biomarkers linked to mental health outcomes.\n\nThe challenge encourages learning physiologically meaningful signal representations.\nModels of any size should emphasize robust, interpretable features that generalize across subjects,\nsessions, and acquisition sites.\n\nUnlike a standard in-distribution classification task, this regression problem stresses out-of-distribution robustness\nand extrapolation. The goal is not only to minimize error on seen subjects, but also to transfer effectively to unseen data.\n\nEnsure the dataset is available locally. If not, see the [dataset download guide](https://eeg2025.github.io/data/#downloading-the-data)\n\nThis tutorial is divided as follows:\n1. **Loading the data**\n2. **Wrap the data into a PyTorch-compatible dataset**\n3. **Define, train and save a model**\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## 1. Loading the data\n\n\n"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {
32
+ "collapsed": false
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "import random\nfrom pathlib import Path\nfrom eegdash import EEGChallengeDataset\nfrom braindecode.preprocessing import create_fixed_length_windows\nfrom braindecode.datasets.base import EEGWindowsDataset, BaseConcatDataset, BaseDataset"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {},
42
+ "source": [
43
+ "### 2. Define local path and download the data\n\nIn this challenge 2 example, we load the EEG 2025 release using EEG Dash and Braindecode,\nwe load all the public datasets available in the EEG 2025 release.\n\n"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {
50
+ "collapsed": false
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "# The first step is define the cache folder!\ncache_dir = (Path.home() / \"mne_data\" / \"eeg_challenge_cache\").resolve()\n\n# Creating the path if it does not exist\ncache_dir.mkdir(parents=True, exist_ok=True)\n\n# We define the list of releases to load.\n# Here, all releases are loaded, i.e., 1 to 11.\nrelease_list = [\"R{}\".format(i) for i in [5]] # range(1, 11 + 1)]\n\n# For this tutorial, we will only load the \"resting state\" recording,\n# but you may use all available data.\nall_datasets_list = [\n EEGChallengeDataset(\n release=release,\n query=dict(\n task=\"RestingState\",\n ),\n description_fields=[\n \"subject\",\n \"session\",\n \"run\",\n \"task\",\n \"age\",\n \"gender\",\n \"sex\",\n \"p_factor\",\n ],\n cache_dir=cache_dir,\n )\n for release in release_list\n]\nprint(\"Datasets loaded\")"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {},
60
+ "source": [
61
+ "### Combine the datasets into single one\nHere, we combine the datasets from the different releases into a single\n``BaseConcatDataset`` object.\n\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {
68
+ "collapsed": false
69
+ },
70
+ "outputs": [],
71
+ "source": [
72
+ "all_datasets = BaseConcatDataset(all_datasets_list)\nprint(all_datasets.description)"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "### Inspect your data\nWe can check what is inside the dataset consuming the\nMNE-object inside the Braindecode dataset.\n\nThe following snippet, if uncommented, will show the first 10 seconds of the raw EEG signal.\nWe can also inspect the data further by looking at the events and annotations.\nWe strong recommend you to take a look into the details and check how the events are structured.\n\n"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {
86
+ "collapsed": false
87
+ },
88
+ "outputs": [],
89
+ "source": [
90
+ "# raw = all_datasets.datasets[0].raw # mne.io.Raw object\n# print(raw.info)\n\n# raw.plot(duration=10, scalings=\"auto\", show=True)\n\n# print(raw.annotations)\n\nSFREQ = 100"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {},
96
+ "source": [
97
+ "## Wrap the data into a PyTorch-compatible dataset\n\nThe class below defines a dataset wrapper that will extract 2-second windows,\nuniformly sampled over the whole signal. In addition, it will add useful information\nabout the extracted windows, such as the p-factor, the subject or the task.\n\n"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {
104
+ "collapsed": false
105
+ },
106
+ "outputs": [],
107
+ "source": [
108
+ "class DatasetWrapper(BaseDataset):\n def __init__(self, dataset: EEGWindowsDataset, crop_size_samples: int, seed=None):\n self.dataset = dataset\n self.crop_size_samples = crop_size_samples\n self.rng = random.Random(seed)\n\n def __len__(self):\n return len(self.dataset)\n\n def __getitem__(self, index):\n X, _, crop_inds = self.dataset[index]\n\n # P-factor label:\n p_factor = self.dataset.description[\"p_factor\"]\n p_factor = float(p_factor)\n\n # Additional information:\n infos = {\n \"subject\": self.dataset.description[\"subject\"],\n \"sex\": self.dataset.description[\"sex\"],\n \"age\": float(self.dataset.description[\"age\"]),\n \"task\": self.dataset.description[\"task\"],\n \"session\": self.dataset.description.get(\"session\", None) or \"\",\n \"run\": self.dataset.description.get(\"run\", None) or \"\",\n }\n\n # Randomly crop the signal to the desired length:\n i_window_in_trial, i_start, i_stop = crop_inds\n assert i_stop - i_start >= self.crop_size_samples, f\"{i_stop=} {i_start=}\"\n start_offset = self.rng.randint(0, i_stop - i_start - self.crop_size_samples)\n i_start = i_start + start_offset\n i_stop = i_start + self.crop_size_samples\n X = X[:, start_offset : start_offset + self.crop_size_samples]\n\n return X, p_factor, (i_window_in_trial, i_start, i_stop), infos\n\n\n# Filter out recordings that are too short\nall_datasets = BaseConcatDataset(\n [ds for ds in all_datasets.datasets if ds.raw.n_times >= 4 * SFREQ]\n)\n\n# Create 4-seconds windows with 2-seconds stride\nwindows_ds = create_fixed_length_windows(\n all_datasets,\n window_size_samples=4 * SFREQ,\n window_stride_samples=2 * SFREQ,\n drop_last_window=True,\n)\n\n# Wrap each sub-dataset in the windows_ds\nwindows_ds = BaseConcatDataset(\n [DatasetWrapper(ds, crop_size_samples=2 * SFREQ) for ds in windows_ds.datasets]\n)"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "metadata": {},
114
+ "source": [
115
+ "## Define, train and save a model\nNow we have our pytorch dataset necessary for the training!\n\nBelow, we define a simple EEGNetv4 model from Braindecode and train it for one epoch\nusing pure PyTorch code.\nHowever, you can use any pytorch model you want, or training framework.\n\n"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {
122
+ "collapsed": false
123
+ },
124
+ "outputs": [],
125
+ "source": [
126
+ "import torch\nfrom torch.utils.data import DataLoader\nfrom torch import optim\nfrom torch.nn.functional import l1_loss\nfrom braindecode.models import EEGNetv4\n\n# Use GPU if available\nDEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# Create PyTorch Dataloader\ndataloader = DataLoader(windows_ds, batch_size=10, shuffle=True)\n\n# Initialize model\nmodel = EEGNetv4(n_chans=129, n_outputs=1, n_times=2 * SFREQ).to(DEVICE)\n\n# All the braindecode models expect the input to be of shape (batch_size, n_channels, n_times)\n# and have a test coverage about the behavior of the model.\nprint(model)\n\n# Specify optimizer\noptimizer = optim.Adamax(params=model.parameters(), lr=0.002)\n\n# Train model for 1 epoch\nfor epoch in range(1):\n for idx, batch in enumerate(dataloader):\n # Reset gradients\n optimizer.zero_grad()\n\n # Unpack the batch\n X, y, crop_inds, infos = batch\n X = X.to(dtype=torch.float32, device=DEVICE)\n y = y.to(dtype=torch.float32, device=DEVICE).unsqueeze(1)\n\n # Forward pass\n y_pred = model(X)\n\n # Compute loss\n loss = l1_loss(y_pred, y)\n print(f\"Epoch {0} - step {idx}, loss: {loss.item()}\")\n\n # Gradient backpropagation\n loss.backward()\n optimizer.step()\n\n# Finally, we can save the model for later use\ntorch.save(model.state_dict(), \"./example_submission_challenge_2/weights.pt\")"
127
+ ]
128
+ }
129
+ ],
130
+ "metadata": {
131
+ "kernelspec": {
132
+ "display_name": "Python 3",
133
+ "language": "python",
134
+ "name": "python3"
135
+ },
136
+ "language_info": {
137
+ "codemirror_mode": {
138
+ "name": "ipython",
139
+ "version": 3
140
+ },
141
+ "file_extension": ".py",
142
+ "mimetype": "text/x-python",
143
+ "name": "python",
144
+ "nbconvert_exporter": "python",
145
+ "pygments_lexer": "ipython3",
146
+ "version": "3.10.0"
147
+ }
148
+ },
149
+ "nbformat": 4,
150
+ "nbformat_minor": 0
151
+ }