py-neuromodulation 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +0 -1
  2. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +0 -2
  3. py_neuromodulation/__init__.py +12 -4
  4. py_neuromodulation/analysis/RMAP.py +3 -3
  5. py_neuromodulation/analysis/decode.py +55 -2
  6. py_neuromodulation/analysis/feature_reader.py +1 -0
  7. py_neuromodulation/analysis/stats.py +3 -3
  8. py_neuromodulation/default_settings.yaml +25 -17
  9. py_neuromodulation/features/bandpower.py +65 -23
  10. py_neuromodulation/features/bispectra.py +3 -7
  11. py_neuromodulation/features/bursts.py +9 -8
  12. py_neuromodulation/features/coherence.py +17 -9
  13. py_neuromodulation/features/feature_processor.py +4 -4
  14. py_neuromodulation/features/fooof.py +7 -6
  15. py_neuromodulation/features/mne_connectivity.py +25 -3
  16. py_neuromodulation/features/oscillatory.py +5 -4
  17. py_neuromodulation/features/sharpwaves.py +21 -0
  18. py_neuromodulation/filter/kalman_filter.py +17 -6
  19. py_neuromodulation/gui/__init__.py +3 -0
  20. py_neuromodulation/gui/backend/app_backend.py +419 -0
  21. py_neuromodulation/gui/backend/app_manager.py +345 -0
  22. py_neuromodulation/gui/backend/app_pynm.py +244 -0
  23. py_neuromodulation/gui/backend/app_socket.py +95 -0
  24. py_neuromodulation/gui/backend/app_utils.py +306 -0
  25. py_neuromodulation/gui/backend/app_window.py +202 -0
  26. py_neuromodulation/gui/frontend/assets/Figtree-VariableFont_wght-CkXbWBDP.ttf +0 -0
  27. py_neuromodulation/gui/frontend/assets/index-NbJiOU5a.js +300133 -0
  28. py_neuromodulation/gui/frontend/assets/plotly-DTCwMlpS.js +23594 -0
  29. py_neuromodulation/gui/frontend/charite.svg +16 -0
  30. py_neuromodulation/gui/frontend/index.html +14 -0
  31. py_neuromodulation/gui/window_api.py +115 -0
  32. py_neuromodulation/lsl_api.cfg +3 -0
  33. py_neuromodulation/processing/data_preprocessor.py +9 -2
  34. py_neuromodulation/processing/filter_preprocessing.py +43 -27
  35. py_neuromodulation/processing/normalization.py +32 -17
  36. py_neuromodulation/processing/projection.py +2 -2
  37. py_neuromodulation/processing/resample.py +6 -2
  38. py_neuromodulation/run_gui.py +36 -0
  39. py_neuromodulation/stream/__init__.py +7 -1
  40. py_neuromodulation/stream/backend_interface.py +47 -0
  41. py_neuromodulation/stream/data_processor.py +24 -3
  42. py_neuromodulation/stream/mnelsl_player.py +121 -21
  43. py_neuromodulation/stream/mnelsl_stream.py +9 -17
  44. py_neuromodulation/stream/settings.py +80 -34
  45. py_neuromodulation/stream/stream.py +82 -62
  46. py_neuromodulation/utils/channels.py +1 -1
  47. py_neuromodulation/utils/file_writer.py +110 -0
  48. py_neuromodulation/utils/io.py +46 -5
  49. py_neuromodulation/utils/perf.py +156 -0
  50. py_neuromodulation/utils/pydantic_extensions.py +322 -0
  51. py_neuromodulation/utils/types.py +33 -107
  52. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/METADATA +27 -22
  53. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/RECORD +56 -36
  54. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/WHEEL +1 -1
  55. py_neuromodulation-0.1.0.dist-info/entry_points.txt +2 -0
  56. {py_neuromodulation-0.0.6.dist-info → py_neuromodulation-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,9 @@
1
1
  from collections.abc import Iterable
2
2
  import numpy as np
3
- from typing import TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, Literal
4
4
 
5
5
  from py_neuromodulation.utils.types import NMFeature, NMBaseModel
6
+ from py_neuromodulation.utils.pydantic_extensions import NMField
6
7
 
7
8
  if TYPE_CHECKING:
8
9
  from py_neuromodulation import NMSettings
@@ -10,9 +11,30 @@ if TYPE_CHECKING:
10
11
  from mne import Epochs
11
12
 
12
13
 
14
+ MNE_CONNECTIVITY_METHOD = Literal[
15
+ "coh",
16
+ "cohy",
17
+ "imcoh",
18
+ "cacoh",
19
+ "mic",
20
+ "mim",
21
+ "plv",
22
+ "ciplv",
23
+ "ppc",
24
+ "pli",
25
+ "dpli",
26
+ "wpli",
27
+ "wpli2_debiased",
28
+ "gc",
29
+ "gc_tr",
30
+ ]
31
+
32
+ MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"]
33
+
34
+
13
35
  class MNEConnectivitySettings(NMBaseModel):
14
- method: str = "plv"
15
- mode: str = "multitaper"
36
+ method: MNE_CONNECTIVITY_METHOD = NMField(default="plv")
37
+ mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper")
16
38
 
17
39
 
18
40
  class MNEConnectivity(NMFeature):
@@ -3,6 +3,7 @@ import numpy as np
3
3
  from itertools import product
4
4
 
5
5
  from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
6
+ from py_neuromodulation.utils.pydantic_extensions import NMField
6
7
  from typing import TYPE_CHECKING
7
8
 
8
9
  if TYPE_CHECKING:
@@ -17,12 +18,12 @@ class OscillatoryFeatures(BoolSelector):
17
18
 
18
19
 
19
20
  class OscillatorySettings(NMBaseModel):
20
- windowlength_ms: int = 1000
21
+ windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"})
21
22
  log_transform: bool = True
22
23
  features: OscillatoryFeatures = OscillatoryFeatures(
23
24
  mean=True, median=False, std=False, max=False
24
25
  )
25
- return_spectrum: bool = False
26
+ return_spectrum: bool = True
26
27
 
27
28
 
28
29
  ESTIMATOR_DICT = {
@@ -176,7 +177,7 @@ class Welch(OscillatoryFeature):
176
177
  if self.settings.return_spectrum:
177
178
  combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
178
179
  for (ch_idx, ch_name), (idx, f) in combinations:
179
- feature_results[f"{ch_name}_welch_psd_{str(f)}"] = Z[ch_idx][idx]
180
+ feature_results[f"{ch_name}_welch_psd_{int(f)}"] = Z[ch_idx][idx]
180
181
 
181
182
  return feature_results
182
183
 
@@ -242,7 +243,7 @@ class STFT(OscillatoryFeature):
242
243
  if self.settings.return_spectrum:
243
244
  combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
244
245
  for (ch_idx, ch_name), (idx, f) in combinations:
245
- feature_results[f"{ch_name}_stft_psd_{str(f)}"] = Z[ch_idx].mean(
246
+ feature_results[f"{ch_name}_stft_psd_{int(f)}"] = Z[ch_idx].mean(
246
247
  axis=1
247
248
  )[idx]
248
249
 
@@ -267,6 +267,14 @@ class SharpwaveAnalyzer(NMFeature):
267
267
 
268
268
  # for each feature take the respective fun.
269
269
  for feature_name, estimator_name, estimator in estimator_combinations:
270
+ if feature_name == "num_peaks":
271
+ key_name = f"{ch_name}_Sharpwave_{feature_name}_{filter_name}"
272
+ if len(waveform_results[feature_name]) == 1:
273
+ dict_ch_features[key_name][key_name_pt] = waveform_results[feature_name][0]
274
+ continue
275
+ else:
276
+ raise ValueError("num_peaks should be a list with length 1")
277
+ # there can be only one num_peak in each batch
270
278
  feature_data = waveform_results[feature_name]
271
279
  key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
272
280
 
@@ -280,12 +288,25 @@ class SharpwaveAnalyzer(NMFeature):
280
288
 
281
289
  # the key_name stays, since the estimator function stays between peaks and troughs
282
290
  for key_name, estimator in self.estimator_key_map.items():
291
+ if len(dict_ch_features[key_name]) == 0:
292
+ # might happen if num_peaks was written in estimator
293
+ # e.g. estimator["mean"] = ["num_peaks"]
294
+ # for conveniance this doesn't raise an exception
295
+ continue
296
+
283
297
  feature_results[key_name] = estimator(
284
298
  [
285
299
  list(dict_ch_features[key_name].values())[0],
286
300
  list(dict_ch_features[key_name].values())[1],
287
301
  ]
288
302
  )
303
+ # add here also the num_peaks features
304
+ if self.sw_settings.sharpwave_features.num_peaks:
305
+ for ch_name in self.ch_names:
306
+ for filter_name in self.filter_names:
307
+ key_name = f"{ch_name}_Sharpwave_num_peaks_{filter_name}"
308
+ feature_results[key_name] = np_mean([dict_ch_features[key_name]["Peak"],
309
+ dict_ch_features[key_name]["Trough"]])
289
310
  else:
290
311
  # otherwise, save all write all "flattened" key value pairs in feature_results
291
312
  for key, subdict in dict_ch_features.items():
@@ -1,7 +1,9 @@
1
1
  import numpy as np
2
2
  from typing import TYPE_CHECKING
3
3
 
4
+
4
5
  from py_neuromodulation.utils.types import NMBaseModel
6
+ from py_neuromodulation.utils.pydantic_extensions import NMErrorList
5
7
 
6
8
 
7
9
  if TYPE_CHECKING:
@@ -22,13 +24,22 @@ class KalmanSettings(NMBaseModel):
22
24
  "HFA",
23
25
  ]
24
26
 
25
- def validate_fbands(self, settings: "NMSettings") -> None:
26
- assert all(
27
+ def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
28
+ errors: NMErrorList = NMErrorList()
29
+
30
+ if not all(
27
31
  (item in settings.frequency_ranges_hz for item in self.frequency_bands)
28
- ), (
29
- "Frequency bands for Kalman filter must also be specified in "
30
- "bandpass_filter_settings."
31
- )
32
+ ):
33
+ errors.add_error(
34
+ "Frequency bands for Kalman filter must also be specified in "
35
+ "frequency_ranges_hz.",
36
+ location=[
37
+ "kalman_filter_settings",
38
+ "frequency_bands",
39
+ ],
40
+ )
41
+
42
+ return errors
32
43
 
33
44
 
34
45
  def define_KF(Tp, sigma_w, sigma_v):
@@ -0,0 +1,3 @@
1
+ from .backend.app_manager import AppManager as App
2
+
3
+ __all__ = ["App"]
@@ -0,0 +1,419 @@
1
+ import logging
2
+ import importlib.metadata
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ import os
6
+
7
+ from fastapi import (
8
+ FastAPI,
9
+ HTTPException,
10
+ Query,
11
+ WebSocket,
12
+ )
13
+ from fastapi.staticfiles import StaticFiles
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import ValidationError
16
+
17
+ from . import app_pynm
18
+ from .app_socket import WebsocketManager
19
+ from .app_utils import is_hidden, get_quick_access
20
+ import pandas as pd
21
+
22
+ from py_neuromodulation import PYNM_DIR, NMSettings
23
+ from py_neuromodulation.utils.types import FileInfo
24
+
25
+ # TODO: maybe pull this list from the MNE package?
26
+ ALLOWED_EXTENSIONS = [".npy", ".vhdr", ".fif", ".edf", ".bdf"]
27
+
28
+ DEV_SERVER_PORT = 54321
29
+
30
+
31
+ class PyNMBackend(FastAPI):
32
+ def __init__(
33
+ self,
34
+ debug: bool | None = None,
35
+ dev: bool | None = None,
36
+ dev_port: int | None = None,
37
+ fastapi_kwargs: dict = {},
38
+ ) -> None:
39
+ if debug is None:
40
+ self.debug = os.environ.get("PYNM_DEBUG", "False").lower() == "true"
41
+ if dev is None:
42
+ self.dev = os.environ.get("PYNM_DEV", "False").lower() == "true"
43
+ if dev_port is None:
44
+ self.dev_port = os.environ.get("PYNM_DEV_PORT", str(DEV_SERVER_PORT))
45
+
46
+ super().__init__(
47
+ title="PyNeuromodulation",
48
+ description="PyNeuromodulation FastAPI backend",
49
+ version=importlib.metadata.version("py_neuromodulation"),
50
+ debug=self.debug,
51
+ **fastapi_kwargs,
52
+ )
53
+
54
+ # Use the FastAPI logger for the backend
55
+ self.logger = logging.getLogger("uvicorn.error")
56
+ self.logger.warning(PYNM_DIR)
57
+
58
+ if self.dev:
59
+ cors_origins = (
60
+ ["http://localhost:" + str(self.dev_port)] if self.dev else []
61
+ )
62
+ # Configure CORS
63
+ self.add_middleware(
64
+ CORSMiddleware,
65
+ allow_origins=cors_origins,
66
+ allow_credentials=True,
67
+ allow_methods=["*"],
68
+ allow_headers=["*"],
69
+ )
70
+
71
+ # Has to be before mounting static files
72
+ self.setup_routes()
73
+
74
+ # Serve static files
75
+ if not self.dev:
76
+ self.logger.info("Serving static files")
77
+ self.mount(
78
+ "/",
79
+ StaticFiles(directory=PYNM_DIR / "gui" / "frontend", html=True),
80
+ name="static",
81
+ )
82
+
83
+ self.websocket_manager = WebsocketManager()
84
+ self.pynm_state = app_pynm.PyNMState()
85
+
86
+ def setup_routes(self):
87
+ @self.get("/api/health")
88
+ async def healthcheck():
89
+ return {"message": "API is working"}
90
+
91
+ ####################
92
+ ##### SETTINGS #####
93
+ ####################
94
+ @self.get("/api/settings")
95
+ async def get_settings(
96
+ reset: bool = Query(False, description="Reset settings to default"),
97
+ ):
98
+ if reset:
99
+ settings = NMSettings.get_default()
100
+ else:
101
+ settings = self.pynm_state.stream.settings
102
+
103
+ return settings.serialize_with_metadata()
104
+
105
+ @self.post("/api/settings")
106
+ async def update_settings(data: dict, validate_only: bool = Query(False)):
107
+ try:
108
+ # First, validate with Pydantic
109
+ try:
110
+ # TODO: check if this works properly or needs model_validate_strings
111
+ validated_settings = NMSettings.model_validate(data)
112
+ except ValidationError as e:
113
+ self.logger.error(f"Error validating settings: {e}")
114
+ if not validate_only:
115
+ # If validation failed but we wanted to upload, return error
116
+ raise HTTPException(
117
+ status_code=422,
118
+ detail={
119
+ "error": "Error validating settings",
120
+ "details": str(e),
121
+ },
122
+ )
123
+ # Else return list of errors
124
+ return {
125
+ "valid": False,
126
+ "errors": [err for err in e.errors()],
127
+ "details": str(e),
128
+ }
129
+
130
+ # If validation succesful, return or update settings
131
+ if validate_only:
132
+ return {
133
+ "valid": True,
134
+ "settings": validated_settings.serialize_with_metadata(),
135
+ }
136
+
137
+ self.pynm_state.stream.settings = validated_settings
138
+ self.logger.info("Settings successfully updated")
139
+
140
+ return {
141
+ "valid": True,
142
+ "settings": self.pynm_state.stream.settings.serialize_with_metadata(),
143
+ }
144
+
145
+ # If something else than validation went wrong, return error
146
+ except Exception as e:
147
+ self.logger.error(f"Error validating/updating settings: {e}")
148
+ raise HTTPException(
149
+ status_code=422,
150
+ detail={"error": "Error uploading settings", "details": str(e)},
151
+ )
152
+
153
+ ########################
154
+ ##### PYNM CONTROL #####
155
+ ########################
156
+
157
+ @self.post("/api/stream-control")
158
+ async def handle_stream_control(data: dict):
159
+ action = data["action"]
160
+ if action == "start":
161
+ # TODO: create out_dir and experiment_name text filds in frontend
162
+ self.logger.info("websocket:")
163
+ self.logger.info(self.websocket_manager)
164
+ self.logger.info("Starting stream")
165
+
166
+ self.pynm_state.start_run_function(
167
+ websocket_manager=self.websocket_manager,
168
+ )
169
+
170
+ if action == "stop":
171
+ self.logger.info("Stopping stream")
172
+ self.pynm_state.stop_run_function()
173
+
174
+ return {"message": f"Stream action '{action}' executed"}
175
+
176
+ ####################
177
+ ##### CHANNELS #####
178
+ ####################
179
+
180
+ @self.get("/api/channels")
181
+ async def get_channels():
182
+ channels = self.pynm_state.stream.channels
183
+ self.logger.info(f"Sending channels: {channels}")
184
+ if isinstance(channels, pd.DataFrame):
185
+ return {"channels": channels.to_dict(orient="records")}
186
+ else:
187
+ raise HTTPException(
188
+ status_code=422,
189
+ detail={"error": "Channels is not a DataFrame"},
190
+ )
191
+
192
+ @self.post("/api/channels")
193
+ async def update_channels(data: dict):
194
+ try:
195
+ new_channels = pd.DataFrame(data["channels"])
196
+ self.logger.info(f"Received channels:\n {new_channels}")
197
+ self.pynm_state.stream.channels = new_channels
198
+ return {
199
+ "channels": self.pynm_state.stream.channels.to_dict(
200
+ orient="records"
201
+ )
202
+ }
203
+ except Exception as e:
204
+ raise HTTPException(
205
+ status_code=422,
206
+ detail={"error": "Error updating channels", "details": str(e)},
207
+ )
208
+
209
+ ###################
210
+ ### LSL STREAMS ###
211
+ ###################
212
+
213
+ @self.get("/api/LSL-streams")
214
+ async def get_lsl_streams():
215
+ from mne_lsl.lsl import resolve_streams
216
+
217
+ return {
218
+ "message": [
219
+ {
220
+ "dtype": # MNE-LSL might return a class, so we get the name
221
+ getattr(stream.dtype, "__name__", str(stream.dtype)),
222
+ "name": stream.name,
223
+ "n_channels": stream.n_channels,
224
+ "sfreq": stream.sfreq,
225
+ "source_id": stream.source_id,
226
+ "stype": stream.stype, # Stream type (e.g. EEG)
227
+ "created_at": stream.created_at,
228
+ "hostname": stream.hostname,
229
+ "session_id": stream.session_id,
230
+ "uid": stream.uid,
231
+ "protocol_version": stream.protocol_version,
232
+ }
233
+ for stream in resolve_streams()
234
+ ]
235
+ }
236
+
237
+ @self.post("/api/setup-LSL-stream")
238
+ async def setup_lsl_stream(data: dict):
239
+ try:
240
+ stream_name = data["stream_name"]
241
+ self.logger.info(f"Attempting to setup LSL stream: '{stream_name}'")
242
+ self.pynm_state.setup_lsl_stream(
243
+ lsl_stream_name=stream_name,
244
+ line_noise=data["line_noise"],
245
+ )
246
+ return {"message": f"LSL stream '{stream_name}' setup successfully"}
247
+ except Exception as e:
248
+ return {
249
+ "message": "LSL stream could not be setup",
250
+ "error": str(e),
251
+ }
252
+
253
+ @self.post("/api/setup-Offline-stream")
254
+ async def setup_offline_stream(data: dict):
255
+ self.logger.info("Data received to setup offline stream:")
256
+ self.logger.info(data)
257
+ try:
258
+ self.pynm_state.setup_offline_stream(
259
+ file_path=data["file_path"],
260
+ line_noise=float(data["line_noise"]),
261
+ )
262
+ return {"message": "Offline stream setup successfully"}
263
+ except ValueError:
264
+ return {"message": "Offline stream could not be setup"}
265
+
266
+ @self.post("/api/set-stream-params")
267
+ async def set_stream_params(data: dict):
268
+ try:
269
+ self.pynm_state.stream.line_noise = float(data["line_noise"])
270
+ self.pynm_state.stream.sfreq = float(data["sampling_rate"])
271
+ self.pynm_state.experiment_name = data["experiment_name"]
272
+ self.pynm_state.out_dir = data["out_dir"]
273
+ self.pynm_state.decoding_model_path = data["decoding_path"]
274
+
275
+ return {"message": "Stream parameters updated successfully"}
276
+ except ValueError:
277
+ return {"message": "Stream parameters could not be updated"}
278
+
279
+ #######################
280
+ ### PYNM ABOUT INFO ###
281
+ #######################
282
+
283
+ @self.get("/api/app-info")
284
+ # TODO: fix this function
285
+ async def get_app_info():
286
+ metadata = importlib.metadata.metadata("py_neuromodulation")
287
+ url_list = metadata.get_all("Project-URL")
288
+ urls = (
289
+ {url.split(",")[0]: url.split(",")[1] for url in url_list}
290
+ if url_list
291
+ else {}
292
+ )
293
+
294
+ classifier_list = metadata.get_all("Classifier")
295
+ classifiers = (
296
+ {
297
+ item[: item.find("::") - 1]: item[item.find("::") + 3 :]
298
+ for item in classifier_list
299
+ }
300
+ if classifier_list
301
+ else {}
302
+ )
303
+ if "License" in classifiers:
304
+ classifiers["License"] = classifiers["License"].split("::")[1]
305
+
306
+ return {
307
+ "version": metadata.get("Version", ""),
308
+ "website": urls.get("Homepage", ""),
309
+ "authors": [metadata.get("Author-email", "")],
310
+ "maintainers": [metadata.get("Maintainer", "")],
311
+ "repository": urls.get("Repository", ""),
312
+ "documentation": urls.get("Documentation", ""),
313
+ "license": classifiers["License"],
314
+ # "launchMode": "debug" if app.debug else "release",
315
+ }
316
+
317
+ ##############################
318
+ ### FILE BROWSER ENDPOINTS ###
319
+ ##############################
320
+ # Get home directory for the current user
321
+ @self.get("/api/home_directory")
322
+ async def home_directory():
323
+ try:
324
+ home_dir = str(Path.home())
325
+ return {"home_directory": home_dir}
326
+ except Exception as e:
327
+ raise HTTPException(status_code=500, detail=str(e))
328
+
329
+ # Get PYNM_DIR
330
+ @self.get("/api/pynm_dir")
331
+ async def get_pynm_dir():
332
+ try:
333
+ return {"pynm_dir": PYNM_DIR}
334
+ except Exception as e:
335
+ raise HTTPException(status_code=500, detail=str(e))
336
+
337
+ # Get list of available drives in Windows systems
338
+ @self.get("/api/drives")
339
+ async def list_drives():
340
+ if os.name == "nt":
341
+ import string
342
+
343
+ drives = []
344
+ for letter in string.ascii_uppercase:
345
+ if Path(f"{letter}:").exists():
346
+ drives.append(f"{letter}:")
347
+
348
+ return {"drives": drives}
349
+ else:
350
+ return {"drives": ["/"]} # Unix-like systems have a single root
351
+
352
+ # Get list of files and directories in a directory
353
+ @self.get("/api/files")
354
+ async def list_files(
355
+ path: str = Query(default="", description="Directory path to list"),
356
+ allowed_extensions: str = Query(
357
+ default=",".join(ALLOWED_EXTENSIONS),
358
+ description="Comma-separated list of allowed file extensions",
359
+ ),
360
+ show_hidden: bool = Query(
361
+ default=False,
362
+ description="Whether to show hidden files and directories",
363
+ ),
364
+ ) -> list[FileInfo]:
365
+ try:
366
+ if not path:
367
+ path = str(Path.home())
368
+
369
+ if not Path(path).is_dir():
370
+ raise FileNotFoundError("The specified path is not a directory")
371
+
372
+ allowed_ext = allowed_extensions.split(",")
373
+
374
+ files = []
375
+ for entry in Path(path).iterdir():
376
+ # Skip hidden files/directories if show_hidden is False
377
+ if not show_hidden and is_hidden(entry):
378
+ continue
379
+
380
+ if entry.is_file() and not any(
381
+ entry.name.lower().endswith(ext.lower()) for ext in allowed_ext
382
+ ):
383
+ continue
384
+
385
+ stats = entry.stat()
386
+ files.append(
387
+ FileInfo(
388
+ name=entry.name,
389
+ path=str(entry),
390
+ dir=str(entry.parent),
391
+ is_directory=entry.is_dir(),
392
+ size=stats.st_size if not entry.is_dir() else 0,
393
+ created_at=datetime.fromtimestamp(stats.st_birthtime),
394
+ modified_at=datetime.fromtimestamp(stats.st_mtime),
395
+ )
396
+ )
397
+ return files
398
+ except FileNotFoundError:
399
+ raise HTTPException(status_code=404, detail="Directory not found")
400
+ except PermissionError:
401
+ raise HTTPException(status_code=403, detail="Permission denied")
402
+ except Exception as e:
403
+ raise HTTPException(status_code=500, detail=str(e))
404
+
405
+ @self.get("/api/quick-access")
406
+ def quick_access():
407
+ return get_quick_access()
408
+
409
+ ###########################
410
+ ### WEBSOCKET ENDPOINTS ###
411
+ ###########################
412
+ @self.websocket("/ws")
413
+ async def websocket_endpoint(websocket: WebSocket):
414
+ await self.websocket_manager.connect(websocket)
415
+ while True:
416
+ try:
417
+ await websocket.receive_text()
418
+ except Exception:
419
+ break