py-neuromodulation 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
  2. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
  3. py_neuromodulation/__init__.py +12 -4
  4. py_neuromodulation/analysis/RMAP.py +3 -3
  5. py_neuromodulation/analysis/decode.py +55 -2
  6. py_neuromodulation/analysis/feature_reader.py +1 -0
  7. py_neuromodulation/analysis/stats.py +3 -3
  8. py_neuromodulation/default_settings.yaml +25 -17
  9. py_neuromodulation/features/bandpower.py +65 -23
  10. py_neuromodulation/features/bispectra.py +3 -7
  11. py_neuromodulation/features/bursts.py +9 -8
  12. py_neuromodulation/features/coherence.py +17 -9
  13. py_neuromodulation/features/feature_processor.py +4 -4
  14. py_neuromodulation/features/fooof.py +7 -6
  15. py_neuromodulation/features/mne_connectivity.py +25 -3
  16. py_neuromodulation/features/oscillatory.py +5 -4
  17. py_neuromodulation/features/sharpwaves.py +21 -0
  18. py_neuromodulation/filter/kalman_filter.py +17 -6
  19. py_neuromodulation/gui/__init__.py +3 -0
  20. py_neuromodulation/gui/backend/app_backend.py +419 -0
  21. py_neuromodulation/gui/backend/app_manager.py +345 -0
  22. py_neuromodulation/gui/backend/app_pynm.py +244 -0
  23. py_neuromodulation/gui/backend/app_socket.py +95 -0
  24. py_neuromodulation/gui/backend/app_utils.py +306 -0
  25. py_neuromodulation/gui/backend/app_window.py +202 -0
  26. py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
  27. py_neuromodulation/gui/frontend/assets/index-NbJiOU5a.js +300133 -0
  28. py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
  29. py_neuromodulation/gui/frontend/charite.svg +16 -0
  30. py_neuromodulation/gui/frontend/index.html +14 -0
  31. py_neuromodulation/gui/window_api.py +115 -0
  32. py_neuromodulation/lsl_api.cfg +3 -0
  33. py_neuromodulation/processing/data_preprocessor.py +9 -2
  34. py_neuromodulation/processing/filter_preprocessing.py +43 -27
  35. py_neuromodulation/processing/normalization.py +32 -17
  36. py_neuromodulation/processing/projection.py +2 -2
  37. py_neuromodulation/processing/resample.py +6 -2
  38. py_neuromodulation/run_gui.py +36 -0
  39. py_neuromodulation/stream/__init__.py +7 -1
  40. py_neuromodulation/stream/backend_interface.py +47 -0
  41. py_neuromodulation/stream/data_processor.py +24 -3
  42. py_neuromodulation/stream/mnelsl_player.py +121 -21
  43. py_neuromodulation/stream/mnelsl_stream.py +9 -17
  44. py_neuromodulation/stream/settings.py +80 -34
  45. py_neuromodulation/stream/stream.py +82 -62
  46. py_neuromodulation/utils/channels.py +1 -1
  47. py_neuromodulation/utils/file_writer.py +110 -0
  48. py_neuromodulation/utils/io.py +46 -5
  49. py_neuromodulation/utils/perf.py +156 -0
  50. py_neuromodulation/utils/pydantic_extensions.py +322 -0
  51. py_neuromodulation/utils/types.py +33 -107
  52. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/METADATA +27 -22
  53. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/RECORD +56 -36
  54. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/WHEEL +1 -1
  55. py_neuromodulation-0.1.0.dist-info/entry_points.txt +2 -0
  56. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -70,7 +70,6 @@ class NiiToMNI:
70
70
  coord_arr = np.array(coord_)
71
71
  ival_non_zero = ival_arr[ival != 0]
72
72
  coord_non_zero = coord_arr[ival != 0]
73
- print(coord_non_zero.shape)
74
73
 
75
74
  return coord_non_zero, ival_non_zero
76
75
 
