py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
- py_neuromodulation/FieldTrip.py +589 -589
- py_neuromodulation/__init__.py +74 -13
- py_neuromodulation/_write_example_dataset_helper.py +83 -65
- py_neuromodulation/data/README +6 -6
- py_neuromodulation/data/dataset_description.json +8 -8
- py_neuromodulation/data/participants.json +32 -32
- py_neuromodulation/data/participants.tsv +2 -2
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
- py_neuromodulation/grid_cortex.tsv +40 -40
- py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
- py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
- py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
- py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
- py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
- py_neuromodulation/nm_IO.py +413 -417
- py_neuromodulation/nm_RMAP.py +496 -531
- py_neuromodulation/nm_analysis.py +993 -1074
- py_neuromodulation/nm_artifacts.py +30 -25
- py_neuromodulation/nm_bispectra.py +154 -168
- py_neuromodulation/nm_bursts.py +292 -198
- py_neuromodulation/nm_coherence.py +251 -205
- py_neuromodulation/nm_database.py +149 -0
- py_neuromodulation/nm_decode.py +918 -992
- py_neuromodulation/nm_define_nmchannels.py +300 -302
- py_neuromodulation/nm_features.py +144 -116
- py_neuromodulation/nm_filter.py +219 -219
- py_neuromodulation/nm_filter_preprocessing.py +79 -91
- py_neuromodulation/nm_fooof.py +139 -159
- py_neuromodulation/nm_generator.py +45 -37
- py_neuromodulation/nm_hjorth_raw.py +52 -73
- py_neuromodulation/nm_kalmanfilter.py +71 -58
- py_neuromodulation/nm_linelength.py +21 -33
- py_neuromodulation/nm_logger.py +66 -0
- py_neuromodulation/nm_mne_connectivity.py +149 -112
- py_neuromodulation/nm_mnelsl_generator.py +90 -0
- py_neuromodulation/nm_mnelsl_stream.py +116 -0
- py_neuromodulation/nm_nolds.py +96 -93
- py_neuromodulation/nm_normalization.py +173 -214
- py_neuromodulation/nm_oscillatory.py +423 -448
- py_neuromodulation/nm_plots.py +585 -612
- py_neuromodulation/nm_preprocessing.py +83 -0
- py_neuromodulation/nm_projection.py +370 -394
- py_neuromodulation/nm_rereference.py +97 -95
- py_neuromodulation/nm_resample.py +59 -50
- py_neuromodulation/nm_run_analysis.py +325 -435
- py_neuromodulation/nm_settings.py +289 -68
- py_neuromodulation/nm_settings.yaml +244 -0
- py_neuromodulation/nm_sharpwaves.py +423 -401
- py_neuromodulation/nm_stats.py +464 -480
- py_neuromodulation/nm_stream.py +398 -0
- py_neuromodulation/nm_stream_abc.py +166 -218
- py_neuromodulation/nm_types.py +193 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
- py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
- {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
- py_neuromodulation/nm_EpochStream.py +0 -92
- py_neuromodulation/nm_across_patient_decoding.py +0 -927
- py_neuromodulation/nm_cohortwrapper.py +0 -435
- py_neuromodulation/nm_eval_timing.py +0 -239
- py_neuromodulation/nm_features_abc.py +0 -39
- py_neuromodulation/nm_settings.json +0 -338
- py_neuromodulation/nm_stream_offline.py +0 -359
- py_neuromodulation/utils/_logging.py +0 -24
- py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
|
@@ -1,205 +1,251 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def
|
|
43
|
-
self
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
1
|
+
import numpy as np
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated
|
|
6
|
+
from pydantic import Field, field_validator
|
|
7
|
+
|
|
8
|
+
from py_neuromodulation.nm_features import NMFeature
|
|
9
|
+
from py_neuromodulation.nm_types import BoolSelector, FrequencyRange, NMBaseModel
|
|
10
|
+
from py_neuromodulation import logger
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from py_neuromodulation.nm_settings import NMSettings
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CoherenceMethods(BoolSelector):
|
|
17
|
+
coh: bool = True
|
|
18
|
+
icoh: bool = True
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CoherenceFeatures(BoolSelector):
|
|
22
|
+
mean_fband: bool = True
|
|
23
|
+
max_fband: bool = True
|
|
24
|
+
max_allfbands: bool = True
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CoherenceSettings(NMBaseModel):
|
|
31
|
+
features: CoherenceFeatures = CoherenceFeatures()
|
|
32
|
+
method: CoherenceMethods = CoherenceMethods()
|
|
33
|
+
channels: list[ListOfTwoStr] = []
|
|
34
|
+
frequency_bands: list[str] = Field(default=["high_beta"], min_length=1)
|
|
35
|
+
|
|
36
|
+
@field_validator("frequency_bands")
|
|
37
|
+
def fbands_spaces_to_underscores(cls, frequency_bands):
|
|
38
|
+
return [f.replace(" ", "_") for f in frequency_bands]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CoherenceObject:
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
sfreq: float,
|
|
45
|
+
window: str,
|
|
46
|
+
fbands: list[FrequencyRange],
|
|
47
|
+
fband_names: list[str],
|
|
48
|
+
ch_1_name: str,
|
|
49
|
+
ch_2_name: str,
|
|
50
|
+
ch_1_idx: int,
|
|
51
|
+
ch_2_idx: int,
|
|
52
|
+
coh: bool,
|
|
53
|
+
icoh: bool,
|
|
54
|
+
features_coh: CoherenceFeatures,
|
|
55
|
+
) -> None:
|
|
56
|
+
self.sfreq = sfreq
|
|
57
|
+
self.window = window
|
|
58
|
+
self.fbands = fbands
|
|
59
|
+
self.fband_names = fband_names
|
|
60
|
+
self.ch_1 = ch_1_name
|
|
61
|
+
self.ch_2 = ch_2_name
|
|
62
|
+
self.ch_1_idx = ch_1_idx
|
|
63
|
+
self.ch_2_idx = ch_2_idx
|
|
64
|
+
self.coh = coh
|
|
65
|
+
self.icoh = icoh
|
|
66
|
+
self.features_coh = features_coh
|
|
67
|
+
|
|
68
|
+
self.Pxx = None
|
|
69
|
+
self.Pyy = None
|
|
70
|
+
self.Pxy = None
|
|
71
|
+
self.f = None
|
|
72
|
+
self.coh_val = None
|
|
73
|
+
self.icoh_val = None
|
|
74
|
+
|
|
75
|
+
def get_coh(self, feature_results, x, y):
|
|
76
|
+
from scipy.signal import welch, csd
|
|
77
|
+
|
|
78
|
+
self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=128)
|
|
79
|
+
self.Pyy = welch(y, self.sfreq, self.window, nperseg=128)[1]
|
|
80
|
+
self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=128)[1]
|
|
81
|
+
|
|
82
|
+
if self.coh:
|
|
83
|
+
self.coh_val = np.abs(self.Pxy**2) / (self.Pxx * self.Pyy)
|
|
84
|
+
if self.icoh:
|
|
85
|
+
self.icoh_val = np.array(self.Pxy / (self.Pxx * self.Pyy)).imag
|
|
86
|
+
|
|
87
|
+
for coh_idx, coh_type in enumerate([self.coh, self.icoh]):
|
|
88
|
+
if coh_type:
|
|
89
|
+
if coh_idx == 0:
|
|
90
|
+
coh_val = self.coh_val
|
|
91
|
+
coh_name = "coh"
|
|
92
|
+
else:
|
|
93
|
+
coh_val = self.icoh_val
|
|
94
|
+
coh_name = "icoh"
|
|
95
|
+
|
|
96
|
+
for idx, fband in enumerate(self.fbands):
|
|
97
|
+
if self.features_coh.mean_fband:
|
|
98
|
+
feature_calc = np.mean(
|
|
99
|
+
coh_val[np.bitwise_and(self.f > fband[0], self.f < fband[1])]
|
|
100
|
+
)
|
|
101
|
+
feature_name = "_".join(
|
|
102
|
+
[
|
|
103
|
+
coh_name,
|
|
104
|
+
self.ch_1,
|
|
105
|
+
"to",
|
|
106
|
+
self.ch_2,
|
|
107
|
+
"mean_fband",
|
|
108
|
+
self.fband_names[idx],
|
|
109
|
+
]
|
|
110
|
+
)
|
|
111
|
+
feature_results[feature_name] = feature_calc
|
|
112
|
+
if self.features_coh.max_fband:
|
|
113
|
+
feature_calc = np.max(
|
|
114
|
+
coh_val[np.bitwise_and(self.f > fband[0], self.f < fband[1])]
|
|
115
|
+
)
|
|
116
|
+
feature_name = "_".join(
|
|
117
|
+
[
|
|
118
|
+
coh_name,
|
|
119
|
+
self.ch_1,
|
|
120
|
+
"to",
|
|
121
|
+
self.ch_2,
|
|
122
|
+
"max_fband",
|
|
123
|
+
self.fband_names[idx],
|
|
124
|
+
]
|
|
125
|
+
)
|
|
126
|
+
feature_results[feature_name] = feature_calc
|
|
127
|
+
if self.features_coh.max_allfbands:
|
|
128
|
+
feature_calc = self.f[np.argmax(coh_val)]
|
|
129
|
+
feature_name = "_".join(
|
|
130
|
+
[
|
|
131
|
+
coh_name,
|
|
132
|
+
self.ch_1,
|
|
133
|
+
"to",
|
|
134
|
+
self.ch_2,
|
|
135
|
+
"max_allfbands",
|
|
136
|
+
self.fband_names[idx],
|
|
137
|
+
]
|
|
138
|
+
)
|
|
139
|
+
feature_results[feature_name] = feature_calc
|
|
140
|
+
return feature_results
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class NMCoherence(NMFeature):
|
|
144
|
+
def __init__(
|
|
145
|
+
self, settings: "NMSettings", ch_names: list[str], sfreq: float
|
|
146
|
+
) -> None:
|
|
147
|
+
self.settings = settings.coherence
|
|
148
|
+
self.frequency_ranges_hz = settings.frequency_ranges_hz
|
|
149
|
+
self.sfreq = sfreq
|
|
150
|
+
self.ch_names = ch_names
|
|
151
|
+
self.coherence_objects: Iterable[CoherenceObject] = []
|
|
152
|
+
|
|
153
|
+
self.test_settings(settings, ch_names, sfreq)
|
|
154
|
+
|
|
155
|
+
for idx_coh in range(len(self.settings.channels)):
|
|
156
|
+
fband_names = self.settings.frequency_bands
|
|
157
|
+
fband_specs = []
|
|
158
|
+
for band_name in fband_names:
|
|
159
|
+
fband_specs.append(self.frequency_ranges_hz[band_name])
|
|
160
|
+
|
|
161
|
+
ch_1_name = self.settings.channels[idx_coh][0]
|
|
162
|
+
ch_1_name_reref = [ch for ch in self.ch_names if ch.startswith(ch_1_name)][
|
|
163
|
+
0
|
|
164
|
+
]
|
|
165
|
+
ch_1_idx = self.ch_names.index(ch_1_name_reref)
|
|
166
|
+
|
|
167
|
+
ch_2_name = self.settings.channels[idx_coh][1]
|
|
168
|
+
ch_2_name_reref = [ch for ch in self.ch_names if ch.startswith(ch_2_name)][
|
|
169
|
+
0
|
|
170
|
+
]
|
|
171
|
+
ch_2_idx = self.ch_names.index(ch_2_name_reref)
|
|
172
|
+
|
|
173
|
+
self.coherence_objects.append(
|
|
174
|
+
CoherenceObject(
|
|
175
|
+
sfreq,
|
|
176
|
+
"hann",
|
|
177
|
+
fband_specs,
|
|
178
|
+
fband_names,
|
|
179
|
+
ch_1_name,
|
|
180
|
+
ch_2_name,
|
|
181
|
+
ch_1_idx,
|
|
182
|
+
ch_2_idx,
|
|
183
|
+
self.settings.method.coh,
|
|
184
|
+
self.settings.method.icoh,
|
|
185
|
+
self.settings.features,
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def test_settings(
|
|
191
|
+
settings: "NMSettings",
|
|
192
|
+
ch_names: Iterable[str],
|
|
193
|
+
sfreq: float,
|
|
194
|
+
):
|
|
195
|
+
flat_channels = [
|
|
196
|
+
ch for ch_pair in settings.coherence.channels for ch in ch_pair
|
|
197
|
+
]
|
|
198
|
+
|
|
199
|
+
valid_coh_channel = [
|
|
200
|
+
sum(ch.startswith(ch_coh) for ch in ch_names) for ch_coh in flat_channels
|
|
201
|
+
]
|
|
202
|
+
for ch_idx, ch_coh in enumerate(flat_channels):
|
|
203
|
+
if valid_coh_channel[ch_idx] == 0:
|
|
204
|
+
raise RuntimeError(
|
|
205
|
+
f"Coherence selected channel {ch_coh} does not match any channel name: \n"
|
|
206
|
+
f" - settings.coherence.channels: {settings.coherence.channels}\n"
|
|
207
|
+
f" - ch_names: {ch_names} \n"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if valid_coh_channel[ch_idx] > 1:
|
|
211
|
+
raise RuntimeError(
|
|
212
|
+
f"Coherence selected channel {ch_coh} is ambigous and matches more than one channel name: \n"
|
|
213
|
+
f" - settings.coherence.channels: {settings.coherence.channels}\n"
|
|
214
|
+
f" - ch_names: {ch_names} \n"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
assert all(
|
|
218
|
+
f_band_coh in settings.frequency_ranges_hz
|
|
219
|
+
for f_band_coh in settings.coherence.frequency_bands
|
|
220
|
+
), (
|
|
221
|
+
"coherence selected frequency bands don't match the ones"
|
|
222
|
+
"specified in s['frequency_ranges_hz']"
|
|
223
|
+
f"coherence frequency bands: {settings.coherence.frequency_bands}"
|
|
224
|
+
f"specified frequency_ranges_hz: {settings.frequency_ranges_hz}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
assert all(
|
|
228
|
+
settings.frequency_ranges_hz[fb][0] < sfreq / 2
|
|
229
|
+
and settings.frequency_ranges_hz[fb][1] < sfreq / 2
|
|
230
|
+
for fb in settings.coherence.frequency_bands
|
|
231
|
+
), (
|
|
232
|
+
"the coherence frequency band ranges need to be smaller than the Nyquist frequency"
|
|
233
|
+
f"got sfreq = {sfreq} and fband ranges {settings.coherence.frequency_bands}"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if not settings.coherence.method.get_enabled():
|
|
237
|
+
logger.warn(
|
|
238
|
+
"feature coherence enabled, but no coherence['method'] selected"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def calc_feature(self, data: np.ndarray) -> dict:
|
|
242
|
+
feature_results = {}
|
|
243
|
+
|
|
244
|
+
for coh_obj in self.coherence_objects:
|
|
245
|
+
feature_results = coh_obj.get_coh(
|
|
246
|
+
feature_results,
|
|
247
|
+
data[coh_obj.ch_1_idx, :],
|
|
248
|
+
data[coh_obj.ch_2_idx, :],
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return feature_results
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import sqlite3
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from py_neuromodulation.nm_types import _PathLike
|
|
5
|
+
from py_neuromodulation.nm_IO import generate_unique_filename
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class NMDatabase:
|
|
9
|
+
"""
|
|
10
|
+
Class to create a database and insert data into it.
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
out_dir : _PathLike
|
|
14
|
+
The directory to save the database.
|
|
15
|
+
csv_path : str, optional
|
|
16
|
+
The path to save the csv file. If not provided, it will be saved in the same folder as the database.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
name: str,
|
|
22
|
+
out_dir: _PathLike,
|
|
23
|
+
csv_path: _PathLike | None = None,
|
|
24
|
+
):
|
|
25
|
+
# Make sure out_dir exists
|
|
26
|
+
Path(out_dir).mkdir(parents=True, exist_ok=True)
|
|
27
|
+
|
|
28
|
+
self.db_path = Path(out_dir, f"{name}.db")
|
|
29
|
+
|
|
30
|
+
self.table_name = f"{name}_data" # change to param?
|
|
31
|
+
self.table_created = False
|
|
32
|
+
|
|
33
|
+
if self.db_path.exists():
|
|
34
|
+
self.db_path = generate_unique_filename(self.db_path)
|
|
35
|
+
name = self.db_path.stem
|
|
36
|
+
|
|
37
|
+
if csv_path is None:
|
|
38
|
+
self.csv_path = Path(out_dir, f"{name}.csv")
|
|
39
|
+
else:
|
|
40
|
+
self.csv_path = Path(csv_path)
|
|
41
|
+
|
|
42
|
+
self.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
|
|
44
|
+
self.conn = sqlite3.connect(self.db_path)
|
|
45
|
+
self.cursor = self.conn.cursor()
|
|
46
|
+
|
|
47
|
+
# Database config and optimization, prioritize data integrity
|
|
48
|
+
self.cursor.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging mode
|
|
49
|
+
self.cursor.execute("PRAGMA synchronous=FULL") # Sync on every commit
|
|
50
|
+
self.cursor.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory
|
|
51
|
+
self.cursor.execute(
|
|
52
|
+
"PRAGMA wal_autocheckpoint = 1000"
|
|
53
|
+
) # WAL checkpoint every 1000 pages (default, 4MB, might change)
|
|
54
|
+
self.cursor.execute(
|
|
55
|
+
f"PRAGMA mmap_size = {2 * 1024 * 1024 * 1024}"
|
|
56
|
+
) # 2GB of memory mapped
|
|
57
|
+
|
|
58
|
+
def infer_type(self, value):
|
|
59
|
+
"""Infer the type of the value to create the table schema.
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
value : int, float, str
|
|
63
|
+
The value to infer the type."""
|
|
64
|
+
|
|
65
|
+
if isinstance(value, (int, float)):
|
|
66
|
+
return "REAL"
|
|
67
|
+
elif isinstance(value, str):
|
|
68
|
+
return "TEXT"
|
|
69
|
+
else:
|
|
70
|
+
return "BLOB"
|
|
71
|
+
|
|
72
|
+
def create_table(self, feature_dict: dict):
|
|
73
|
+
"""
|
|
74
|
+
Create a table in the database.
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
feature_dict : dict
|
|
78
|
+
The dictionary with the feature names and values.
|
|
79
|
+
"""
|
|
80
|
+
columns_schema = ", ".join(
|
|
81
|
+
[
|
|
82
|
+
f'"{column}" {self.infer_type(value)}'
|
|
83
|
+
for column, value in feature_dict.items()
|
|
84
|
+
]
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.cursor.execute(
|
|
88
|
+
f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ({columns_schema})'
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Create column names and placeholders for insert statement
|
|
92
|
+
self.columns: str = ", ".join([f'"{column}"' for column in feature_dict.keys()])
|
|
93
|
+
# Use named placeholders for more resiliency against unexpected change in column order
|
|
94
|
+
self.placeholders = ", ".join([f":{key}" for key in feature_dict.keys()])
|
|
95
|
+
|
|
96
|
+
def insert_data(self, feature_dict: dict):
|
|
97
|
+
"""
|
|
98
|
+
Insert data into the database.
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
feature_dict : dict
|
|
102
|
+
The dictionary with the feature names and values.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
if not self.table_created:
|
|
106
|
+
self.create_table(feature_dict)
|
|
107
|
+
self.table_created = True
|
|
108
|
+
|
|
109
|
+
insert_sql = f'INSERT INTO "{self.table_name}" ({self.columns}) VALUES ({self.placeholders})'
|
|
110
|
+
|
|
111
|
+
self.cursor.execute(insert_sql, feature_dict)
|
|
112
|
+
|
|
113
|
+
def commit(self):
|
|
114
|
+
self.conn.commit()
|
|
115
|
+
|
|
116
|
+
def fetch_all(self):
|
|
117
|
+
""" "
|
|
118
|
+
Fetch all the data from the database.
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
pd.DataFrame
|
|
122
|
+
The data in a pandas DataFrame.
|
|
123
|
+
"""
|
|
124
|
+
return pd.read_sql_query(f'SELECT * FROM "{self.table_name}"', self.conn)
|
|
125
|
+
|
|
126
|
+
def head(self, n: int = 5):
|
|
127
|
+
""" "
|
|
128
|
+
Returns the first N rows of the database.
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
n : int, optional
|
|
132
|
+
The number of rows to fetch, by default 1
|
|
133
|
+
-------
|
|
134
|
+
pd.DataFrame
|
|
135
|
+
The data in a pandas DataFrame.
|
|
136
|
+
"""
|
|
137
|
+
return pd.read_sql_query(
|
|
138
|
+
f'SELECT * FROM "{self.table_name}" LIMIT {n}', self.conn
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def save_as_csv(self):
|
|
142
|
+
df = self.fetch_all()
|
|
143
|
+
df.to_csv(self.csv_path, index=False)
|
|
144
|
+
|
|
145
|
+
def close(self):
|
|
146
|
+
# Optimize before closing is recommended:
|
|
147
|
+
# https://www.sqlite.org/pragma.html#pragma_optimize
|
|
148
|
+
self.cursor.execute("PRAGMA optimize")
|
|
149
|
+
self.conn.close()
|