py-neuromodulation 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
  2. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
  3. py_neuromodulation/__init__.py +12 -4
  4. py_neuromodulation/analysis/RMAP.py +3 -3
  5. py_neuromodulation/analysis/decode.py +55 -2
  6. py_neuromodulation/analysis/feature_reader.py +1 -0
  7. py_neuromodulation/analysis/stats.py +3 -3
  8. py_neuromodulation/default_settings.yaml +24 -17
  9. py_neuromodulation/features/bandpower.py +65 -23
  10. py_neuromodulation/features/bursts.py +9 -8
  11. py_neuromodulation/features/coherence.py +7 -4
  12. py_neuromodulation/features/feature_processor.py +4 -4
  13. py_neuromodulation/features/fooof.py +7 -6
  14. py_neuromodulation/features/mne_connectivity.py +25 -3
  15. py_neuromodulation/features/oscillatory.py +5 -4
  16. py_neuromodulation/features/sharpwaves.py +21 -0
  17. py_neuromodulation/filter/kalman_filter.py +17 -6
  18. py_neuromodulation/gui/__init__.py +3 -0
  19. py_neuromodulation/gui/backend/app_backend.py +419 -0
  20. py_neuromodulation/gui/backend/app_manager.py +345 -0
  21. py_neuromodulation/gui/backend/app_pynm.py +244 -0
  22. py_neuromodulation/gui/backend/app_socket.py +95 -0
  23. py_neuromodulation/gui/backend/app_utils.py +306 -0
  24. py_neuromodulation/gui/backend/app_window.py +202 -0
  25. py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
  26. py_neuromodulation/gui/frontend/assets/index-NbJiOU5a.js +300133 -0
  27. py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
  28. py_neuromodulation/gui/frontend/charite.svg +16 -0
  29. py_neuromodulation/gui/frontend/index.html +14 -0
  30. py_neuromodulation/gui/window_api.py +115 -0
  31. py_neuromodulation/lsl_api.cfg +3 -0
  32. py_neuromodulation/processing/data_preprocessor.py +9 -2
  33. py_neuromodulation/processing/filter_preprocessing.py +43 -27
  34. py_neuromodulation/processing/normalization.py +32 -17
  35. py_neuromodulation/processing/projection.py +2 -2
  36. py_neuromodulation/processing/resample.py +6 -2
  37. py_neuromodulation/run_gui.py +36 -0
  38. py_neuromodulation/stream/__init__.py +7 -1
  39. py_neuromodulation/stream/backend_interface.py +47 -0
  40. py_neuromodulation/stream/data_processor.py +24 -3
  41. py_neuromodulation/stream/mnelsl_player.py +121 -21
  42. py_neuromodulation/stream/mnelsl_stream.py +9 -17
  43. py_neuromodulation/stream/settings.py +80 -34
  44. py_neuromodulation/stream/stream.py +82 -62
  45. py_neuromodulation/utils/channels.py +1 -1
  46. py_neuromodulation/utils/file_writer.py +110 -0
  47. py_neuromodulation/utils/io.py +46 -5
  48. py_neuromodulation/utils/perf.py +156 -0
  49. py_neuromodulation/utils/pydantic_extensions.py +322 -0
  50. py_neuromodulation/utils/types.py +33 -107
  51. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/METADATA +18 -4
  52. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/RECORD +55 -35
  53. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/WHEEL +1 -1
  54. py_neuromodulation-0.1.0.dist-info/entry_points.txt +2 -0
  55. {py_neuromodulation-0.0.7.dist-info → py_neuromodulation-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,345 @@
1
+ import multiprocessing as mp
2
+ import threading
3
+ import os
4
+ import signal
5
+ import time
6
+ import platform
7
+
8
+ import logging
9
+
10
+ from .app_utils import force_terminate_process, create_logger, ansi_color, ansi_reset
11
+
12
+ from typing import TYPE_CHECKING
13
+
14
+ if TYPE_CHECKING:
15
+ from multiprocessing.synchronize import Event
16
+
17
+
18
+ # Shared memory configuration
19
+ ARRAY_SIZE = 1000 # Adjust based on your needs
20
+
21
+ SERVER_PORT = 50001
22
+ DEV_SERVER_PORT = 54321
23
+
24
+
25
+ def run_vite(
26
+ shutdown_event: "Event",
27
+ debug: bool = False,
28
+ dev_port: int = DEV_SERVER_PORT,
29
+ backend_port: int = SERVER_PORT,
30
+ dev_env_vars: dict = {},
31
+ ) -> None:
32
+ """Run Vite in a separate shell"""
33
+ import subprocess
34
+
35
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
36
+
37
+ logger = create_logger(
38
+ "Vite",
39
+ "magenta",
40
+ logging.DEBUG if debug else logging.INFO,
41
+ )
42
+
43
+ os.environ["VITE_BACKEND_PORT"] = str(backend_port)
44
+
45
+ for key, value in dev_env_vars.items():
46
+ os.environ["VITE_" + key] = value
47
+
48
+ def output_reader(shutdown_event: "Event", process: subprocess.Popen):
49
+ logger.debug("Initialized output stream")
50
+ color = ansi_color(color="magenta", bright=True, styles=["BOLD"])
51
+
52
+ def read_stream(stream, stream_name):
53
+ for line in iter(stream.readline, ""):
54
+ if shutdown_event.is_set():
55
+ break
56
+ logger.info(f"{color}[{stream_name}]{ansi_reset} {line.strip()}")
57
+
58
+ stdout_thread = threading.Thread(
59
+ target=read_stream, args=(process.stdout, "stdout")
60
+ )
61
+ stderr_thread = threading.Thread(
62
+ target=read_stream, args=(process.stderr, "stderr")
63
+ )
64
+
65
+ stdout_thread.start()
66
+ stderr_thread.start()
67
+
68
+ shutdown_event.wait()
69
+
70
+ stdout_thread.join(timeout=2)
71
+ stderr_thread.join(timeout=2)
72
+
73
+ logger.debug("Output stream closed")
74
+
75
+ # Handle different operating systems
76
+ shutdown_signal = signal.CTRL_BREAK_EVENT if os.name == "nt" else signal.SIGINT
77
+ subprocess_flags = subprocess.CREATE_NEW_PROCESS_GROUP if os.name == "nt" else 0
78
+
79
+ process = subprocess.Popen(
80
+ "bun run dev --port " + str(dev_port),
81
+ cwd="gui_dev",
82
+ stdout=subprocess.PIPE,
83
+ stderr=subprocess.PIPE,
84
+ text=True,
85
+ creationflags=subprocess_flags,
86
+ shell=True,
87
+ )
88
+
89
+ logging_thread = threading.Thread(
90
+ target=output_reader,
91
+ args=(shutdown_event, process),
92
+ )
93
+ logging_thread.start()
94
+
95
+ shutdown_event.wait() # Wait for shutdown
96
+
97
+ logger.debug("Terminating Vite server...")
98
+ process.send_signal(shutdown_signal)
99
+
100
+ try:
101
+ process.wait(timeout=3)
102
+ except subprocess.TimeoutExpired:
103
+ logger.debug("Timeout expired, forcing termination...")
104
+ process.kill()
105
+
106
+ logging_thread.join(timeout=3)
107
+ if logging_thread.is_alive():
108
+ logger.debug("Logging thread did not finish in time")
109
+
110
+ logger.info("Development server stopped")
111
+
112
+
113
+ def run_uvicorn(
114
+ debug: bool = False, reload=False, server_port: int = SERVER_PORT
115
+ ) -> None:
116
+ from uvicorn.server import Server
117
+ from uvicorn.config import LOGGING_CONFIG, Config
118
+
119
+ # Configure logging
120
+ color = ansi_color(color="green", bright=True, styles=["BOLD"])
121
+ log_level = "DEBUG" if debug else "INFO"
122
+ log_config = LOGGING_CONFIG.copy()
123
+ log_config["loggers"]["uvicorn"]["level"] = log_level
124
+ log_config["loggers"]["uvicorn.error"]["level"] = log_level
125
+ log_config["loggers"]["uvicorn.access"]["level"] = log_level
126
+ log_config["formatters"]["default"]["fmt"] = (
127
+ f"{color}[FastAPI %(levelname)s (%(asctime)s)]:{ansi_reset} %(message)s"
128
+ )
129
+ log_config["formatters"]["default"]["datefmt"] = "%H:%M:%S"
130
+ log_config["formatters"]["access"]["fmt"] = (
131
+ f"{color}[FastAPI access (%(asctime)s)]:{ansi_reset} %(message)s"
132
+ )
133
+ log_config["formatters"]["access"]["datefmt"] = "%H:%M:%S"
134
+
135
+ config = Config(
136
+ app="py_neuromodulation.gui.backend.app_backend:PyNMBackend",
137
+ host="localhost",
138
+ reload=reload,
139
+ factory=True,
140
+ port=server_port,
141
+ log_level="debug" if debug else "info",
142
+ log_config=log_config,
143
+ http="httptools",
144
+ ws_ping_interval=None,
145
+ ws_ping_timeout=None,
146
+ ws_max_size=1024 * 1024 * 1024, # 1GB
147
+ loop="asyncio" if platform.system() == "Windows" else "uvloop",
148
+ lifespan="off",
149
+ )
150
+
151
+ server = Server(config=config)
152
+
153
+ if reload:
154
+ from uvicorn.supervisors import ChangeReload
155
+ from uvicorn._subprocess import get_subprocess
156
+
157
+ # Overload the restart method of uvicorn so that is does not kill all of our processes
158
+ # IMPORTANT: This is a hack and prevents shutdown events from triggering when the reloader is used
159
+ class CustomReloader(ChangeReload):
160
+ def restart(self) -> None:
161
+ self.process.terminate() # Use terminate instead of os.kill
162
+ self.process.join()
163
+ self.process = get_subprocess(
164
+ config=self.config, target=self.target, sockets=self.sockets
165
+ )
166
+ self.process.start()
167
+
168
+ sock = config.bind_socket()
169
+ server = CustomReloader(config, target=server.run, sockets=[sock])
170
+
171
+ server.run()
172
+
173
+
174
+ def run_backend(
175
+ shutdown_event: "Event",
176
+ dev: bool = True,
177
+ debug: bool = False,
178
+ reload: bool = True,
179
+ server_port: int = SERVER_PORT,
180
+ dev_port: int = DEV_SERVER_PORT,
181
+ ) -> None:
182
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
183
+
184
+ # Pass create_backend parameters through environment variables
185
+ os.environ["PYNM_DEBUG"] = str(debug)
186
+ os.environ["PYNM_DEV"] = str(dev)
187
+ os.environ["PYNM_DEV_PORT"] = str(dev_port)
188
+
189
+ server_process = mp.Process(
190
+ target=run_uvicorn,
191
+ kwargs={"debug": debug, "reload": reload, "server_port": server_port},
192
+ name="Server",
193
+ )
194
+ server_process.start()
195
+ shutdown_event.wait()
196
+ server_process.join()
197
+
198
+
199
+ class AppManager:
200
+ LAUNCH_FLAG = "PYNM_RUNNING"
201
+
202
+ def __init__(
203
+ self,
204
+ debug: bool = False,
205
+ dev: bool = True,
206
+ run_in_webview=False,
207
+ server_port=SERVER_PORT,
208
+ dev_port=DEV_SERVER_PORT,
209
+ dev_env_vars: dict = {},
210
+ ) -> None:
211
+ """_summary_
212
+
213
+ Args:
214
+ debug (bool, optional): If True, run the app in debug mode, which sets logging level to debug,
215
+ and starts uvicorn, FastAPI, and Vite in debug mode. Defaults to False.
216
+ dev (bool, optional): If True, run the app in development mode, which enables hot
217
+ reloading and runs the frontend in Vite server. If False, run the app in production mode,
218
+ which runs the frontend from the static files in the `frontend` directory. Defaults to True.
219
+ run_in_webview (bool, optional): If True, open a PyWebView window to display the app. Defaults to False.
220
+ """
221
+ self.debug = debug
222
+ self.dev = dev
223
+ self.run_in_webview = run_in_webview
224
+ self.server_port = server_port
225
+ self.dev_port = dev_port
226
+ self.dev_env_vars = dev_env_vars
227
+
228
+ self._reset()
229
+ # Prevent launching multiple instances of the app due to multiprocessing
230
+ # This allows the absence of a main guard in the main script
231
+ self.is_child_process = os.environ.get(self.LAUNCH_FLAG) == "TRUE"
232
+ os.environ[self.LAUNCH_FLAG] = "TRUE"
233
+
234
+ self.logger = create_logger(
235
+ "PyNM",
236
+ "yellow",
237
+ logging.DEBUG if self.debug else logging.INFO,
238
+ )
239
+
240
+ def _reset(self) -> None:
241
+ """Reset the AppManager to its initial state."""
242
+ # Flags to track the state of the application
243
+ self.shutdown_complete = False
244
+ self.shutdown_started = False
245
+
246
+ # Store background tasks
247
+ self.tasks: dict[str, mp.Process] = {}
248
+
249
+ # Events for multiprocessing synchronization
250
+ self.shutdown_event: Event = mp.Event()
251
+
252
+ def _terminate_app(self) -> None:
253
+ if self.shutdown_started:
254
+ self.logger.info("Termination already in progress. Skipping.")
255
+ return
256
+
257
+ self.shutdown_started = True
258
+
259
+ timeout = 10
260
+ deadline = time.time() + timeout
261
+
262
+ self.logger.info("App closed, cleaning up background tasks...")
263
+ self.shutdown_event.set()
264
+
265
+ for process_name, process in self.tasks.items():
266
+ remaining_time = max(deadline - time.time(), 0)
267
+ process.join(timeout=remaining_time)
268
+
269
+ if process.is_alive():
270
+ self.logger.info(
271
+ f"{process_name} did not terminate in time. Forcing termination..."
272
+ )
273
+ force_terminate_process(process, process_name, logger=self.logger)
274
+
275
+ self.logger.info(f"Process {process.name} terminated.")
276
+
277
+ self.shutdown_complete = True
278
+ self.shutdown_event.clear()
279
+ self.logger.info("All background tasks succesfully terminated.")
280
+
281
+ def _sigint_handler(self, signum, frame):
282
+ if not self.shutdown_started:
283
+ self.logger.info("Received SIGINT. Initiating graceful shutdown...")
284
+ self._terminate_app()
285
+ else:
286
+ self.logger.info("SIGINT received again. Ignoring...")
287
+
288
+ def launch(self) -> None:
289
+ if self.is_child_process:
290
+ return
291
+
292
+ # Handle keyboard interrupt signals
293
+ signal.signal(signal.SIGINT, self._sigint_handler)
294
+ # signal.signal(signal.SIGINT, signal.SIG_IGN)
295
+
296
+ # Create and start the subprocesses
297
+ if self.dev:
298
+ self.logger.info("Starting Vite server...")
299
+ self.tasks["vite"] = mp.Process(
300
+ target=run_vite,
301
+ kwargs={
302
+ "shutdown_event": self.shutdown_event,
303
+ "debug": self.debug,
304
+ "dev_port": self.dev_port,
305
+ "backend_port": self.server_port,
306
+ "dev_env_vars": self.dev_env_vars,
307
+ },
308
+ name="Vite",
309
+ )
310
+
311
+ self.logger.info("Starting backend server...")
312
+ self.tasks["backend"] = mp.Process(
313
+ target=run_backend,
314
+ kwargs={
315
+ "shutdown_event": self.shutdown_event,
316
+ "debug": self.debug,
317
+ "reload": self.dev,
318
+ "dev": self.dev,
319
+ "server_port": self.server_port,
320
+ "dev_port": self.dev_port,
321
+ },
322
+ name="Backend",
323
+ )
324
+
325
+ for process in self.tasks.values():
326
+ process.start()
327
+
328
+ if self.run_in_webview:
329
+ from .app_window import WebViewWindow
330
+
331
+ self.logger.info("Starting PyWebView window...")
332
+ window = WebViewWindow(debug=self.debug) # Must be called from main thread
333
+ window.register_event_handler("closed", self._terminate_app)
334
+ window.start() # Start the window, this will block until the window is closed
335
+ else:
336
+ try:
337
+ while not self.shutdown_complete:
338
+ time.sleep(0.1)
339
+ except KeyboardInterrupt:
340
+ pass # The SIGINT handler will take care of termination
341
+
342
+ if not self.shutdown_complete:
343
+ self._terminate_app()
344
+
345
+ self.logger.info("All processes cleaned up. Exiting...")
@@ -0,0 +1,244 @@
1
+ import os
2
+ import numpy as np
3
+ from threading import Thread
4
+ import time
5
+ import asyncio
6
+ import multiprocessing as mp
7
+ from queue import Empty
8
+ from pathlib import Path
9
+ from py_neuromodulation.stream import Stream, NMSettings
10
+ from py_neuromodulation.analysis.decode import RealTimeDecoder
11
+ from py_neuromodulation.utils import set_channels
12
+ from py_neuromodulation.utils.io import read_mne_data
13
+ from py_neuromodulation.utils.types import _PathLike
14
+ from py_neuromodulation import logger
15
+ from py_neuromodulation.gui.backend.app_socket import WebsocketManager
16
+ from py_neuromodulation.stream.backend_interface import StreamBackendInterface
17
+ from py_neuromodulation import logger
18
+
19
+
20
+ class PyNMState:
21
+ def __init__(
22
+ self,
23
+ log_queue_size: bool = False,
24
+ ) -> None:
25
+ self.log_queue_size = log_queue_size
26
+ self.lsl_stream_name: str = ""
27
+ self.is_stream_lsl: bool = False
28
+ self.experiment_name: str = "PyNM_Experiment" # set by set_stream_params
29
+ self.out_dir: _PathLike = str(
30
+ Path.home() / "PyNM" / self.experiment_name
31
+ ) # set by set_stream_params
32
+ self.decoding_model_path: _PathLike | None = None
33
+ self.decoder: RealTimeDecoder | None = None
34
+
35
+ self.backend_interface: StreamBackendInterface | None = None
36
+ self.websocket_manager: WebsocketManager | None = None
37
+
38
+ # Note: sfreq and data are required for stream init
39
+ self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1]))
40
+
41
+ self.feature_queue = mp.Queue()
42
+ self.rawdata_queue = mp.Queue()
43
+ self.control_queue = mp.Queue()
44
+ self.stop_event = asyncio.Event()
45
+
46
+ self.messages_sent = 0
47
+
48
+ def start_run_function(
49
+ self,
50
+ websocket_manager: WebsocketManager | None = None,
51
+ ) -> None:
52
+
53
+ self.websocket_manager = websocket_manager
54
+
55
+ # Create decoder
56
+ if self.decoding_model_path is not None and self.decoding_model_path != "None":
57
+ if os.path.exists(self.decoding_model_path):
58
+ self.decoder = RealTimeDecoder(self.decoding_model_path)
59
+ else:
60
+ logger.debug("Passed decoding model path does't exist")
61
+
62
+ # Initialize the backend interface if not already done
63
+ if not self.backend_interface:
64
+ self.backend_interface = StreamBackendInterface(
65
+ self.feature_queue, self.rawdata_queue, self.control_queue
66
+ )
67
+
68
+ # The run_func_thread is terminated through the stream_handling_queue
69
+ # which initiates to break the data generator and save the features
70
+ stream_process = mp.Process(
71
+ target=self.stream.run,
72
+ kwargs={
73
+ "out_dir": "" if self.out_dir == "default" else self.out_dir,
74
+ "experiment_name": self.experiment_name,
75
+ "is_stream_lsl": self.is_stream_lsl,
76
+ "stream_lsl_name": self.lsl_stream_name,
77
+ "simulate_real_time": True,
78
+ "decoder": self.decoder,
79
+ "backend_interface": self.backend_interface,
80
+ },
81
+ )
82
+
83
+ stream_process.start()
84
+
85
+ # Start websocket sender process
86
+
87
+ if self.websocket_manager:
88
+ # TONI: Instead of having this function be not async and send the
89
+ # _process_queue function to the Uvicorn async loop, we could
90
+ # have this entire "start_run_function" function be async as well
91
+
92
+ # Get the current event loop and run the queue processor
93
+ loop = asyncio.get_running_loop()
94
+ queue_task = loop.create_task(self._process_queue())
95
+
96
+ # Store task reference for cleanup
97
+ self._queue_task = queue_task
98
+
99
+ # Store processes for cleanup
100
+ self.stream_process = stream_process
101
+
102
+ def stop_run_function(self) -> None:
103
+ """Stop the stream processing"""
104
+ if self.backend_interface:
105
+ self.backend_interface.send_command("stop")
106
+ self.stop_event.set()
107
+
108
+ def setup_lsl_stream(
109
+ self,
110
+ lsl_stream_name: str = "",
111
+ line_noise: float | None = None,
112
+ ):
113
+ from mne_lsl.lsl import resolve_streams
114
+
115
+ logger.info("resolving streams")
116
+ lsl_streams = resolve_streams()
117
+
118
+ for stream in lsl_streams:
119
+ if stream.name == lsl_stream_name:
120
+ logger.info(f"found stream {lsl_stream_name}")
121
+
122
+ ch_names = stream.get_channel_names()
123
+ if ch_names is None:
124
+ ch_names = ["ch" + str(i) for i in range(stream.n_channels)]
125
+ logger.info(f"channel names: {ch_names}")
126
+
127
+ ch_types = stream.get_channel_types()
128
+ if ch_types is None:
129
+ ch_types = ["eeg" for i in range(stream.n_channels)]
130
+
131
+ logger.info(f"channel types: {ch_types}")
132
+
133
+ info_ = stream.get_channel_info()
134
+ logger.info(f"channel info: {info_}")
135
+
136
+ channels = set_channels(
137
+ ch_names=ch_names,
138
+ ch_types=ch_types,
139
+ used_types=["eeg", "ecog", "dbs", "seeg"],
140
+ )
141
+
142
+ # set all used column to 0
143
+ #channels.loc[:, "used"] = 0
144
+
145
+ logger.info(channels)
146
+ sfreq = stream.sfreq
147
+
148
+ self.stream: Stream = Stream(
149
+ sfreq=sfreq,
150
+ line_noise=line_noise,
151
+ channels=channels,
152
+ )
153
+ logger.info("stream setup")
154
+ logger.info("settings setup")
155
+
156
+ self.lsl_stream_name = lsl_stream_name
157
+ self.is_stream_lsl = True
158
+ break
159
+ else:
160
+ logger.error(f"Stream {lsl_stream_name} not found")
161
+ self.is_stream_lsl = False
162
+ self.is_stream_lsl = ""
163
+ raise ValueError(f"Stream {lsl_stream_name} not found")
164
+
165
+ def setup_offline_stream(
166
+ self,
167
+ file_path: str,
168
+ line_noise: float,
169
+ ):
170
+ data, sfreq, ch_names, ch_types, bads = read_mne_data(file_path)
171
+
172
+ channels = set_channels(
173
+ ch_names=ch_names,
174
+ ch_types=ch_types,
175
+ bads=bads,
176
+ reference=None,
177
+ used_types=["eeg", "ecog", "dbs", "seeg"],
178
+ target_keywords=None,
179
+ )
180
+
181
+ self.stream: Stream = Stream(
182
+ sfreq=sfreq,
183
+ data=data,
184
+ channels=channels,
185
+ line_noise=line_noise,
186
+ )
187
+ self.is_stream_lsl = False
188
+ self.lsl_stream_name = ""
189
+
190
+ # Async function that will continuously run in the Uvicorn async loop
191
+ # and handle sending data through the websocket manager
192
+ async def _process_queue(self):
193
+ last_queue_check = time.time()
194
+
195
+ while not self.stop_event.is_set():
196
+ # Use asyncio.gather to process both queues concurrently
197
+ tasks = []
198
+ current_time = time.time()
199
+
200
+ # Check feature queue
201
+ while not self.feature_queue.empty():
202
+ try:
203
+ data = self.feature_queue.get_nowait()
204
+ tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore
205
+ self.messages_sent += 1
206
+ except Empty:
207
+ break
208
+
209
+ # Check raw data queue
210
+ while not self.rawdata_queue.empty():
211
+ try:
212
+ data = self.rawdata_queue.get_nowait()
213
+ self.messages_sent += 1
214
+ tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore
215
+ except Empty:
216
+ break
217
+
218
+ if tasks:
219
+ # Wait for all send operations to complete
220
+ await asyncio.gather(*tasks, return_exceptions=True)
221
+ else:
222
+ # Only sleep if we didn't process any messages
223
+ await asyncio.sleep(0.001)
224
+
225
+ # Log queue diagnostics every 5 seconds
226
+ if self.log_queue_size:
227
+ if current_time - last_queue_check > 5:
228
+ logger.info(
229
+ "\nQueue diagnostics:\n"
230
+ f"\tMessages send to websocket: {self.messages_sent}.\n"
231
+ )
232
+ try:
233
+ logger.info(
234
+ f"\tFeature queue size: ~{self.feature_queue.qsize()}\n"
235
+ f"\tRaw data queue size: ~{self.rawdata_queue.qsize()}"
236
+ )
237
+ except NotImplementedError:
238
+ continue
239
+
240
+ last_queue_check = current_time
241
+
242
+ # Check if stream process is still alive
243
+ if not self.stream_process.is_alive():
244
+ break
@@ -0,0 +1,95 @@
1
+ from fastapi import WebSocket
2
+ import logging
3
+ import cbor2
4
+ import time
5
+
6
+
7
+ class WebsocketManager:
8
+ """
9
+ Manages WebSocket connections and messages.
10
+ Perhaps in the future it will handle multiple connections.
11
+ """
12
+
13
+ def __init__(self):
14
+ self.active_connections: list[WebSocket] = []
15
+ self.logger = logging.getLogger("PyNM")
16
+ self.disconnected = []
17
+ self._queue_task = None
18
+ self._stop_event = None
19
+ self.loop = None
20
+ self.messages_sent = 0
21
+ self._last_diagnostic_time = time.time()
22
+
23
+ async def connect(self, websocket: WebSocket):
24
+ await websocket.accept()
25
+ self.active_connections.append(websocket)
26
+ client_address = websocket.client
27
+ if client_address:
28
+ self.logger.info(
29
+ f"Client connected with client ID: {client_address.port}:{client_address.port}"
30
+ )
31
+
32
+ def disconnect(self, websocket: WebSocket):
33
+ self.active_connections.remove(websocket)
34
+ client_address = websocket.client
35
+ if client_address:
36
+ self.logger.info(
37
+ f"Client {client_address.port}:{client_address.port} disconnected."
38
+ )
39
+
40
+ # Combine IP and port to create a unique client ID
41
+ async def send_cbor(self, object: dict):
42
+ if not self.active_connections:
43
+ self.logger.warning("No active connection to send message.")
44
+ return
45
+
46
+ start_time = time.time()
47
+ cbor_data = cbor2.dumps(object)
48
+ serialize_time = time.time() - start_time
49
+
50
+ if serialize_time > 0.1: # Log slow serializations
51
+ self.logger.warning(f"CBOR serialization took {serialize_time:.3f}s")
52
+
53
+ send_start = time.time()
54
+ for connection in self.active_connections:
55
+ try:
56
+ await connection.send_bytes(cbor_data)
57
+ except RuntimeError as e:
58
+ self.logger.error(f"Error sending CBOR message: {e}")
59
+ self.disconnected.append(connection)
60
+
61
+ send_time = time.time() - send_start
62
+ if send_time > 0.1: # Log slow sends
63
+ self.logger.warning(f"WebSocket send took {send_time:.3f}s")
64
+
65
+ self.messages_sent += 1
66
+
67
+ # Log diagnostics every 5 seconds
68
+ current_time = time.time()
69
+ if current_time - self._last_diagnostic_time > 5:
70
+ self.logger.info(f"Messages sent: {self.messages_sent}")
71
+ self._last_diagnostic_time = current_time
72
+
73
+ async def send_message(self, message: str | dict):
74
+ if not self.active_connections:
75
+ self.logger.warning("No active connection to send message.")
76
+ return
77
+
78
+ self.logger.info(
79
+ f"Sending message within app_socket: {message.keys() if type(message) is dict else message}"
80
+ )
81
+ for connection in self.active_connections:
82
+ try:
83
+ if type(message) is dict:
84
+ await connection.send_json(message)
85
+ elif type(message) is str:
86
+ await connection.send_text(message)
87
+ self.logger.info(f"Message sent to {connection.client}")
88
+ except RuntimeError as e:
89
+ self.logger.error(f"Error sending message: {e}.")
90
+ self.active_connections.remove(connection)
91
+ await connection.close()
92
+
93
+ @property
94
+ def is_connected(self):
95
+ return self.active_connections is not None