py-neuromodulation 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
  2. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
  3. py_neuromodulation/__init__.py +12 -4
  4. py_neuromodulation/analysis/RMAP.py +3 -3
  5. py_neuromodulation/analysis/decode.py +55 -2
  6. py_neuromodulation/analysis/feature_reader.py +1 -0
  7. py_neuromodulation/analysis/stats.py +3 -3
  8. py_neuromodulation/default_settings.yaml +25 -20
  9. py_neuromodulation/features/bandpower.py +65 -23
  10. py_neuromodulation/features/bursts.py +9 -8
  11. py_neuromodulation/features/coherence.py +7 -4
  12. py_neuromodulation/features/feature_processor.py +4 -4
  13. py_neuromodulation/features/fooof.py +7 -6
  14. py_neuromodulation/features/mne_connectivity.py +60 -87
  15. py_neuromodulation/features/oscillatory.py +5 -4
  16. py_neuromodulation/features/sharpwaves.py +21 -0
  17. py_neuromodulation/filter/kalman_filter.py +17 -6
  18. py_neuromodulation/gui/__init__.py +3 -0
  19. py_neuromodulation/gui/backend/app_backend.py +419 -0
  20. py_neuromodulation/gui/backend/app_manager.py +345 -0
  21. py_neuromodulation/gui/backend/app_pynm.py +253 -0
  22. py_neuromodulation/gui/backend/app_socket.py +97 -0
  23. py_neuromodulation/gui/backend/app_utils.py +306 -0
  24. py_neuromodulation/gui/backend/app_window.py +202 -0
  25. py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
  26. py_neuromodulation/gui/frontend/assets/index-_6V8ZfAS.js +300137 -0
  27. py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
  28. py_neuromodulation/gui/frontend/charite.svg +16 -0
  29. py_neuromodulation/gui/frontend/index.html +14 -0
  30. py_neuromodulation/gui/window_api.py +115 -0
  31. py_neuromodulation/lsl_api.cfg +3 -0
  32. py_neuromodulation/processing/data_preprocessor.py +9 -2
  33. py_neuromodulation/processing/filter_preprocessing.py +43 -27
  34. py_neuromodulation/processing/normalization.py +32 -17
  35. py_neuromodulation/processing/projection.py +2 -2
  36. py_neuromodulation/processing/resample.py +6 -2
  37. py_neuromodulation/run_gui.py +36 -0
  38. py_neuromodulation/stream/__init__.py +7 -1
  39. py_neuromodulation/stream/backend_interface.py +47 -0
  40. py_neuromodulation/stream/data_processor.py +24 -3
  41. py_neuromodulation/stream/mnelsl_player.py +121 -21
  42. py_neuromodulation/stream/mnelsl_stream.py +9 -17
  43. py_neuromodulation/stream/settings.py +80 -34
  44. py_neuromodulation/stream/stream.py +83 -62
  45. py_neuromodulation/utils/channels.py +1 -1
  46. py_neuromodulation/utils/file_writer.py +110 -0
  47. py_neuromodulation/utils/io.py +46 -5
  48. py_neuromodulation/utils/perf.py +156 -0
  49. py_neuromodulation/utils/pydantic_extensions.py +322 -0
  50. py_neuromodulation/utils/types.py +33 -107
  51. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/METADATA +23 -4
  52. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/RECORD +55 -35
  53. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/WHEEL +1 -1
  54. py_neuromodulation-0.1.1.dist-info/entry_points.txt +2 -0
  55. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -70,7 +70,6 @@ class NiiToMNI:
70
70
  coord_arr = np.array(coord_)
71
71
  ival_non_zero = ival_arr[ival != 0]
72
72
  coord_non_zero = coord_arr[ival != 0]
73
- print(coord_non_zero.shape)
74
73
 
75
74
  return coord_non_zero, ival_non_zero
76
75
 
