py-neuromodulation 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
  2. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
  3. py_neuromodulation/__init__.py +12 -4
  4. py_neuromodulation/analysis/RMAP.py +3 -3
  5. py_neuromodulation/analysis/decode.py +55 -2
  6. py_neuromodulation/analysis/feature_reader.py +1 -0
  7. py_neuromodulation/analysis/stats.py +3 -3
  8. py_neuromodulation/default_settings.yaml +25 -20
  9. py_neuromodulation/features/bandpower.py +65 -23
  10. py_neuromodulation/features/bursts.py +9 -8
  11. py_neuromodulation/features/coherence.py +7 -4
  12. py_neuromodulation/features/feature_processor.py +4 -4
  13. py_neuromodulation/features/fooof.py +7 -6
  14. py_neuromodulation/features/mne_connectivity.py +60 -87
  15. py_neuromodulation/features/oscillatory.py +5 -4
  16. py_neuromodulation/features/sharpwaves.py +21 -0
  17. py_neuromodulation/filter/kalman_filter.py +17 -6
  18. py_neuromodulation/gui/__init__.py +3 -0
  19. py_neuromodulation/gui/backend/app_backend.py +419 -0
  20. py_neuromodulation/gui/backend/app_manager.py +345 -0
  21. py_neuromodulation/gui/backend/app_pynm.py +253 -0
  22. py_neuromodulation/gui/backend/app_socket.py +97 -0
  23. py_neuromodulation/gui/backend/app_utils.py +306 -0
  24. py_neuromodulation/gui/backend/app_window.py +202 -0
  25. py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
  26. py_neuromodulation/gui/frontend/assets/index-_6V8ZfAS.js +300137 -0
  27. py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
  28. py_neuromodulation/gui/frontend/charite.svg +16 -0
  29. py_neuromodulation/gui/frontend/index.html +14 -0
  30. py_neuromodulation/gui/window_api.py +115 -0
  31. py_neuromodulation/lsl_api.cfg +3 -0
  32. py_neuromodulation/processing/data_preprocessor.py +9 -2
  33. py_neuromodulation/processing/filter_preprocessing.py +43 -27
  34. py_neuromodulation/processing/normalization.py +32 -17
  35. py_neuromodulation/processing/projection.py +2 -2
  36. py_neuromodulation/processing/resample.py +6 -2
  37. py_neuromodulation/run_gui.py +36 -0
  38. py_neuromodulation/stream/__init__.py +7 -1
  39. py_neuromodulation/stream/backend_interface.py +47 -0
  40. py_neuromodulation/stream/data_processor.py +24 -3
  41. py_neuromodulation/stream/mnelsl_player.py +121 -21
  42. py_neuromodulation/stream/mnelsl_stream.py +9 -17
  43. py_neuromodulation/stream/settings.py +80 -34
  44. py_neuromodulation/stream/stream.py +83 -62
  45. py_neuromodulation/utils/channels.py +1 -1
  46. py_neuromodulation/utils/file_writer.py +110 -0
  47. py_neuromodulation/utils/io.py +46 -5
  48. py_neuromodulation/utils/perf.py +156 -0
  49. py_neuromodulation/utils/pydantic_extensions.py +322 -0
  50. py_neuromodulation/utils/types.py +33 -107
  51. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/METADATA +23 -4
  52. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/RECORD +55 -35
  53. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/WHEEL +1 -1
  54. py_neuromodulation-0.1.1.dist-info/entry_points.txt +2 -0
  55. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,18 +1,44 @@
1
1
  from collections.abc import Iterable
2
2
  import numpy as np
3
- from typing import TYPE_CHECKING
3
+
4
+ from typing import TYPE_CHECKING, Annotated, Literal
5
+ from pydantic import Field
4
6
 
5
7
  from py_neuromodulation.utils.types import NMFeature, NMBaseModel
8
+ from py_neuromodulation.utils.pydantic_extensions import NMField
6
9
 
7
10
  if TYPE_CHECKING:
8
11
  from py_neuromodulation import NMSettings
9
- from mne.io import RawArray
10
- from mne import Epochs
12
+
13
+
14
+ ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
15
+
16
+
17
+ MNE_CONNECTIVITY_METHOD = Literal[
18
+ "coh",
19
+ "cohy",
20
+ "imcoh",
21
+ "cacoh",
22
+ "mic",
23
+ "mim",
24
+ "plv",
25
+ "ciplv",
26
+ "ppc",
27
+ "pli",
28
+ "dpli",
29
+ "wpli",
30
+ "wpli2_debiased",
31
+ "gc",
32
+ "gc_tr",
33
+ ]
34
+
35
+ MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"]
11
36
 