@@ -58,8 +58,6 @@ def write_connectome_mat(
58
58
  dict_connectome[
59
59
  f[f.find("ROI-") + 4 : f.find("_func_seed_AvgR_Fz.nii")]
60
60
  ] = fp
61
-
62
- print(idx)
63
61
  # save the dictionary
64
62
  sio.savemat(
65
63
  PATH_CONNECTOME,
@@ -4,11 +4,12 @@ from pathlib import PurePath
4
4
  from importlib.metadata import version
5
5
  from py_neuromodulation.utils.logging import NMLogger
6
6
 
7
+
7
8
  #####################################
8
9
  # Globals and environment variables #
9
10
  #####################################
10
11
 
11
- __version__ = version("py_neuromodulation") # get version from pyproject.toml
12
+ __version__ = version("py_neuromodulation")
12
13
 
13
14
  # Check if the module is running headless (no display) for tests and doc builds
14
15
  PYNM_HEADLESS: bool = not os.environ.get("DISPLAY")
@@ -16,6 +17,9 @@ PYNM_DIR = PurePath(__file__).parent # Define constant for py_nm directory
16
17
 
17
18
  os.environ["MPLBACKEND"] = "agg" if PYNM_HEADLESS else "qtagg" # Set matplotlib backend
18
19
 
20
+ os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file
21
+
22
+
19
23
  # Set environment variable MNE_LSL_LIB (required to import Stream below)
20
24
  LSL_DICT = {
21
25
  "windows_32bit": "windows/x86/liblsl.1.16.2.dll",
@@ -34,15 +38,17 @@ LSL_DICT = {
34
38
 
35
39
  PLATFORM = platform.system().lower().strip()
36
40
  ARCH = platform.architecture()[0]
41
+
37
42
  match PLATFORM:
38
43
  case "windows":
39
44
  KEY = PLATFORM + "_" + ARCH
40
45
  case "darwin":
41
46
  KEY = PLATFORM + "_" + platform.processor()
42
47
  case "linux":
43
- DIST = platform.freedesktop_os_release()["VERSION_CODENAME"]
44
- KEY = PLATFORM + "_" + DIST + "_" + ARCH
45
- if KEY not in LSL_DICT:
48
+ if "VERSION_CODENAME" in platform.freedesktop_os_release().keys():
49
+ DIST = platform.freedesktop_os_release()["VERSION_CODENAME"]
50
+ KEY = PLATFORM + "_" + DIST + "_" + ARCH
51
+ else:
46
52
  KEY = PLATFORM + "_" + ARCH
47
53
  case _:
48
54
  KEY = ""
@@ -78,3 +84,5 @@ from . import stream
78
84
  from . import analysis
79
85
 
80
86
  from .stream.settings import get_default_settings, get_fast_compute, reset_settings
87
+
88
+ from .gui.backend.app_manager import AppManager as App
@@ -8,8 +8,8 @@ import pandas as pd
8
8
  import nibabel as nib
9
9
  from matplotlib import pyplot as plt
10
10
 
11
- from py_neuromodulation.plots import reg_plot
12
- from py_neuromodulation.types import _PathLike
11
+ from py_neuromodulation.analysis import reg_plot
12
+ from py_neuromodulation.utils.types import _PathLike
13
13
  from py_neuromodulation import PYNM_DIR
14
14
 
15
15
  LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL = [256, 385, 417, 447, 819, 914]
@@ -335,7 +335,7 @@ class RMAPCross_Val_ChannelSelector:
335
335
  reshape: bool = True,
336
336
  ):
337
337
  if reshape:
338
- fp = np.reshape(fp, (91, 109, 91), order="F")
338
+ fp = np.reshape(fp, (91, 109, 91), order="C")
339
339
 
340
340
  img = nib.nifti1.Nifti1Image(fp, affine=affine)
341
341
 
@@ -6,14 +6,67 @@ from sklearn.metrics import r2_score
6
6
  import pandas as pd
7
7
  import numpy as np
8
8
  from copy import deepcopy
9
- from pathlib import PurePath
9
+ from pathlib import Path, PurePath
10
10
  import pickle
11
11
 
12
12
  from py_neuromodulation import logger
13
+ from py_neuromodulation.utils.types import _PathLike
13
14
 
14
15
  from typing import Callable
15
16
 
16
17
 
18
+ class RealTimeDecoder:
19
+ def __init__(self, model_path: _PathLike):
20
+ self.model_path = Path(model_path)
21
+ if not self.model_path.exists():
22
+ raise FileNotFoundError(f"Model file {self.model_path} not found")
23
+ if not self.model_path.is_file():
24
+ raise IsADirectoryError(f"Model file {self.model_path} is a directory")
25
+
26
+ if self.model_path.suffix == ".skops":
27
+ from skops import io as skops_io
28
+
29
+ self.model = skops_io.load(self.model_path)
30
+ else:
31
+ return NotImplementedError("Only skops models are supported")
32
+
33
+ def predict(
34
+ self,
35
+ feature_dict: dict,
36
+ channel: str | None = None,
37
+ fft_bands_only: bool = True,
38
+ ) -> dict:
39
+ try:
40
+ if channel is not None:
41
+ features_ch = {
42
+ f: feature_dict[f]
43
+ for f in feature_dict.keys()
44
+ if f.startswith(channel)
45
+ }
46
+ if fft_bands_only is True:
47
+ features_ch_fft = {
48
+ f: features_ch[f]
49
+ for f in features_ch.keys()
50
+ if "fft" in f and "psd" not in f
51
+ }
52
+ out = self.model.predict_proba(
53
+ np.array(list(features_ch_fft.values())).reshape(1, -1)
54
+ )
55
+ else:
56
+ out = self.model.predict_proba(features_ch)
57
+ else:
58
+ out = self.model.predict(feature_dict)
59
+ for decode_output_idx in range(out.shape[1]):
60
+ feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[
61
+ decode_output_idx
62
+ ]
63
+ logger.debug(f"Decoded values: {out}")
64
+ return feature_dict
65
+ except Exception as e:
66
+ logger.error(f"Error in decoding: {e}")
67
+ return feature_dict
68
+
69
+
17
70
  class CV_res:
18
71
  def __init__(
19
72
  self,
@@ -168,7 +221,7 @@ class Decoder:
168
221
  self.feature_names = [
169
222
  col
170
223
  for col in self.features.columns
171
- if not (("time" in col) or (self.label_name in col))
224
+ if any(col.startswith(used_ch) for used_ch in self.used_chs)
172
225
  ]
173
226
  self.data = np.nan_to_num(np.array(self.features[self.feature_names]))
174
227
 
@@ -51,6 +51,7 @@ class FeatureReader:
51
51
  self.feature_file = feature_file if feature_file else self.feature_list[0]
52
52
 
53
53
  FILE_BASENAME = PurePath(self.feature_file).stem
54
+ # the features are saved in a directory feature_file, and each file hold the name starting with feature_file
54
55
  PATH_READ_FILE = str(PurePath(self.feature_dir, FILE_BASENAME, FILE_BASENAME))
55
56
 
56
57
  self.settings = NMSettings.from_file(PATH_READ_FILE)
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
6
6
  # from numba import njit
7
7
  import numpy as np
8
8
  import pandas as pd
9
- import scipy.stats as stats
9
+ import scipy.stats as scipy_stats
10
10
 
11
11
 
12
12
  def fitlm_kfold(x, y, kfold_splits=5):
@@ -51,7 +51,7 @@ def permutationTestSpearmansRho(x, y, plot_distr=True, x_unit=None, p=5000):
51
51
  """
52
52
 
53
53
  # compute ground truth difference
54
- gT = stats.spearmanr(x, y)[0]
54
+ gT = scipy_stats.spearmanr(x, y)[0]
55
55
  #
56
56
  pV = np.array((x, y))
57
57
  # Initialize permutation:
@@ -65,7 +65,7 @@ def permutationTestSpearmansRho(x, y, plot_distr=True, x_unit=None, p=5000):
65
65
  random.shuffle(args_order_2)
66
66
  # Compute permuted absolute difference of your two sampled
67
67
  # distributions and store it in pD:
68
- pD.append(stats.spearmanr(pV[0, args_order], pV[1, args_order_2])[0])
68
+ pD.append(scipy_stats.spearmanr(pV[0, args_order], pV[1, args_order_2])[0])
69
69
 
70
70
  # calculate p value
71
71
  if gT < 0:
@@ -1,4 +1,15 @@
1
- ---
1
+ # We
2
+ # should
3
+ # have
4
+ # a
5
+ # brief
6
+ # explanation
7
+ # of
8
+ # the
9
+ # settings
10
+ # format
11
+ # here
12
+
2
13
  ########################
3
14
  ### General settings ###
4
15
  ########################
@@ -51,12 +62,8 @@ preprocessing_filter:
51
62
  lowpass_filter: true
52
63
  highpass_filter: true
53
64
  bandpass_filter: true
54
- bandstop_filter_settings:
55
- frequency_low_hz: 100
56
- frequency_high_hz: 160
57
- bandpass_filter_settings:
58
- frequency_low_hz: 3
59
- frequency_high_hz: 200
65
+ bandstop_filter_settings: [100, 160] # [low_hz, high_hz]
66
+ bandpass_filter_settings: [3, 200] # [hz, _hz]
60
67
  lowpass_filter_cutoff_hz: 200
61
68
  highpass_filter_cutoff_hz: 3
62
69
 
@@ -71,6 +78,7 @@ postprocessing:
71
78
  feature_normalization_settings:
72
79
  normalization_time_s: 30
73
80
  normalization_method: zscore # supported methods: mean, median, zscore, zscore-median, quantile, power, robust, minmax
81
+ normalize_psd: false
74
82
  clip: 3
75
83
 
76
84
  project_cortex_settings:
@@ -91,7 +99,7 @@ fft_settings:
91
99
  median: false
92
100
  std: false
93
101
  max: false
94
- return_spectrum: false
102
+ return_spectrum: true
95
103
 
96
104
  welch_settings:
97
105
  windowlength_ms: 1000
@@ -101,7 +109,7 @@ welch_settings:
101
109
  median: false
102
110
  std: false
103
111
  max: false
104
- return_spectrum: false
112
+ return_spectrum: true
105
113
 
106
114
  stft_settings:
107
115
  windowlength_ms: 500
@@ -136,7 +144,7 @@ kalman_filter_settings:
136
144
  frequency_bands:
137
145
  [theta, alpha, low_beta, high_beta, low_gamma, high_gamma, HFA]
138
146
 
139
- burst_settings:
147
+ bursts_settings:
140
148
  threshold: 75
141
149
  time_duration_s: 30
142
150
  frequency_bands: [low_beta, high_beta, low_gamma]
@@ -161,11 +169,9 @@ sharpwave_analysis_settings:
161
169
  rise_steepness: false
162
170
  decay_steepness: false
163
171
  slope_ratio: false
164
- filter_ranges_hz:
165
- - frequency_low_hz: 5
166
- frequency_high_hz: 80
167
- - frequency_low_hz: 5
168
- frequency_high_hz: 30
172
+ filter_ranges_hz: # list of [low_hz, high_hz]
173
+ - [5, 80]
174
+ - [5, 30]
169
175
  detect_troughs:
170
176
  estimate: true
171
177
  distance_troughs_ms: 10
@@ -174,6 +180,7 @@ sharpwave_analysis_settings:
174
180
  estimate: true
175
181
  distance_troughs_ms: 5
176
182
  distance_peaks_ms: 10
183
+ # TONI: Reverse this setting? e.g. interval: [mean, var]
177
184
  estimator:
178
185
  mean: [interval]
179
186
  median: []
@@ -193,6 +200,7 @@ coherence_settings:
193
200
  method:
194
201
  coh: true
195
202
  icoh: true
203
+ nperseg: 128
196
204
 
197
205
  fooof_settings:
198
206
  aperiodic:
@@ -222,8 +230,8 @@ nolds_settings:
222
230
  frequency_bands: [low_beta]
223
231
 
224
232
  mne_connectiviy_settings:
225
- method: plv
226
- mode: multitaper
233
+ method: plv # One of ['coh', 'cohy', 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli','wpli', 'wpli2_debiased', 'gc', 'gc_tr']
234
+ mode: multitaper # One of ['multitaper', 'fourier', 'cwt_morlet']
227
235
 
228
236
  bispectrum_settings:
229
237
  f1s: [5, 35]
@@ -2,8 +2,13 @@ import numpy as np
2
2
  from collections.abc import Sequence
3
3
  from typing import TYPE_CHECKING
4
4
  from pydantic import field_validator
5
-
6
5
  from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
6
+ from py_neuromodulation.utils.pydantic_extensions import (
7
+ NMField,
8
+ NMErrorList,
9
+ create_validation_error,
10
+ )
11
+ from py_neuromodulation import logger
7
12
 
8
13
  if TYPE_CHECKING:
9
14
  from py_neuromodulation.stream.settings import NMSettings
@@ -17,15 +22,18 @@ class BandpowerFeatures(BoolSelector):
17
22
 
18
23
 
19
24
  class BandPowerSettings(NMBaseModel):
20
- segment_lengths_ms: dict[str, int] = {
21
- "theta": 1000,
22
- "alpha": 500,
23
- "low beta": 333,
24
- "high beta": 333,
25
- "low gamma": 100,
26
- "high gamma": 100,
27
- "HFA": 100,
28
- }
25
+ segment_lengths_ms: dict[str, int] = NMField(
26
+ default={
27
+ "theta": 1000,
28
+ "alpha": 500,
29
+ "low beta": 333,
30
+ "high beta": 333,
31
+ "low gamma": 100,
32
+ "high gamma": 100,
33
+ "HFA": 100,
34
+ },
35
+ custom_metadata={"field_type": "FrequencySegmentLength"},
36
+ )
29
37
  bandpower_features: BandpowerFeatures = BandpowerFeatures()
30
38
  log_transform: bool = True
31
39
  kalman_filter: bool = False
@@ -33,24 +41,58 @@ class BandPowerSettings(NMBaseModel):
33
41
  @field_validator("bandpower_features")
34
42
  @classmethod
35
43
  def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
36
- assert (
37
- len(bandpower_features.get_enabled()) > 0
38
- ), "Set at least one bandpower_feature to True."
39
-
44
+ if not len(bandpower_features.get_enabled()) > 0:
45
+ raise create_validation_error(
46
+ error_message="Set at least one bandpower_feature to True.",
47
+ location=["bandpass_filter_settings", "bandpower_features"],
48
+ )
40
49
  return bandpower_features
41
50
 
42
- def validate_fbands(self, settings: "NMSettings") -> None:
51
+ def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
52
+ """_summary_
53
+
54
+ :param settings: _description_
55
+ :type settings: NMSettings
56
+ :raises create_validation_error: _description_
57
+ :raises create_validation_error: _description_
58
+ :raises ValueError: _description_
59
+ """
60
+ errors: NMErrorList = NMErrorList()
61
+
43
62
  for fband_name, seg_length_fband in self.segment_lengths_ms.items():
44
- assert seg_length_fband <= settings.segment_length_features_ms, (
45
- f"segment length {seg_length_fband} needs to be smaller than "
46
- f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
47
- )
63
+ # Check that all frequency bands are defined in settings.frequency_ranges_hz
64
+ if fband_name not in settings.frequency_ranges_hz:
65
+ logger.warning(
66
+ f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms"
67
+ " is not defined in settings.frequency_ranges_hz"
68
+ )
69
+
70
+ # Check that all segment lengths are smaller than settings.segment_length_features_ms
71
+ if not seg_length_fband <= settings.segment_length_features_ms:
72
+ errors.add_error(
73
+ f"segment length {seg_length_fband} needs to be smaller than "
74
+ f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
75
+ location=[
76
+ "bandpass_filter_settings",
77
+ "segment_lengths_ms",
78
+ fband_name,
79
+ ],
80
+ )
48
81
 
82
+ # Check that all frequency bands defined in settings.frequency_ranges_hz
49
83
  for fband_name in settings.frequency_ranges_hz.keys():
50
- assert fband_name in self.segment_lengths_ms, (
51
- f"frequency range {fband_name} "
52
- "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
53
- )
84
+ if fband_name not in self.segment_lengths_ms:
85
+ errors.add_error(
86
+ f"frequency range {fband_name} "
87
+ "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms",
88
+ location=[
89
+ "bandpass_filter_settings",
90
+ "segment_lengths_ms",
91
+ fband_name,
92
+ ],
93
+ )
94
+
95
+ return errors
54
96
 
55
97
 
56
98
  class BandPower(NMFeature):
@@ -96,11 +96,8 @@ class Bispectra(NMFeature):
96
96
  def calc_feature(self, data: np.ndarray) -> dict:
97
97
  from pybispectra import compute_fft, WaveShape
98
98
 
99
- # PyBispectra's compute_fft uses PQDM to parallelize the calculation per channel
100
- # Is this necessary? Maybe the overhead of parallelization is not worth it
101
- # considering that we incur in it once per batch of data
102
99
  fft_coeffs, freqs = compute_fft(
103
- data=np.expand_dims(data, axis=(0)),
100
+ data=np.expand_dims(data, axis=0),
104
101
  sampling_freq=self.sfreq,
105
102
  n_points=data.shape[1],
106
103
  verbose=False,
@@ -127,12 +124,11 @@ class Bispectra(NMFeature):
127
124
  f1s=tuple(self.settings.f1s), # type: ignore
128
125
  f2s=tuple(self.settings.f2s), # type: ignore
129
126
  )
127
+ waveshape = waveshape.results.get_results(copy=False) # can overwrite obj with array
130
128
 
131
129
  feature_results = {}
132
130
  for ch_idx, ch_name in enumerate(self.ch_names):
133
- bispectrum = waveshape._bicoherence[
134
- ch_idx
135
- ] # Same as waveshape.results._data, skips a copy
131
+ bispectrum = waveshape[ch_idx]
136
132
 
137
133
  for component in self.settings.components.get_enabled():
138
134
  spectrum_ch = COMPONENT_DICT[component](bispectrum)
@@ -7,11 +7,12 @@ else:
7
7
  from collections.abc import Sequence
8
8
  from itertools import product
9
9
 
10
- from pydantic import Field, field_validator
10
+ from pydantic import field_validator
11
11
  from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
12
+ from py_neuromodulation.utils.pydantic_extensions import NMField
12
13
 
13
14
  from typing import TYPE_CHECKING, Callable
14
- from py_neuromodulation.utils.types import create_validation_error
15
+ from py_neuromodulation.utils.pydantic_extensions import create_validation_error
15
16
 
16
17
  if TYPE_CHECKING:
17
18
  from py_neuromodulation import NMSettings
@@ -46,8 +47,8 @@ class BurstFeatures(BoolSelector):
46
47
 
47
48
 
48
49
  class BurstsSettings(NMBaseModel):
49
- threshold: float = Field(default=75, ge=0, le=100)
50
- time_duration_s: float = Field(default=30, ge=0)
50
+ threshold: float = NMField(default=75, ge=0)
51
+ time_duration_s: float = NMField(default=30, ge=0, custom_metadata={"unit": "s"})
51
52
  frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
52
53
  burst_features: BurstFeatures = BurstFeatures()
53
54
 
@@ -64,16 +65,16 @@ class Bursts(NMFeature):
64
65
  settings.validate()
65
66
 
66
67
  # Validate that all frequency bands are defined in the settings
67
- for fband_burst in settings.burst_settings.frequency_bands:
68
+ for fband_burst in settings.bursts_settings.frequency_bands:
68
69
  if fband_burst not in list(settings.frequency_ranges_hz.keys()):
69
70
  raise create_validation_error(
70
71
  f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']",
71
- loc=["burst_settings", "frequency_bands"],
72
+ location=["burst_settings", "frequency_bands"],
72
73
  )
73
74
 
74
75
  from py_neuromodulation.filter import MNEFilter
75
76
 
76
- self.settings = settings.burst_settings
77
+ self.settings = settings.bursts_settings
77
78
  self.sfreq = sfreq
78
79
  self.ch_names = ch_names
79
80
  self.segment_length_features_s = settings.segment_length_features_ms / 1000
@@ -83,7 +84,7 @@ class Bursts(NMFeature):
83
84
  / settings.sampling_rate_features_hz
84
85
  )
85
86
 
86
- self.fband_names = settings.burst_settings.frequency_bands
87
+ self.fband_names = settings.bursts_settings.frequency_bands
87
88
 
88
89
  f_ranges: list[tuple[float, float]] = [
89
90
  (
@@ -3,7 +3,7 @@ from collections.abc import Iterable
3
3
 
4
4
 
5
5
  from typing import TYPE_CHECKING, Annotated
6
- from pydantic import Field, field_validator
6
+ from pydantic import field_validator
7
7
 
8
8
  from py_neuromodulation.utils.types import (
9
9
  NMFeature,
@@ -11,6 +11,7 @@ from py_neuromodulation.utils.types import (
11
11
  FrequencyRange,
12
12
  NMBaseModel,
13
13
  )
14
+ from py_neuromodulation.utils.pydantic_extensions import NMField
14
15
  from py_neuromodulation import logger
15
16
 
16
17
  if TYPE_CHECKING:
@@ -26,16 +27,19 @@ class CoherenceFeatures(BoolSelector):
26
27
  mean_fband: bool = True
27
28
  max_fband: bool = True
28
29
  max_allfbands: bool = True
30
+
29
31
 
30
-
31
- ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
32
+ # TODO: make this into a pydantic model that only accepts names from
33
+ # the channels objects and does not accept the same string twice
34
+ ListOfTwoStr = Annotated[list[str], NMField(min_length=2, max_length=2)]
32
35
 
33
36
 
34
37
  class CoherenceSettings(NMBaseModel):
35
38
  features: CoherenceFeatures = CoherenceFeatures()
36
39
  method: CoherenceMethods = CoherenceMethods()
37
40
  channels: list[ListOfTwoStr] = []
38
- frequency_bands: list[str] = Field(default=["high_beta"], min_length=1)
41
+ nperseg: int = NMField(default=256, ge=1)
42
+ frequency_bands: list[str] = NMField(default=["high_beta"], min_length=1)
39
43
 
40
44
  @field_validator("frequency_bands")
41
45
  def fbands_spaces_to_underscores(cls, frequency_bands):
@@ -49,6 +53,7 @@ class CoherenceObject:
49
53
  window: str,
50
54
  fbands: list[FrequencyRange],
51
55
  fband_names: list[str],
56
+ nperseg: int,
52
57
  ch_1_name: str,
53
58
  ch_2_name: str,
54
59
  ch_1_idx: int,
@@ -65,6 +70,7 @@ class CoherenceObject:
65
70
  self.ch_2 = ch_2_name
66
71
  self.ch_1_idx = ch_1_idx
67
72
  self.ch_2_idx = ch_2_idx
73
+ self.nperseg = nperseg
68
74
  self.coh = coh
69
75
  self.icoh = icoh
70
76
  self.features_coh = features_coh
@@ -79,14 +85,15 @@ class CoherenceObject:
79
85
  def get_coh(self, feature_results, x, y):
80
86
  from scipy.signal import welch, csd
81
87
 
82
- self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=128)
83
- self.Pyy = welch(y, self.sfreq, self.window, nperseg=128)[1]
84
- self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=128)[1]
88
+ self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=self.nperseg)
89
+ self.Pyy = welch(y, self.sfreq, self.window, nperseg=self.nperseg)[1]
90
+ self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=self.nperseg)[1]
85
91
 
86
92
  if self.coh:
87
- self.coh_val = np.abs(self.Pxy**2) / (self.Pxx * self.Pyy)
93
+ # XXX: gives different output to abs(Sxy) / sqrt(Sxx * Syy)
94
+ self.coh_val = np.abs(self.Pxy) ** 2 / (self.Pxx * self.Pyy)
88
95
  if self.icoh:
89
- self.icoh_val = np.array(self.Pxy / (self.Pxx * self.Pyy)).imag
96
+ self.icoh_val = self.Pxy.imag / np.sqrt(self.Pxx * self.Pyy)
90
97
 
91
98
  for coh_idx, coh_type in enumerate([self.coh, self.icoh]):
92
99
  if coh_type:
@@ -180,6 +187,7 @@ class Coherence(NMFeature):
180
187
  "hann",
181
188
  fband_specs,
182
189
  fband_names,
190
+ self.settings.nperseg,
183
191
  ch_1_name,
184
192
  ch_2_name,
185
193
  ch_1_idx,
@@ -1,13 +1,13 @@
1
1
  from typing import Type, TYPE_CHECKING
2
2
 
3
- from py_neuromodulation.utils.types import NMFeature, FeatureName
3
+ from py_neuromodulation.utils.types import NMFeature, FEATURE_NAME
4
4
 
5
5
  if TYPE_CHECKING:
6
6
  import numpy as np
7
7
  from py_neuromodulation import NMSettings
8
8
 
9
9
 
10
- FEATURE_DICT: dict[FeatureName | str, str] = {
10
+ FEATURE_DICT: dict[FEATURE_NAME | str, str] = {
11
11
  "raw_hjorth": "Hjorth",
12
12
  "return_raw": "Raw",
13
13
  "bandpass_filter": "BandPower",
@@ -42,7 +42,7 @@ class FeatureProcessors:
42
42
  from importlib import import_module
43
43
 
44
44
  # Accept 'str' for custom features
45
- self.features: dict[FeatureName | str, NMFeature] = {
45
+ self.features: dict[FEATURE_NAME | str, NMFeature] = {
46
46
  feature_name: getattr(
47
47
  import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name]
48
48
  )(settings, ch_names, sfreq)
@@ -83,7 +83,7 @@ class FeatureProcessors:
83
83
 
84
84
  return feature_results
85
85
 
86
- def get_feature(self, fname: FeatureName) -> NMFeature:
86
+ def get_feature(self, fname: FEATURE_NAME) -> NMFeature:
87
87
  return self.features[fname]
88
88
 
89
89
 
@@ -9,6 +9,7 @@ from py_neuromodulation.utils.types import (
9
9
  BoolSelector,
10
10
  FrequencyRange,
11
11
  )
12
+ from py_neuromodulation.utils.pydantic_extensions import NMField
12
13
 
13
14
  if TYPE_CHECKING:
14
15
  from py_neuromodulation import NMSettings
@@ -29,11 +30,11 @@ class FooofPeriodicSettings(BoolSelector):
29
30
  class FooofSettings(NMBaseModel):
30
31
  aperiodic: FooofAperiodicSettings = FooofAperiodicSettings()
31
32
  periodic: FooofPeriodicSettings = FooofPeriodicSettings()
32
- windowlength_ms: float = 800
33
+ windowlength_ms: float = NMField(800, gt=0, custom_metadata={"unit": "ms"})
33
34
  peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12)
34
- max_n_peaks: int = 3
35
- min_peak_height: float = 0
36
- peak_threshold: float = 2
35
+ max_n_peaks: int = NMField(3, ge=0)
36
+ min_peak_height: float = NMField(0, ge=0)
37
+ peak_threshold: float = NMField(2, ge=0)
37
38
  freq_range_hz: FrequencyRange = FrequencyRange(2, 40)
38
39
  knee: bool = True
39
40
 
@@ -63,12 +64,12 @@ class FooofAnalyzer(NMFeature):
63
64
 
64
65
  assert (
65
66
  self.settings.windowlength_ms <= settings.segment_length_features_ms
66
- ), f"fooof windowlength_ms ({settings.fooof.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
67
+ ), f"fooof windowlength_ms ({settings.fooof_settings.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
67
68
 
68
69
  assert (
69
70
  self.settings.freq_range_hz[0] < sfreq
70
71
  and self.settings.freq_range_hz[1] < sfreq
71
- ), f"fooof frequency range needs to be below sfreq, got {settings.fooof.freq_range_hz}"
72
+ ), f"fooof frequency range needs to be below sfreq, got {settings.fooof_settings.freq_range_hz}"
72
73
 
73
74
  from fooof import FOOOFGroup
74
75