@@ -58,8 +58,6 @@ def write_connectome_mat(
58
58
  dict_connectome[
59
59
  f[f.find("ROI-") + 4 : f.find("_func_seed_AvgR_Fz.nii")]
60
60
  ] = fp
61
-
62
- print(idx)
63
61
  # save the dictionary
64
62
  sio.savemat(
65
63
  PATH_CONNECTOME,
@@ -4,11 +4,12 @@ from pathlib import PurePath
4
4
  from importlib.metadata import version
5
5
  from py_neuromodulation.utils.logging import NMLogger
6
6
 
7
+
7
8
  #####################################
8
9
  # Globals and environment variables #
9
10
  #####################################
10
11
 
11
- __version__ = version("py_neuromodulation") # get version from pyproject.toml
12
+ __version__ = version("py_neuromodulation")
12
13
 
13
14
  # Check if the module is running headless (no display) for tests and doc builds
14
15
  PYNM_HEADLESS: bool = not os.environ.get("DISPLAY")
@@ -16,6 +17,9 @@ PYNM_DIR = PurePath(__file__).parent # Define constant for py_nm directory
16
17
 
17
18
  os.environ["MPLBACKEND"] = "agg" if PYNM_HEADLESS else "qtagg" # Set matplotlib backend
18
19
 
20
+ os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file
21
+
22
+
19
23
  # Set environment variable MNE_LSL_LIB (required to import Stream below)
20
24
  LSL_DICT = {
21
25
  "windows_32bit": "windows/x86/liblsl.1.16.2.dll",
@@ -34,15 +38,17 @@ LSL_DICT = {
34
38
 
35
39
  PLATFORM = platform.system().lower().strip()
36
40
  ARCH = platform.architecture()[0]
41
+
37
42
  match PLATFORM:
38
43
  case "windows":
39
44
  KEY = PLATFORM + "_" + ARCH
40
45
  case "darwin":
41
46
  KEY = PLATFORM + "_" + platform.processor()
42
47
  case "linux":
43
- DIST = platform.freedesktop_os_release()["VERSION_CODENAME"]
44
- KEY = PLATFORM + "_" + DIST + "_" + ARCH
45
- if KEY not in LSL_DICT:
48
+ if "VERSION_CODENAME" in platform.freedesktop_os_release().keys():
49
+ DIST = platform.freedesktop_os_release()["VERSION_CODENAME"]
50
+ KEY = PLATFORM + "_" + DIST + "_" + ARCH
51
+ else:
46
52
  KEY = PLATFORM + "_" + ARCH
47
53
  case _:
48
54
  KEY = ""
@@ -78,3 +84,5 @@ from . import stream
78
84
  from . import analysis
79
85
 
80
86
  from .stream.settings import get_default_settings, get_fast_compute, reset_settings
87
+
88
+ from .gui.backend.app_manager import AppManager as App
@@ -8,8 +8,8 @@ import pandas as pd
8
8
  import nibabel as nib
9
9
  from matplotlib import pyplot as plt
10
10
 
11
- from py_neuromodulation.plots import reg_plot
12
- from py_neuromodulation.types import _PathLike
11
+ from py_neuromodulation.analysis import reg_plot
12
+ from py_neuromodulation.utils.types import _PathLike
13
13
  from py_neuromodulation import PYNM_DIR
14
14
 
15
15
  LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL = [256, 385, 417, 447, 819, 914]
@@ -335,7 +335,7 @@ class RMAPCross_Val_ChannelSelector:
335
335
  reshape: bool = True,
336
336
  ):
337
337
  if reshape:
338
- fp = np.reshape(fp, (91, 109, 91), order="F")
338
+ fp = np.reshape(fp, (91, 109, 91), order="C")
339
339
 
340
340
  img = nib.nifti1.Nifti1Image(fp, affine=affine)
341
341
 
@@ -6,14 +6,67 @@ from sklearn.metrics import r2_score
6
6
  import pandas as pd
7
7
  import numpy as np
8
8
  from copy import deepcopy
9
- from pathlib import PurePath
9
+ from pathlib import Path, PurePath
10
10
  import pickle
11
11
 
12
12
  from py_neuromodulation import logger
13
+ from py_neuromodulation.utils.types import _PathLike
13
14
 
14
15
  from typing import Callable
15
16
 
16
17
 
18
+ class RealTimeDecoder:
19
+ def __init__(self, model_path: _PathLike):
20
+ self.model_path = Path(model_path)
21
+ if not self.model_path.exists():
22
+ raise FileNotFoundError(f"Model file {self.model_path} not found")
23
+ if not self.model_path.is_file():
24
+ raise IsADirectoryError(f"Model file {self.model_path} is a directory")
25
+
26
+ if self.model_path.suffix == ".skops":
27
+ from skops import io as skops_io
28
+
29
+ self.model = skops_io.load(self.model_path)
30
+ else:
31
+ return NotImplementedError("Only skops models are supported")
32
+
33
+ def predict(
34
+ self,
35
+ feature_dict: dict,
36
+ channel: str | None = None,
37
+ fft_bands_only: bool = True,
38
+ ) -> dict:
39
+ try:
40
+ if channel is not None:
41
+ features_ch = {
42
+ f: feature_dict[f]
43
+ for f in feature_dict.keys()
44
+ if f.startswith(channel)
45
+ }
46
+ if fft_bands_only is True:
47
+ features_ch_fft = {
48
+ f: features_ch[f]
49
+ for f in features_ch.keys()
50
+ if "fft" in f and "psd" not in f
51
+ }
52
+ out = self.model.predict_proba(
53
+ np.array(list(features_ch_fft.values())).reshape(1, -1)
54
+ )
55
+ else:
56
+ out = self.model.predict_proba(features_ch)
57
+ else:
58
+ out = self.model.predict(feature_dict)
59
+ for decode_output_idx in range(out.shape[1]):
60
+ feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[
61
+ decode_output_idx
62
+ ]
63
+ logger.debug(f"Decoded values: {out}")
64
+ return feature_dict
65
+ except Exception as e:
66
+ logger.error(f"Error in decoding: {e}")
67
+ return feature_dict
68
+
69
+
17
70
  class CV_res:
18
71
  def __init__(
19
72
  self,
@@ -168,7 +221,7 @@ class Decoder:
168
221
  self.feature_names = [
169
222
  col
170
223
  for col in self.features.columns
171
- if not (("time" in col) or (self.label_name in col))
224
+ if any(col.startswith(used_ch) for used_ch in self.used_chs)
172
225
  ]
173
226
  self.data = np.nan_to_num(np.array(self.features[self.feature_names]))
174
227
 
@@ -51,6 +51,7 @@ class FeatureReader:
51
51
  self.feature_file = feature_file if feature_file else self.feature_list[0]
52
52
 
53
53
  FILE_BASENAME = PurePath(self.feature_file).stem
54
+ # the features are saved in a directory feature_file, and each file hold the name starting with feature_file
54
55
  PATH_READ_FILE = str(PurePath(self.feature_dir, FILE_BASENAME, FILE_BASENAME))
55
56
 
56
57
  self.settings = NMSettings.from_file(PATH_READ_FILE)
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
6
6
  # from numba import njit
7
7
  import numpy as np
8
8
  import pandas as pd
9
- import scipy.stats as stats
9
+ import scipy.stats as scipy_stats
10
10
 
11
11
 
12
12
  def fitlm_kfold(x, y, kfold_splits=5):
@@ -51,7 +51,7 @@ def permutationTestSpearmansRho(x, y, plot_distr=True, x_unit=None, p=5000):
51
51
  """
52
52
 
53
53
  # compute ground truth difference
54
- gT = stats.spearmanr(x, y)[0]
54
+ gT = scipy_stats.spearmanr(x, y)[0]
55
55
  #
56
56
  pV = np.array((x, y))
57
57
  # Initialize permutation:
@@ -65,7 +65,7 @@ def permutationTestSpearmansRho(x, y, plot_distr=True, x_unit=None, p=5000):
65
65
  random.shuffle(args_order_2)
66
66
  # Compute permuted absolute difference of your two sampled
67
67
  # distributions and store it in pD:
68
- pD.append(stats.spearmanr(pV[0, args_order], pV[1, args_order_2])[0])
68
+ pD.append(scipy_stats.spearmanr(pV[0, args_order], pV[1, args_order_2])[0])
69
69
 
70
70
  # calculate p value
71
71
  if gT < 0:
@@ -1,4 +1,11 @@
1
- ---
1
+ # Settings should be modified either directly through
2
+ # the default_settings.yaml file or by creating a new
3
+ # settings.yaml file than can be loaded with
4
+ # settings.NMSettings.load(FILE_PATH)
5
+ #
6
+ # Alternatively, the settings can also be modified through the
7
+ # settings object directly, e.g. settings.features.raw_hjorth = False
8
+
2
9
  ########################
3
10
  ### General settings ###
4
11
  ########################
@@ -10,9 +17,9 @@ frequency_ranges_hz: # frequency band ranges can be added, removed and altered
10
17
  alpha: [8, 12]
11
18
  low_beta: [13, 20]
12
19
  high_beta: [20, 35]
13
- low_gamma: [60, 80]
14
- high_gamma: [90, 200]
15
- HFA: [200, 400]
20
+ #low_gamma: [60, 80]
21
+ #high_gamma: [90, 200]
22
+ #HFA: [200, 400]
16
23
 
17
24
  # Enabled features
18
25
  features:
@@ -51,12 +58,8 @@ preprocessing_filter:
51
58
  lowpass_filter: true
52
59
  highpass_filter: true
53
60
  bandpass_filter: true
54
- bandstop_filter_settings:
55
- frequency_low_hz: 100
56
- frequency_high_hz: 160
57
- bandpass_filter_settings:
58
- frequency_low_hz: 3
59
- frequency_high_hz: 200
61
+ bandstop_filter_settings: [100, 160] # [low_hz, high_hz]
62
+ bandpass_filter_settings: [3, 200] # [hz, _hz]
60
63
  lowpass_filter_cutoff_hz: 200
61
64
  highpass_filter_cutoff_hz: 3
62
65
 
@@ -71,6 +74,7 @@ postprocessing:
71
74
  feature_normalization_settings:
72
75
  normalization_time_s: 30
73
76
  normalization_method: zscore # supported methods: mean, median, zscore, zscore-median, quantile, power, robust, minmax
77
+ normalize_psd: false
74
78
  clip: 3
75
79
 
76
80
  project_cortex_settings:
@@ -136,10 +140,10 @@ kalman_filter_settings:
136
140
  frequency_bands:
137
141
  [theta, alpha, low_beta, high_beta, low_gamma, high_gamma, HFA]
138
142
 
139
- burst_settings:
143
+ bursts_settings:
140
144
  threshold: 75
141
145
  time_duration_s: 30
142
- frequency_bands: [low_beta, high_beta, low_gamma]
146
+ frequency_bands: [low_beta, high_beta] # low_gamma
143
147
  burst_features:
144
148
  duration: true
145
149
  amplitude: true
@@ -161,11 +165,9 @@ sharpwave_analysis_settings:
161
165
  rise_steepness: false
162
166
  decay_steepness: false
163
167
  slope_ratio: false
164
- filter_ranges_hz:
165
- - frequency_low_hz: 5
166
- frequency_high_hz: 80
167
- - frequency_low_hz: 5
168
- frequency_high_hz: 30
168
+ filter_ranges_hz: # list of [low_hz, high_hz]
169
+ - [5, 80]
170
+ - [5, 30]
169
171
  detect_troughs:
170
172
  estimate: true
171
173
  distance_troughs_ms: 10
@@ -174,6 +176,7 @@ sharpwave_analysis_settings:
174
176
  estimate: true
175
177
  distance_troughs_ms: 5
176
178
  distance_peaks_ms: 10
179
+ # TONI: Reverse this setting? e.g. interval: [mean, var]
177
180
  estimator:
178
181
  mean: [interval]
179
182
  median: []
@@ -183,7 +186,7 @@ sharpwave_analysis_settings:
183
186
  apply_estimator_between_peaks_and_troughs: true
184
187
 
185
188
  coherence_settings:
186
- channels: [] # List of channel pairs, empty by default. Each pair is a list of two channels.
189
+ channels: [] # List of channel pairs, empty by default. Each pair is a list of two channels, where the first channel is the seed and the second channel is the target.
187
190
  # Example channels: [[STN_RIGHT_0, ECOG_RIGHT_0], [STN_RIGHT_1, ECOG_RIGHT_1]]
188
191
  frequency_bands: [high_beta]
189
192
  features:
@@ -223,8 +226,10 @@ nolds_settings:
223
226
  frequency_bands: [low_beta]
224
227
 
225
228
  mne_connectiviy_settings:
226
- method: plv
227
- mode: multitaper
229
+ channels: [] # List of channel pairs, empty by default. Each pair is a list of two channels, where the first channel is the seed and the second channel is the target.
230
+ # Example channels: [[STN_RIGHT_0, ECOG_RIGHT_0], [STN_RIGHT_1, ECOG_RIGHT_1]]
231
+ method: plv # One of ['coh', 'cohy', 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli','wpli', 'wpli2_debiased', 'gc', 'gc_tr']
232
+ mode: multitaper # One of ['multitaper', 'fourier', 'cwt_morlet']
228
233
 
229
234
  bispectrum_settings:
230
235
  f1s: [5, 35]
@@ -2,8 +2,13 @@ import numpy as np
2
2
  from collections.abc import Sequence
3
3
  from typing import TYPE_CHECKING
4
4
  from pydantic import field_validator
5
-
6
5
  from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
6
+ from py_neuromodulation.utils.pydantic_extensions import (
7
+ NMField,
8
+ NMErrorList,
9
+ create_validation_error,
10
+ )
11
+ from py_neuromodulation import logger
7
12
 
8
13
  if TYPE_CHECKING:
9
14
  from py_neuromodulation.stream.settings import NMSettings
@@ -17,15 +22,18 @@ class BandpowerFeatures(BoolSelector):
17
22
 
18
23
 
19
24
  class BandPowerSettings(NMBaseModel):
20
- segment_lengths_ms: dict[str, int] = {
21
- "theta": 1000,
22
- "alpha": 500,
23
- "low beta": 333,
24
- "high beta": 333,
25
- "low gamma": 100,
26
- "high gamma": 100,
27
- "HFA": 100,
28
- }
25
+ segment_lengths_ms: dict[str, int] = NMField(
26
+ default={
27
+ "theta": 1000,
28
+ "alpha": 500,
29
+ "low beta": 333,
30
+ "high beta": 333,
31
+ "low gamma": 100,
32
+ "high gamma": 100,
33
+ "HFA": 100,
34
+ },
35
+ custom_metadata={"field_type": "FrequencySegmentLength"},
36
+ )
29
37
  bandpower_features: BandpowerFeatures = BandpowerFeatures()
30
38
  log_transform: bool = True
31
39
  kalman_filter: bool = False
@@ -33,24 +41,58 @@ class BandPowerSettings(NMBaseModel):
33
41
  @field_validator("bandpower_features")
34
42
  @classmethod
35
43
  def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
36
- assert (
37
- len(bandpower_features.get_enabled()) > 0
38
- ), "Set at least one bandpower_feature to True."
39
-
44
+ if not len(bandpower_features.get_enabled()) > 0:
45
+ raise create_validation_error(
46
+ error_message="Set at least one bandpower_feature to True.",
47
+ location=["bandpass_filter_settings", "bandpower_features"],
48
+ )
40
49
  return bandpower_features
41
50
 
42
- def validate_fbands(self, settings: "NMSettings") -> None:
51
+ def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
52
+ """_summary_
53
+
54
+ :param settings: _description_
55
+ :type settings: NMSettings
56
+ :raises create_validation_error: _description_
57
+ :raises create_validation_error: _description_
58
+ :raises ValueError: _description_
59
+ """
60
+ errors: NMErrorList = NMErrorList()
61
+
43
62
  for fband_name, seg_length_fband in self.segment_lengths_ms.items():
44
- assert seg_length_fband <= settings.segment_length_features_ms, (
45
- f"segment length {seg_length_fband} needs to be smaller than "
46
- f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
47
- )
63
+ # Check that all frequency bands are defined in settings.frequency_ranges_hz
64
+ if fband_name not in settings.frequency_ranges_hz:
65
+ logger.warning(
66
+ f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms"
67
+ " is not defined in settings.frequency_ranges_hz"
68
+ )
69
+
70
+ # Check that all segment lengths are smaller than settings.segment_length_features_ms
71
+ if not seg_length_fband <= settings.segment_length_features_ms:
72
+ errors.add_error(
73
+ f"segment length {seg_length_fband} needs to be smaller than "
74
+ f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
75
+ location=[
76
+ "bandpass_filter_settings",
77
+ "segment_lengths_ms",
78
+ fband_name,
79
+ ],
80
+ )
48
81
 
82
+ # Check that all frequency bands defined in settings.frequency_ranges_hz
49
83
  for fband_name in settings.frequency_ranges_hz.keys():
50
- assert fband_name in self.segment_lengths_ms, (
51
- f"frequency range {fband_name} "
52
- "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
53
- )
84
+ if fband_name not in self.segment_lengths_ms:
85
+ errors.add_error(
86
+ f"frequency range {fband_name} "
87
+ "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms",
88
+ location=[
89
+ "bandpass_filter_settings",
90
+ "segment_lengths_ms",
91
+ fband_name,
92
+ ],
93
+ )
94
+
95
+ return errors
54
96
 
55
97
 
56
98
  class BandPower(NMFeature):
@@ -7,11 +7,12 @@ else:
7
7
  from collections.abc import Sequence
8
8
  from itertools import product
9
9
 
10
- from pydantic import Field, field_validator
10
+ from pydantic import field_validator
11
11
  from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
12
+ from py_neuromodulation.utils.pydantic_extensions import NMField
12
13
 
13
14
  from typing import TYPE_CHECKING, Callable
14
- from py_neuromodulation.utils.types import create_validation_error
15
+ from py_neuromodulation.utils.pydantic_extensions import create_validation_error
15
16
 
16
17
  if TYPE_CHECKING:
17
18
  from py_neuromodulation import NMSettings
@@ -46,8 +47,8 @@ class BurstFeatures(BoolSelector):
46
47
 
47
48
 
48
49
  class BurstsSettings(NMBaseModel):
49
- threshold: float = Field(default=75, ge=0, le=100)
50
- time_duration_s: float = Field(default=30, ge=0)
50
+ threshold: float = NMField(default=75, ge=0)
51
+ time_duration_s: float = NMField(default=30, ge=0, custom_metadata={"unit": "s"})
51
52
  frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
52
53
  burst_features: BurstFeatures = BurstFeatures()
53
54
 
@@ -64,16 +65,16 @@ class Bursts(NMFeature):
64
65
  settings.validate()
65
66
 
66
67
  # Validate that all frequency bands are defined in the settings
67
- for fband_burst in settings.burst_settings.frequency_bands:
68
+ for fband_burst in settings.bursts_settings.frequency_bands:
68
69
  if fband_burst not in list(settings.frequency_ranges_hz.keys()):
69
70
  raise create_validation_error(
70
71
  f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']",
71
- loc=["burst_settings", "frequency_bands"],
72
+ location=["burst_settings", "frequency_bands"],
72
73
  )
73
74
 
74
75
  from py_neuromodulation.filter import MNEFilter
75
76
 
76
- self.settings = settings.burst_settings
77
+ self.settings = settings.bursts_settings
77
78
  self.sfreq = sfreq
78
79
  self.ch_names = ch_names
79
80
  self.segment_length_features_s = settings.segment_length_features_ms / 1000
@@ -83,7 +84,7 @@ class Bursts(NMFeature):
83
84
  / settings.sampling_rate_features_hz
84
85
  )
85
86
 
86
- self.fband_names = settings.burst_settings.frequency_bands
87
+ self.fband_names = settings.bursts_settings.frequency_bands
87
88
 
88
89
  f_ranges: list[tuple[float, float]] = [
89
90
  (
@@ -3,7 +3,7 @@ from collections.abc import Iterable
3
3
 
4
4
 
5
5
  from typing import TYPE_CHECKING, Annotated
6
- from pydantic import Field, field_validator
6
+ from pydantic import field_validator
7
7
 
8
8
  from py_neuromodulation.utils.types import (
9
9
  NMFeature,
@@ -11,6 +11,7 @@ from py_neuromodulation.utils.types import (
11
11
  FrequencyRange,
12
12
  NMBaseModel,
13
13
  )
14
+ from py_neuromodulation.utils.pydantic_extensions import NMField
14
15
  from py_neuromodulation import logger
15
16
 
16
17
  if TYPE_CHECKING:
@@ -28,15 +29,17 @@ class CoherenceFeatures(BoolSelector):
28
29
  max_allfbands: bool = True
29
30
 
30
31
 
31
- ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
32
+ # TODO: make this into a pydantic model that only accepts names from
33
+ # the channels objects and does not accept the same string twice
34
+ ListOfTwoStr = Annotated[list[str], NMField(min_length=2, max_length=2)]
32
35
 
33
36
 
34
37
  class CoherenceSettings(NMBaseModel):
35
38
  features: CoherenceFeatures = CoherenceFeatures()
36
39
  method: CoherenceMethods = CoherenceMethods()
37
40
  channels: list[ListOfTwoStr] = []
38
- nperseg: int = Field(default=128, ge=0)
39
- frequency_bands: list[str] = Field(default=["high_beta"], min_length=1)
41
+ nperseg: int = NMField(default=256, ge=1)
42
+ frequency_bands: list[str] = NMField(default=["high_beta"], min_length=1)
40
43
 
41
44
  @field_validator("frequency_bands")
42
45
  def fbands_spaces_to_underscores(cls, frequency_bands):
@@ -1,13 +1,13 @@
1
1
  from typing import Type, TYPE_CHECKING
2
2
 
3
- from py_neuromodulation.utils.types import NMFeature, FeatureName
3
+ from py_neuromodulation.utils.types import NMFeature, FEATURE_NAME
4
4
 
5
5
  if TYPE_CHECKING:
6
6
  import numpy as np
7
7
  from py_neuromodulation import NMSettings
8
8
 
9
9
 
10
- FEATURE_DICT: dict[FeatureName | str, str] = {
10
+ FEATURE_DICT: dict[FEATURE_NAME | str, str] = {
11
11
  "raw_hjorth": "Hjorth",
12
12
  "return_raw": "Raw",
13
13
  "bandpass_filter": "BandPower",
@@ -42,7 +42,7 @@ class FeatureProcessors:
42
42
  from importlib import import_module
43
43
 
44
44
  # Accept 'str' for custom features
45
- self.features: dict[FeatureName | str, NMFeature] = {
45
+ self.features: dict[FEATURE_NAME | str, NMFeature] = {
46
46
  feature_name: getattr(
47
47
  import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name]
48
48
  )(settings, ch_names, sfreq)
@@ -83,7 +83,7 @@ class FeatureProcessors:
83
83
 
84
84
  return feature_results
85
85
 
86
- def get_feature(self, fname: FeatureName) -> NMFeature:
86
+ def get_feature(self, fname: FEATURE_NAME) -> NMFeature:
87
87
  return self.features[fname]
88
88
 
89
89
 
@@ -9,6 +9,7 @@ from py_neuromodulation.utils.types import (
9
9
  BoolSelector,
10
10
  FrequencyRange,
11
11
  )
12
+ from py_neuromodulation.utils.pydantic_extensions import NMField
12
13
 
13
14
  if TYPE_CHECKING:
14
15
  from py_neuromodulation import NMSettings
@@ -29,11 +30,11 @@ class FooofPeriodicSettings(BoolSelector):
29
30
  class FooofSettings(NMBaseModel):
30
31
  aperiodic: FooofAperiodicSettings = FooofAperiodicSettings()
31
32
  periodic: FooofPeriodicSettings = FooofPeriodicSettings()
32
- windowlength_ms: float = 800
33
+ windowlength_ms: float = NMField(800, gt=0, custom_metadata={"unit": "ms"})
33
34
  peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12)
34
- max_n_peaks: int = 3
35
- min_peak_height: float = 0
36
- peak_threshold: float = 2
35
+ max_n_peaks: int = NMField(3, ge=0)
36
+ min_peak_height: float = NMField(0, ge=0)
37
+ peak_threshold: float = NMField(2, ge=0)
37
38
  freq_range_hz: FrequencyRange = FrequencyRange(2, 40)
38
39
  knee: bool = True
39
40
 
@@ -63,12 +64,12 @@ class FooofAnalyzer(NMFeature):
63
64
 
64
65
  assert (
65
66
  self.settings.windowlength_ms <= settings.segment_length_features_ms
66
- ), f"fooof windowlength_ms ({settings.fooof.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
67
+ ), f"fooof windowlength_ms ({settings.fooof_settings.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
67
68
 
68
69
  assert (
69
70
  self.settings.freq_range_hz[0] < sfreq
70
71
  and self.settings.freq_range_hz[1] < sfreq
71
- ), f"fooof frequency range needs to be below sfreq, got {settings.fooof.freq_range_hz}"
72
+ ), f"fooof frequency range needs to be below sfreq, got {settings.fooof_settings.freq_range_hz}"
72
73
 
73
74
  from fooof import FOOOFGroup
74
75