12
37
 
13
38
  class MNEConnectivitySettings(NMBaseModel):
14
- method: str = "plv"
15
- mode: str = "multitaper"
39
+ method: MNE_CONNECTIVITY_METHOD = NMField(default="plv")
40
+ mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper")
41
+ channels: list[ListOfTwoStr] = []
16
42
 
17
43
 
18
44
  class MNEConnectivity(NMFeature):
@@ -22,102 +48,42 @@ class MNEConnectivity(NMFeature):
22
48
  ch_names: Iterable[str],
23
49
  sfreq: float,
24
50
  ) -> None:
25
- from mne import create_info
26
-
27
51
  self.settings = settings
28
52
 
29
53
  self.ch_names = ch_names
30
54
  self.sfreq = sfreq
31
55
 
56
+ self.channels = settings.mne_connectivity_settings.channels
57
+
32
58
  # Params used by spectral_connectivity_epochs
33
59
  self.mode = settings.mne_connectivity_settings.mode
34
60
  self.method = settings.mne_connectivity_settings.method
61
+ self.indices = ([], []) # convert channel names to channel indices in data
62
+ for con_idx in range(len(self.channels)):
63
+ seed_name = self.channels[con_idx][0]
64
+ target_name = self.channels[con_idx][1]
65
+ seed_name_reref = [ch for ch in self.ch_names if ch.startswith(seed_name)][0]
66
+ target_name_reref = [ch for ch in self.ch_names if ch.startswith(target_name)][0]
67
+ self.indices[0].append(self.ch_names.index(seed_name_reref))
68
+ self.indices[1].append(self.ch_names.index(target_name_reref))
35
69
 
36
70
  self.fbands = settings.frequency_ranges_hz
37
71
  self.fband_ranges: list = []
38
72
  self.result_keys = []
39
73
 
40
- self.raw_info = create_info(ch_names=self.ch_names, sfreq=self.sfreq)
41
- self.raw_array: "RawArray"
42
- self.epochs: "Epochs"
43
74
  self.prev_batch_shape: tuple = (-1, -1) # sentinel value
44
75
 
45
76
  def calc_feature(self, data: np.ndarray) -> dict:
46
- from mne.io import RawArray
47
- from mne import Epochs
48
77
  from mne_connectivity import spectral_connectivity_epochs
49
- import pandas as pd
50
-
51
- time_samples_s = data.shape[1] / self.sfreq
52
- epoch_length: float = 1 # TODO: Make this a parameter?
53
-
54
- if epoch_length > time_samples_s:
55
- raise ValueError(
56
- f"the intended epoch length for mne connectivity: {epoch_length}s"
57
- f" are longer than the passed data array {np.round(time_samples_s, 2)}s"
58
- )
59
-
60
- # Only reinitialize the raw_array and epochs object if the data shape has changed
61
- # That could mean that the channels have been re-selected, or we're in the last batch
62
- # TODO: If sfreq or channels change, do we re-initialize the whole Stream object?
63
- if data.shape != self.prev_batch_shape:
64
- self.raw_array = RawArray(
65
- data=data,
66
- info=self.raw_info,
67
- copy=None, # type: ignore
68
- verbose=False,
69
- )
70
-
71
- # self.events = make_fixed_length_events(self.raw_array, duration=epoch_length)
72
- # Equivalent code for those parameters:
73
- event_times = np.arange(
74
- 0, data.shape[-1], self.sfreq * epoch_length, dtype=int
75
- )
76
- events = np.column_stack(
77
- (
78
- event_times,
79
- np.zeros_like(event_times, dtype=int),
80
- np.ones_like(event_times, dtype=int),
81
- )
82
- )
83
-
84
- # there need to be minimum 2 of two epochs, otherwise mne_connectivity
85
- # is not correctly initialized
86
- if events.shape[0] < 2:
87
- raise RuntimeError(
88
- f"A minimum of 2 epochs is required for mne_connectivity,"
89
- f" got only {events.shape[0]}. Increase settings['segment_length_features_ms']"
90
- )
91
-
92
- self.epochs = Epochs(
93
- self.raw_array,
94
- events=events,
95
- event_id={"rest": 1},
96
- tmin=0,
97
- tmax=epoch_length,
98
- baseline=None,
99
- reject_by_annotation=True,
100
- verbose=False,
101
- )
102
-
103
- # Trick the function "spectral_connectivity_epochs" into not calling "add_annotations_to_metadata"
104
- # TODO: This is a hack, and maybe needs a fix in the mne_connectivity library
105
- self.epochs._metadata = pd.DataFrame(index=np.arange(events.shape[0]))
106
-
107
- else:
108
- # As long as the initialization parameters, channels, sfreq and batch size are the same
109
- # We can re-use the existing epochs object by updating the raw data
110
- self.raw_array._data = data
111
- self.epochs._raw = self.raw_array
112
78
 
