py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
- py_neuromodulation/FieldTrip.py +589 -589
- py_neuromodulation/__init__.py +74 -13
- py_neuromodulation/_write_example_dataset_helper.py +83 -65
- py_neuromodulation/data/README +6 -6
- py_neuromodulation/data/dataset_description.json +8 -8
- py_neuromodulation/data/participants.json +32 -32
- py_neuromodulation/data/participants.tsv +2 -2
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
- py_neuromodulation/grid_cortex.tsv +40 -40
- py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
- py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
- py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
- py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/nm_IO.py +413 -417
- py_neuromodulation/nm_RMAP.py +496 -531
- py_neuromodulation/nm_analysis.py +993 -1074
- py_neuromodulation/nm_artifacts.py +30 -25
- py_neuromodulation/nm_bispectra.py +154 -168
- py_neuromodulation/nm_bursts.py +292 -198
- py_neuromodulation/nm_coherence.py +251 -205
- py_neuromodulation/nm_database.py +149 -0
- py_neuromodulation/nm_decode.py +918 -992
- py_neuromodulation/nm_define_nmchannels.py +300 -302
- py_neuromodulation/nm_features.py +144 -116
- py_neuromodulation/nm_filter.py +219 -219
- py_neuromodulation/nm_filter_preprocessing.py +79 -91
- py_neuromodulation/nm_fooof.py +139 -159
- py_neuromodulation/nm_generator.py +45 -37
- py_neuromodulation/nm_hjorth_raw.py +52 -73
- py_neuromodulation/nm_kalmanfilter.py +71 -58
- py_neuromodulation/nm_linelength.py +21 -33
- py_neuromodulation/nm_logger.py +66 -0
- py_neuromodulation/nm_mne_connectivity.py +149 -112
- py_neuromodulation/nm_mnelsl_generator.py +90 -0
- py_neuromodulation/nm_mnelsl_stream.py +116 -0
- py_neuromodulation/nm_nolds.py +96 -93
- py_neuromodulation/nm_normalization.py +173 -214
- py_neuromodulation/nm_oscillatory.py +423 -448
- py_neuromodulation/nm_plots.py +585 -612
- py_neuromodulation/nm_preprocessing.py +83 -0
- py_neuromodulation/nm_projection.py +370 -394
- py_neuromodulation/nm_rereference.py +97 -95
- py_neuromodulation/nm_resample.py +59 -50
- py_neuromodulation/nm_run_analysis.py +325 -435
- py_neuromodulation/nm_settings.py +289 -68
- py_neuromodulation/nm_settings.yaml +244 -0
- py_neuromodulation/nm_sharpwaves.py +423 -401
- py_neuromodulation/nm_stats.py +464 -480
- py_neuromodulation/nm_stream.py +398 -0
- py_neuromodulation/nm_stream_abc.py +166 -218
- py_neuromodulation/nm_types.py +193 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
- py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
- py_neuromodulation/nm_EpochStream.py +0 -92
- py_neuromodulation/nm_across_patient_decoding.py +0 -927
- py_neuromodulation/nm_cohortwrapper.py +0 -435
- py_neuromodulation/nm_eval_timing.py +0 -239
- py_neuromodulation/nm_features_abc.py +0 -39
- py_neuromodulation/nm_settings.json +0 -338
- py_neuromodulation/nm_stream_offline.py +0 -359
- py_neuromodulation/utils/_logging.py +0 -24
- py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
|
@@ -1,218 +1,166 @@
|
|
|
1
|
-
"""Module that contains
|
|
2
|
-
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from .
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
self.
|
|
89
|
-
self.
|
|
90
|
-
self.
|
|
91
|
-
self.
|
|
92
|
-
self.
|
|
93
|
-
self.
|
|
94
|
-
self.
|
|
95
|
-
self.
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
""
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
"""Load sklearn model, that utilizes predict"""
|
|
168
|
-
with open(model_name, "rb") as fid:
|
|
169
|
-
self.model = cPickle.load(fid)
|
|
170
|
-
|
|
171
|
-
def save_after_stream(
|
|
172
|
-
self,
|
|
173
|
-
out_path_root: _PathLike | None = None,
|
|
174
|
-
folder_name: str = "sub",
|
|
175
|
-
feature_arr: pd.DataFrame | None = None,
|
|
176
|
-
) -> None:
|
|
177
|
-
"""Save features, settings, nm_channels and sidecar after run"""
|
|
178
|
-
|
|
179
|
-
if out_path_root is None:
|
|
180
|
-
out_path_root = os.getcwd()
|
|
181
|
-
# create derivate folder_name output folder if doesn't exist
|
|
182
|
-
if os.path.exists(os.path.join(out_path_root, folder_name)) is False:
|
|
183
|
-
os.makedirs(os.path.join(out_path_root, folder_name))
|
|
184
|
-
|
|
185
|
-
self.PATH_OUT = out_path_root
|
|
186
|
-
self.PATH_OUT_folder_name = folder_name
|
|
187
|
-
self.save_sidecar(out_path_root, folder_name)
|
|
188
|
-
|
|
189
|
-
if feature_arr is not None:
|
|
190
|
-
self.save_features(out_path_root, folder_name, feature_arr)
|
|
191
|
-
|
|
192
|
-
self.save_settings(out_path_root, folder_name)
|
|
193
|
-
|
|
194
|
-
self.save_nm_channels(out_path_root, folder_name)
|
|
195
|
-
|
|
196
|
-
def save_features(
|
|
197
|
-
self,
|
|
198
|
-
out_path_root: _PathLike,
|
|
199
|
-
folder_name: str,
|
|
200
|
-
feature_arr: pd.DataFrame,
|
|
201
|
-
) -> None:
|
|
202
|
-
nm_IO.save_features(feature_arr, out_path_root, folder_name)
|
|
203
|
-
|
|
204
|
-
def save_nm_channels(
|
|
205
|
-
self, out_path_root: _PathLike, folder_name: str
|
|
206
|
-
) -> None:
|
|
207
|
-
self.run_analysis.save_nm_channels(out_path_root, folder_name)
|
|
208
|
-
|
|
209
|
-
def save_settings(self, out_path_root: _PathLike, folder_name: str) -> None:
|
|
210
|
-
self.run_analysis.save_settings(out_path_root, folder_name)
|
|
211
|
-
|
|
212
|
-
def save_sidecar(self, out_path_root: _PathLike, folder_name: str) -> None:
|
|
213
|
-
"""Save sidecar incduing fs, coords, sess_right to
|
|
214
|
-
out_path_root and subfolder 'folder_name'"""
|
|
215
|
-
additional_args = {"sess_right": self.sess_right}
|
|
216
|
-
self.run_analysis.save_sidecar(
|
|
217
|
-
out_path_root, folder_name, additional_args
|
|
218
|
-
)
|
|
1
|
+
"""Module that contains NMStream ABC."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
from py_neuromodulation.nm_run_analysis import DataProcessor
|
|
8
|
+
from py_neuromodulation.nm_settings import NMSettings
|
|
9
|
+
from py_neuromodulation.nm_types import _PathLike, FeatureName
|
|
10
|
+
from py_neuromodulation import nm_IO, PYNM_DIR
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NMStream(ABC):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
sfreq: float,
|
|
17
|
+
nm_channels: pd.DataFrame | _PathLike,
|
|
18
|
+
settings: "NMSettings | _PathLike | None" = None,
|
|
19
|
+
line_noise: float | None = 50,
|
|
20
|
+
sampling_rate_features_hz: float | None = None,
|
|
21
|
+
path_grids: _PathLike | None = None,
|
|
22
|
+
coord_names: list | None = None,
|
|
23
|
+
stream_name: str
|
|
24
|
+
| None = "example_stream", # Timon: do we need those in the nmstream_abc?
|
|
25
|
+
stream_lsl: bool = False,
|
|
26
|
+
coord_list: list | None = None,
|
|
27
|
+
verbose: bool = True,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Stream initialization
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
sfreq : float
|
|
34
|
+
sampling frequency of data in Hertz
|
|
35
|
+
nm_channels : pd.DataFrame | _PathLike
|
|
36
|
+
parametrization of channels (see nm_define_channels.py for initialization)
|
|
37
|
+
settings : NMSettings | _PathLike | None, optional
|
|
38
|
+
Initialized nm_settings.NMSettings object, by default the py_neuromodulation/nm_settings.yaml are read
|
|
39
|
+
and passed into a settings object
|
|
40
|
+
line_noise : float | None, optional
|
|
41
|
+
line noise, by default 50
|
|
42
|
+
sampling_rate_features_hz : float | None, optional
|
|
43
|
+
feature sampling rate, by default None
|
|
44
|
+
path_grids : _PathLike | None, optional
|
|
45
|
+
path to grid_cortex.tsv and/or gird_subcortex.tsv, by default Non
|
|
46
|
+
coord_names : list | None, optional
|
|
47
|
+
coordinate name in the form [coord_1_name, coord_2_name, etc], by default None
|
|
48
|
+
coord_list : list | None, optional
|
|
49
|
+
coordinates in the form [[coord_1_x, coord_1_y, coord_1_z], [coord_2_x, coord_2_y, coord_2_z],], by default None
|
|
50
|
+
verbose : bool, optional
|
|
51
|
+
print out stream computation time information, by default True
|
|
52
|
+
"""
|
|
53
|
+
self.settings: NMSettings = NMSettings.load(settings)
|
|
54
|
+
|
|
55
|
+
# If features that use frequency ranges are on, test them against nyquist frequency
|
|
56
|
+
use_freq_ranges: list[FeatureName] = [
|
|
57
|
+
"bandpass_filter",
|
|
58
|
+
"stft",
|
|
59
|
+
"fft",
|
|
60
|
+
"welch",
|
|
61
|
+
"bursts",
|
|
62
|
+
"coherence",
|
|
63
|
+
"nolds",
|
|
64
|
+
"bispectrum",
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
need_nyquist_check = any(
|
|
68
|
+
(f in use_freq_ranges for f in self.settings.features.get_enabled())
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if need_nyquist_check:
|
|
72
|
+
assert all(
|
|
73
|
+
fb.frequency_high_hz < sfreq / 2
|
|
74
|
+
for fb in self.settings.frequency_ranges_hz.values()
|
|
75
|
+
), (
|
|
76
|
+
"If a feature that uses frequency ranges is selected, "
|
|
77
|
+
"the frequency band ranges need to be smaller than the nyquist frequency.\n"
|
|
78
|
+
f"Got sfreq = {sfreq} and fband ranges:\n {self.settings.frequency_ranges_hz}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if sampling_rate_features_hz is not None:
|
|
82
|
+
self.settings.sampling_rate_features_hz = sampling_rate_features_hz
|
|
83
|
+
|
|
84
|
+
self.nm_channels = self._load_nm_channels(nm_channels)
|
|
85
|
+
if path_grids is None:
|
|
86
|
+
path_grids = PYNM_DIR
|
|
87
|
+
self.path_grids = path_grids
|
|
88
|
+
self.verbose = verbose
|
|
89
|
+
self.sfreq = sfreq
|
|
90
|
+
self.line_noise = line_noise
|
|
91
|
+
self.coord_names = coord_names
|
|
92
|
+
self.coord_list = coord_list
|
|
93
|
+
self.sess_right = None
|
|
94
|
+
self.projection = None
|
|
95
|
+
self.model = None
|
|
96
|
+
|
|
97
|
+
self.data_processor = DataProcessor(
|
|
98
|
+
sfreq=self.sfreq,
|
|
99
|
+
settings=self.settings,
|
|
100
|
+
nm_channels=self.nm_channels,
|
|
101
|
+
path_grids=self.path_grids,
|
|
102
|
+
coord_names=coord_names,
|
|
103
|
+
coord_list=coord_list,
|
|
104
|
+
line_noise=line_noise,
|
|
105
|
+
verbose=self.verbose,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def run(self) -> pd.DataFrame:
|
|
110
|
+
"""Reinitialize the stream
|
|
111
|
+
This might be handy in case the nm_channels or nm_settings changed
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
self.data_processor = DataProcessor(
|
|
115
|
+
sfreq=self.sfreq,
|
|
116
|
+
settings=self.settings,
|
|
117
|
+
nm_channels=self.nm_channels,
|
|
118
|
+
path_grids=self.path_grids,
|
|
119
|
+
coord_names=self.coord_names,
|
|
120
|
+
coord_list=self.coord_list,
|
|
121
|
+
line_noise=self.line_noise,
|
|
122
|
+
verbose=self.verbose,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def _load_nm_channels(
|
|
127
|
+
nm_channels: pd.DataFrame | _PathLike,
|
|
128
|
+
) -> pd.DataFrame:
|
|
129
|
+
if not isinstance(nm_channels, pd.DataFrame):
|
|
130
|
+
nm_channels = nm_IO.load_nm_channels(nm_channels)
|
|
131
|
+
|
|
132
|
+
if nm_channels.query("used == 1 and target == 0").shape[0] == 0:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"No channels selected for analysis that have column 'used' = 1 and 'target' = 0. Please check your nm_channels"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return nm_channels
|
|
138
|
+
|
|
139
|
+
def save_after_stream(
|
|
140
|
+
self,
|
|
141
|
+
out_dir: _PathLike = "",
|
|
142
|
+
prefix: str = "",
|
|
143
|
+
feature_arr: pd.DataFrame | None = None,
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Save features, settings, nm_channels and sidecar after run"""
|
|
146
|
+
|
|
147
|
+
self.save_sidecar(out_dir, prefix)
|
|
148
|
+
|
|
149
|
+
if feature_arr is not None:
|
|
150
|
+
nm_IO.save_features(feature_arr, out_dir, prefix)
|
|
151
|
+
|
|
152
|
+
self.save_settings(out_dir, prefix)
|
|
153
|
+
|
|
154
|
+
self.save_nm_channels(out_dir, prefix)
|
|
155
|
+
|
|
156
|
+
def save_nm_channels(self, out_dir: _PathLike, prefix: str = "") -> None:
|
|
157
|
+
self.data_processor.save_nm_channels(out_dir, prefix)
|
|
158
|
+
|
|
159
|
+
def save_settings(self, out_dir: _PathLike, prefix: str = "") -> None:
|
|
160
|
+
self.data_processor.save_settings(out_dir, prefix)
|
|
161
|
+
|
|
162
|
+
def save_sidecar(self, out_dir: _PathLike, prefix: str = "") -> None:
|
|
163
|
+
"""Save sidecar incduing fs, coords, sess_right to
|
|
164
|
+
out_path_root and subfolder 'folder_name'"""
|
|
165
|
+
additional_args = {"sess_right": self.sess_right}
|
|
166
|
+
self.data_processor.save_sidecar(out_dir, prefix, additional_args)
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
from os import PathLike
|
|
2
|
+
from math import isnan
|
|
3
|
+
from typing import NamedTuple, Type, Any, Literal
|
|
4
|
+
from importlib import import_module
|
|
5
|
+
from pydantic import ConfigDict, Field, model_validator, BaseModel
|
|
6
|
+
from pprint import pformat
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
|
|
9
|
+
###################################
|
|
10
|
+
########## TYPE ALIASES ##########
|
|
11
|
+
###################################
|
|
12
|
+
|
|
13
|
+
_PathLike = str | PathLike
|
|
14
|
+
|
|
15
|
+
FeatureName = Literal[
|
|
16
|
+
"raw_hjorth",
|
|
17
|
+
"return_raw",
|
|
18
|
+
"bandpass_filter",
|
|
19
|
+
"stft",
|
|
20
|
+
"fft",
|
|
21
|
+
"welch",
|
|
22
|
+
"sharpwave_analysis",
|
|
23
|
+
"fooof",
|
|
24
|
+
"nolds",
|
|
25
|
+
"coherence",
|
|
26
|
+
"bursts",
|
|
27
|
+
"linelength",
|
|
28
|
+
"mne_connectivity",
|
|
29
|
+
"bispectrum",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
PreprocessorName = Literal[
|
|
33
|
+
"preprocessing_filter",
|
|
34
|
+
"notch_filter",
|
|
35
|
+
"raw_resampling",
|
|
36
|
+
"re_referencing",
|
|
37
|
+
"raw_normalization",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
NormMethod = Literal[
|
|
41
|
+
"mean",
|
|
42
|
+
"median",
|
|
43
|
+
"zscore",
|
|
44
|
+
"zscore-median",
|
|
45
|
+
"quantile",
|
|
46
|
+
"power",
|
|
47
|
+
"robust",
|
|
48
|
+
"minmax",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
###################################
|
|
53
|
+
###### LAZY MODULE IMPORTS #######
|
|
54
|
+
###################################
|
|
55
|
+
class ImportDetails(NamedTuple):
|
|
56
|
+
module_name: str
|
|
57
|
+
class_name: str
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_class(module_details: ImportDetails) -> Type[Any]:
|
|
61
|
+
return getattr(
|
|
62
|
+
import_module("py_neuromodulation." + module_details.module_name),
|
|
63
|
+
module_details.class_name,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
###################################
|
|
68
|
+
######## PYDANTIC CLASSES ########
|
|
69
|
+
###################################
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class NMBaseModel(BaseModel):
|
|
73
|
+
model_config = ConfigDict(validate_assignment=False, extra="allow")
|
|
74
|
+
|
|
75
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
76
|
+
if kwargs:
|
|
77
|
+
super().__init__(**kwargs)
|
|
78
|
+
else:
|
|
79
|
+
field_names = list(self.model_fields.keys())
|
|
80
|
+
kwargs = {}
|
|
81
|
+
for i in range(len(args)):
|
|
82
|
+
kwargs[field_names[i]] = args[i]
|
|
83
|
+
super().__init__(**kwargs)
|
|
84
|
+
|
|
85
|
+
def __str__(self):
|
|
86
|
+
return pformat(self.model_dump())
|
|
87
|
+
|
|
88
|
+
def __repr__(self):
|
|
89
|
+
return pformat(self.model_dump())
|
|
90
|
+
|
|
91
|
+
def validate(self) -> Any: # type: ignore
|
|
92
|
+
return self.model_validate(self.model_dump())
|
|
93
|
+
|
|
94
|
+
def __getitem__(self, key):
|
|
95
|
+
return getattr(self, key)
|
|
96
|
+
|
|
97
|
+
def __setitem__(self, key, value) -> None:
|
|
98
|
+
setattr(self, key, value)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class FrequencyRange(NMBaseModel):
|
|
102
|
+
frequency_low_hz: float = Field(default=0, gt=0)
|
|
103
|
+
frequency_high_hz: float = Field(default=0, gt=0)
|
|
104
|
+
|
|
105
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
106
|
+
super().__init__(*args, **kwargs)
|
|
107
|
+
|
|
108
|
+
def __getitem__(self, item: int):
|
|
109
|
+
match item:
|
|
110
|
+
case 0:
|
|
111
|
+
return self.frequency_low_hz
|
|
112
|
+
case 1:
|
|
113
|
+
return self.frequency_high_hz
|
|
114
|
+
case _:
|
|
115
|
+
raise IndexError(f"Index {item} out of range")
|
|
116
|
+
|
|
117
|
+
def as_tuple(self) -> tuple[float, float]:
|
|
118
|
+
return (self.frequency_low_hz, self.frequency_high_hz)
|
|
119
|
+
|
|
120
|
+
def __iter__(self): # type: ignore
|
|
121
|
+
return iter(self.as_tuple())
|
|
122
|
+
|
|
123
|
+
@model_validator(mode="after")
|
|
124
|
+
def validate_range(self):
|
|
125
|
+
if not (isnan(self.frequency_high_hz) or isnan(self.frequency_low_hz)):
|
|
126
|
+
assert (
|
|
127
|
+
self.frequency_high_hz > self.frequency_low_hz
|
|
128
|
+
), "Frequency high must be greater than frequency low"
|
|
129
|
+
return self
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def create_from(cls, input) -> "FrequencyRange":
|
|
133
|
+
match input:
|
|
134
|
+
case FrequencyRange():
|
|
135
|
+
return input
|
|
136
|
+
case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
|
|
137
|
+
return FrequencyRange(
|
|
138
|
+
input["frequency_low_hz"], input["frequency_high_hz"]
|
|
139
|
+
)
|
|
140
|
+
case Sequence() if len(input) == 2:
|
|
141
|
+
return FrequencyRange(input[0], input[1])
|
|
142
|
+
case _:
|
|
143
|
+
raise ValueError("Invalid input for FrequencyRange creation.")
|
|
144
|
+
|
|
145
|
+
@model_validator(mode="before")
|
|
146
|
+
@classmethod
|
|
147
|
+
def check_input(cls, input):
|
|
148
|
+
match input:
|
|
149
|
+
case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
|
|
150
|
+
return input
|
|
151
|
+
case Sequence() if len(input) == 2:
|
|
152
|
+
return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]}
|
|
153
|
+
case _:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
"Value for FrequencyRange must be a dictionary, "
|
|
156
|
+
"or a sequence of 2 numeric values, "
|
|
157
|
+
f"but got {input} instead."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class BoolSelector(NMBaseModel):
|
|
162
|
+
def get_enabled(self):
|
|
163
|
+
return [
|
|
164
|
+
f
|
|
165
|
+
for f in self.model_fields.keys()
|
|
166
|
+
if (isinstance(self[f], bool) and self[f])
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
def enable_all(self):
|
|
170
|
+
for f in self.model_fields.keys():
|
|
171
|
+
if isinstance(self[f], bool):
|
|
172
|
+
self[f] = True
|
|
173
|
+
|
|
174
|
+
def disable_all(self):
|
|
175
|
+
for f in self.model_fields.keys():
|
|
176
|
+
if isinstance(self[f], bool):
|
|
177
|
+
self[f] = False
|
|
178
|
+
|
|
179
|
+
def __iter__(self): # type: ignore
|
|
180
|
+
return iter(self.model_dump().keys())
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
def list_all(cls):
|
|
184
|
+
return list(cls.model_fields.keys())
|
|
185
|
+
|
|
186
|
+
@classmethod
|
|
187
|
+
def print_all(cls):
|
|
188
|
+
for f in cls.list_all():
|
|
189
|
+
print(f)
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def get_fields(cls):
|
|
193
|
+
return cls.model_fields
|