py-neuromodulation 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
- py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
- py_neuromodulation/__init__.py +12 -4
- py_neuromodulation/analysis/RMAP.py +3 -3
- py_neuromodulation/analysis/decode.py +55 -2
- py_neuromodulation/analysis/feature_reader.py +1 -0
- py_neuromodulation/analysis/stats.py +3 -3
- py_neuromodulation/default_settings.yaml +24 -17
- py_neuromodulation/features/bandpower.py +65 -23
- py_neuromodulation/features/bursts.py +9 -8
- py_neuromodulation/features/coherence.py +7 -4
- py_neuromodulation/features/feature_processor.py +4 -4
- py_neuromodulation/features/fooof.py +7 -6
- py_neuromodulation/features/mne_connectivity.py +25 -3
- py_neuromodulation/features/oscillatory.py +5 -4
- py_neuromodulation/features/sharpwaves.py +21 -0
- py_neuromodulation/filter/kalman_filter.py +17 -6
- py_neuromodulation/gui/__init__.py +3 -0
- py_neuromodulation/gui/backend/app_backend.py +419 -0
- py_neuromodulation/gui/backend/app_manager.py +345 -0
- py_neuromodulation/gui/backend/app_pynm.py +244 -0
- py_neuromodulation/gui/backend/app_socket.py +95 -0
- py_neuromodulation/gui/backend/app_utils.py +306 -0
- py_neuromodulation/gui/backend/app_window.py +202 -0
- py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
- py_neuromodulation/gui/frontend/assets/index-NbJiOU5a.js +300133 -0
- py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
- py_neuromodulation/gui/frontend/charite.svg +16 -0
- py_neuromodulation/gui/frontend/index.html +14 -0
- py_neuromodulation/gui/window_api.py +115 -0
- py_neuromodulation/lsl_api.cfg +3 -0
- py_neuromodulation/processing/data_preprocessor.py +9 -2
- py_neuromodulation/processing/filter_preprocessing.py +43 -27
- py_neuromodulation/processing/normalization.py +32 -17
- py_neuromodulation/processing/projection.py +2 -2
- py_neuromodulation/processing/resample.py +6 -2
- py_neuromodulation/run_gui.py +36 -0
- py_neuromodulation/stream/__init__.py +7 -1
- py_neuromodulation/stream/backend_interface.py +47 -0
- py_neuromodulation/stream/data_processor.py +24 -3
- py_neuromodulation/stream/mnelsl_player.py +121 -21
- py_neuromodulation/stream/mnelsl_stream.py +9 -17
- py_neuromodulation/stream/settings.py +80 -34
- py_neuromodulation/stream/stream.py +82 -62
- py_neuromodulation/utils/channels.py +1 -1
- py_neuromodulation/utils/file_writer.py +110 -0
- py_neuromodulation/utils/io.py +46 -5
- py_neuromodulation/utils/perf.py +156 -0
- py_neuromodulation/utils/pydantic_extensions.py +322 -0
- py_neuromodulation/utils/types.py +33 -107
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/METADATA +18 -4
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/RECORD +55 -35
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/WHEEL +1 -1
- py_neuromodulation-0.1.0.dist-info/entry_points.txt +2 -0
- {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
import multiprocessing as mp
|
|
2
|
+
import threading
|
|
3
|
+
import os
|
|
4
|
+
import signal
|
|
5
|
+
import time
|
|
6
|
+
import platform
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
from .app_utils import force_terminate_process, create_logger, ansi_color, ansi_reset
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from multiprocessing.synchronize import Event
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Shared memory configuration
|
|
19
|
+
ARRAY_SIZE = 1000 # Adjust based on your needs
|
|
20
|
+
|
|
21
|
+
SERVER_PORT = 50001
|
|
22
|
+
DEV_SERVER_PORT = 54321
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def run_vite(
|
|
26
|
+
shutdown_event: "Event",
|
|
27
|
+
debug: bool = False,
|
|
28
|
+
dev_port: int = DEV_SERVER_PORT,
|
|
29
|
+
backend_port: int = SERVER_PORT,
|
|
30
|
+
dev_env_vars: dict = {},
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Run Vite in a separate shell"""
|
|
33
|
+
import subprocess
|
|
34
|
+
|
|
35
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
36
|
+
|
|
37
|
+
logger = create_logger(
|
|
38
|
+
"Vite",
|
|
39
|
+
"magenta",
|
|
40
|
+
logging.DEBUG if debug else logging.INFO,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
os.environ["VITE_BACKEND_PORT"] = str(backend_port)
|
|
44
|
+
|
|
45
|
+
for key, value in dev_env_vars.items():
|
|
46
|
+
os.environ["VITE_" + key] = value
|
|
47
|
+
|
|
48
|
+
def output_reader(shutdown_event: "Event", process: subprocess.Popen):
|
|
49
|
+
logger.debug("Initialized output stream")
|
|
50
|
+
color = ansi_color(color="magenta", bright=True, styles=["BOLD"])
|
|
51
|
+
|
|
52
|
+
def read_stream(stream, stream_name):
|
|
53
|
+
for line in iter(stream.readline, ""):
|
|
54
|
+
if shutdown_event.is_set():
|
|
55
|
+
break
|
|
56
|
+
logger.info(f"{color}[{stream_name}]{ansi_reset} {line.strip()}")
|
|
57
|
+
|
|
58
|
+
stdout_thread = threading.Thread(
|
|
59
|
+
target=read_stream, args=(process.stdout, "stdout")
|
|
60
|
+
)
|
|
61
|
+
stderr_thread = threading.Thread(
|
|
62
|
+
target=read_stream, args=(process.stderr, "stderr")
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
stdout_thread.start()
|
|
66
|
+
stderr_thread.start()
|
|
67
|
+
|
|
68
|
+
shutdown_event.wait()
|
|
69
|
+
|
|
70
|
+
stdout_thread.join(timeout=2)
|
|
71
|
+
stderr_thread.join(timeout=2)
|
|
72
|
+
|
|
73
|
+
logger.debug("Output stream closed")
|
|
74
|
+
|
|
75
|
+
# Handle different operating systems
|
|
76
|
+
shutdown_signal = signal.CTRL_BREAK_EVENT if os.name == "nt" else signal.SIGINT
|
|
77
|
+
subprocess_flags = subprocess.CREATE_NEW_PROCESS_GROUP if os.name == "nt" else 0
|
|
78
|
+
|
|
79
|
+
process = subprocess.Popen(
|
|
80
|
+
"bun run dev --port " + str(dev_port),
|
|
81
|
+
cwd="gui_dev",
|
|
82
|
+
stdout=subprocess.PIPE,
|
|
83
|
+
stderr=subprocess.PIPE,
|
|
84
|
+
text=True,
|
|
85
|
+
creationflags=subprocess_flags,
|
|
86
|
+
shell=True,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
logging_thread = threading.Thread(
|
|
90
|
+
target=output_reader,
|
|
91
|
+
args=(shutdown_event, process),
|
|
92
|
+
)
|
|
93
|
+
logging_thread.start()
|
|
94
|
+
|
|
95
|
+
shutdown_event.wait() # Wait for shutdown
|
|
96
|
+
|
|
97
|
+
logger.debug("Terminating Vite server...")
|
|
98
|
+
process.send_signal(shutdown_signal)
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
process.wait(timeout=3)
|
|
102
|
+
except subprocess.TimeoutExpired:
|
|
103
|
+
logger.debug("Timeout expired, forcing termination...")
|
|
104
|
+
process.kill()
|
|
105
|
+
|
|
106
|
+
logging_thread.join(timeout=3)
|
|
107
|
+
if logging_thread.is_alive():
|
|
108
|
+
logger.debug("Logging thread did not finish in time")
|
|
109
|
+
|
|
110
|
+
logger.info("Development server stopped")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def run_uvicorn(
|
|
114
|
+
debug: bool = False, reload=False, server_port: int = SERVER_PORT
|
|
115
|
+
) -> None:
|
|
116
|
+
from uvicorn.server import Server
|
|
117
|
+
from uvicorn.config import LOGGING_CONFIG, Config
|
|
118
|
+
|
|
119
|
+
# Configure logging
|
|
120
|
+
color = ansi_color(color="green", bright=True, styles=["BOLD"])
|
|
121
|
+
log_level = "DEBUG" if debug else "INFO"
|
|
122
|
+
log_config = LOGGING_CONFIG.copy()
|
|
123
|
+
log_config["loggers"]["uvicorn"]["level"] = log_level
|
|
124
|
+
log_config["loggers"]["uvicorn.error"]["level"] = log_level
|
|
125
|
+
log_config["loggers"]["uvicorn.access"]["level"] = log_level
|
|
126
|
+
log_config["formatters"]["default"]["fmt"] = (
|
|
127
|
+
f"{color}[FastAPI %(levelname)s (%(asctime)s)]:{ansi_reset} %(message)s"
|
|
128
|
+
)
|
|
129
|
+
log_config["formatters"]["default"]["datefmt"] = "%H:%M:%S"
|
|
130
|
+
log_config["formatters"]["access"]["fmt"] = (
|
|
131
|
+
f"{color}[FastAPI access (%(asctime)s)]:{ansi_reset} %(message)s"
|
|
132
|
+
)
|
|
133
|
+
log_config["formatters"]["access"]["datefmt"] = "%H:%M:%S"
|
|
134
|
+
|
|
135
|
+
config = Config(
|
|
136
|
+
app="py_neuromodulation.gui.backend.app_backend:PyNMBackend",
|
|
137
|
+
host="localhost",
|
|
138
|
+
reload=reload,
|
|
139
|
+
factory=True,
|
|
140
|
+
port=server_port,
|
|
141
|
+
log_level="debug" if debug else "info",
|
|
142
|
+
log_config=log_config,
|
|
143
|
+
http="httptools",
|
|
144
|
+
ws_ping_interval=None,
|
|
145
|
+
ws_ping_timeout=None,
|
|
146
|
+
ws_max_size=1024 * 1024 * 1024, # 1GB
|
|
147
|
+
loop="asyncio" if platform.system() == "Windows" else "uvloop",
|
|
148
|
+
lifespan="off",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
server = Server(config=config)
|
|
152
|
+
|
|
153
|
+
if reload:
|
|
154
|
+
from uvicorn.supervisors import ChangeReload
|
|
155
|
+
from uvicorn._subprocess import get_subprocess
|
|
156
|
+
|
|
157
|
+
# Overload the restart method of uvicorn so that is does not kill all of our processes
|
|
158
|
+
# IMPORTANT: This is a hack and prevents shutdown events from triggering when the reloader is used
|
|
159
|
+
class CustomReloader(ChangeReload):
|
|
160
|
+
def restart(self) -> None:
|
|
161
|
+
self.process.terminate() # Use terminate instead of os.kill
|
|
162
|
+
self.process.join()
|
|
163
|
+
self.process = get_subprocess(
|
|
164
|
+
config=self.config, target=self.target, sockets=self.sockets
|
|
165
|
+
)
|
|
166
|
+
self.process.start()
|
|
167
|
+
|
|
168
|
+
sock = config.bind_socket()
|
|
169
|
+
server = CustomReloader(config, target=server.run, sockets=[sock])
|
|
170
|
+
|
|
171
|
+
server.run()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def run_backend(
|
|
175
|
+
shutdown_event: "Event",
|
|
176
|
+
dev: bool = True,
|
|
177
|
+
debug: bool = False,
|
|
178
|
+
reload: bool = True,
|
|
179
|
+
server_port: int = SERVER_PORT,
|
|
180
|
+
dev_port: int = DEV_SERVER_PORT,
|
|
181
|
+
) -> None:
|
|
182
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
183
|
+
|
|
184
|
+
# Pass create_backend parameters through environment variables
|
|
185
|
+
os.environ["PYNM_DEBUG"] = str(debug)
|
|
186
|
+
os.environ["PYNM_DEV"] = str(dev)
|
|
187
|
+
os.environ["PYNM_DEV_PORT"] = str(dev_port)
|
|
188
|
+
|
|
189
|
+
server_process = mp.Process(
|
|
190
|
+
target=run_uvicorn,
|
|
191
|
+
kwargs={"debug": debug, "reload": reload, "server_port": server_port},
|
|
192
|
+
name="Server",
|
|
193
|
+
)
|
|
194
|
+
server_process.start()
|
|
195
|
+
shutdown_event.wait()
|
|
196
|
+
server_process.join()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class AppManager:
|
|
200
|
+
LAUNCH_FLAG = "PYNM_RUNNING"
|
|
201
|
+
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
debug: bool = False,
|
|
205
|
+
dev: bool = True,
|
|
206
|
+
run_in_webview=False,
|
|
207
|
+
server_port=SERVER_PORT,
|
|
208
|
+
dev_port=DEV_SERVER_PORT,
|
|
209
|
+
dev_env_vars: dict = {},
|
|
210
|
+
) -> None:
|
|
211
|
+
"""_summary_
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
debug (bool, optional): If True, run the app in debug mode, which sets logging level to debug,
|
|
215
|
+
and starts uvicorn, FastAPI, and Vite in debug mode. Defaults to False.
|
|
216
|
+
dev (bool, optional): If True, run the app in development mode, which enables hot
|
|
217
|
+
reloading and runs the frontend in Vite server. If False, run the app in production mode,
|
|
218
|
+
which runs the frontend from the static files in the `frontend` directory. Defaults to True.
|
|
219
|
+
run_in_webview (bool, optional): If True, open a PyWebView window to display the app. Defaults to False.
|
|
220
|
+
"""
|
|
221
|
+
self.debug = debug
|
|
222
|
+
self.dev = dev
|
|
223
|
+
self.run_in_webview = run_in_webview
|
|
224
|
+
self.server_port = server_port
|
|
225
|
+
self.dev_port = dev_port
|
|
226
|
+
self.dev_env_vars = dev_env_vars
|
|
227
|
+
|
|
228
|
+
self._reset()
|
|
229
|
+
# Prevent launching multiple instances of the app due to multiprocessing
|
|
230
|
+
# This allows the absence of a main guard in the main script
|
|
231
|
+
self.is_child_process = os.environ.get(self.LAUNCH_FLAG) == "TRUE"
|
|
232
|
+
os.environ[self.LAUNCH_FLAG] = "TRUE"
|
|
233
|
+
|
|
234
|
+
self.logger = create_logger(
|
|
235
|
+
"PyNM",
|
|
236
|
+
"yellow",
|
|
237
|
+
logging.DEBUG if self.debug else logging.INFO,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def _reset(self) -> None:
|
|
241
|
+
"""Reset the AppManager to its initial state."""
|
|
242
|
+
# Flags to track the state of the application
|
|
243
|
+
self.shutdown_complete = False
|
|
244
|
+
self.shutdown_started = False
|
|
245
|
+
|
|
246
|
+
# Store background tasks
|
|
247
|
+
self.tasks: dict[str, mp.Process] = {}
|
|
248
|
+
|
|
249
|
+
# Events for multiprocessing synchronization
|
|
250
|
+
self.shutdown_event: Event = mp.Event()
|
|
251
|
+
|
|
252
|
+
def _terminate_app(self) -> None:
|
|
253
|
+
if self.shutdown_started:
|
|
254
|
+
self.logger.info("Termination already in progress. Skipping.")
|
|
255
|
+
return
|
|
256
|
+
|
|
257
|
+
self.shutdown_started = True
|
|
258
|
+
|
|
259
|
+
timeout = 10
|
|
260
|
+
deadline = time.time() + timeout
|
|
261
|
+
|
|
262
|
+
self.logger.info("App closed, cleaning up background tasks...")
|
|
263
|
+
self.shutdown_event.set()
|
|
264
|
+
|
|
265
|
+
for process_name, process in self.tasks.items():
|
|
266
|
+
remaining_time = max(deadline - time.time(), 0)
|
|
267
|
+
process.join(timeout=remaining_time)
|
|
268
|
+
|
|
269
|
+
if process.is_alive():
|
|
270
|
+
self.logger.info(
|
|
271
|
+
f"{process_name} did not terminate in time. Forcing termination..."
|
|
272
|
+
)
|
|
273
|
+
force_terminate_process(process, process_name, logger=self.logger)
|
|
274
|
+
|
|
275
|
+
self.logger.info(f"Process {process.name} terminated.")
|
|
276
|
+
|
|
277
|
+
self.shutdown_complete = True
|
|
278
|
+
self.shutdown_event.clear()
|
|
279
|
+
self.logger.info("All background tasks succesfully terminated.")
|
|
280
|
+
|
|
281
|
+
def _sigint_handler(self, signum, frame):
|
|
282
|
+
if not self.shutdown_started:
|
|
283
|
+
self.logger.info("Received SIGINT. Initiating graceful shutdown...")
|
|
284
|
+
self._terminate_app()
|
|
285
|
+
else:
|
|
286
|
+
self.logger.info("SIGINT received again. Ignoring...")
|
|
287
|
+
|
|
288
|
+
def launch(self) -> None:
|
|
289
|
+
if self.is_child_process:
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
# Handle keyboard interrupt signals
|
|
293
|
+
signal.signal(signal.SIGINT, self._sigint_handler)
|
|
294
|
+
# signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
295
|
+
|
|
296
|
+
# Create and start the subprocesses
|
|
297
|
+
if self.dev:
|
|
298
|
+
self.logger.info("Starting Vite server...")
|
|
299
|
+
self.tasks["vite"] = mp.Process(
|
|
300
|
+
target=run_vite,
|
|
301
|
+
kwargs={
|
|
302
|
+
"shutdown_event": self.shutdown_event,
|
|
303
|
+
"debug": self.debug,
|
|
304
|
+
"dev_port": self.dev_port,
|
|
305
|
+
"backend_port": self.server_port,
|
|
306
|
+
"dev_env_vars": self.dev_env_vars,
|
|
307
|
+
},
|
|
308
|
+
name="Vite",
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
self.logger.info("Starting backend server...")
|
|
312
|
+
self.tasks["backend"] = mp.Process(
|
|
313
|
+
target=run_backend,
|
|
314
|
+
kwargs={
|
|
315
|
+
"shutdown_event": self.shutdown_event,
|
|
316
|
+
"debug": self.debug,
|
|
317
|
+
"reload": self.dev,
|
|
318
|
+
"dev": self.dev,
|
|
319
|
+
"server_port": self.server_port,
|
|
320
|
+
"dev_port": self.dev_port,
|
|
321
|
+
},
|
|
322
|
+
name="Backend",
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
for process in self.tasks.values():
|
|
326
|
+
process.start()
|
|
327
|
+
|
|
328
|
+
if self.run_in_webview:
|
|
329
|
+
from .app_window import WebViewWindow
|
|
330
|
+
|
|
331
|
+
self.logger.info("Starting PyWebView window...")
|
|
332
|
+
window = WebViewWindow(debug=self.debug) # Must be called from main thread
|
|
333
|
+
window.register_event_handler("closed", self._terminate_app)
|
|
334
|
+
window.start() # Start the window, this will block until the window is closed
|
|
335
|
+
else:
|
|
336
|
+
try:
|
|
337
|
+
while not self.shutdown_complete:
|
|
338
|
+
time.sleep(0.1)
|
|
339
|
+
except KeyboardInterrupt:
|
|
340
|
+
pass # The SIGINT handler will take care of termination
|
|
341
|
+
|
|
342
|
+
if not self.shutdown_complete:
|
|
343
|
+
self._terminate_app()
|
|
344
|
+
|
|
345
|
+
self.logger.info("All processes cleaned up. Exiting...")
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from threading import Thread
|
|
4
|
+
import time
|
|
5
|
+
import asyncio
|
|
6
|
+
import multiprocessing as mp
|
|
7
|
+
from queue import Empty
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from py_neuromodulation.stream import Stream, NMSettings
|
|
10
|
+
from py_neuromodulation.analysis.decode import RealTimeDecoder
|
|
11
|
+
from py_neuromodulation.utils import set_channels
|
|
12
|
+
from py_neuromodulation.utils.io import read_mne_data
|
|
13
|
+
from py_neuromodulation.utils.types import _PathLike
|
|
14
|
+
from py_neuromodulation import logger
|
|
15
|
+
from py_neuromodulation.gui.backend.app_socket import WebsocketManager
|
|
16
|
+
from py_neuromodulation.stream.backend_interface import StreamBackendInterface
|
|
17
|
+
from py_neuromodulation import logger
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PyNMState:
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
log_queue_size: bool = False,
|
|
24
|
+
) -> None:
|
|
25
|
+
self.log_queue_size = log_queue_size
|
|
26
|
+
self.lsl_stream_name: str = ""
|
|
27
|
+
self.is_stream_lsl: bool = False
|
|
28
|
+
self.experiment_name: str = "PyNM_Experiment" # set by set_stream_params
|
|
29
|
+
self.out_dir: _PathLike = str(
|
|
30
|
+
Path.home() / "PyNM" / self.experiment_name
|
|
31
|
+
) # set by set_stream_params
|
|
32
|
+
self.decoding_model_path: _PathLike | None = None
|
|
33
|
+
self.decoder: RealTimeDecoder | None = None
|
|
34
|
+
|
|
35
|
+
self.backend_interface: StreamBackendInterface | None = None
|
|
36
|
+
self.websocket_manager: WebsocketManager | None = None
|
|
37
|
+
|
|
38
|
+
# Note: sfreq and data are required for stream init
|
|
39
|
+
self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1]))
|
|
40
|
+
|
|
41
|
+
self.feature_queue = mp.Queue()
|
|
42
|
+
self.rawdata_queue = mp.Queue()
|
|
43
|
+
self.control_queue = mp.Queue()
|
|
44
|
+
self.stop_event = asyncio.Event()
|
|
45
|
+
|
|
46
|
+
self.messages_sent = 0
|
|
47
|
+
|
|
48
|
+
def start_run_function(
|
|
49
|
+
self,
|
|
50
|
+
websocket_manager: WebsocketManager | None = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
|
|
53
|
+
self.websocket_manager = websocket_manager
|
|
54
|
+
|
|
55
|
+
# Create decoder
|
|
56
|
+
if self.decoding_model_path is not None and self.decoding_model_path != "None":
|
|
57
|
+
if os.path.exists(self.decoding_model_path):
|
|
58
|
+
self.decoder = RealTimeDecoder(self.decoding_model_path)
|
|
59
|
+
else:
|
|
60
|
+
logger.debug("Passed decoding model path does't exist")
|
|
61
|
+
|
|
62
|
+
# Initialize the backend interface if not already done
|
|
63
|
+
if not self.backend_interface:
|
|
64
|
+
self.backend_interface = StreamBackendInterface(
|
|
65
|
+
self.feature_queue, self.rawdata_queue, self.control_queue
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# The run_func_thread is terminated through the stream_handling_queue
|
|
69
|
+
# which initiates to break the data generator and save the features
|
|
70
|
+
stream_process = mp.Process(
|
|
71
|
+
target=self.stream.run,
|
|
72
|
+
kwargs={
|
|
73
|
+
"out_dir": "" if self.out_dir == "default" else self.out_dir,
|
|
74
|
+
"experiment_name": self.experiment_name,
|
|
75
|
+
"is_stream_lsl": self.is_stream_lsl,
|
|
76
|
+
"stream_lsl_name": self.lsl_stream_name,
|
|
77
|
+
"simulate_real_time": True,
|
|
78
|
+
"decoder": self.decoder,
|
|
79
|
+
"backend_interface": self.backend_interface,
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
stream_process.start()
|
|
84
|
+
|
|
85
|
+
# Start websocket sender process
|
|
86
|
+
|
|
87
|
+
if self.websocket_manager:
|
|
88
|
+
# TONI: Instead of having this function be not async and send the
|
|
89
|
+
# _process_queue function to the Uvicorn async loop, we could
|
|
90
|
+
# have this entire "start_run_function" function be async as well
|
|
91
|
+
|
|
92
|
+
# Get the current event loop and run the queue processor
|
|
93
|
+
loop = asyncio.get_running_loop()
|
|
94
|
+
queue_task = loop.create_task(self._process_queue())
|
|
95
|
+
|
|
96
|
+
# Store task reference for cleanup
|
|
97
|
+
self._queue_task = queue_task
|
|
98
|
+
|
|
99
|
+
# Store processes for cleanup
|
|
100
|
+
self.stream_process = stream_process
|
|
101
|
+
|
|
102
|
+
def stop_run_function(self) -> None:
|
|
103
|
+
"""Stop the stream processing"""
|
|
104
|
+
if self.backend_interface:
|
|
105
|
+
self.backend_interface.send_command("stop")
|
|
106
|
+
self.stop_event.set()
|
|
107
|
+
|
|
108
|
+
def setup_lsl_stream(
|
|
109
|
+
self,
|
|
110
|
+
lsl_stream_name: str = "",
|
|
111
|
+
line_noise: float | None = None,
|
|
112
|
+
):
|
|
113
|
+
from mne_lsl.lsl import resolve_streams
|
|
114
|
+
|
|
115
|
+
logger.info("resolving streams")
|
|
116
|
+
lsl_streams = resolve_streams()
|
|
117
|
+
|
|
118
|
+
for stream in lsl_streams:
|
|
119
|
+
if stream.name == lsl_stream_name:
|
|
120
|
+
logger.info(f"found stream {lsl_stream_name}")
|
|
121
|
+
|
|
122
|
+
ch_names = stream.get_channel_names()
|
|
123
|
+
if ch_names is None:
|
|
124
|
+
ch_names = ["ch" + str(i) for i in range(stream.n_channels)]
|
|
125
|
+
logger.info(f"channel names: {ch_names}")
|
|
126
|
+
|
|
127
|
+
ch_types = stream.get_channel_types()
|
|
128
|
+
if ch_types is None:
|
|
129
|
+
ch_types = ["eeg" for i in range(stream.n_channels)]
|
|
130
|
+
|
|
131
|
+
logger.info(f"channel types: {ch_types}")
|
|
132
|
+
|
|
133
|
+
info_ = stream.get_channel_info()
|
|
134
|
+
logger.info(f"channel info: {info_}")
|
|
135
|
+
|
|
136
|
+
channels = set_channels(
|
|
137
|
+
ch_names=ch_names,
|
|
138
|
+
ch_types=ch_types,
|
|
139
|
+
used_types=["eeg", "ecog", "dbs", "seeg"],
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# set all used column to 0
|
|
143
|
+
#channels.loc[:, "used"] = 0
|
|
144
|
+
|
|
145
|
+
logger.info(channels)
|
|
146
|
+
sfreq = stream.sfreq
|
|
147
|
+
|
|
148
|
+
self.stream: Stream = Stream(
|
|
149
|
+
sfreq=sfreq,
|
|
150
|
+
line_noise=line_noise,
|
|
151
|
+
channels=channels,
|
|
152
|
+
)
|
|
153
|
+
logger.info("stream setup")
|
|
154
|
+
logger.info("settings setup")
|
|
155
|
+
|
|
156
|
+
self.lsl_stream_name = lsl_stream_name
|
|
157
|
+
self.is_stream_lsl = True
|
|
158
|
+
break
|
|
159
|
+
else:
|
|
160
|
+
logger.error(f"Stream {lsl_stream_name} not found")
|
|
161
|
+
self.is_stream_lsl = False
|
|
162
|
+
self.is_stream_lsl = ""
|
|
163
|
+
raise ValueError(f"Stream {lsl_stream_name} not found")
|
|
164
|
+
|
|
165
|
+
def setup_offline_stream(
|
|
166
|
+
self,
|
|
167
|
+
file_path: str,
|
|
168
|
+
line_noise: float,
|
|
169
|
+
):
|
|
170
|
+
data, sfreq, ch_names, ch_types, bads = read_mne_data(file_path)
|
|
171
|
+
|
|
172
|
+
channels = set_channels(
|
|
173
|
+
ch_names=ch_names,
|
|
174
|
+
ch_types=ch_types,
|
|
175
|
+
bads=bads,
|
|
176
|
+
reference=None,
|
|
177
|
+
used_types=["eeg", "ecog", "dbs", "seeg"],
|
|
178
|
+
target_keywords=None,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
self.stream: Stream = Stream(
|
|
182
|
+
sfreq=sfreq,
|
|
183
|
+
data=data,
|
|
184
|
+
channels=channels,
|
|
185
|
+
line_noise=line_noise,
|
|
186
|
+
)
|
|
187
|
+
self.is_stream_lsl = False
|
|
188
|
+
self.lsl_stream_name = ""
|
|
189
|
+
|
|
190
|
+
# Async function that will continuously run in the Uvicorn async loop
|
|
191
|
+
# and handle sending data through the websocket manager
|
|
192
|
+
async def _process_queue(self):
|
|
193
|
+
last_queue_check = time.time()
|
|
194
|
+
|
|
195
|
+
while not self.stop_event.is_set():
|
|
196
|
+
# Use asyncio.gather to process both queues concurrently
|
|
197
|
+
tasks = []
|
|
198
|
+
current_time = time.time()
|
|
199
|
+
|
|
200
|
+
# Check feature queue
|
|
201
|
+
while not self.feature_queue.empty():
|
|
202
|
+
try:
|
|
203
|
+
data = self.feature_queue.get_nowait()
|
|
204
|
+
tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore
|
|
205
|
+
self.messages_sent += 1
|
|
206
|
+
except Empty:
|
|
207
|
+
break
|
|
208
|
+
|
|
209
|
+
# Check raw data queue
|
|
210
|
+
while not self.rawdata_queue.empty():
|
|
211
|
+
try:
|
|
212
|
+
data = self.rawdata_queue.get_nowait()
|
|
213
|
+
self.messages_sent += 1
|
|
214
|
+
tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore
|
|
215
|
+
except Empty:
|
|
216
|
+
break
|
|
217
|
+
|
|
218
|
+
if tasks:
|
|
219
|
+
# Wait for all send operations to complete
|
|
220
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
221
|
+
else:
|
|
222
|
+
# Only sleep if we didn't process any messages
|
|
223
|
+
await asyncio.sleep(0.001)
|
|
224
|
+
|
|
225
|
+
# Log queue diagnostics every 5 seconds
|
|
226
|
+
if self.log_queue_size:
|
|
227
|
+
if current_time - last_queue_check > 5:
|
|
228
|
+
logger.info(
|
|
229
|
+
"\nQueue diagnostics:\n"
|
|
230
|
+
f"\tMessages send to websocket: {self.messages_sent}.\n"
|
|
231
|
+
)
|
|
232
|
+
try:
|
|
233
|
+
logger.info(
|
|
234
|
+
f"\tFeature queue size: ~{self.feature_queue.qsize()}\n"
|
|
235
|
+
f"\tRaw data queue size: ~{self.rawdata_queue.qsize()}"
|
|
236
|
+
)
|
|
237
|
+
except NotImplementedError:
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
last_queue_check = current_time
|
|
241
|
+
|
|
242
|
+
# Check if stream process is still alive
|
|
243
|
+
if not self.stream_process.is_alive():
|
|
244
|
+
break
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from fastapi import WebSocket
|
|
2
|
+
import logging
|
|
3
|
+
import cbor2
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class WebsocketManager:
|
|
8
|
+
"""
|
|
9
|
+
Manages WebSocket connections and messages.
|
|
10
|
+
Perhaps in the future it will handle multiple connections.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.active_connections: list[WebSocket] = []
|
|
15
|
+
self.logger = logging.getLogger("PyNM")
|
|
16
|
+
self.disconnected = []
|
|
17
|
+
self._queue_task = None
|
|
18
|
+
self._stop_event = None
|
|
19
|
+
self.loop = None
|
|
20
|
+
self.messages_sent = 0
|
|
21
|
+
self._last_diagnostic_time = time.time()
|
|
22
|
+
|
|
23
|
+
async def connect(self, websocket: WebSocket):
|
|
24
|
+
await websocket.accept()
|
|
25
|
+
self.active_connections.append(websocket)
|
|
26
|
+
client_address = websocket.client
|
|
27
|
+
if client_address:
|
|
28
|
+
self.logger.info(
|
|
29
|
+
f"Client connected with client ID: {client_address.port}:{client_address.port}"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def disconnect(self, websocket: WebSocket):
|
|
33
|
+
self.active_connections.remove(websocket)
|
|
34
|
+
client_address = websocket.client
|
|
35
|
+
if client_address:
|
|
36
|
+
self.logger.info(
|
|
37
|
+
f"Client {client_address.port}:{client_address.port} disconnected."
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Combine IP and port to create a unique client ID
|
|
41
|
+
async def send_cbor(self, object: dict):
|
|
42
|
+
if not self.active_connections:
|
|
43
|
+
self.logger.warning("No active connection to send message.")
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
start_time = time.time()
|
|
47
|
+
cbor_data = cbor2.dumps(object)
|
|
48
|
+
serialize_time = time.time() - start_time
|
|
49
|
+
|
|
50
|
+
if serialize_time > 0.1: # Log slow serializations
|
|
51
|
+
self.logger.warning(f"CBOR serialization took {serialize_time:.3f}s")
|
|
52
|
+
|
|
53
|
+
send_start = time.time()
|
|
54
|
+
for connection in self.active_connections:
|
|
55
|
+
try:
|
|
56
|
+
await connection.send_bytes(cbor_data)
|
|
57
|
+
except RuntimeError as e:
|
|
58
|
+
self.logger.error(f"Error sending CBOR message: {e}")
|
|
59
|
+
self.disconnected.append(connection)
|
|
60
|
+
|
|
61
|
+
send_time = time.time() - send_start
|
|
62
|
+
if send_time > 0.1: # Log slow sends
|
|
63
|
+
self.logger.warning(f"WebSocket send took {send_time:.3f}s")
|
|
64
|
+
|
|
65
|
+
self.messages_sent += 1
|
|
66
|
+
|
|
67
|
+
# Log diagnostics every 5 seconds
|
|
68
|
+
current_time = time.time()
|
|
69
|
+
if current_time - self._last_diagnostic_time > 5:
|
|
70
|
+
self.logger.info(f"Messages sent: {self.messages_sent}")
|
|
71
|
+
self._last_diagnostic_time = current_time
|
|
72
|
+
|
|
73
|
+
async def send_message(self, message: str | dict):
|
|
74
|
+
if not self.active_connections:
|
|
75
|
+
self.logger.warning("No active connection to send message.")
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
self.logger.info(
|
|
79
|
+
f"Sending message within app_socket: {message.keys() if type(message) is dict else message}"
|
|
80
|
+
)
|
|
81
|
+
for connection in self.active_connections:
|
|
82
|
+
try:
|
|
83
|
+
if type(message) is dict:
|
|
84
|
+
await connection.send_json(message)
|
|
85
|
+
elif type(message) is str:
|
|
86
|
+
await connection.send_text(message)
|
|
87
|
+
self.logger.info(f"Message sent to {connection.client}")
|
|
88
|
+
except RuntimeError as e:
|
|
89
|
+
self.logger.error(f"Error sending message: {e}.")
|
|
90
|
+
self.active_connections.remove(connection)
|
|
91
|
+
await connection.close()
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def is_connected(self):
|
|
95
|
+
return self.active_connections is not None
|