113
79
  # n_jobs is here kept to 1, since setup of the multiprocessing Pool
114
80
  # takes longer than most batch computing sizes
115
81
  spec_out = spectral_connectivity_epochs(
116
- data=self.epochs,
82
+ data=np.expand_dims(data, axis=0), # add singleton epoch dimension
117
83
  sfreq=self.sfreq,
118
84
  method=self.method,
119
85
  mode=self.mode,
120
- indices=(np.array([0, 0, 1, 1]), np.array([2, 3, 2, 3])),
86
+ indices=self.indices,
121
87
  verbose=False,
122
88
  )
123
89
  dat_conn: np.ndarray = spec_out.get_data()
@@ -127,20 +93,27 @@ class MNEConnectivity(NMFeature):
127
93
  for fband_range in self.fbands.values():
128
94
  self.fband_ranges.append(
129
95
  np.where(
130
- (np.array(spec_out.freqs) > fband_range[0])
131
- & (np.array(spec_out.freqs) < fband_range[1])
96
+ (np.array(spec_out.freqs) >= fband_range[0])
97
+ & (np.array(spec_out.freqs) <= fband_range[1])
132
98
  )[0]
133
99
  )
134
100
 
135
- # TODO: If I compute the mean for the entire fband, results are almost the same before
136
- # normalization (0.9999999... vs 1.0), but some change wildly after normalization (-3 vs 0)
137
- # Investigate why, is this a bug in normalization?
138
101
  feature_results = {}
139
- for conn in np.arange(dat_conn.shape[0]):
140
- for fband_idx, fband in enumerate(self.fbands):
141
- feature_results["_".join(["ch1", self.method, str(conn), fband])] = (
142
- np.mean(dat_conn[conn, self.fband_ranges[fband_idx]])
143
- )
102
+ for con_idx in np.arange(dat_conn.shape[0]):
103
+ for fband_idx, fband_name in enumerate(self.fbands):
104
+ # TODO: Add support for max_fband and max_allfbands
105
+ feature_results[
106
+ "_".join(
107
+ [
108
+ self.method,
109
+ self.channels[con_idx][0], # seed channel name
110
+ "to",
111
+ self.channels[con_idx][1], # target channel name
112
+ "mean_fband",
113
+ fband_name,
114
+ ]
115
+ )
116
+ ] = np.mean(dat_conn[con_idx, self.fband_ranges[fband_idx]])
144
117
 
145
118
  # Store current experiment parameters to check if re-initialization is needed
146
119
  self.prev_batch_shape = data.shape
@@ -3,6 +3,7 @@ import numpy as np
3
3
  from itertools import product
4
4
 
5
5
  from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
6
+ from py_neuromodulation.utils.pydantic_extensions import NMField
6
7
  from typing import TYPE_CHECKING
7
8
 
8
9
  if TYPE_CHECKING:
@@ -17,12 +18,12 @@ class OscillatoryFeatures(BoolSelector):
17
18
 
18
19
 
19
20
  class OscillatorySettings(NMBaseModel):
20
- windowlength_ms: int = 1000
21
+ windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"})
21
22
  log_transform: bool = True
22
23
  features: OscillatoryFeatures = OscillatoryFeatures(
23
24
  mean=True, median=False, std=False, max=False
24
25
  )
25
- return_spectrum: bool = False
26
+ return_spectrum: bool = True
26
27
 
27
28
 
