supervisely 6.73.444__py3-none-any.whl → 6.73.468__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (68) hide show
  1. supervisely/__init__.py +24 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/api/dataset_api.py +74 -12
  5. supervisely/api/entity_annotation/figure_api.py +8 -5
  6. supervisely/api/image_api.py +4 -0
  7. supervisely/api/video/video_annotation_api.py +4 -2
  8. supervisely/api/video/video_api.py +41 -1
  9. supervisely/app/__init__.py +1 -1
  10. supervisely/app/content.py +14 -6
  11. supervisely/app/fastapi/__init__.py +1 -0
  12. supervisely/app/fastapi/custom_static_files.py +1 -1
  13. supervisely/app/fastapi/multi_user.py +88 -0
  14. supervisely/app/fastapi/subapp.py +88 -42
  15. supervisely/app/fastapi/websocket.py +77 -9
  16. supervisely/app/singleton.py +21 -0
  17. supervisely/app/v1/app_service.py +18 -2
  18. supervisely/app/v1/constants.py +7 -1
  19. supervisely/app/widgets/card/card.py +20 -0
  20. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  21. supervisely/app/widgets/dialog/dialog.py +12 -0
  22. supervisely/app/widgets/dialog/template.html +2 -1
  23. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  24. supervisely/app/widgets/fast_table/fast_table.py +121 -31
  25. supervisely/app/widgets/fast_table/template.html +1 -1
  26. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  27. supervisely/app/widgets/radio_tabs/template.html +1 -0
  28. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
  29. supervisely/app/widgets/table/table.py +68 -13
  30. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  31. supervisely/convert/image/csv/csv_converter.py +24 -15
  32. supervisely/convert/video/video_converter.py +2 -2
  33. supervisely/geometry/polyline_3d.py +110 -0
  34. supervisely/io/env.py +76 -1
  35. supervisely/nn/inference/cache.py +37 -17
  36. supervisely/nn/inference/inference.py +667 -114
  37. supervisely/nn/inference/inference_request.py +15 -8
  38. supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
  39. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  40. supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
  41. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  42. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  43. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  44. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  45. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  46. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  47. supervisely/nn/inference/session.py +43 -35
  48. supervisely/nn/model/model_api.py +9 -0
  49. supervisely/nn/model/prediction_session.py +8 -7
  50. supervisely/nn/prediction_dto.py +7 -0
  51. supervisely/nn/tracker/base_tracker.py +11 -1
  52. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  53. supervisely/nn/tracker/botsort_tracker.py +14 -7
  54. supervisely/nn/tracker/visualize.py +70 -72
  55. supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
  56. supervisely/nn/training/train_app.py +10 -5
  57. supervisely/project/project.py +9 -1
  58. supervisely/video/sampling.py +39 -20
  59. supervisely/video/video.py +41 -12
  60. supervisely/volume/stl_converter.py +2 -0
  61. supervisely/worker_api/agent_rpc.py +24 -1
  62. supervisely/worker_api/rpc_servicer.py +31 -7
  63. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/METADATA +14 -11
  64. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/RECORD +68 -66
  65. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/LICENSE +0 -0
  66. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/WHEEL +0 -0
  67. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/entry_points.txt +0 -0
  68. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/top_level.txt +0 -0
supervisely/__init__.py CHANGED
@@ -8,6 +8,22 @@ try:
8
8
  except TypeError as e:
9
9
  __version__ = "development"
10
10
 
