py-neuromodulation 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
  2. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
  3. py_neuromodulation/__init__.py +12 -4
  4. py_neuromodulation/analysis/RMAP.py +3 -3
  5. py_neuromodulation/analysis/decode.py +55 -2
  6. py_neuromodulation/analysis/feature_reader.py +1 -0
  7. py_neuromodulation/analysis/stats.py +3 -3
  8. py_neuromodulation/default_settings.yaml +25 -17
  9. py_neuromodulation/features/bandpower.py +65 -23
  10. py_neuromodulation/features/bispectra.py +3 -7
  11. py_neuromodulation/features/bursts.py +9 -8
  12. py_neuromodulation/features/coherence.py +17 -9
  13. py_neuromodulation/features/feature_processor.py +4 -4
  14. py_neuromodulation/features/fooof.py +7 -6
  15. py_neuromodulation/features/mne_connectivity.py +25 -3
  16. py_neuromodulation/features/oscillatory.py +5 -4
  17. py_neuromodulation/features/sharpwaves.py +21 -0
  18. py_neuromodulation/filter/kalman_filter.py +17 -6
  19. py_neuromodulation/gui/__init__.py +3 -0
  20. py_neuromodulation/gui/backend/app_backend.py +419 -0
  21. py_neuromodulation/gui/backend/app_manager.py +345 -0
  22. py_neuromodulation/gui/backend/app_pynm.py +244 -0
  23. py_neuromodulation/gui/backend/app_socket.py +95 -0
  24. py_neuromodulation/gui/backend/app_utils.py +306 -0
  25. py_neuromodulation/gui/backend/app_window.py +202 -0
  26. py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
  27. py_neuromodulation/gui/frontend/assets/index-NbJiOU5a.js +300133 -0
  28. py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
  29. py_neuromodulation/gui/frontend/charite.svg +16 -0
  30. py_neuromodulation/gui/frontend/index.html +14 -0
  31. py_neuromodulation/gui/window_api.py +115 -0
  32. py_neuromodulation/lsl_api.cfg +3 -0
  33. py_neuromodulation/processing/data_preprocessor.py +9 -2
  34. py_neuromodulation/processing/filter_preprocessing.py +43 -27
  35. py_neuromodulation/processing/normalization.py +32 -17
  36. py_neuromodulation/processing/projection.py +2 -2
  37. py_neuromodulation/processing/resample.py +6 -2
  38. py_neuromodulation/run_gui.py +36 -0
  39. py_neuromodulation/stream/__init__.py +7 -1
  40. py_neuromodulation/stream/backend_interface.py +47 -0
  41. py_neuromodulation/stream/data_processor.py +24 -3
  42. py_neuromodulation/stream/mnelsl_player.py +121 -21
  43. py_neuromodulation/stream/mnelsl_stream.py +9 -17
  44. py_neuromodulation/stream/settings.py +80 -34
  45. py_neuromodulation/stream/stream.py +82 -62
  46. py_neuromodulation/utils/channels.py +1 -1
  47. py_neuromodulation/utils/file_writer.py +110 -0
  48. py_neuromodulation/utils/io.py +46 -5
  49. py_neuromodulation/utils/perf.py +156 -0
  50. py_neuromodulation/utils/pydantic_extensions.py +322 -0
  51. py_neuromodulation/utils/types.py +33 -107
  52. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/METADATA +27 -22
  53. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/RECORD +56 -36
  54. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/WHEEL +1 -1
  55. py_neuromodulation-0.1.0.dist-info/entry_points.txt +2 -0
  56. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,13 +1,20 @@
1
1
  import numpy as np
2
2
  import mne
3
3
  from pathlib import Path
4
+ import multiprocessing as mp
5
+ import atexit
6
+ import time
7
+ import signal
4
8
 
5
9
  from py_neuromodulation.utils.types import _PathLike
6
- from py_neuromodulation.utils import io
10
+ from py_neuromodulation.utils.io import read_BIDS_data
7
11
  from py_neuromodulation import logger