28
29
  ESTIMATOR_DICT = {
@@ -176,7 +177,7 @@ class Welch(OscillatoryFeature):
176
177
  if self.settings.return_spectrum:
177
178
  combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
178
179
  for (ch_idx, ch_name), (idx, f) in combinations:
179
- feature_results[f"{ch_name}_welch_psd_{str(f)}"] = Z[ch_idx][idx]
180
+ feature_results[f"{ch_name}_welch_psd_{int(f)}"] = Z[ch_idx][idx]
180
181
 
181
182
  return feature_results
182
183
 
@@ -242,7 +243,7 @@ class STFT(OscillatoryFeature):
242
243
  if self.settings.return_spectrum:
243
244
  combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
244
245
  for (ch_idx, ch_name), (idx, f) in combinations:
245
- feature_results[f"{ch_name}_stft_psd_{str(f)}"] = Z[ch_idx].mean(
246
+ feature_results[f"{ch_name}_stft_psd_{int(f)}"] = Z[ch_idx].mean(
246
247
  axis=1
247
248
  )[idx]
248
249
 
@@ -267,6 +267,14 @@ class SharpwaveAnalyzer(NMFeature):
267
267
 
268
268
  # for each feature take the respective fun.
269
269
  for feature_name, estimator_name, estimator in estimator_combinations:
270
+ if feature_name == "num_peaks":
271
+ key_name = f"{ch_name}_Sharpwave_{feature_name}_{filter_name}"
272
+ if len(waveform_results[feature_name]) == 1:
273
+ dict_ch_features[key_name][key_name_pt] = waveform_results[feature_name][0]
274
+ continue
275
+ else:
276
+ raise ValueError("num_peaks should be a list with length 1")
277
+ # there can be only one num_peak in each batch
270
278
  feature_data = waveform_results[feature_name]
271
279
  key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
272
280
 
@@ -280,12 +288,25 @@ class SharpwaveAnalyzer(NMFeature):
280
288
 
281
289
  # the key_name stays, since the estimator function stays between peaks and troughs
282
290
  for key_name, estimator in self.estimator_key_map.items():
291
+ if len(dict_ch_features[key_name]) == 0:
292
+ # might happen if num_peaks was written in estimator
293
+ # e.g. estimator["mean"] = ["num_peaks"]
294
+ # for conveniance this doesn't raise an exception
295
+ continue
296
+
283
297
  feature_results[key_name] = estimator(
284
298
  [
285
299
  list(dict_ch_features[key_name].values())[0],
286
300
  list(dict_ch_features[key_name].values())[1],
287
301
  ]
288
302
  )
303
+ # add here also the num_peaks features
304
+ if self.sw_settings.sharpwave_features.num_peaks:
305
+ for ch_name in self.ch_names:
306
+ for filter_name in self.filter_names:
307
+ key_name = f"{ch_name}_Sharpwave_num_peaks_{filter_name}"
308
+ feature_results[key_name] = np_mean([dict_ch_features[key_name]["Peak"],
309
+ dict_ch_features[key_name]["Trough"]])
289
310
  else:
290
311
  # otherwise, save all write all "flattened" key value pairs in feature_results
291
312
  for key, subdict in dict_ch_features.items():
@@ -1,7 +1,9 @@
1
1
  import numpy as np
2
2
  from typing import TYPE_CHECKING
3
3
 
4
+
4
5
  from py_neuromodulation.utils.types import NMBaseModel
6
+ from py_neuromodulation.utils.pydantic_extensions import NMErrorList
5
7
 
6
8
 
7
9
  if TYPE_CHECKING:
@@ -22,13 +24,22 @@ class KalmanSettings(NMBaseModel):
22
24
  "HFA",
23
25
  ]
24
26
 
25
- def validate_fbands(self, settings: "NMSettings") -> None:
26
- assert all(
27
+ def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
28
+ errors: NMErrorList = NMErrorList()
29
+
30
+ if not all(
27
31
  (item in settings.frequency_ranges_hz for item in self.frequency_bands)
28
- ), (
29
- "Frequency bands for Kalman filter must also be specified in "
30
- "bandpass_filter_settings."
31
- )
32
+ ):
33
+ errors.add_error(
34
+ "Frequency bands for Kalman filter must also be specified in "
35
+ "frequency_ranges_hz.",
36
+ location=[
37
+ "kalman_filter_settings",
38
+ "frequency_bands",
39
+ ],
40
+ )
41
+
42
+ return errors
32
43
 
33
44
 
34
45
  def define_KF(Tp, sigma_w, sigma_v):
@@ -0,0 +1,3 @@
1
+ from .backend.app_manager import AppManager as App
2
+
3
+ __all__ = ["App"]