11
+
12
+ class _ApiProtoNotAvailable:
13
+ """Placeholder class that raises an error when accessing any attribute"""
14
+
15
+ def __getattr__(self, name):
16
+ from supervisely.app.v1.constants import PROTOBUF_REQUIRED_ERROR
17
+
18
+ raise ImportError(f"Cannot access `api_proto.{name}` : " + PROTOBUF_REQUIRED_ERROR)
19
+
20
+ def __bool__(self):
21
+ return False
22
+
23
+ def __repr__(self):
24
+ return "<api_proto: not available - install supervisely[agent] to enable>"
25
+
26
+
11
27
  from supervisely.sly_logger import (
12
28
  logger,
13
29
  ServiceType,
@@ -112,7 +128,14 @@ from supervisely.worker_api.chunking import (
112
128
  ChunkedFileWriter,
113
129
  ChunkedFileReader,
114
130
  )
115
- import supervisely.worker_proto.worker_api_pb2 as api_proto
131
+
132
+ # Global import of api_proto works only if protobuf is installed and compatible
133
+ # Otherwise, we use a placeholder that raises an error when accessed
134
+ try:
135
+ import supervisely.worker_proto.worker_api_pb2 as api_proto
136
+ except Exception:
137
+ api_proto = _ApiProtoNotAvailable()
138
+
116
139
 
117
140
  from supervisely.api.api import Api, UserSession, ApiContext
118
141
  from supervisely.api import api
supervisely/_utils.py CHANGED
@@ -319,6 +319,87 @@ def resize_image_url(
319
319
  return full_storage_url
320
320
 
321
321
 
322
+ def get_storage_url(
323
+ entity_type: Literal["dataset-entities", "dataset", "project", "file-storage"],
324
+ entity_id: int,
325
+ source_type: Literal["original", "preview"],
326
+ ) -> str:
327
+ """
328
+ Generate URL for storage resources endpoints.
329
+
330
+ :param entity_type: Type of entity ("dataset-entities", "dataset", "project", "file-storage")
331
+ :type entity_type: str
332
+ :param entity_id: ID of the entity
333
+ :type entity_id: int
334
+ :param source_type: Type of source ("original" or "preview")
335
+ :type source_type: Literal["original", "preview"]
336
+ :return: Storage URL
337
+ :rtype: str
338
+ """
339
+ relative_url = f"/storage-resources/{entity_type}/{source_type}/{entity_id}"
340
+ if is_development():
341
+ return abs_url(relative_url)
342
+ return relative_url
343
+
344
+
345
+ def get_image_storage_url(image_id: int, source_type: Literal["original", "preview"]) -> str:
346
+ """
347
+ Generate URL for image storage resources.
348
+
349
+ :param image_id: ID of the image
350
+ :type image_id: int
351
+ :param source_type: Type of source ("original" or "preview")
352
+ :type source_type: Literal["original", "preview"]
353
+ :return: Storage URL for image
354
+ :rtype: str
355
+ """
356
+ return get_storage_url("dataset-entities", image_id, source_type)
357
+
358
+
359
+ def get_dataset_storage_url(
360
+ dataset_id: int, source_type: Literal["original", "preview", "raw"]
361
+ ) -> str:
362
+ """
363
+ Generate URL for dataset storage resources.
364
+
365
+ :param dataset_id: ID of the dataset
366
+ :type dataset_id: int
367
+ :param source_type: Type of source ("original", "preview", or "raw")
368
+ :type source_type: Literal["original", "preview", "raw"]
369
+ :return: Storage URL for dataset
370
+ :rtype: str
371
+ """
372
+ return get_storage_url("dataset", dataset_id, source_type)
373
+
374
+
375
+ def get_project_storage_url(
376
+ project_id: int, source_type: Literal["original", "preview", "raw"]
377
+ ) -> str:
378
+ """
379
+ Generate URL for project storage resources.
380
+
381
+ :param project_id: ID of the project
382
+ :type project_id: int
383
+ :param source_type: Type of source ("original", "preview", or "raw")
384
+ :type source_type: Literal["original", "preview", "raw"]
385
+ :return: Storage URL for project
386
+ :rtype: str
387
+ """
388
+ return get_storage_url("project", project_id, source_type)
389
+
390
+
391
+ def get_file_storage_url(file_id: int) -> str:
392
+ """
393
+ Generate URL for file storage resources (raw files).
394
+
395
+ :param file_id: ID of the file
396
+ :type file_id: int
397
+ :return: Storage URL for file
398
+ :rtype: str
399
+ """
400
+ return get_storage_url("file-storage", file_id, "raw")
401
+
402
+
322
403
  def get_preview_link(title="preview"):
323
404
  return (
324
405
  f'<a href="javascript:;">{title}<i class="zmdi zmdi-cast" style="margin-left: 5px"></i></a>'
@@ -15,6 +15,7 @@ from supervisely.geometry.multichannel_bitmap import MultichannelBitmap
15
15
  from supervisely.geometry.closed_surface_mesh import ClosedSurfaceMesh
16
16
  from supervisely.geometry.alpha_mask import AlphaMask
17
17
  from supervisely.geometry.cuboid_2d import Cuboid2d
18
+ from supervisely.geometry.polyline_3d import Polyline3D
18
19
 
19
20
 
20
21
  _INPUT_GEOMETRIES = [
@@ -34,6 +35,7 @@ _INPUT_GEOMETRIES = [
34
35
  ClosedSurfaceMesh,
35
36
  AlphaMask,
36
37
  Cuboid2d,
38
+ Polyline3D,
37
39
  ]
38
40
  _JSON_SHAPE_TO_GEOMETRY_TYPE = {
39
41
  geometry.geometry_name(): geometry for geometry in _INPUT_GEOMETRIES
@@ -1021,13 +1021,66 @@ class DatasetApi(UpdateableModule, RemoveableModuleApi):
1021
1021
 
1022
1022
  return dataset_tree
1023
1023
 
1024
- def tree(self, project_id: int) -> Generator[Tuple[List[str], DatasetInfo], None, None]:
1024
+ def _yield_tree(
1025
+ self, tree: Dict[DatasetInfo, Dict], path: List[str]
1026
+ ) -> Generator[Tuple[List[str], DatasetInfo], None, None]:
1027
+ """
1028
+ Helper method for recursive tree traversal.
1029
+ Yields tuples of (path, dataset) for all datasets in the tree. For each node (dataset) at the current level,
1030
+ yields its (path, dataset) before recursively traversing and yielding from its children.
1031
+
1032
+ :param tree: Tree structure to yield from.
1033
+ :type tree: Dict[DatasetInfo, Dict]
1034
+ :param path: Current path (used for recursion).
1035
+ :type path: List[str]
1036
+ :return: Generator of tuples of (path, dataset).
1037
+ :rtype: Generator[Tuple[List[str], DatasetInfo], None, None]
1038
+ """
1039
+ for dataset, children in tree.items():
1040
+ yield path, dataset
1041
+ new_path = path + [dataset.name]
1042
+ if children:
1043
+ yield from self._yield_tree(children, new_path)
1044
+
1045
+ def _find_dataset_in_tree(
1046
+ self, tree: Dict[DatasetInfo, Dict], target_id: int, path: List[str] = None
1047
+ ) -> Tuple[Optional[DatasetInfo], Optional[Dict], List[str]]:
1048
+ """Find a specific dataset in the tree and return its subtree and path.
1049
+
1050
+ :param tree: Tree structure to search in.
1051
+ :type tree: Dict[DatasetInfo, Dict]
1052
+ :param target_id: ID of the dataset to find.
1053
+ :type target_id: int
1054
+ :param path: Current path (used for recursion).
1055
+ :type path: List[str], optional
1056
+ :return: Tuple of (found_dataset, its_subtree, path_to_dataset).
1057
+ :rtype: Tuple[Optional[DatasetInfo], Optional[Dict], List[str]]
1058
+ """
1059
+ if path is None:
1060
+ path = []
1061
+
1062
+ for dataset, children in tree.items():
1063
+ if dataset.id == target_id:
1064
+ return dataset, children, path
1065
+ # Search in children
1066
+ if children:
1067
+ found_dataset, found_children, found_path = self._find_dataset_in_tree(
1068
+ children, target_id, path + [dataset.name]
1069
+ )
1070
+ if found_dataset is not None:
1071
+ return found_dataset, found_children, found_path
1072
+ return None, None, []
1073
+
1074
+ def tree(self, project_id: int, dataset_id: Optional[int] = None) -> Generator[Tuple[List[str], DatasetInfo], None, None]:
1025
1075
  """Yields tuples of (path, dataset) for all datasets in the project.
1026
1076
  Path of the dataset is a list of parents, e.g. ["ds1", "ds2", "ds3"].
1027
1077
  For root datasets, the path is an empty list.
1028
1078
 
1029
1079
  :param project_id: Project ID in which the Dataset is located.
1030
1080
  :type project_id: int
1081
+ :param dataset_id: Optional Dataset ID to start the tree from. If provided, only yields
1082
+ the subtree starting from this dataset (including the dataset itself and all its children).
1083
+ :type dataset_id: Optional[int]
1031
1084
  :return: Generator of tuples of (path, dataset).
1032
1085
  :rtype: Generator[Tuple[List[str], DatasetInfo], None, None]
1033
1086
  :Usage example:
@@ -1040,11 +1093,17 @@ class DatasetApi(UpdateableModule, RemoveableModuleApi):
1040
1093
 
1041
1094
  project_id = 123
1042
1095
 
1096
+ # Get all datasets in the project
1043
1097
  for parents, dataset in api.dataset.tree(project_id):
1044
1098
  parents: List[str]
1045
1099
  dataset: sly.DatasetInfo
1046
1100
  print(parents, dataset.name)
1047
1101
 
1102
+ # Get only a specific branch starting from dataset_id = 456
1103
+ for parents, dataset in api.dataset.tree(project_id, dataset_id=456):
1104
+ parents: List[str]
1105
+ dataset: sly.DatasetInfo
1106
+ print(parents, dataset.name)
1048
1107
 
1049
1108
  # Output:
1050
1109
  # [] ds1
@@ -1052,17 +1111,20 @@ class DatasetApi(UpdateableModule, RemoveableModuleApi):
1052
1111
  # ["ds1", "ds2"] ds3
1053
1112
  """
1054
1113
 
1055
- def yield_tree(
1056
- tree: Dict[DatasetInfo, Dict], path: List[str]
1057
- ) -> Generator[Tuple[List[str], DatasetInfo], None, None]:
1058
- """Yields tuples of (path, dataset) for all datasets in the tree."""
1059
- for dataset, children in tree.items():
1060
- yield path, dataset
1061
- new_path = path + [dataset.name]
1062
- if children:
1063
- yield from yield_tree(children, new_path)
1064
-
1065
- yield from yield_tree(self.get_tree(project_id), [])
1114
+ full_tree = self.get_tree(project_id)
1115
+
1116
+ if dataset_id is None:
1117
+ # Return the full tree
1118
+ yield from self._yield_tree(full_tree, [])
1119
+ else:
1120
+ # Find the specific dataset and return only its subtree
1121
+ target_dataset, subtree, dataset_path = self._find_dataset_in_tree(full_tree, dataset_id)
1122
+ if target_dataset is not None:
1123
+ # Yield the target dataset first, then its children
1124
+ yield dataset_path, target_dataset
1125
+ if subtree:
1126
+ new_path = dataset_path + [target_dataset.name]
1127
+ yield from self._yield_tree(subtree, new_path)
1066
1128
 
1067
1129
  def get_nested(self, project_id: int, dataset_id: int) -> List[DatasetInfo]:
1068
1130
  """Returns a list of all nested datasets in the specified dataset.
@@ -24,10 +24,10 @@ from requests_toolbelt import MultipartDecoder, MultipartEncoder
24
24
  from tqdm import tqdm
25
25
 
26
26
  from supervisely._utils import batched, logger, run_coroutine
27
+ from supervisely.annotation.label import LabelingStatus
27
28
  from supervisely.api.module_api import ApiField, ModuleApi, RemoveableBulkModuleApi
28
29
  from supervisely.geometry.rectangle import Rectangle
29
30
  from supervisely.video_annotation.key_id_map import KeyIdMap
30
- from supervisely.annotation.label import LabelingStatus
31
31
 
32
32
 
33
33
  class FigureInfo(NamedTuple):
@@ -595,10 +595,13 @@ class FigureApi(RemoveableBulkModuleApi):
595
595
  """
596
596
  geometries = {}
597
597
  for idx, part in self._download_geometries_generator(ids):
598
- if progress_cb is not None:
599
- progress_cb(len(part.content))
600
- geometry_json = json.loads(part.content)
601
- geometries[idx] = geometry_json
598
+ try:
599
+ if progress_cb is not None:
600
+ progress_cb(len(part.content))
601
+ geometry_json = json.loads(part.content)
602
+ geometries[idx] = geometry_json
603
+ except Exception as e:
604
+ raise RuntimeError(f"Failed to decode geometry for figure ID {idx}") from e
602
605
 
603
606
  if len(geometries) != len(ids):
604
607
  raise RuntimeError("Not all geometries were downloaded")
@@ -397,6 +397,9 @@ class ImageInfo(NamedTuple):
397
397
  #: Format: "YYYY-MM-DDTHH:MM:SS.sssZ"
398
398
  embeddings_updated_at: Optional[str] = None
399
399
 
400
+ #: :class:`int`: :class:`Dataset<supervisely.project.project.Project>` ID in Supervisely.
401
+ project_id: int = None
402
+
400
403
  # DO NOT DELETE THIS COMMENT
401
404
  #! New fields must be added with default values to keep backward compatibility.
402
405
 
@@ -476,6 +479,7 @@ class ImageApi(RemoveableBulkModuleApi):
476
479
  ApiField.OFFSET_END,
477
480
  ApiField.AI_SEARCH_META,
478
481
  ApiField.EMBEDDINGS_UPDATED_AT,
482
+ ApiField.PROJECT_ID,
479
483
  ]
480
484
 
481
485
  @staticmethod
@@ -236,11 +236,13 @@ class VideoAnnotationAPI(EntityAnnotationAPI):
236
236
  dst_project_meta = ProjectMeta.from_json(
237
237
  self._api.project.get_meta(dst_dataset_info.project_id)
238
238
  )
239
- for src_ids_batch, dst_ids_batch in batched(list(zip(src_video_ids, dst_video_ids))):
239
+ for src_ids_batch, dst_ids_batch in zip(batched(src_video_ids), batched(dst_video_ids)):
240
240
  ann_jsons = self.download_bulk(src_dataset_id, src_ids_batch)
241
241
  for dst_id, ann_json in zip(dst_ids_batch, ann_jsons):
242
242
  try:
243
- ann = VideoAnnotation.from_json(ann_json, dst_project_meta)
243
+ ann = VideoAnnotation.from_json(
244
+ ann_json, dst_project_meta, key_id_map=KeyIdMap()
245
+ )
244
246
  except Exception as e:
245
247
  raise RuntimeError("Failed to validate Annotation") from e
246
248
  self.append(dst_id, ann)
@@ -5,6 +5,7 @@ import asyncio
5
5
  import datetime
6
6
  import json
7
7
  import os
8
+ import re
8
9
  import urllib.parse
9
10
  from functools import partial
10
11
  from typing import (
@@ -23,7 +24,11 @@ from typing import (
23
24
  import aiofiles
24
25
  from numerize.numerize import numerize
25
26
  from requests import Response
26
- from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
27
+ from requests_toolbelt import (
28
+ MultipartDecoder,
29
+ MultipartEncoder,
30
+ MultipartEncoderMonitor,
31
+ )
27
32
  from tqdm import tqdm
28
33
 
29
34
  import supervisely.io.fs as sly_fs
@@ -1186,6 +1191,41 @@ class VideoApi(RemoveableBulkModuleApi):
1186
1191
  if progress_cb is not None:
1187
1192
  progress_cb(len(chunk))
1188
1193
 
1194
+ def download_frames(
1195
+ self, video_id: int, frames: List[int], paths: List[str], progress_cb=None
1196
+ ) -> None:
1197
+ endpoint = "videos.bulk.download-frame"
1198
+ response: Response = self._api.get(
1199
+ endpoint,
1200
+ params={},
1201
+ data={ApiField.VIDEO_ID: video_id, ApiField.FRAMES: frames},
1202
+ stream=True,
1203
+ )
1204
+ response.raise_for_status()
1205
+
1206
+ files = {frame_n: None for frame_n in frames}
1207
+ file_paths = {frame_n: path for frame_n, path in zip(frames, paths)}
1208
+
1209
+ try:
1210
+ decoder = MultipartDecoder.from_response(response)
1211
+ for part in decoder.parts:
1212
+ content_utf8 = part.headers[b"Content-Disposition"].decode("utf-8")
1213
+ # Find name="1245" preceded by a whitespace, semicolon or beginning of line.
1214
+ # The regex has 2 capture group: one for the prefix and one for the actual name value.
1215
+ frame_n = int(re.findall(r'(^|[\s;])name="(\d*)"', content_utf8)[0][1])
1216
+ if files[frame_n] is None:
1217
+ file_path = file_paths[frame_n]
1218
+ files[frame_n] = open(file_path, "wb")
1219
+ if progress_cb is not None:
1220
+ progress_cb(1)
1221
+ f = files[frame_n]
1222
+ f.write(part.content)
1223
+
1224
+ finally:
1225
+ for f in files.values():
1226
+ if f is not None:
1227
+ f.close()
1228
+
1189
1229
  def download_range_by_id(
1190
1230
  self,
1191
1231
  id: int,
@@ -1,7 +1,7 @@
1
1
  from fastapi import FastAPI
2
2
  from supervisely.app.content import StateJson, DataJson
3
3
  from supervisely.app.content import get_data_dir, get_synced_data_dir
4
- from supervisely.app.fastapi.subapp import call_on_autostart
4
+ from supervisely.app.fastapi.subapp import call_on_autostart, session_user_api
5
5
  import supervisely.app.fastapi as fastapi
6
6
  import supervisely.app.widgets as widgets
7
7
  import supervisely.app.development as development
@@ -11,6 +11,7 @@ import threading
11
11
  import time
12
12
  import traceback
13
13
  from concurrent.futures import ThreadPoolExecutor
14
+ from typing import Optional, Union
14
15
 
15
16
  import jsonpatch
16
17
  from fastapi import Request
@@ -109,16 +110,22 @@ class _PatchableJson(dict):
109
110
  patch.apply(self._last, in_place=True)
110
111
  self._last = copy.deepcopy(self._last)
111
112
 
112
- async def synchronize_changes(self):
113
+ async def synchronize_changes(self, user_id: Optional[Union[int, str]] = None):
113
114
  patch = self._get_patch()
114
115
  await self._apply_patch(patch)
115
- await self._ws.broadcast(self.get_changes(patch))
116
+ await self._ws.broadcast(self.get_changes(patch), user_id=user_id)
116
117
 
117
118
  async def send_changes_async(self):
118
- await self.synchronize_changes()
119
+ user_id = None
120
+ if sly_env.is_multiuser_mode_enabled():
121
+ user_id = sly_env.user_from_multiuser_app()
122
+ await self.synchronize_changes(user_id=user_id)
119
123
 
120
124
  def send_changes(self):
121
- run_sync(self.synchronize_changes())
125
+ user_id = None
126
+ if sly_env.is_multiuser_mode_enabled():
127
+ user_id = sly_env.user_from_multiuser_app()
128
+ run_sync(self.synchronize_changes(user_id=user_id))
122
129
 
123
130
  def raise_for_key(self, key: str):
124
131
  if key in self:
@@ -139,7 +146,7 @@ class StateJson(_PatchableJson, metaclass=Singleton):
139
146
  await StateJson._replace_global(dict(self))
140
147
 
141
148
  @classmethod
142
- async def from_request(cls, request: Request) -> StateJson:
149
+ async def from_request(cls, request: Request, local: bool = True) -> StateJson:
143
150
  if "application/json" not in request.headers.get("Content-Type", ""):
144
151
  return None
145
152
  content = await request.json()
@@ -149,7 +156,8 @@ class StateJson(_PatchableJson, metaclass=Singleton):
149
156
  # TODO: should we always replace STATE with {}?
150
157
  d = content.get(Field.STATE, {})
151
158
  await cls._replace_global(d)
152
- return cls(d, __local__=True)
159
+
160
+ return cls(d, __local__=local)
153
161
 
154
162
  @classmethod
155
163
  async def _replace_global(cls, d: dict):
@@ -5,6 +5,7 @@ from supervisely.app.fastapi.subapp import (
5
5
  Application,
6
6
  get_name_from_env,
7
7
  _MainServer,
8
+ session_user_api,
8
9
  )
9
10
  from supervisely.app.fastapi.templating import Jinja2Templates
10
11
  from supervisely.app.fastapi.websocket import WebsocketManager
@@ -42,7 +42,7 @@ class CustomStaticFiles(StaticFiles):
42
42
  def _get_range_header(range_header: str, file_size: int) -> typing.Tuple[int, int]:
43
43
  def _invalid_range():
44
44
  return HTTPException(
45
- status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
45
+ status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, #TODO: change to status.HTTP_416_RANGE_NOT_SATISFIABLE if update starlette to 0.48.0+
46
46
  detail=f"Invalid request range (Range:{range_header!r})",
47
47
  )
48
48
 
@@ -0,0 +1,88 @@
1
+ import hashlib
2
+ from contextlib import contextmanager
3
+ from typing import Optional, Union
4
+
5
+ from fastapi import Request
6
+
7
+ import supervisely.io.env as sly_env
8
+ from supervisely.api.module_api import ApiField
9
+ from supervisely.app.fastapi.websocket import WebsocketManager
10
+ from supervisely.sly_logger import logger
11
+
12
+
13
+ def _parse_int(value):
14
+ try:
15
+ return int(value)
16
+ except (TypeError, ValueError):
17
+ return None
18
+
19
+
20
+ def _user_identity_from_cookie(request: Request) -> Optional[str]:
21
+ cookie_header = request.headers.get("cookie")
22
+ if not cookie_header:
23
+ return None
24
+ return hashlib.sha256(cookie_header.encode("utf-8")).hexdigest()
25
+
26
+
27
+ async def extract_user_id_from_request(request: Request) -> Optional[Union[int, str]]:
28
+ """Extract user ID from various parts of the request."""
29
+ if not sly_env.is_multiuser_mode_enabled():
30
+ return None
31
+ user_id = _parse_int(request.query_params.get("userId"))
32
+ if user_id is None:
33
+ header_user = _parse_int(request.headers.get("x-user-id"))
34
+ if header_user is not None:
35
+ user_id = header_user
36
+ if user_id is None:
37
+ referer = request.headers.get("referer", "")
38
+ if referer:
39
+ from urllib.parse import parse_qs, urlparse
40
+
41
+ try:
42
+ parsed_url = urlparse(referer)
43
+ query_params = parse_qs(parsed_url.query)
44
+ referer_user = query_params.get("userId", [None])[0]
45
+ user_id = _parse_int(referer_user)
46
+ except Exception as e:
47
+ logger.error(f"Error parsing userId from referer: {e}")
48
+ if user_id is None and "application/json" in request.headers.get("Content-Type", ""):
49
+ try:
50
+ payload = await request.json()
51
+ except Exception:
52
+ payload = {}
53
+ context = payload.get("context") or {}
54
+ user_id = _parse_int(context.get("userId") or context.get(ApiField.USER_ID))
55
+ if user_id is None:
56
+ state_payload = payload.get("state") or {}
57
+ user_id = _parse_int(state_payload.get("userId") or state_payload.get(ApiField.USER_ID))
58
+ if user_id is None:
59
+ user_id = _user_identity_from_cookie(request)
60
+ return user_id
61
+
62
+
63
+ @contextmanager
64
+ def session_context(user_id: Optional[Union[int, str]]):
65
+ """
66
+ Context manager to set and reset user context for multiuser applications.
67
+ Call this at the beginning of a request handling to ensure the correct user context is set in environment variables (`supervisely_multiuser_app_user_id` ContextVar).
68
+ """
69
+ if not sly_env.is_multiuser_mode_enabled() or user_id is None:
70
+ yield
71
+ return
72
+ token = sly_env.set_user_for_multiuser_app(user_id)
73
+ try:
74
+ yield
75
+ finally:
76
+ sly_env.reset_user_for_multiuser_app(token)
77
+
78
+
79
+ def remember_cookie(request: Request, user_id: Optional[Union[int, str]]):
80
+ """
81
+ Remember user cookie for the given user ID. This is used to associate WebSocket connections with users in multiuser applications based on cookies.
82
+ Allows WebSocket connections to be correctly routed to the appropriate user.
83
+ """
84
+ if not sly_env.is_multiuser_mode_enabled() or user_id is None:
85
+ return
86
+ cookie_header = request.headers.get("cookie")
87
+ if cookie_header:
88
+ WebsocketManager().remember_user_cookie(cookie_header, user_id)