py-neuromodulation 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/Automated Anatomical Labeling 3 (Rolls 2020).nii +0 -0
- py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -0
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +106 -0
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +119 -0
- py_neuromodulation/ConnectivityDecoding/mni_coords_cortical_surface.mat +0 -0
- py_neuromodulation/ConnectivityDecoding/mni_coords_whole_brain.mat +0 -0
- py_neuromodulation/ConnectivityDecoding/rmap_func_all.nii +0 -0
- py_neuromodulation/ConnectivityDecoding/rmap_struc.nii +0 -0
- py_neuromodulation/data/README +6 -0
- py_neuromodulation/data/dataset_description.json +8 -0
- py_neuromodulation/data/participants.json +32 -0
- py_neuromodulation/data/participants.tsv +2 -0
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -0
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -0
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -0
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.eeg +0 -0
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -0
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -0
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -0
- py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -0
- py_neuromodulation/grid_cortex.tsv +40 -0
- py_neuromodulation/grid_subcortex.tsv +1429 -0
- py_neuromodulation/nm_settings.json +338 -0
- py_neuromodulation/nm_stream_offline.py +7 -6
- py_neuromodulation/plots/STN_surf.mat +0 -0
- py_neuromodulation/plots/Vertices.mat +0 -0
- py_neuromodulation/plots/faces.mat +0 -0
- py_neuromodulation/plots/grid.mat +0 -0
- {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.4.dist-info}/METADATA +182 -182
- py_neuromodulation-0.0.4.dist-info/RECORD +72 -0
- {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.4.dist-info}/WHEEL +1 -2
- docs/build/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
- docs/build/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -233
- docs/build/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
- docs/build/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
- docs/build/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
- docs/build/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
- docs/build/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
- docs/build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -68
- docs/build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -239
- docs/build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
- docs/build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -97
- docs/build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
- docs/build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -192
- docs/build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
- docs/source/_build/html/_downloads/09df217f95985497f45d69e2d4bdc5b1/plot_2_example_add_feature.py +0 -76
- docs/source/_build/html/_downloads/0d0d0a76e8f648d5d3cbc47da6351932/plot_real_time_demo.py +0 -97
- docs/source/_build/html/_downloads/3b4900a2b2818ff30362215b76f7d5eb/plot_1_example_BIDS.py +0 -240
- docs/source/_build/html/_downloads/5d73cadc59a8805c47e3b84063afc157/plot_example_BIDS.py +0 -233
- docs/source/_build/html/_downloads/7660317fa5a6bfbd12fcca9961457fc4/plot_example_rmap_computing.py +0 -63
- docs/source/_build/html/_downloads/7e92dd2e6cc86b239d14cafad972ae4f/plot_3_example_sharpwave_analysis.py +0 -219
- docs/source/_build/html/_downloads/839e5b319379f7fd9e867deb00fd797f/plot_example_gridPointProjection.py +0 -210
- docs/source/_build/html/_downloads/ae8be19afe5e559f011fc9b138968ba0/plot_first_demo.py +0 -192
- docs/source/_build/html/_downloads/b8b06cacc17969d3725a0b6f1d7741c5/plot_example_sharpwave_analysis.py +0 -219
- docs/source/_build/html/_downloads/c2db0bf2b334d541b00662b991682256/plot_6_real_time_demo.py +0 -121
- docs/source/_build/html/_downloads/c31a86c0b68cb4167d968091ace8080d/plot_example_add_feature.py +0 -68
- docs/source/_build/html/_downloads/ce3914826f782cbd1ea8fd024eaf0ac3/plot_5_example_rmap_computing.py +0 -64
- docs/source/_build/html/_downloads/da36848a41e6a3235d91fb7cfb6d59b4/plot_0_first_demo.py +0 -189
- docs/source/_build/html/_downloads/eaa4305c75b19a1e2eea941f742a6331/plot_4_example_gridPointProjection.py +0 -210
- docs/source/auto_examples/plot_0_first_demo.py +0 -189
- docs/source/auto_examples/plot_1_example_BIDS.py +0 -240
- docs/source/auto_examples/plot_2_example_add_feature.py +0 -76
- docs/source/auto_examples/plot_3_example_sharpwave_analysis.py +0 -219
- docs/source/auto_examples/plot_4_example_gridPointProjection.py +0 -210
- docs/source/auto_examples/plot_5_example_rmap_computing.py +0 -64
- docs/source/auto_examples/plot_6_real_time_demo.py +0 -121
- docs/source/conf.py +0 -105
- examples/plot_0_first_demo.py +0 -189
- examples/plot_1_example_BIDS.py +0 -240
- examples/plot_2_example_add_feature.py +0 -76
- examples/plot_3_example_sharpwave_analysis.py +0 -219
- examples/plot_4_example_gridPointProjection.py +0 -210
- examples/plot_5_example_rmap_computing.py +0 -64
- examples/plot_6_real_time_demo.py +0 -121
- packages/realtime_decoding/build/lib/realtime_decoding/__init__.py +0 -4
- packages/realtime_decoding/build/lib/realtime_decoding/decoder.py +0 -104
- packages/realtime_decoding/build/lib/realtime_decoding/features.py +0 -163
- packages/realtime_decoding/build/lib/realtime_decoding/helpers.py +0 -15
- packages/realtime_decoding/build/lib/realtime_decoding/run_decoding.py +0 -345
- packages/realtime_decoding/build/lib/realtime_decoding/trainer.py +0 -54
- packages/tmsi/build/lib/TMSiFileFormats/__init__.py +0 -37
- packages/tmsi/build/lib/TMSiFileFormats/file_formats/__init__.py +0 -36
- packages/tmsi/build/lib/TMSiFileFormats/file_formats/lsl_stream_writer.py +0 -200
- packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_file_writer.py +0 -496
- packages/tmsi/build/lib/TMSiFileFormats/file_formats/poly5_to_edf_converter.py +0 -236
- packages/tmsi/build/lib/TMSiFileFormats/file_formats/xdf_file_writer.py +0 -977
- packages/tmsi/build/lib/TMSiFileFormats/file_readers/__init__.py +0 -35
- packages/tmsi/build/lib/TMSiFileFormats/file_readers/edf_reader.py +0 -116
- packages/tmsi/build/lib/TMSiFileFormats/file_readers/poly5reader.py +0 -294
- packages/tmsi/build/lib/TMSiFileFormats/file_readers/xdf_reader.py +0 -229
- packages/tmsi/build/lib/TMSiFileFormats/file_writer.py +0 -102
- packages/tmsi/build/lib/TMSiPlotters/__init__.py +0 -2
- packages/tmsi/build/lib/TMSiPlotters/gui/__init__.py +0 -39
- packages/tmsi/build/lib/TMSiPlotters/gui/_plotter_gui.py +0 -234
- packages/tmsi/build/lib/TMSiPlotters/gui/plotting_gui.py +0 -440
- packages/tmsi/build/lib/TMSiPlotters/plotters/__init__.py +0 -44
- packages/tmsi/build/lib/TMSiPlotters/plotters/hd_emg_plotter.py +0 -446
- packages/tmsi/build/lib/TMSiPlotters/plotters/impedance_plotter.py +0 -589
- packages/tmsi/build/lib/TMSiPlotters/plotters/signal_plotter.py +0 -1326
- packages/tmsi/build/lib/TMSiSDK/__init__.py +0 -54
- packages/tmsi/build/lib/TMSiSDK/device.py +0 -588
- packages/tmsi/build/lib/TMSiSDK/devices/__init__.py +0 -34
- packages/tmsi/build/lib/TMSiSDK/devices/saga/TMSi_Device_API.py +0 -1764
- packages/tmsi/build/lib/TMSiSDK/devices/saga/__init__.py +0 -34
- packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_device.py +0 -1366
- packages/tmsi/build/lib/TMSiSDK/devices/saga/saga_types.py +0 -520
- packages/tmsi/build/lib/TMSiSDK/devices/saga/xml_saga_config.py +0 -165
- packages/tmsi/build/lib/TMSiSDK/error.py +0 -95
- packages/tmsi/build/lib/TMSiSDK/sample_data.py +0 -63
- packages/tmsi/build/lib/TMSiSDK/sample_data_server.py +0 -99
- packages/tmsi/build/lib/TMSiSDK/settings.py +0 -45
- packages/tmsi/build/lib/TMSiSDK/tmsi_device.py +0 -111
- packages/tmsi/build/lib/__init__.py +0 -4
- packages/tmsi/build/lib/apex_sdk/__init__.py +0 -34
- packages/tmsi/build/lib/apex_sdk/device/__init__.py +0 -41
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API.py +0 -1009
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_enums.py +0 -239
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_API_structures.py +0 -668
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_device.py +0 -1611
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_dongle.py +0 -38
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_event_reader.py +0 -57
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_channel.py +0 -44
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_config.py +0 -150
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_const.py +0 -36
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_impedance_channel.py +0 -48
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/apex_info.py +0 -108
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/apex_structures/dongle_info.py +0 -39
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/download_measurement.py +0 -77
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/eeg_measurement.py +0 -150
- packages/tmsi/build/lib/apex_sdk/device/devices/apex/measurements/impedance_measurement.py +0 -129
- packages/tmsi/build/lib/apex_sdk/device/threads/conversion_thread.py +0 -59
- packages/tmsi/build/lib/apex_sdk/device/threads/sampling_thread.py +0 -57
- packages/tmsi/build/lib/apex_sdk/device/tmsi_channel.py +0 -83
- packages/tmsi/build/lib/apex_sdk/device/tmsi_device.py +0 -201
- packages/tmsi/build/lib/apex_sdk/device/tmsi_device_enums.py +0 -103
- packages/tmsi/build/lib/apex_sdk/device/tmsi_dongle.py +0 -43
- packages/tmsi/build/lib/apex_sdk/device/tmsi_event_reader.py +0 -50
- packages/tmsi/build/lib/apex_sdk/device/tmsi_measurement.py +0 -118
- packages/tmsi/build/lib/apex_sdk/sample_data_server/__init__.py +0 -33
- packages/tmsi/build/lib/apex_sdk/sample_data_server/event_data.py +0 -44
- packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data.py +0 -50
- packages/tmsi/build/lib/apex_sdk/sample_data_server/sample_data_server.py +0 -136
- packages/tmsi/build/lib/apex_sdk/tmsi_errors/error.py +0 -126
- packages/tmsi/build/lib/apex_sdk/tmsi_sdk.py +0 -113
- packages/tmsi/build/lib/apex_sdk/tmsi_utilities/apex/apex_structure_generator.py +0 -134
- packages/tmsi/build/lib/apex_sdk/tmsi_utilities/decorators.py +0 -60
- packages/tmsi/build/lib/apex_sdk/tmsi_utilities/logger_filter.py +0 -42
- packages/tmsi/build/lib/apex_sdk/tmsi_utilities/singleton.py +0 -42
- packages/tmsi/build/lib/apex_sdk/tmsi_utilities/support_functions.py +0 -72
- packages/tmsi/build/lib/apex_sdk/tmsi_utilities/tmsi_logger.py +0 -98
- py_neuromodulation-0.0.3.dist-info/RECORD +0 -188
- py_neuromodulation-0.0.3.dist-info/top_level.txt +0 -5
- tests/__init__.py +0 -0
- tests/conftest.py +0 -117
- tests/test_all_examples.py +0 -10
- tests/test_all_features.py +0 -63
- tests/test_bispectra.py +0 -70
- tests/test_bursts.py +0 -105
- tests/test_feature_sampling_rates.py +0 -143
- tests/test_fooof.py +0 -16
- tests/test_initalization_offline_stream.py +0 -41
- tests/test_multiprocessing.py +0 -58
- tests/test_nan_values.py +0 -29
- tests/test_nm_filter.py +0 -95
- tests/test_nm_resample.py +0 -63
- tests/test_normalization_settings.py +0 -146
- tests/test_notch_filter.py +0 -31
- tests/test_osc_features.py +0 -424
- tests/test_preprocessing_filter.py +0 -151
- tests/test_rereference.py +0 -171
- tests/test_sampling.py +0 -57
- tests/test_settings_change_after_init.py +0 -76
- tests/test_sharpwave.py +0 -165
- tests/test_target_channel_add.py +0 -100
- tests/test_timing.py +0 -80
- {py_neuromodulation-0.0.3.dist-info → py_neuromodulation-0.0.4.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,104 +0,0 @@
|
|
|
1
|
-
import multiprocessing
|
|
2
|
-
import multiprocessing.synchronize
|
|
3
|
-
import pathlib
|
|
4
|
-
import pickle
|
|
5
|
-
import queue
|
|
6
|
-
import tkinter
|
|
7
|
-
import tkinter.filedialog
|
|
8
|
-
from datetime import datetime, timezone
|
|
9
|
-
|
|
10
|
-
import numpy as np
|
|
11
|
-
import pandas as pd
|
|
12
|
-
import pylsl
|
|
13
|
-
|
|
14
|
-
import realtime_decoding
|
|
15
|
-
|
|
16
|
-
from .helpers import _PathLike
|
|
17
|
-
|
|
18
|
-
_timezone = timezone.utc
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class Decoder(multiprocessing.Process):
|
|
22
|
-
"""Make predictions in real time."""
|
|
23
|
-
|
|
24
|
-
def __init__(
|
|
25
|
-
self,
|
|
26
|
-
queue_decoding: multiprocessing.Queue,
|
|
27
|
-
queue_features: multiprocessing.Queue,
|
|
28
|
-
interval: float,
|
|
29
|
-
out_dir: _PathLike,
|
|
30
|
-
verbose: bool,
|
|
31
|
-
model_path: _PathLike,
|
|
32
|
-
) -> None:
|
|
33
|
-
super().__init__(name="DecodingProcess")
|
|
34
|
-
self.queue_decoding = queue_decoding
|
|
35
|
-
self.queue_feat = queue_features
|
|
36
|
-
self.interval = interval
|
|
37
|
-
self.verbose = verbose
|
|
38
|
-
self.out_dir = pathlib.Path(out_dir)
|
|
39
|
-
|
|
40
|
-
self._threshold: float = 0.5
|
|
41
|
-
|
|
42
|
-
self.filename = pathlib.Path(model_path)
|
|
43
|
-
|
|
44
|
-
with open(self.filename, "rb") as file:
|
|
45
|
-
self._model = pickle.load(file)
|
|
46
|
-
self._save_model()
|
|
47
|
-
|
|
48
|
-
def _save_model(self) -> None:
|
|
49
|
-
with open(self.out_dir / self.filename.name, "wb") as file:
|
|
50
|
-
pickle.dump(self._model, file)
|
|
51
|
-
|
|
52
|
-
def clear_queue(self) -> None:
|
|
53
|
-
for q in (self.queue_feat, self.queue_decoding):
|
|
54
|
-
realtime_decoding.clear_queue(q)
|
|
55
|
-
|
|
56
|
-
def run(self) -> None:
|
|
57
|
-
labels = ["Prediction", "Probability", "Threshold"]
|
|
58
|
-
|
|
59
|
-
info = pylsl.StreamInfo(
|
|
60
|
-
name="Decoding",
|
|
61
|
-
type="EEG",
|
|
62
|
-
channel_count=3,
|
|
63
|
-
channel_format="double64",
|
|
64
|
-
source_id="decoding_1",
|
|
65
|
-
)
|
|
66
|
-
channels = info.desc().append_child("channels")
|
|
67
|
-
for label in labels:
|
|
68
|
-
channels.append_child("channel").append_child_value("label", label)
|
|
69
|
-
outlet = pylsl.StreamOutlet(info)
|
|
70
|
-
while True:
|
|
71
|
-
try:
|
|
72
|
-
sample = self.queue_feat.get(timeout=10.0)
|
|
73
|
-
except queue.Empty:
|
|
74
|
-
break
|
|
75
|
-
else:
|
|
76
|
-
if self.verbose:
|
|
77
|
-
print("Got features.")
|
|
78
|
-
if sample is None:
|
|
79
|
-
print("Found None value, terminating decoder process.")
|
|
80
|
-
break
|
|
81
|
-
|
|
82
|
-
# Predict
|
|
83
|
-
sample_ = sample[[i for i in sample.index if i != "label_train"]]
|
|
84
|
-
|
|
85
|
-
y = float(self._model.predict_proba(np.expand_dims(sample_.to_numpy(), 0))[0, 1])
|
|
86
|
-
print(f"pr: {y}")
|
|
87
|
-
|
|
88
|
-
timestamp = np.datetime64(datetime.now(_timezone), "ns")
|
|
89
|
-
|
|
90
|
-
output = pd.DataFrame(
|
|
91
|
-
[[y >= self._threshold, y, self._threshold]],
|
|
92
|
-
columns=labels,
|
|
93
|
-
index=[timestamp],
|
|
94
|
-
)
|
|
95
|
-
outlet.push_sample(
|
|
96
|
-
x=list(output.to_numpy().squeeze()),
|
|
97
|
-
timestamp=timestamp.astype(float),
|
|
98
|
-
)
|
|
99
|
-
try:
|
|
100
|
-
self.queue_decoding.put(None, timeout=3.0)
|
|
101
|
-
except queue.Full:
|
|
102
|
-
pass
|
|
103
|
-
self.clear_queue()
|
|
104
|
-
print(f"Terminating: {self.name}")
|
|
@@ -1,163 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import multiprocessing
|
|
3
|
-
import multiprocessing.synchronize
|
|
4
|
-
import pathlib
|
|
5
|
-
import queue
|
|
6
|
-
import tkinter
|
|
7
|
-
import tkinter.filedialog
|
|
8
|
-
from datetime import datetime
|
|
9
|
-
|
|
10
|
-
import numpy as np
|
|
11
|
-
import py_neuromodulation as nm
|
|
12
|
-
import pylsl
|
|
13
|
-
from numpy_ringbuffer import RingBuffer
|
|
14
|
-
|
|
15
|
-
import realtime_decoding
|
|
16
|
-
|
|
17
|
-
from .helpers import _PathLike
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class Features(multiprocessing.Process):
|
|
21
|
-
"""Process class to calculate features from LSL stream."""
|
|
22
|
-
|
|
23
|
-
def __init__(
|
|
24
|
-
self,
|
|
25
|
-
name: str,
|
|
26
|
-
source_id: str,
|
|
27
|
-
n_feats: int,
|
|
28
|
-
sfreq: int,
|
|
29
|
-
interval: float,
|
|
30
|
-
queue_raw: multiprocessing.Queue,
|
|
31
|
-
queue_features: multiprocessing.Queue,
|
|
32
|
-
path_nm_settings: _PathLike,
|
|
33
|
-
path_nm_channels: _PathLike,
|
|
34
|
-
out_dir: _PathLike,
|
|
35
|
-
verbose: bool,
|
|
36
|
-
path_grids: str | None = None,
|
|
37
|
-
line_noise: int | float | None = None,
|
|
38
|
-
training_samples: int = 60,
|
|
39
|
-
training_enabled: bool = False,
|
|
40
|
-
) -> None:
|
|
41
|
-
super().__init__(name=f"{name}Process")
|
|
42
|
-
self.interval = interval
|
|
43
|
-
self.sfreq = sfreq
|
|
44
|
-
self.queue_raw = queue_raw
|
|
45
|
-
self.queue_features = queue_features
|
|
46
|
-
self.verbose = verbose
|
|
47
|
-
self.out_dir = pathlib.Path(out_dir)
|
|
48
|
-
self.finished = multiprocessing.Event()
|
|
49
|
-
self.path_nm_settings = path_nm_settings
|
|
50
|
-
self.path_nm_channels = path_nm_channels
|
|
51
|
-
|
|
52
|
-
self.processor = nm.nm_run_analysis.DataProcessor(
|
|
53
|
-
sfreq=self.sfreq,
|
|
54
|
-
settings=path_nm_settings,
|
|
55
|
-
nm_channels=path_nm_channels,
|
|
56
|
-
line_noise=line_noise,
|
|
57
|
-
path_grids=path_grids,
|
|
58
|
-
verbose=self.verbose,
|
|
59
|
-
)
|
|
60
|
-
self.num_channels = len(self.processor.nm_channels)
|
|
61
|
-
self.buffer = RingBuffer(
|
|
62
|
-
capacity=self.sfreq,
|
|
63
|
-
dtype=(float, self.num_channels), # type: ignore
|
|
64
|
-
allow_overwrite=True,
|
|
65
|
-
)
|
|
66
|
-
# Channels * Number of different features
|
|
67
|
-
self.n_feats_total = (
|
|
68
|
-
sum(self.processor.nm_channels["used"] == 1) * n_feats
|
|
69
|
-
)
|
|
70
|
-
self.source_id = source_id
|
|
71
|
-
self.outlet = None
|
|
72
|
-
self._save_settings()
|
|
73
|
-
|
|
74
|
-
print(f"value of training enabled: {training_enabled}")
|
|
75
|
-
self.training_enabled = training_enabled
|
|
76
|
-
if training_enabled is True:
|
|
77
|
-
print("training is enabled")
|
|
78
|
-
self.training_counter = 0
|
|
79
|
-
self.training_samples = training_samples
|
|
80
|
-
self.training_class = 0 # REST
|
|
81
|
-
|
|
82
|
-
# the labels are sent as an additional LSL channel
|
|
83
|
-
self.n_feats_total = self.n_feats_total + 1
|
|
84
|
-
|
|
85
|
-
def _save_settings(self) -> None:
|
|
86
|
-
# print("SAVING DATA ....")
|
|
87
|
-
self.processor.nm_channels.to_csv(
|
|
88
|
-
self.out_dir / self.path_nm_channels, index=False
|
|
89
|
-
)
|
|
90
|
-
with open(
|
|
91
|
-
self.out_dir / self.path_nm_settings,
|
|
92
|
-
"w",
|
|
93
|
-
encoding="utf-8",
|
|
94
|
-
) as outfile:
|
|
95
|
-
json.dump(self.processor.settings, outfile)
|
|
96
|
-
|
|
97
|
-
def clear_queue(self) -> None:
|
|
98
|
-
realtime_decoding.clear_queue(self.queue_raw)
|
|
99
|
-
|
|
100
|
-
def run(self) -> None:
|
|
101
|
-
while True:
|
|
102
|
-
try:
|
|
103
|
-
sd = self.queue_raw.get(timeout=10.0)
|
|
104
|
-
# data = self.queue_raw.get(timeout=10.0)
|
|
105
|
-
except queue.Empty:
|
|
106
|
-
break
|
|
107
|
-
else:
|
|
108
|
-
# print("Got data")
|
|
109
|
-
if sd is None:
|
|
110
|
-
print("Found None value, terminating features process.")
|
|
111
|
-
break
|
|
112
|
-
if self.verbose:
|
|
113
|
-
print("Found raw input sample.")
|
|
114
|
-
# Reshape the samples retrieved from the queue
|
|
115
|
-
data = np.reshape(
|
|
116
|
-
sd.samples,
|
|
117
|
-
(sd.num_samples_per_sample_set, sd.num_sample_sets),
|
|
118
|
-
order="F",
|
|
119
|
-
)
|
|
120
|
-
# data = np.array(samples) # shape (time, ch)
|
|
121
|
-
self.buffer.extend(data.T)
|
|
122
|
-
if not self.buffer.is_full:
|
|
123
|
-
continue
|
|
124
|
-
features = self.processor.process(self.buffer[:].T)
|
|
125
|
-
timestamp = np.datetime64(datetime.utcnow(), "ns")
|
|
126
|
-
|
|
127
|
-
if self.training_enabled is True:
|
|
128
|
-
|
|
129
|
-
# the analog channel data is stored in self.buffer
|
|
130
|
-
# this channel can be added to the calculated features, and simply finished with escape
|
|
131
|
-
#print(self.buffer[:].T)
|
|
132
|
-
#print(f"buffer shape: {self.buffer.shape}")
|
|
133
|
-
features["label_train"] = np.mean(self.buffer[-409:, 24]) # get index from analog
|
|
134
|
-
try:
|
|
135
|
-
self.queue_features.put(features, timeout=self.interval)
|
|
136
|
-
except queue.Full:
|
|
137
|
-
if self.verbose:
|
|
138
|
-
print("Features queue Full. Skipping sample.")
|
|
139
|
-
if self.outlet is None:
|
|
140
|
-
info = pylsl.StreamInfo(
|
|
141
|
-
name=self.name,
|
|
142
|
-
type="EEG",
|
|
143
|
-
channel_count=self.n_feats_total,
|
|
144
|
-
nominal_srate=self.sfreq,
|
|
145
|
-
channel_format="double64",
|
|
146
|
-
source_id=self.source_id,
|
|
147
|
-
)
|
|
148
|
-
channels = info.desc().append_child("channels")
|
|
149
|
-
for label in features.index:
|
|
150
|
-
channels.append_child("channel").append_child_value(
|
|
151
|
-
"label", label
|
|
152
|
-
)
|
|
153
|
-
self.outlet = pylsl.StreamOutlet(info)
|
|
154
|
-
self.outlet.push_sample(
|
|
155
|
-
x=features.tolist(), timestamp=timestamp.astype(float)
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
try:
|
|
159
|
-
self.queue_features.put(None, timeout=3.0)
|
|
160
|
-
except queue.Full:
|
|
161
|
-
pass
|
|
162
|
-
self.clear_queue()
|
|
163
|
-
print(f"Terminating: {self.name}")
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import queue
|
|
3
|
-
|
|
4
|
-
_PathLike = str | os.PathLike
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def clear_queue(q) -> None:
|
|
8
|
-
print("Emptying queue.")
|
|
9
|
-
try:
|
|
10
|
-
while True:
|
|
11
|
-
q.get(block=False)
|
|
12
|
-
except queue.Empty:
|
|
13
|
-
print("Queue emptied.")
|
|
14
|
-
except ValueError: # Queue is already closed
|
|
15
|
-
print("Queue was already closed.")
|
|
@@ -1,345 +0,0 @@
|
|
|
1
|
-
import multiprocessing
|
|
2
|
-
import multiprocessing.synchronize
|
|
3
|
-
import os
|
|
4
|
-
import pathlib
|
|
5
|
-
import queue
|
|
6
|
-
import signal
|
|
7
|
-
import sys
|
|
8
|
-
import time
|
|
9
|
-
from pynput.keyboard import Key, Listener
|
|
10
|
-
from contextlib import contextmanager
|
|
11
|
-
from dataclasses import dataclass, field
|
|
12
|
-
import tkinter
|
|
13
|
-
import tkinter.filedialog
|
|
14
|
-
from typing import Generator, Literal
|
|
15
|
-
|
|
16
|
-
import TMSiFileFormats
|
|
17
|
-
import TMSiSDK
|
|
18
|
-
|
|
19
|
-
import realtime_decoding
|
|
20
|
-
|
|
21
|
-
from .helpers import _PathLike
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
@contextmanager
|
|
25
|
-
def open_tmsi_device(
|
|
26
|
-
out_dir: _PathLike,
|
|
27
|
-
tmsi_cfg_file: _PathLike,
|
|
28
|
-
verbose: bool = True,
|
|
29
|
-
) -> Generator[TMSiSDK.SagaDevice, None, None]:
|
|
30
|
-
|
|
31
|
-
out_dir = pathlib.Path(out_dir)
|
|
32
|
-
|
|
33
|
-
tmsi_cfg_file = pathlib.Path(tmsi_cfg_file)
|
|
34
|
-
device = None
|
|
35
|
-
try:
|
|
36
|
-
print("Initializing TMSi device...")
|
|
37
|
-
# Initialise the TMSi-SDK first before starting using it
|
|
38
|
-
TMSiSDK.initialize()
|
|
39
|
-
# Execute a device discovery. This returns a list of device-objects.
|
|
40
|
-
discovery_list = TMSiSDK.discover(
|
|
41
|
-
TMSiSDK.DeviceType.saga,
|
|
42
|
-
TMSiSDK.DeviceInterfaceType.docked,
|
|
43
|
-
TMSiSDK.DeviceInterfaceType.usb, # .network
|
|
44
|
-
)
|
|
45
|
-
if len(discovery_list) == 0:
|
|
46
|
-
raise ValueError(
|
|
47
|
-
"No TMSi device found. Please check your connections."
|
|
48
|
-
)
|
|
49
|
-
if len(discovery_list) > 1:
|
|
50
|
-
raise ValueError(
|
|
51
|
-
"More than one TMSi device found. Please check your"
|
|
52
|
-
f" connections. Found: {discovery_list}."
|
|
53
|
-
)
|
|
54
|
-
# Get the handle to the first discovered device.
|
|
55
|
-
device = discovery_list[0]
|
|
56
|
-
print(f"Found device: {device}")
|
|
57
|
-
device.open()
|
|
58
|
-
print("Connected to device.")
|
|
59
|
-
# cfg_file = TMSiSDK.get_config(saga_config)
|
|
60
|
-
device.load_config(tmsi_cfg_file)
|
|
61
|
-
TMSiSDK.xml_saga_config.xml_write_config(
|
|
62
|
-
filename=out_dir / tmsi_cfg_file.name, saga_config=device.config
|
|
63
|
-
)
|
|
64
|
-
if verbose:
|
|
65
|
-
print("\nThe active channels are : ")
|
|
66
|
-
for idx, ch in enumerate(device.channels):
|
|
67
|
-
print(
|
|
68
|
-
"[{0}] : [{1}] in [{2}]".format(idx, ch.name, ch.unit_name)
|
|
69
|
-
)
|
|
70
|
-
print("\nCurrent device configuration:")
|
|
71
|
-
print(
|
|
72
|
-
f"Base-sample-rate: \t\t\t{device.config.base_sample_rate} Hz"
|
|
73
|
-
)
|
|
74
|
-
print(f"Sample-rate: \t\t\t\t{device.config.sample_rate} Hz")
|
|
75
|
-
print(f"Reference Method: \t\t\t{device.config.reference_method}")
|
|
76
|
-
print(
|
|
77
|
-
f"Sync out configuration: \t{device.config.get_sync_out_config()}"
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
device.start_measurement()
|
|
81
|
-
if device is None:
|
|
82
|
-
raise ValueError("No TMSi device found!")
|
|
83
|
-
yield device
|
|
84
|
-
except TMSiSDK.TMSiError as error:
|
|
85
|
-
print("!!! TMSiError !!! : ", error.code)
|
|
86
|
-
if (
|
|
87
|
-
device is not None
|
|
88
|
-
and error.code == TMSiSDK.error.TMSiErrorCode.device_error
|
|
89
|
-
):
|
|
90
|
-
print(" => device error : ", hex(device.status.error))
|
|
91
|
-
TMSiSDK.DeviceErrorLookupTable(hex(device.status.error))
|
|
92
|
-
except Exception as exception:
|
|
93
|
-
if device is not None:
|
|
94
|
-
if device.status.state == TMSiSDK.DeviceState.sampling:
|
|
95
|
-
print("Stopping TMSi measurement...")
|
|
96
|
-
device.stop_measurement()
|
|
97
|
-
if device.status.state == TMSiSDK.DeviceState.connected:
|
|
98
|
-
print("Closing TMSi device...")
|
|
99
|
-
device.close()
|
|
100
|
-
raise exception
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
@contextmanager
|
|
104
|
-
def open_lsl_stream(
|
|
105
|
-
device,
|
|
106
|
-
) -> Generator[TMSiFileFormats.FileWriter, None, None]:
|
|
107
|
-
lsl_stream = TMSiFileFormats.FileWriter(
|
|
108
|
-
TMSiFileFormats.FileFormat.lsl, "SAGA"
|
|
109
|
-
)
|
|
110
|
-
try:
|
|
111
|
-
lsl_stream.open(device)
|
|
112
|
-
yield lsl_stream
|
|
113
|
-
except Exception as exception:
|
|
114
|
-
print("Closing LSL stream...")
|
|
115
|
-
lsl_stream.close()
|
|
116
|
-
raise exception
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
@contextmanager
|
|
120
|
-
def open_poly5_writer(
|
|
121
|
-
device,
|
|
122
|
-
out_file: _PathLike,
|
|
123
|
-
) -> Generator[TMSiFileFormats.file_writer.FileWriter, None, None]:
|
|
124
|
-
out_file = str(out_file)
|
|
125
|
-
file_writer = TMSiFileFormats.file_writer.FileWriter(
|
|
126
|
-
TMSiFileFormats.file_writer.FileFormat.poly5, out_file
|
|
127
|
-
)
|
|
128
|
-
try:
|
|
129
|
-
print("Opening poly5 writer")
|
|
130
|
-
file_writer.open(device)
|
|
131
|
-
print("Poly 5 writer opened")
|
|
132
|
-
yield file_writer
|
|
133
|
-
except Exception as exception:
|
|
134
|
-
print("Closing Poly5 file writer")
|
|
135
|
-
file_writer.close()
|
|
136
|
-
raise exception
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
@dataclass
|
|
140
|
-
class ProcessManager:
|
|
141
|
-
device: TMSiSDK.SagaDevice
|
|
142
|
-
lsl_stream: TMSiFileFormats.FileWriter
|
|
143
|
-
file_writer: TMSiFileFormats.FileWriter
|
|
144
|
-
out_dir: _PathLike
|
|
145
|
-
settings: dict
|
|
146
|
-
timeout: float = 0.05
|
|
147
|
-
verbose: bool = True
|
|
148
|
-
_terminated: bool = field(init=False, default=False)
|
|
149
|
-
|
|
150
|
-
def __enter__(self):
|
|
151
|
-
return self
|
|
152
|
-
|
|
153
|
-
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
154
|
-
if isinstance(exc_type, BaseException):
|
|
155
|
-
if not self._terminated:
|
|
156
|
-
self.terminate()
|
|
157
|
-
return False
|
|
158
|
-
|
|
159
|
-
def __post_init__(self) -> None:
|
|
160
|
-
self.out_dir = pathlib.Path(self.out_dir)
|
|
161
|
-
self.queue_source = multiprocessing.Queue(
|
|
162
|
-
int(self.timeout * 1000 * 20)
|
|
163
|
-
) # seconds/sample * ms/s * s
|
|
164
|
-
self.queue_raw = multiprocessing.Queue(int(self.timeout * 1000))
|
|
165
|
-
self.queue_features = multiprocessing.Queue(1)
|
|
166
|
-
self.queue_decoding = multiprocessing.Queue(1)
|
|
167
|
-
self.queues = [
|
|
168
|
-
self.queue_raw,
|
|
169
|
-
self.queue_features,
|
|
170
|
-
self.queue_decoding,
|
|
171
|
-
self.queue_source,
|
|
172
|
-
]
|
|
173
|
-
for q in self.queues:
|
|
174
|
-
q.cancel_join_thread()
|
|
175
|
-
|
|
176
|
-
def start(self) -> None:
|
|
177
|
-
def on_press(key) -> None:
|
|
178
|
-
pass
|
|
179
|
-
|
|
180
|
-
def on_release(key) -> Literal[False] | None:
|
|
181
|
-
if key == Key.esc:
|
|
182
|
-
print("Received stop key.")
|
|
183
|
-
self.terminate()
|
|
184
|
-
return False
|
|
185
|
-
|
|
186
|
-
listener = Listener(on_press=on_press, on_release=on_release)
|
|
187
|
-
listener.start()
|
|
188
|
-
print("Listener started.")
|
|
189
|
-
|
|
190
|
-
TMSiSDK.sample_data_server.registerConsumer(
|
|
191
|
-
self.device.id, self.queue_source
|
|
192
|
-
)
|
|
193
|
-
features = realtime_decoding.Features(
|
|
194
|
-
name="Features",
|
|
195
|
-
source_id="features_1",
|
|
196
|
-
n_feats=7,
|
|
197
|
-
sfreq=self.device.config.sample_rate,
|
|
198
|
-
interval=self.timeout,
|
|
199
|
-
queue_raw=self.queue_source,
|
|
200
|
-
queue_features=self.queue_features,
|
|
201
|
-
path_nm_settings=self.settings["PATH_pynm_SETTINGS"],
|
|
202
|
-
path_nm_channels=self.settings["PATH_pynm_CHANNELS"],
|
|
203
|
-
out_dir=self.out_dir,
|
|
204
|
-
path_grids=None,
|
|
205
|
-
line_noise=50,
|
|
206
|
-
verbose=self.verbose,
|
|
207
|
-
training_enabled=True,
|
|
208
|
-
training_samples=20*50
|
|
209
|
-
)
|
|
210
|
-
decoder = realtime_decoding.Decoder(
|
|
211
|
-
queue_decoding=self.queue_decoding,
|
|
212
|
-
queue_features=self.queue_features,
|
|
213
|
-
interval=self.timeout,
|
|
214
|
-
out_dir=self.out_dir,
|
|
215
|
-
verbose=self.verbose,
|
|
216
|
-
model_path=self.settings["PATH_MODEL_PREDICT"]
|
|
217
|
-
)
|
|
218
|
-
processes = [features, decoder]
|
|
219
|
-
for process in processes:
|
|
220
|
-
process.start()
|
|
221
|
-
time.sleep(0.5)
|
|
222
|
-
print("Decoding started.")
|
|
223
|
-
|
|
224
|
-
def terminate(self) -> None:
|
|
225
|
-
"""Terminate all workers."""
|
|
226
|
-
print("Terminating experiment...")
|
|
227
|
-
self._terminated = True
|
|
228
|
-
try:
|
|
229
|
-
self.queue_source.put(None, block=False)
|
|
230
|
-
except queue.Full:
|
|
231
|
-
self.queue_source.get(block=False)
|
|
232
|
-
try:
|
|
233
|
-
self.queue_source.put(None, block=False)
|
|
234
|
-
except queue.Full:
|
|
235
|
-
pass
|
|
236
|
-
print("Set terminating event.")
|
|
237
|
-
TMSiSDK.sample_data_server.unregisterConsumer(
|
|
238
|
-
self.device.id, self.queue_source
|
|
239
|
-
)
|
|
240
|
-
print("Unregistered consumer.")
|
|
241
|
-
|
|
242
|
-
try:
|
|
243
|
-
self.lsl_stream.close()
|
|
244
|
-
except Exception:
|
|
245
|
-
pass
|
|
246
|
-
try:
|
|
247
|
-
self.file_writer.close()
|
|
248
|
-
except Exception:
|
|
249
|
-
pass
|
|
250
|
-
if self.device.status.state == TMSiSDK.DeviceState.sampling:
|
|
251
|
-
self.device.stop_measurement()
|
|
252
|
-
print("Controlled stopping TMSi measurement...")
|
|
253
|
-
if self.device.status.state == TMSiSDK.DeviceState.connected:
|
|
254
|
-
self.device.close()
|
|
255
|
-
print("Controlled closing TMSi device...")
|
|
256
|
-
|
|
257
|
-
# Check if all processes have terminated
|
|
258
|
-
active_children = multiprocessing.active_children()
|
|
259
|
-
if not active_children:
|
|
260
|
-
print("No alive processes found.")
|
|
261
|
-
sys.exit()
|
|
262
|
-
|
|
263
|
-
# Wait for processes to temrinate on their own
|
|
264
|
-
print(f"Alive processes: {list(p.name for p in active_children)}")
|
|
265
|
-
print("Waiting for processes to finish. Please wait...")
|
|
266
|
-
self.wait(active_children, timeout=5)
|
|
267
|
-
active_children = multiprocessing.active_children()
|
|
268
|
-
if not active_children:
|
|
269
|
-
print("No alive processes found.")
|
|
270
|
-
sys.exit()
|
|
271
|
-
|
|
272
|
-
# Try flushing all queues
|
|
273
|
-
print(f"Alive processes: {(p.name for p in active_children)}")
|
|
274
|
-
print("Flushing all queues. Please wait...")
|
|
275
|
-
for queue_ in self.queues:
|
|
276
|
-
realtime_decoding.clear_queue(queue_)
|
|
277
|
-
self.wait(active_children, timeout=5)
|
|
278
|
-
active_children = multiprocessing.active_children()
|
|
279
|
-
if not active_children:
|
|
280
|
-
print("No alive processes found.")
|
|
281
|
-
sys.exit()
|
|
282
|
-
|
|
283
|
-
# Try killing all processes gracefully
|
|
284
|
-
print(f"Alive processes: {(p.name for p in active_children)}")
|
|
285
|
-
print("Trying to kill processes gracefully. Please wait...")
|
|
286
|
-
interrupt = (
|
|
287
|
-
signal.CTRL_C_EVENT if sys.platform == "win32" else signal.SIGINT
|
|
288
|
-
)
|
|
289
|
-
for process in active_children:
|
|
290
|
-
if process.is_alive():
|
|
291
|
-
os.kill(process.pid, interrupt)
|
|
292
|
-
self.wait(active_children, timeout=5)
|
|
293
|
-
active_children = multiprocessing.active_children()
|
|
294
|
-
if not active_children:
|
|
295
|
-
print("No alive processes found.")
|
|
296
|
-
sys.exit()
|
|
297
|
-
|
|
298
|
-
# Try forcefully terminating processes
|
|
299
|
-
print(f"Alive processes: {(p.name for p in active_children)}")
|
|
300
|
-
print("Terminating processes forcefully.")
|
|
301
|
-
for process in active_children:
|
|
302
|
-
if process.is_alive():
|
|
303
|
-
process.terminate()
|
|
304
|
-
sys.exit()
|
|
305
|
-
|
|
306
|
-
@staticmethod
|
|
307
|
-
def wait(processes, timeout=None) -> None:
|
|
308
|
-
"""Wait for all workers to die."""
|
|
309
|
-
if not processes:
|
|
310
|
-
return
|
|
311
|
-
start = time.time()
|
|
312
|
-
while True:
|
|
313
|
-
try:
|
|
314
|
-
if all(not process.is_alive() for process in processes):
|
|
315
|
-
return
|
|
316
|
-
if timeout and time.time() - start >= timeout:
|
|
317
|
-
return
|
|
318
|
-
time.sleep(0.1)
|
|
319
|
-
except Exception:
|
|
320
|
-
pass
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def run(
|
|
324
|
-
config_settings: dict,
|
|
325
|
-
) -> None:
|
|
326
|
-
"""Initialize data processing by launching all necessary processes."""
|
|
327
|
-
out_dir = pathlib.Path(config_settings["PATH_OUT_DIR"])
|
|
328
|
-
file_name = config_settings["filename"]
|
|
329
|
-
|
|
330
|
-
with (
|
|
331
|
-
open_tmsi_device(out_dir, config_settings["PATH_XML_CONFIG"]) as device,
|
|
332
|
-
open_poly5_writer(device, out_dir / file_name) as file_writer,
|
|
333
|
-
open_lsl_stream(device) as stream,
|
|
334
|
-
):
|
|
335
|
-
manager = ProcessManager(
|
|
336
|
-
device=device,
|
|
337
|
-
lsl_stream=stream,
|
|
338
|
-
file_writer=file_writer,
|
|
339
|
-
out_dir=out_dir,
|
|
340
|
-
timeout=0.05,
|
|
341
|
-
verbose=False,
|
|
342
|
-
settings=config_settings
|
|
343
|
-
)
|
|
344
|
-
|
|
345
|
-
manager.start()
|
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Read saved features from timeflux .hdf5 file and train model
|
|
3
|
-
"""
|
|
4
|
-
from sklearn import linear_model, model_selection, metrics
|
|
5
|
-
import pickle
|
|
6
|
-
import pandas as pd
|
|
7
|
-
import os
|
|
8
|
-
import numpy as np
|
|
9
|
-
#from matplotlib import pyplot as plt
|
|
10
|
-
|
|
11
|
-
if __name__ == "__main__":
|
|
12
|
-
|
|
13
|
-
sub = "487_train"
|
|
14
|
-
|
|
15
|
-
PATH_HDF5_FEATURES = rf"C:\CODE\py_neuromodulation\realtime_experiment\data\sub-{sub}\ses-EcogLfpMedOff01\sub-{sub}_ses-EcogLfpMedOff01_task-RealtimeDecodingR_acq-StimOff_run-1_ieeg.hdf5"
|
|
16
|
-
PATH_MODEL_SAVE = os.path.join(
|
|
17
|
-
rf"C:\CODE\py_neuromodulation\realtime_experiment\data\sub-{sub}\ses-EcogLfpMedOff01",
|
|
18
|
-
"model_trained.p"
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
df = pd.read_hdf(PATH_HDF5_FEATURES, key="features")
|
|
22
|
-
|
|
23
|
-
y = np.abs(np.diff(df["label_train"]))
|
|
24
|
-
X = df[[f for f in df.columns if "time" not in f and "label" not in f]].iloc[1:, :]
|
|
25
|
-
|
|
26
|
-
#from matplotlib import pyplot as plt
|
|
27
|
-
#plt.figure()
|
|
28
|
-
#plt.plot(np.array(df["label_train"]))
|
|
29
|
-
#plt.plot(y)
|
|
30
|
-
#plt.show()
|
|
31
|
-
|
|
32
|
-
X_lim = X.iloc[850:, :]
|
|
33
|
-
y_lim = y[850:]
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
model = linear_model.LogisticRegression()
|
|
37
|
-
|
|
38
|
-
model = model.fit(X, y>0.01)
|
|
39
|
-
model = model.fit(X_lim, y_lim>0.01)
|
|
40
|
-
|
|
41
|
-
with open(PATH_MODEL_SAVE, "wb") as fid:
|
|
42
|
-
pickle.dump(model, fid)
|
|
43
|
-
|
|
44
|
-
pr = model_selection.cross_val_predict(
|
|
45
|
-
estimator=linear_model.LinearRegression(),
|
|
46
|
-
X=X_lim,
|
|
47
|
-
y=y_lim>0.01,
|
|
48
|
-
cv=model_selection.KFold(n_splits=3, shuffle=False)
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
plt.figure()
|
|
52
|
-
plt.plot(pr)
|
|
53
|
-
plt.plot(y_lim)
|
|
54
|
-
plt.show()
|