py-neuromodulation 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
- py_neuromodulation/__init__.py +12 -4
- py_neuromodulation/analysis/RMAP.py +3 -3
- py_neuromodulation/analysis/decode.py +55 -2
- py_neuromodulation/analysis/feature_reader.py +1 -0
- py_neuromodulation/analysis/stats.py +3 -3
- py_neuromodulation/default_settings.yaml +25 -17
- py_neuromodulation/features/bandpower.py +65 -23
- py_neuromodulation/features/bispectra.py +3 -7
- py_neuromodulation/features/bursts.py +9 -8
- py_neuromodulation/features/coherence.py +17 -9
- py_neuromodulation/features/feature_processor.py +4 -4
- py_neuromodulation/features/fooof.py +7 -6
- py_neuromodulation/features/mne_connectivity.py +25 -3
- py_neuromodulation/features/oscillatory.py +5 -4
- py_neuromodulation/features/sharpwaves.py +21 -0
- py_neuromodulation/filter/kalman_filter.py +17 -6
- py_neuromodulation/gui/__init__.py +3 -0
- py_neuromodulation/gui/backend/app_backend.py +419 -0
- py_neuromodulation/gui/backend/app_manager.py +345 -0
- py_neuromodulation/gui/backend/app_pynm.py +244 -0
- py_neuromodulation/gui/backend/app_socket.py +95 -0
- py_neuromodulation/gui/backend/app_utils.py +306 -0
- py_neuromodulation/gui/backend/app_window.py +202 -0
- py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
- py_neuromodulation/gui/frontend/assets/index-NbJiOU5a.js +300133 -0
- py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
- py_neuromodulation/gui/frontend/charite.svg +16 -0
- py_neuromodulation/gui/frontend/index.html +14 -0
- py_neuromodulation/gui/window_api.py +115 -0
- py_neuromodulation/lsl_api.cfg +3 -0
- py_neuromodulation/processing/data_preprocessor.py +9 -2
- py_neuromodulation/processing/filter_preprocessing.py +43 -27
- py_neuromodulation/processing/normalization.py +32 -17
- py_neuromodulation/processing/projection.py +2 -2
- py_neuromodulation/processing/resample.py +6 -2
- py_neuromodulation/run_gui.py +36 -0
- py_neuromodulation/stream/__init__.py +7 -1
- py_neuromodulation/stream/backend_interface.py +47 -0
- py_neuromodulation/stream/data_processor.py +24 -3
- py_neuromodulation/stream/mnelsl_player.py +121 -21
- py_neuromodulation/stream/mnelsl_stream.py +9 -17
- py_neuromodulation/stream/settings.py +80 -34
- py_neuromodulation/stream/stream.py +82 -62
- py_neuromodulation/utils/channels.py +1 -1
- py_neuromodulation/utils/file_writer.py +110 -0
- py_neuromodulation/utils/io.py +46 -5
- py_neuromodulation/utils/perf.py +156 -0
- py_neuromodulation/utils/pydantic_extensions.py +322 -0
- py_neuromodulation/utils/types.py +33 -107
- {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/METADATA +27 -22
- {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/RECORD +56 -36
- {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/WHEEL +1 -1
- py_neuromodulation-0.1.0.dist-info/entry_points.txt +2 -0
- {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from collections.abc import Iterable
|
|
2
2
|
import numpy as np
|
|
3
|
-
from typing import TYPE_CHECKING
|
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
|
4
4
|
|
|
5
5
|
from py_neuromodulation.utils.types import NMFeature, NMBaseModel
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
6
7
|
|
|
7
8
|
if TYPE_CHECKING:
|
|
8
9
|
from py_neuromodulation import NMSettings
|
|
@@ -10,9 +11,30 @@ if TYPE_CHECKING:
|
|
|
10
11
|
from mne import Epochs
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
MNE_CONNECTIVITY_METHOD = Literal[
|
|
15
|
+
"coh",
|
|
16
|
+
"cohy",
|
|
17
|
+
"imcoh",
|
|
18
|
+
"cacoh",
|
|
19
|
+
"mic",
|
|
20
|
+
"mim",
|
|
21
|
+
"plv",
|
|
22
|
+
"ciplv",
|
|
23
|
+
"ppc",
|
|
24
|
+
"pli",
|
|
25
|
+
"dpli",
|
|
26
|
+
"wpli",
|
|
27
|
+
"wpli2_debiased",
|
|
28
|
+
"gc",
|
|
29
|
+
"gc_tr",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"]
|
|
33
|
+
|
|
34
|
+
|
|
13
35
|
class MNEConnectivitySettings(NMBaseModel):
|
|
14
|
-
method:
|
|
15
|
-
mode:
|
|
36
|
+
method: MNE_CONNECTIVITY_METHOD = NMField(default="plv")
|
|
37
|
+
mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper")
|
|
16
38
|
|
|
17
39
|
|
|
18
40
|
class MNEConnectivity(NMFeature):
|
|
@@ -3,6 +3,7 @@ import numpy as np
|
|
|
3
3
|
from itertools import product
|
|
4
4
|
|
|
5
5
|
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import NMField
|
|
6
7
|
from typing import TYPE_CHECKING
|
|
7
8
|
|
|
8
9
|
if TYPE_CHECKING:
|
|
@@ -17,12 +18,12 @@ class OscillatoryFeatures(BoolSelector):
|
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class OscillatorySettings(NMBaseModel):
|
|
20
|
-
windowlength_ms: int = 1000
|
|
21
|
+
windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"})
|
|
21
22
|
log_transform: bool = True
|
|
22
23
|
features: OscillatoryFeatures = OscillatoryFeatures(
|
|
23
24
|
mean=True, median=False, std=False, max=False
|
|
24
25
|
)
|
|
25
|
-
return_spectrum: bool =
|
|
26
|
+
return_spectrum: bool = True
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
ESTIMATOR_DICT = {
|
|
@@ -176,7 +177,7 @@ class Welch(OscillatoryFeature):
|
|
|
176
177
|
if self.settings.return_spectrum:
|
|
177
178
|
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
178
179
|
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
179
|
-
feature_results[f"{ch_name}_welch_psd_{
|
|
180
|
+
feature_results[f"{ch_name}_welch_psd_{int(f)}"] = Z[ch_idx][idx]
|
|
180
181
|
|
|
181
182
|
return feature_results
|
|
182
183
|
|
|
@@ -242,7 +243,7 @@ class STFT(OscillatoryFeature):
|
|
|
242
243
|
if self.settings.return_spectrum:
|
|
243
244
|
combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
|
|
244
245
|
for (ch_idx, ch_name), (idx, f) in combinations:
|
|
245
|
-
feature_results[f"{ch_name}_stft_psd_{
|
|
246
|
+
feature_results[f"{ch_name}_stft_psd_{int(f)}"] = Z[ch_idx].mean(
|
|
246
247
|
axis=1
|
|
247
248
|
)[idx]
|
|
248
249
|
|
|
@@ -267,6 +267,14 @@ class SharpwaveAnalyzer(NMFeature):
|
|
|
267
267
|
|
|
268
268
|
# for each feature take the respective fun.
|
|
269
269
|
for feature_name, estimator_name, estimator in estimator_combinations:
|
|
270
|
+
if feature_name == "num_peaks":
|
|
271
|
+
key_name = f"{ch_name}_Sharpwave_{feature_name}_{filter_name}"
|
|
272
|
+
if len(waveform_results[feature_name]) == 1:
|
|
273
|
+
dict_ch_features[key_name][key_name_pt] = waveform_results[feature_name][0]
|
|
274
|
+
continue
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError("num_peaks should be a list with length 1")
|
|
277
|
+
# there can be only one num_peak in each batch
|
|
270
278
|
feature_data = waveform_results[feature_name]
|
|
271
279
|
key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
|
|
272
280
|
|
|
@@ -280,12 +288,25 @@ class SharpwaveAnalyzer(NMFeature):
|
|
|
280
288
|
|
|
281
289
|
# the key_name stays, since the estimator function stays between peaks and troughs
|
|
282
290
|
for key_name, estimator in self.estimator_key_map.items():
|
|
291
|
+
if len(dict_ch_features[key_name]) == 0:
|
|
292
|
+
# might happen if num_peaks was written in estimator
|
|
293
|
+
# e.g. estimator["mean"] = ["num_peaks"]
|
|
294
|
+
# for conveniance this doesn't raise an exception
|
|
295
|
+
continue
|
|
296
|
+
|
|
283
297
|
feature_results[key_name] = estimator(
|
|
284
298
|
[
|
|
285
299
|
list(dict_ch_features[key_name].values())[0],
|
|
286
300
|
list(dict_ch_features[key_name].values())[1],
|
|
287
301
|
]
|
|
288
302
|
)
|
|
303
|
+
# add here also the num_peaks features
|
|
304
|
+
if self.sw_settings.sharpwave_features.num_peaks:
|
|
305
|
+
for ch_name in self.ch_names:
|
|
306
|
+
for filter_name in self.filter_names:
|
|
307
|
+
key_name = f"{ch_name}_Sharpwave_num_peaks_{filter_name}"
|
|
308
|
+
feature_results[key_name] = np_mean([dict_ch_features[key_name]["Peak"],
|
|
309
|
+
dict_ch_features[key_name]["Trough"]])
|
|
289
310
|
else:
|
|
290
311
|
# otherwise, save all write all "flattened" key value pairs in feature_results
|
|
291
312
|
for key, subdict in dict_ch_features.items():
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from typing import TYPE_CHECKING
|
|
3
3
|
|
|
4
|
+
|
|
4
5
|
from py_neuromodulation.utils.types import NMBaseModel
|
|
6
|
+
from py_neuromodulation.utils.pydantic_extensions import NMErrorList
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
if TYPE_CHECKING:
|
|
@@ -22,13 +24,22 @@ class KalmanSettings(NMBaseModel):
|
|
|
22
24
|
"HFA",
|
|
23
25
|
]
|
|
24
26
|
|
|
25
|
-
def validate_fbands(self, settings: "NMSettings") ->
|
|
26
|
-
|
|
27
|
+
def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
|
|
28
|
+
errors: NMErrorList = NMErrorList()
|
|
29
|
+
|
|
30
|
+
if not all(
|
|
27
31
|
(item in settings.frequency_ranges_hz for item in self.frequency_bands)
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
+
):
|
|
33
|
+
errors.add_error(
|
|
34
|
+
"Frequency bands for Kalman filter must also be specified in "
|
|
35
|
+
"frequency_ranges_hz.",
|
|
36
|
+
location=[
|
|
37
|
+
"kalman_filter_settings",
|
|
38
|
+
"frequency_bands",
|
|
39
|
+
],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return errors
|
|
32
43
|
|
|
33
44
|
|
|
34
45
|
def define_KF(Tp, sigma_w, sigma_v):
|
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import importlib.metadata
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
from fastapi import (
|
|
8
|
+
FastAPI,
|
|
9
|
+
HTTPException,
|
|
10
|
+
Query,
|
|
11
|
+
WebSocket,
|
|
12
|
+
)
|
|
13
|
+
from fastapi.staticfiles import StaticFiles
|
|
14
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
15
|
+
from pydantic import ValidationError
|
|
16
|
+
|
|
17
|
+
from . import app_pynm
|
|
18
|
+
from .app_socket import WebsocketManager
|
|
19
|
+
from .app_utils import is_hidden, get_quick_access
|
|
20
|
+
import pandas as pd
|
|
21
|
+
|
|
22
|
+
from py_neuromodulation import PYNM_DIR, NMSettings
|
|
23
|
+
from py_neuromodulation.utils.types import FileInfo
|
|
24
|
+
|
|
25
|
+
# TODO: maybe pull this list from the MNE package?
|
|
26
|
+
ALLOWED_EXTENSIONS = [".npy", ".vhdr", ".fif", ".edf", ".bdf"]
|
|
27
|
+
|
|
28
|
+
DEV_SERVER_PORT = 54321
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PyNMBackend(FastAPI):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
debug: bool | None = None,
|
|
35
|
+
dev: bool | None = None,
|
|
36
|
+
dev_port: int | None = None,
|
|
37
|
+
fastapi_kwargs: dict = {},
|
|
38
|
+
) -> None:
|
|
39
|
+
if debug is None:
|
|
40
|
+
self.debug = os.environ.get("PYNM_DEBUG", "False").lower() == "true"
|
|
41
|
+
if dev is None:
|
|
42
|
+
self.dev = os.environ.get("PYNM_DEV", "False").lower() == "true"
|
|
43
|
+
if dev_port is None:
|
|
44
|
+
self.dev_port = os.environ.get("PYNM_DEV_PORT", str(DEV_SERVER_PORT))
|
|
45
|
+
|
|
46
|
+
super().__init__(
|
|
47
|
+
title="PyNeuromodulation",
|
|
48
|
+
description="PyNeuromodulation FastAPI backend",
|
|
49
|
+
version=importlib.metadata.version("py_neuromodulation"),
|
|
50
|
+
debug=self.debug,
|
|
51
|
+
**fastapi_kwargs,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Use the FastAPI logger for the backend
|
|
55
|
+
self.logger = logging.getLogger("uvicorn.error")
|
|
56
|
+
self.logger.warning(PYNM_DIR)
|
|
57
|
+
|
|
58
|
+
if self.dev:
|
|
59
|
+
cors_origins = (
|
|
60
|
+
["http://localhost:" + str(self.dev_port)] if self.dev else []
|
|
61
|
+
)
|
|
62
|
+
# Configure CORS
|
|
63
|
+
self.add_middleware(
|
|
64
|
+
CORSMiddleware,
|
|
65
|
+
allow_origins=cors_origins,
|
|
66
|
+
allow_credentials=True,
|
|
67
|
+
allow_methods=["*"],
|
|
68
|
+
allow_headers=["*"],
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Has to be before mounting static files
|
|
72
|
+
self.setup_routes()
|
|
73
|
+
|
|
74
|
+
# Serve static files
|
|
75
|
+
if not self.dev:
|
|
76
|
+
self.logger.info("Serving static files")
|
|
77
|
+
self.mount(
|
|
78
|
+
"/",
|
|
79
|
+
StaticFiles(directory=PYNM_DIR / "gui" / "frontend", html=True),
|
|
80
|
+
name="static",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.websocket_manager = WebsocketManager()
|
|
84
|
+
self.pynm_state = app_pynm.PyNMState()
|
|
85
|
+
|
|
86
|
+
def setup_routes(self):
|
|
87
|
+
@self.get("/api/health")
|
|
88
|
+
async def healthcheck():
|
|
89
|
+
return {"message": "API is working"}
|
|
90
|
+
|
|
91
|
+
####################
|
|
92
|
+
##### SETTINGS #####
|
|
93
|
+
####################
|
|
94
|
+
@self.get("/api/settings")
|
|
95
|
+
async def get_settings(
|
|
96
|
+
reset: bool = Query(False, description="Reset settings to default"),
|
|
97
|
+
):
|
|
98
|
+
if reset:
|
|
99
|
+
settings = NMSettings.get_default()
|
|
100
|
+
else:
|
|
101
|
+
settings = self.pynm_state.stream.settings
|
|
102
|
+
|
|
103
|
+
return settings.serialize_with_metadata()
|
|
104
|
+
|
|
105
|
+
@self.post("/api/settings")
|
|
106
|
+
async def update_settings(data: dict, validate_only: bool = Query(False)):
|
|
107
|
+
try:
|
|
108
|
+
# First, validate with Pydantic
|
|
109
|
+
try:
|
|
110
|
+
# TODO: check if this works properly or needs model_validate_strings
|
|
111
|
+
validated_settings = NMSettings.model_validate(data)
|
|
112
|
+
except ValidationError as e:
|
|
113
|
+
self.logger.error(f"Error validating settings: {e}")
|
|
114
|
+
if not validate_only:
|
|
115
|
+
# If validation failed but we wanted to upload, return error
|
|
116
|
+
raise HTTPException(
|
|
117
|
+
status_code=422,
|
|
118
|
+
detail={
|
|
119
|
+
"error": "Error validating settings",
|
|
120
|
+
"details": str(e),
|
|
121
|
+
},
|
|
122
|
+
)
|
|
123
|
+
# Else return list of errors
|
|
124
|
+
return {
|
|
125
|
+
"valid": False,
|
|
126
|
+
"errors": [err for err in e.errors()],
|
|
127
|
+
"details": str(e),
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
# If validation succesful, return or update settings
|
|
131
|
+
if validate_only:
|
|
132
|
+
return {
|
|
133
|
+
"valid": True,
|
|
134
|
+
"settings": validated_settings.serialize_with_metadata(),
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
self.pynm_state.stream.settings = validated_settings
|
|
138
|
+
self.logger.info("Settings successfully updated")
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"valid": True,
|
|
142
|
+
"settings": self.pynm_state.stream.settings.serialize_with_metadata(),
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
# If something else than validation went wrong, return error
|
|
146
|
+
except Exception as e:
|
|
147
|
+
self.logger.error(f"Error validating/updating settings: {e}")
|
|
148
|
+
raise HTTPException(
|
|
149
|
+
status_code=422,
|
|
150
|
+
detail={"error": "Error uploading settings", "details": str(e)},
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
########################
|
|
154
|
+
##### PYNM CONTROL #####
|
|
155
|
+
########################
|
|
156
|
+
|
|
157
|
+
@self.post("/api/stream-control")
|
|
158
|
+
async def handle_stream_control(data: dict):
|
|
159
|
+
action = data["action"]
|
|
160
|
+
if action == "start":
|
|
161
|
+
# TODO: create out_dir and experiment_name text filds in frontend
|
|
162
|
+
self.logger.info("websocket:")
|
|
163
|
+
self.logger.info(self.websocket_manager)
|
|
164
|
+
self.logger.info("Starting stream")
|
|
165
|
+
|
|
166
|
+
self.pynm_state.start_run_function(
|
|
167
|
+
websocket_manager=self.websocket_manager,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if action == "stop":
|
|
171
|
+
self.logger.info("Stopping stream")
|
|
172
|
+
self.pynm_state.stop_run_function()
|
|
173
|
+
|
|
174
|
+
return {"message": f"Stream action '{action}' executed"}
|
|
175
|
+
|
|
176
|
+
####################
|
|
177
|
+
##### CHANNELS #####
|
|
178
|
+
####################
|
|
179
|
+
|
|
180
|
+
@self.get("/api/channels")
|
|
181
|
+
async def get_channels():
|
|
182
|
+
channels = self.pynm_state.stream.channels
|
|
183
|
+
self.logger.info(f"Sending channels: {channels}")
|
|
184
|
+
if isinstance(channels, pd.DataFrame):
|
|
185
|
+
return {"channels": channels.to_dict(orient="records")}
|
|
186
|
+
else:
|
|
187
|
+
raise HTTPException(
|
|
188
|
+
status_code=422,
|
|
189
|
+
detail={"error": "Channels is not a DataFrame"},
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
@self.post("/api/channels")
|
|
193
|
+
async def update_channels(data: dict):
|
|
194
|
+
try:
|
|
195
|
+
new_channels = pd.DataFrame(data["channels"])
|
|
196
|
+
self.logger.info(f"Received channels:\n {new_channels}")
|
|
197
|
+
self.pynm_state.stream.channels = new_channels
|
|
198
|
+
return {
|
|
199
|
+
"channels": self.pynm_state.stream.channels.to_dict(
|
|
200
|
+
orient="records"
|
|
201
|
+
)
|
|
202
|
+
}
|
|
203
|
+
except Exception as e:
|
|
204
|
+
raise HTTPException(
|
|
205
|
+
status_code=422,
|
|
206
|
+
detail={"error": "Error updating channels", "details": str(e)},
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
###################
|
|
210
|
+
### LSL STREAMS ###
|
|
211
|
+
###################
|
|
212
|
+
|
|
213
|
+
@self.get("/api/LSL-streams")
|
|
214
|
+
async def get_lsl_streams():
|
|
215
|
+
from mne_lsl.lsl import resolve_streams
|
|
216
|
+
|
|
217
|
+
return {
|
|
218
|
+
"message": [
|
|
219
|
+
{
|
|
220
|
+
"dtype": # MNE-LSL might return a class, so we get the name
|
|
221
|
+
getattr(stream.dtype, "__name__", str(stream.dtype)),
|
|
222
|
+
"name": stream.name,
|
|
223
|
+
"n_channels": stream.n_channels,
|
|
224
|
+
"sfreq": stream.sfreq,
|
|
225
|
+
"source_id": stream.source_id,
|
|
226
|
+
"stype": stream.stype, # Stream type (e.g. EEG)
|
|
227
|
+
"created_at": stream.created_at,
|
|
228
|
+
"hostname": stream.hostname,
|
|
229
|
+
"session_id": stream.session_id,
|
|
230
|
+
"uid": stream.uid,
|
|
231
|
+
"protocol_version": stream.protocol_version,
|
|
232
|
+
}
|
|
233
|
+
for stream in resolve_streams()
|
|
234
|
+
]
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
@self.post("/api/setup-LSL-stream")
|
|
238
|
+
async def setup_lsl_stream(data: dict):
|
|
239
|
+
try:
|
|
240
|
+
stream_name = data["stream_name"]
|
|
241
|
+
self.logger.info(f"Attempting to setup LSL stream: '{stream_name}'")
|
|
242
|
+
self.pynm_state.setup_lsl_stream(
|
|
243
|
+
lsl_stream_name=stream_name,
|
|
244
|
+
line_noise=data["line_noise"],
|
|
245
|
+
)
|
|
246
|
+
return {"message": f"LSL stream '{stream_name}' setup successfully"}
|
|
247
|
+
except Exception as e:
|
|
248
|
+
return {
|
|
249
|
+
"message": "LSL stream could not be setup",
|
|
250
|
+
"error": str(e),
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
@self.post("/api/setup-Offline-stream")
|
|
254
|
+
async def setup_offline_stream(data: dict):
|
|
255
|
+
self.logger.info("Data received to setup offline stream:")
|
|
256
|
+
self.logger.info(data)
|
|
257
|
+
try:
|
|
258
|
+
self.pynm_state.setup_offline_stream(
|
|
259
|
+
file_path=data["file_path"],
|
|
260
|
+
line_noise=float(data["line_noise"]),
|
|
261
|
+
)
|
|
262
|
+
return {"message": "Offline stream setup successfully"}
|
|
263
|
+
except ValueError:
|
|
264
|
+
return {"message": "Offline stream could not be setup"}
|
|
265
|
+
|
|
266
|
+
@self.post("/api/set-stream-params")
|
|
267
|
+
async def set_stream_params(data: dict):
|
|
268
|
+
try:
|
|
269
|
+
self.pynm_state.stream.line_noise = float(data["line_noise"])
|
|
270
|
+
self.pynm_state.stream.sfreq = float(data["sampling_rate"])
|
|
271
|
+
self.pynm_state.experiment_name = data["experiment_name"]
|
|
272
|
+
self.pynm_state.out_dir = data["out_dir"]
|
|
273
|
+
self.pynm_state.decoding_model_path = data["decoding_path"]
|
|
274
|
+
|
|
275
|
+
return {"message": "Stream parameters updated successfully"}
|
|
276
|
+
except ValueError:
|
|
277
|
+
return {"message": "Stream parameters could not be updated"}
|
|
278
|
+
|
|
279
|
+
#######################
|
|
280
|
+
### PYNM ABOUT INFO ###
|
|
281
|
+
#######################
|
|
282
|
+
|
|
283
|
+
@self.get("/api/app-info")
|
|
284
|
+
# TODO: fix this function
|
|
285
|
+
async def get_app_info():
|
|
286
|
+
metadata = importlib.metadata.metadata("py_neuromodulation")
|
|
287
|
+
url_list = metadata.get_all("Project-URL")
|
|
288
|
+
urls = (
|
|
289
|
+
{url.split(",")[0]: url.split(",")[1] for url in url_list}
|
|
290
|
+
if url_list
|
|
291
|
+
else {}
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
classifier_list = metadata.get_all("Classifier")
|
|
295
|
+
classifiers = (
|
|
296
|
+
{
|
|
297
|
+
item[: item.find("::") - 1]: item[item.find("::") + 3 :]
|
|
298
|
+
for item in classifier_list
|
|
299
|
+
}
|
|
300
|
+
if classifier_list
|
|
301
|
+
else {}
|
|
302
|
+
)
|
|
303
|
+
if "License" in classifiers:
|
|
304
|
+
classifiers["License"] = classifiers["License"].split("::")[1]
|
|
305
|
+
|
|
306
|
+
return {
|
|
307
|
+
"version": metadata.get("Version", ""),
|
|
308
|
+
"website": urls.get("Homepage", ""),
|
|
309
|
+
"authors": [metadata.get("Author-email", "")],
|
|
310
|
+
"maintainers": [metadata.get("Maintainer", "")],
|
|
311
|
+
"repository": urls.get("Repository", ""),
|
|
312
|
+
"documentation": urls.get("Documentation", ""),
|
|
313
|
+
"license": classifiers["License"],
|
|
314
|
+
# "launchMode": "debug" if app.debug else "release",
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
##############################
|
|
318
|
+
### FILE BROWSER ENDPOINTS ###
|
|
319
|
+
##############################
|
|
320
|
+
# Get home directory for the current user
|
|
321
|
+
@self.get("/api/home_directory")
|
|
322
|
+
async def home_directory():
|
|
323
|
+
try:
|
|
324
|
+
home_dir = str(Path.home())
|
|
325
|
+
return {"home_directory": home_dir}
|
|
326
|
+
except Exception as e:
|
|
327
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
328
|
+
|
|
329
|
+
# Get PYNM_DIR
|
|
330
|
+
@self.get("/api/pynm_dir")
|
|
331
|
+
async def get_pynm_dir():
|
|
332
|
+
try:
|
|
333
|
+
return {"pynm_dir": PYNM_DIR}
|
|
334
|
+
except Exception as e:
|
|
335
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
336
|
+
|
|
337
|
+
# Get list of available drives in Windows systems
|
|
338
|
+
@self.get("/api/drives")
|
|
339
|
+
async def list_drives():
|
|
340
|
+
if os.name == "nt":
|
|
341
|
+
import string
|
|
342
|
+
|
|
343
|
+
drives = []
|
|
344
|
+
for letter in string.ascii_uppercase:
|
|
345
|
+
if Path(f"{letter}:").exists():
|
|
346
|
+
drives.append(f"{letter}:")
|
|
347
|
+
|
|
348
|
+
return {"drives": drives}
|
|
349
|
+
else:
|
|
350
|
+
return {"drives": ["/"]} # Unix-like systems have a single root
|
|
351
|
+
|
|
352
|
+
# Get list of files and directories in a directory
|
|
353
|
+
@self.get("/api/files")
|
|
354
|
+
async def list_files(
|
|
355
|
+
path: str = Query(default="", description="Directory path to list"),
|
|
356
|
+
allowed_extensions: str = Query(
|
|
357
|
+
default=",".join(ALLOWED_EXTENSIONS),
|
|
358
|
+
description="Comma-separated list of allowed file extensions",
|
|
359
|
+
),
|
|
360
|
+
show_hidden: bool = Query(
|
|
361
|
+
default=False,
|
|
362
|
+
description="Whether to show hidden files and directories",
|
|
363
|
+
),
|
|
364
|
+
) -> list[FileInfo]:
|
|
365
|
+
try:
|
|
366
|
+
if not path:
|
|
367
|
+
path = str(Path.home())
|
|
368
|
+
|
|
369
|
+
if not Path(path).is_dir():
|
|
370
|
+
raise FileNotFoundError("The specified path is not a directory")
|
|
371
|
+
|
|
372
|
+
allowed_ext = allowed_extensions.split(",")
|
|
373
|
+
|
|
374
|
+
files = []
|
|
375
|
+
for entry in Path(path).iterdir():
|
|
376
|
+
# Skip hidden files/directories if show_hidden is False
|
|
377
|
+
if not show_hidden and is_hidden(entry):
|
|
378
|
+
continue
|
|
379
|
+
|
|
380
|
+
if entry.is_file() and not any(
|
|
381
|
+
entry.name.lower().endswith(ext.lower()) for ext in allowed_ext
|
|
382
|
+
):
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
stats = entry.stat()
|
|
386
|
+
files.append(
|
|
387
|
+
FileInfo(
|
|
388
|
+
name=entry.name,
|
|
389
|
+
path=str(entry),
|
|
390
|
+
dir=str(entry.parent),
|
|
391
|
+
is_directory=entry.is_dir(),
|
|
392
|
+
size=stats.st_size if not entry.is_dir() else 0,
|
|
393
|
+
created_at=datetime.fromtimestamp(stats.st_birthtime),
|
|
394
|
+
modified_at=datetime.fromtimestamp(stats.st_mtime),
|
|
395
|
+
)
|
|
396
|
+
)
|
|
397
|
+
return files
|
|
398
|
+
except FileNotFoundError:
|
|
399
|
+
raise HTTPException(status_code=404, detail="Directory not found")
|
|
400
|
+
except PermissionError:
|
|
401
|
+
raise HTTPException(status_code=403, detail="Permission denied")
|
|
402
|
+
except Exception as e:
|
|
403
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
404
|
+
|
|
405
|
+
@self.get("/api/quick-access")
|
|
406
|
+
def quick_access():
|
|
407
|
+
return get_quick_access()
|
|
408
|
+
|
|
409
|
+
###########################
|
|
410
|
+
### WEBSOCKET ENDPOINTS ###
|
|
411
|
+
###########################
|
|
412
|
+
@self.websocket("/ws")
|
|
413
|
+
async def websocket_endpoint(websocket: WebSocket):
|
|
414
|
+
await self.websocket_manager.connect(websocket)
|
|
415
|
+
while True:
|
|
416
|
+
try:
|
|
417
|
+
await websocket.receive_text()
|
|
418
|
+
except Exception:
|
|
419
|
+
break
|