8
12
 
9
13
 
10
14
  class LSLOfflinePlayer:
15
+ _instances: set["LSLOfflinePlayer"] = set() # Keep track of initialized players
16
+ _atexit_registered: bool = False # Flag to register atexit
17
+
11
18
  def __init__(
12
19
  self,
13
20
  stream_name: str | None = "lsl_offline_player",
@@ -16,6 +23,8 @@ class LSLOfflinePlayer:
16
23
  sfreq: int | float | None = None,
17
24
  data: np.ndarray | None = None,
18
25
  ch_type: str | None = "dbs",
26
+ chunk_size: int = 10,
27
+ n_repeat: int = 1,
19
28
  ) -> None:
20
29
  """Initialization of MNE-LSL offline player.
21
30
  Either a filename (PathLike) is provided,
@@ -42,24 +51,16 @@ class LSLOfflinePlayer:
42
51
  """
43
52
  self.sfreq = sfreq
44
53
  self.stream_name = stream_name
45
- got_raw = raw is not None
46
- got_fname = f_name is not None
47
- got_sfreq_data = sfreq is not None and data is not None
54
+ self.chunk_size = chunk_size
55
+ self.n_repeat = n_repeat
48
56
 
49
- if not (got_fname or got_sfreq_data or got_raw):
50
- error_msg = "Either f_name or raw or sfreq and data must be provided."
51
- logger.critical(error_msg)
52
- raise ValueError(error_msg)
53
-
54
- if got_fname:
57
+ if f_name:
55
58
  (self._path_raw, data, sfreq, line_noise, coord_list, coord_names) = (
56
- io.read_BIDS_data(f_name)
59
+ read_BIDS_data(f_name)
57
60
  )
58
-
59
- elif got_raw:
61
+ elif raw:
60
62
  self._path_raw = raw
61
-
62
- elif got_sfreq_data:
63
+ elif sfreq and data:
63
64
  info = mne.create_info(
64
65
  ch_names=[f"ch{i}" for i in range(data.shape[0])],
65
66
  ch_types=[ch_type for _ in range(data.shape[0])],
@@ -68,27 +69,126 @@ class LSLOfflinePlayer:
68
69
  raw = mne.io.RawArray(data, info)
69
70
  self._path_raw = Path.cwd() / "temp_raw.fif"
70
71
  raw.save(self._path_raw, overwrite=True)
72
+ else:
73
+ error_msg = "Either f_name or raw or sfreq and data must be provided."
74
+ logger.critical(error_msg)
75
+ raise ValueError(error_msg)
71
76
 
72
- def start_player(self, chunk_size: int = 10, n_repeat: int = 1):
77
+ # Flags to control the player subprocess
78
+ self._streaming_complete = mp.Event()
79
+ self._player_process = None
80
+ self._stop_flag = mp.Event()
81
+
82
+ LSLOfflinePlayer._instances.add(self) # Register instancwe
83
+ if LSLOfflinePlayer._atexit_registered:
84
+ atexit.register(LSLOfflinePlayer._stop_all_players)
85
+ LSLOfflinePlayer._atexit_registered = True
86
+
87
+ def start_player(
88
+ self,
89
+ chunk_size: int | None = None,
90
+ n_repeat: int | None = None,
91
+ block: bool = False,
92
+ ):
73
93
  """Start MNE-LSL Player
74
94
 
75
95
  Parameters
76
96
  ----------
77
97
  chunk_size : int, optional
78
- _description_, by default 1
98
+ Number of samples to stream at once, by default 10
79
99
  n_repeat : int, optional
80
- _description_, by default 1
100
+ Number of times to repeat the stream, by default 1
101
+ block : bool, optional
102
+ If True, block until streaming is complete, by default False
81
103
  """
104
+
105
+ if chunk_size:
106
+ self.chunk_size = chunk_size
107
+ if n_repeat:
108
+ self.n_repeat = n_repeat
109
+
110
+ self._stop_flag.clear()
111
+ self._streaming_complete.clear()
112
+
113
+ self._player_process = mp.Process(
114
+ target=self._run_player,
115
+ args=(
116
+ self.chunk_size,
117
+ self.n_repeat,
118
+ self._stop_flag,
119
+ self._streaming_complete,
120
+ ),
121
+ )
122
+ self._player_process.start()
123
+
124
+ if block:
125
+ try:
126
+ self.wait_for_completion()
127
+ except KeyboardInterrupt:
128
+ logger.info("\nKeyboard interrupt received. Stopping the player...")
129
+ self.stop_player()
130
+
131
+ def _run_player(self, chunk_size, n_repeat, stop_flag, streaming_complete):
82
132
  from mne_lsl.player import PlayerLSL
83
133
 
84
- self.player = PlayerLSL(
134
+ signal.signal(signal.SIGINT, lambda: stop_flag.set())
135
+
136
+ player = PlayerLSL(
85
137
  self._path_raw,
86
138
  name=self.stream_name,
87
139
  chunk_size=chunk_size,
88
140
  n_repeat=n_repeat,
89
141
  )
90
- self.player = self.player.start()
142
+ player = player.start()
143
+
144
+ try:
145
+ while not stop_flag.is_set() and not player._end_streaming:
146
+ time.sleep(0.1)
147
+ finally:
148
+ try:
149
+ player.stop()
150
+ except RuntimeError:
151
+ # player already stopped
152
+ pass
153
+ streaming_complete.set()
154
+
155
+ def wait_for_completion(self):
156
+ """Block until streaming is complete"""
157
+ while self._player_process and self._player_process.is_alive():
158
+ try:
159
+ self._streaming_complete.wait(timeout=1.0)
160
+ if self._streaming_complete.is_set():
161
+ break
162
+ except KeyboardInterrupt:
163
+ logger.info("\nKeyboard interrupt received. Stopping the player...")
164
+ self.stop_player()
165
+ break
91
166
 
92
167
  def stop_player(self):
93
168
  """Stop MNE-LSL Player"""
94
- self.player.stop()
169
+ if self._player_process and self._player_process.is_alive():
170
+ self._stop_flag.set()
171
+ self._player_process.join(timeout=5)
172
+ if self._player_process.is_alive():
173
+ self._player_process.terminate()
174
+ self._player_process.join(timeout=1)
175
+ if self._player_process.is_alive():
176
+ self._player_process.kill()
177
+ self._player_process = None
178
+
179
+ logger.info(f"Player stopped: {self.stream_name}")
180
+ LSLOfflinePlayer._instances.discard(self)
181
+
182
+ @classmethod
183
+ def _stop_all_players(cls):
184
+ """Stop all player instances (used for atexit)"""
185
+ for player in cls._instances:
186
+ player.stop_player()
187
+
188
+ # Enable use as a context manager
189
+ def __enter__(self):
190
+ self.start_player()
191
+ return self
192
+
193
+ def __exit__(self, exc_type, exc_val, exc_tb):
194
+ self.stop_player()
@@ -34,7 +34,7 @@ class LSLStream:
34
34
  from mne_lsl.stream import StreamLSL
35
35
 
36
36
  self.stream: StreamLSL
37
- self.keyboard_interrupt = False
37
+ # self.keyboard_interrupt = False
38
38
 
39
39
  self.settings = settings
40
40
  self._n_seconds_wait_before_disconnect = 3
@@ -58,11 +58,11 @@ class LSLStream:
58
58
 
59
59
  # If not running the generator when the escape key is pressed.
60
60
  self.headless: bool = not os.environ.get("DISPLAY")
61
- if not self.headless:
62
- from py_neuromodulation.utils.keyboard import KeyboardListener
61
+ # if not self.headless:
62
+ # from py_neuromodulation.utils.keyboard import KeyboardListener
63
63
 
64
- self.listener = KeyboardListener(("esc", self.set_keyboard_interrupt))
65
- self.listener.start()
64
+ # self.listener = KeyboardListener(("esc", self.set_keyboard_interrupt))
65
+ # self.listener.start()
66
66
 
67
67
  def get_next_batch(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:
68
68
  self.last_time = time.time()
@@ -91,6 +91,8 @@ class LSLStream:
91
91
  if stream_start_time is None:
92
92
  stream_start_time = timestamp[0]
93
93
 
94
+ logger.info(f"Stream time: {timestamp[-1] - stream_start_time}")
95
+
94
96
  for i in range(self._n_seconds_wait_before_disconnect):
95
97
  if (
96
98
  data is not None
@@ -98,7 +100,7 @@ class LSLStream:
98
100
  and np.allclose(data, check_data, atol=1e-7, rtol=1e-7)
99
101
  ):
100
102
  logger.warning(
101
- f"No new data incoming. Disconnecting stream in {3-i} seconds."
103
+ f"No new data incoming. Disconnecting stream in {3 - i} seconds."
102
104
  )
103
105
  time.sleep(1)
104
106
  i += 1
@@ -107,14 +109,4 @@ class LSLStream:
107
109
  logger.warning("Stream disconnected.")
108
110
  break
109
111
 
110
- yield timestamp, data
111
-
112
- logger.info(f"Stream time: {timestamp[-1] - stream_start_time}")
113
-
114
- if not self.headless and self.keyboard_interrupt:
115
- logger.info("Keyboard interrupt")
116
- self.listener.stop()
117
- self.stream.disconnect()
118
-
119
- def set_keyboard_interrupt(self):
120
- self.keyboard_interrupt = True
112
+ yield timestamp - stream_start_time, data
@@ -1,22 +1,27 @@
1
1
  """Module for handling settings."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import ClassVar
5
- from pydantic import Field, model_validator
4
+ from typing import Any, ClassVar, get_args
5
+ from pydantic import model_validator, ValidationError
6
+ from pydantic.functional_validators import ModelWrapValidatorHandler
6
7
 
7
- from py_neuromodulation import PYNM_DIR, logger, user_features
8
+ from py_neuromodulation import logger, user_features, PYNM_DIR
8
9
 
9
10
  from py_neuromodulation.utils.types import (
10
11
  BoolSelector,
11
12
  FrequencyRange,
12
- PreprocessorName,
13
13
  _PathLike,
14
14
  NMBaseModel,
15
- NormMethod,
15
+ NORM_METHOD,
16
+ PREPROCESSOR_NAME,
16
17
  )
18
+ from py_neuromodulation.utils.pydantic_extensions import NMErrorList, NMField
17
19
 
18
20
  from py_neuromodulation.processing.filter_preprocessing import FilterSettings
19
- from py_neuromodulation.processing.normalization import NormalizationSettings
21
+ from py_neuromodulation.processing.normalization import (
22
+ FeatureNormalizationSettings,
23
+ NormalizationSettings,
24
+ )
20
25
  from py_neuromodulation.processing.resample import ResamplerSettings
21
26
  from py_neuromodulation.processing.projection import ProjectionSettings
22
27
 
@@ -31,7 +36,9 @@ from py_neuromodulation.features import OscillatorySettings, BandPowerSettings
31
36
  from py_neuromodulation.features import BurstsSettings
32
37
 
33
38
 
34
- class FeatureSelection(BoolSelector):
39
+ # TONI: this class has the proble that if a feature is absent,
40
+ # it won't default to False but to whatever is defined here as default
41
+ class FeatureSelector(BoolSelector):
35
42
  raw_hjorth: bool = True
36
43
  return_raw: bool = True
37
44
  bandpass_filter: bool = False
@@ -54,13 +61,24 @@ class PostprocessingSettings(BoolSelector):
54
61
  project_subcortex: bool = False
55
62
 
56
63
 
64
+ DEFAULT_PREPROCESSORS: list[PREPROCESSOR_NAME] = [
65
+ "raw_resampling",
66
+ "notch_filter",
67
+ "re_referencing",
68
+ ]
69
+
70
+
57
71
  class NMSettings(NMBaseModel):
58
72
  # Class variable to store instances
59
73
  _instances: ClassVar[list["NMSettings"]] = []
60
74
 
61
75
  # General settings
62
- sampling_rate_features_hz: float = Field(default=10, gt=0)
63
- segment_length_features_ms: float = Field(default=1000, gt=0)
76
+ sampling_rate_features_hz: float = NMField(
77
+ default=10, gt=0, custom_metadata={"unit": "Hz"}
78
+ )
79
+ segment_length_features_ms: float = NMField(
80
+ default=1000, gt=0, custom_metadata={"unit": "ms"}
81
+ )
64
82
  frequency_ranges_hz: dict[str, FrequencyRange] = {
65
83
  "theta": FrequencyRange(4, 8),
66
84
  "alpha": FrequencyRange(8, 12),
@@ -72,35 +90,40 @@ class NMSettings(NMBaseModel):
72
90
  }
73
91
 
74
92
  # Preproceessing settings
75
- preprocessing: list[PreprocessorName] = [
76
- "raw_resampling",
77
- "notch_filter",
78
- "re_referencing",
79
- ]
93
+ preprocessing: list[PREPROCESSOR_NAME] = NMField(
94
+ default=DEFAULT_PREPROCESSORS,
95
+ custom_metadata={
96
+ "field_type": "PreprocessorList",
97
+ "valid_values": list(get_args(PREPROCESSOR_NAME)),
98
+ },
99
+ )
100
+
80
101
  raw_resampling_settings: ResamplerSettings = ResamplerSettings()
81
102
  preprocessing_filter: FilterSettings = FilterSettings()
82
103
  raw_normalization_settings: NormalizationSettings = NormalizationSettings()
83
104
 
84
105
  # Postprocessing settings
85
106
  postprocessing: PostprocessingSettings = PostprocessingSettings()
86
- feature_normalization_settings: NormalizationSettings = NormalizationSettings()
107
+ feature_normalization_settings: FeatureNormalizationSettings = (
108
+ FeatureNormalizationSettings()
109
+ )
87
110
  project_cortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=20)
88
111
  project_subcortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=5)
89
112
 
90
113
  # Feature settings
91
- features: FeatureSelection = FeatureSelection()
114
+ features: FeatureSelector = FeatureSelector()
92
115
 
93
116
  fft_settings: OscillatorySettings = OscillatorySettings()
94
117
  welch_settings: OscillatorySettings = OscillatorySettings()
95
118
  stft_settings: OscillatorySettings = OscillatorySettings()
96
119
  bandpass_filter_settings: BandPowerSettings = BandPowerSettings()
97
120
  kalman_filter_settings: KalmanSettings = KalmanSettings()
98
- burst_settings: BurstsSettings = BurstsSettings()
121
+ bursts_settings: BurstsSettings = BurstsSettings()
99
122
  sharpwave_analysis_settings: SharpwaveSettings = SharpwaveSettings()
100
123
  mne_connectivity_settings: MNEConnectivitySettings = MNEConnectivitySettings()
101
124
  coherence_settings: CoherenceSettings = CoherenceSettings()
102
125
  fooof_settings: FooofSettings = FooofSettings()
103
- nolds_settings: NoldsSettings = NoldsSettings()
126
+ nolds_features: NoldsSettings = NoldsSettings()
104
127
  bispectrum_settings: BispectraSettings = BispectraSettings()
105
128
 
106
129
  def __init__(self, *args, **kwargs) -> None:
@@ -126,10 +149,38 @@ class NMSettings(NMBaseModel):
126
149
  for instance in cls._instances:
127
150
  delattr(instance.features, feature)
128
151
 
129
- @model_validator(mode="after")
130
- def validate_settings(self):
152
+ @model_validator(mode="wrap") # type: ignore[reportIncompatibleMethodOverride]
153
+ def validate_settings(self, handler: ModelWrapValidatorHandler) -> Any:
154
+ # Perform all necessary custom validations in the settings class and also
155
+ # all validations in the feature classes that need additional information from
156
+ # the settings class
157
+ errors: NMErrorList = NMErrorList()
158
+
159
+ def remove_private_keys(data):
160
+ if isinstance(data, dict):
161
+ if "__value__" in data:
162
+ return data["__value__"]
163
+ else:
164
+ return {
165
+ key: remove_private_keys(value)
166
+ for key, value in data.items()
167
+ if not key.startswith("__")
168
+ }
169
+ elif isinstance(data, (list, tuple, set)):
170
+ return type(data)(remove_private_keys(item) for item in data)
171
+ else:
172
+ return data
173
+
174
+ self = remove_private_keys(self)
175
+
176
+ try:
177
+ self = handler(self) # validate the model
178
+ except ValidationError as e:
179
+ self = NMSettings.unvalidated(**self) # type: ignore
180
+ errors.extend(NMErrorList(e.errors()))
181
+
131
182
  if len(self.features.get_enabled()) == 0:
132
- raise ValueError("At least one feature must be selected.")
183
+ errors.add_error("At least one feature must be selected.")
133
184
 
134
185
  # Replace spaces with underscores in frequency band names
135
186
  self.frequency_ranges_hz = {
@@ -138,32 +189,27 @@ class NMSettings(NMBaseModel):
138
189
 
139
190
  if self.features.bandpass_filter:
140
191
  # Check BandPass settings frequency bands
141
- self.bandpass_filter_settings.validate_fbands(self)
192
+ errors.extend(self.bandpass_filter_settings.validate_fbands(self))
142
193
 
143
194
  # Check Kalman filter frequency bands
144
195
  if self.bandpass_filter_settings.kalman_filter:
145
- self.kalman_filter_settings.validate_fbands(self)
196
+ errors.extend(self.kalman_filter_settings.validate_fbands(self))
146
197
 
147
- for k, v in self.frequency_ranges_hz.items():
148
- if not isinstance(v, FrequencyRange):
149
- self.frequency_ranges_hz[k] = FrequencyRange.create_from(v)
198
+ if len(errors) > 0:
199
+ raise errors.create_error()
150
200
 
151
201
  return self
152
202
 
153
203
  def reset(self) -> "NMSettings":
154
204
  self.features.disable_all()
155
- self.preprocessing = []
205
+ self.preprocessing = DEFAULT_PREPROCESSORS
156
206
  self.postprocessing.disable_all()
157
207
  return self
158
208
 
159
209
  def set_fast_compute(self) -> "NMSettings":
160
210
  self.reset()
161
211
  self.features.fft = True
162
- self.preprocessing = [
163
- "raw_resampling",
164
- "notch_filter",
165
- "re_referencing",
166
- ]
212
+ self.preprocessing = DEFAULT_PREPROCESSORS
167
213
  self.postprocessing.feature_normalization = True
168
214
  self.postprocessing.project_cortex = False
169
215
  self.postprocessing.project_subcortex = False
@@ -253,7 +299,7 @@ class NMSettings(NMBaseModel):
253
299
  return NMSettings.from_file(PYNM_DIR / "default_settings.yaml")
254
300
 
255
301
  @staticmethod
256
- def list_normalization_methods() -> list[NormMethod]:
302
+ def list_normalization_methods() -> list[NORM_METHOD]:
257
303
  return NormalizationSettings.list_normalization_methods()
258
304
 
259
305
  def save(
@@ -261,7 +307,7 @@ class NMSettings(NMBaseModel):
261
307
  ) -> None:
262
308
  filename = f"{prefix}_SETTINGS.{format}" if prefix else f"SETTINGS.{format}"
263
309
 
264
- path_out = Path(out_dir) / filename
310
+ path_out = Path(out_dir) / prefix / filename
265
311
 
266
312
  with open(path_out, "w") as f:
267
313
  match format: