supervisely 6.73.444__py3-none-any.whl → 6.73.468__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (68) hide show
  1. supervisely/__init__.py +24 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/api/dataset_api.py +74 -12
  5. supervisely/api/entity_annotation/figure_api.py +8 -5
  6. supervisely/api/image_api.py +4 -0
  7. supervisely/api/video/video_annotation_api.py +4 -2
  8. supervisely/api/video/video_api.py +41 -1
  9. supervisely/app/__init__.py +1 -1
  10. supervisely/app/content.py +14 -6
  11. supervisely/app/fastapi/__init__.py +1 -0
  12. supervisely/app/fastapi/custom_static_files.py +1 -1
  13. supervisely/app/fastapi/multi_user.py +88 -0
  14. supervisely/app/fastapi/subapp.py +88 -42
  15. supervisely/app/fastapi/websocket.py +77 -9
  16. supervisely/app/singleton.py +21 -0
  17. supervisely/app/v1/app_service.py +18 -2
  18. supervisely/app/v1/constants.py +7 -1
  19. supervisely/app/widgets/card/card.py +20 -0
  20. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  21. supervisely/app/widgets/dialog/dialog.py +12 -0
  22. supervisely/app/widgets/dialog/template.html +2 -1
  23. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  24. supervisely/app/widgets/fast_table/fast_table.py +121 -31
  25. supervisely/app/widgets/fast_table/template.html +1 -1
  26. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  27. supervisely/app/widgets/radio_tabs/template.html +1 -0
  28. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
  29. supervisely/app/widgets/table/table.py +68 -13
  30. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  31. supervisely/convert/image/csv/csv_converter.py +24 -15
  32. supervisely/convert/video/video_converter.py +2 -2
  33. supervisely/geometry/polyline_3d.py +110 -0
  34. supervisely/io/env.py +76 -1
  35. supervisely/nn/inference/cache.py +37 -17
  36. supervisely/nn/inference/inference.py +667 -114
  37. supervisely/nn/inference/inference_request.py +15 -8
  38. supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
  39. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  40. supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
  41. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  42. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  43. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  44. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  45. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  46. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  47. supervisely/nn/inference/session.py +43 -35
  48. supervisely/nn/model/model_api.py +9 -0
  49. supervisely/nn/model/prediction_session.py +8 -7
  50. supervisely/nn/prediction_dto.py +7 -0
  51. supervisely/nn/tracker/base_tracker.py +11 -1
  52. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  53. supervisely/nn/tracker/botsort_tracker.py +14 -7
  54. supervisely/nn/tracker/visualize.py +70 -72
  55. supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
  56. supervisely/nn/training/train_app.py +10 -5
  57. supervisely/project/project.py +9 -1
  58. supervisely/video/sampling.py +39 -20
  59. supervisely/video/video.py +41 -12
  60. supervisely/volume/stl_converter.py +2 -0
  61. supervisely/worker_api/agent_rpc.py +24 -1
  62. supervisely/worker_api/rpc_servicer.py +31 -7
  63. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/METADATA +14 -11
  64. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/RECORD +68 -66
  65. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/LICENSE +0 -0
  66. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/WHEEL +0 -0
  67. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/entry_points.txt +0 -0
  68. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,25 @@
1
+ import hashlib
1
2
  import inspect
2
3
  import json
3
4
  import os
4
5
  import signal
5
6
  import sys
6
7
  import time
7
- from contextlib import suppress
8
+ from contextlib import contextmanager, suppress
8
9
  from contextvars import ContextVar
9
10
  from functools import wraps
10
11
  from pathlib import Path
11
12
  from threading import Event as ThreadingEvent
12
13
  from threading import Thread
13
14
  from time import sleep
14
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
15
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
15
16
 
16
17
  import arel
17
18
  import jinja2
18
19
  import numpy as np
19
20
  import psutil
20
21
  from async_asgi_testclient import TestClient
22
+ from cachetools import TTLCache
21
23
  from fastapi import (
22
24
  Depends,
23
25
  FastAPI,
@@ -32,6 +34,7 @@ from fastapi.responses import JSONResponse
32
34
  from fastapi.routing import APIRouter
33
35
  from fastapi.staticfiles import StaticFiles
34
36
 
37
+ import supervisely.app.fastapi.multi_user as multi_user
35
38
  import supervisely.io.env as sly_env
36
39
  from supervisely._utils import (
37
40
  is_debug_with_sly_net,
@@ -68,6 +71,10 @@ HEALTH_ENDPOINTS = ["/health", "/is_ready"]
68
71
  # Context variable for response time
69
72
  response_time_ctx: ContextVar[float] = ContextVar("response_time", default=None)
70
73
 
74
+ # Mapping from user_id to Api instance
75
+ _USER_API_CACHE = TTLCache(maxsize=500, ttl=60 * 15) # Cache up to 15 minutes
76
+
77
+
71
78
  class ReadyzFilter(logging.Filter):
72
79
  def filter(self, record):
73
80
  if "/readyz" in record.getMessage() or "/livez" in record.getMessage():
@@ -623,18 +630,30 @@ def create(
623
630
  shutdown(process_id, before_shutdown_callbacks)
624
631
 
625
632
  if headless is False:
626
-
627
633
  @app.post("/data")
628
634
  async def send_data(request: Request):
629
- data = DataJson()
630
- response = JSONResponse(content=dict(data))
635
+ if not sly_env.is_multiuser_mode_enabled():
636
+ data = DataJson()
637
+ response = JSONResponse(content=dict(data))
638
+ return response
639
+ user_id = await multi_user.extract_user_id_from_request(request)
640
+ multi_user.remember_cookie(request, user_id)
641
+ with multi_user.session_context(user_id):
642
+ data = DataJson()
643
+ response = JSONResponse(content=dict(data))
631
644
  return response
632
645
 
633
646
  @app.post("/state")
634
647
  async def send_state(request: Request):
635
- state = StateJson()
636
-
637
- response = JSONResponse(content=dict(state))
648
+ if not sly_env.is_multiuser_mode_enabled():
649
+ state = StateJson()
650
+ response = JSONResponse(content=dict(state))
651
+ else:
652
+ user_id = await multi_user.extract_user_id_from_request(request)
653
+ multi_user.remember_cookie(request, user_id)
654
+ with multi_user.session_context(user_id):
655
+ state = StateJson()
656
+ response = JSONResponse(content=dict(state))
638
657
  gettrace = getattr(sys, "gettrace", None)
639
658
  if (gettrace is not None and gettrace()) or is_development():
640
659
  response.headers["x-debug-mode"] = "1"
@@ -813,41 +832,59 @@ def _init(
813
832
  async def get_state_from_request(request: Request, call_next):
814
833
  # Start timer for response time measurement
815
834
  start_time = time.perf_counter()
816
- if headless is False:
817
- await StateJson.from_request(request)
818
-
819
- if not ("application/json" not in request.headers.get("Content-Type", "")):
820
- # {'command': 'inference_batch_ids', 'context': {}, 'state': {'dataset_id': 49711, 'batch_ids': [3120204], 'settings': None}, 'user_api_key': 'XXX', 'api_token': 'XXX', 'instance_type': None, 'server_address': 'https://app.supervisely.com'}
821
- content = await request.json()
822
-
823
- request.state.context = content.get("context")
824
- request.state.state = content.get("state")
825
- request.state.api_token = content.get(
826
- "api_token",
827
- (
828
- request.state.context.get("apiToken")
829
- if request.state.context is not None
830
- else None
831
- ),
832
- )
833
- # logger.debug(f"middleware request api_token {request.state.api_token}")
834
- request.state.server_address = content.get(
835
- "server_address", sly_env.server_address(raise_not_found=False)
836
- )
837
- # request.state.server_address = sly_env.server_address(raise_not_found=False)
838
- # logger.debug(f"middleware request server_address {request.state.server_address}")
839
- # logger.debug(f"middleware request context {request.state.context}")
840
- # logger.debug(f"middleware request state {request.state.state}")
841
- if request.state.server_address is not None and request.state.api_token is not None:
842
- request.state.api = Api(request.state.server_address, request.state.api_token)
843
- else:
844
- request.state.api = None
845
835
 
846
- try:
847
- response = await call_next(request)
848
- except Exception as exc:
849
- need_to_handle_error = is_production()
850
- response = await process_server_error(request, exc, need_to_handle_error)
836
+ async def _process_request(request: Request, call_next):
837
+ if "application/json" in request.headers.get("Content-Type", ""):
838
+ content = await request.json()
839
+ request.state.context = content.get("context")
840
+ request.state.state = content.get("state")
841
+ request.state.api_token = content.get(
842
+ "api_token",
843
+ (
844
+ request.state.context.get("apiToken")
845
+ if request.state.context is not None
846
+ else None
847
+ ),
848
+ )
849
+ request.state.server_address = content.get(
850
+ "server_address", sly_env.server_address(raise_not_found=False)
851
+ )
852
+ if (
853
+ request.state.server_address is not None
854
+ and request.state.api_token is not None
855
+ ):
856
+ request.state.api = Api(
857
+ request.state.server_address, request.state.api_token
858
+ )
859
+ if sly_env.is_multiuser_mode_enabled():
860
+ user_id = sly_env.user_from_multiuser_app()
861
+ if user_id is not None:
862
+ _USER_API_CACHE[user_id] = request.state.api
863
+ else:
864
+ request.state.api = None
865
+
866
+ try:
867
+ response = await call_next(request)
868
+ except Exception as exc:
869
+ need_to_handle_error = is_production()
870
+ response = await process_server_error(
871
+ request, exc, need_to_handle_error
872
+ )
873
+
874
+ return response
875
+
876
+ if not sly_env.is_multiuser_mode_enabled():
877
+ if headless is False:
878
+ await StateJson.from_request(request)
879
+ response = await _process_request(request, call_next)
880
+ else:
881
+ user_id = await multi_user.extract_user_id_from_request(request)
882
+ multi_user.remember_cookie(request, user_id)
883
+
884
+ with multi_user.session_context(user_id):
885
+ if headless is False:
886
+ await StateJson.from_request(request, local=False)
887
+ response = await _process_request(request, call_next)
851
888
  # Calculate response time and set it for uvicorn logger in ms
852
889
  elapsed_ms = round((time.perf_counter() - start_time) * 1000)
853
890
  response_time_ctx.set(elapsed_ms)
@@ -1277,3 +1314,12 @@ def call_on_autostart(
1277
1314
 
1278
1315
  def get_name_from_env(default="Supervisely App"):
1279
1316
  return os.environ.get("APP_NAME", default)
1317
+
1318
+ def session_user_api() -> Optional[Api]:
1319
+ """Returns the API instance for the current session user."""
1320
+ if not sly_env.is_multiuser_mode_enabled():
1321
+ return Api.from_env()
1322
+ user_id = sly_env.user_from_multiuser_app()
1323
+ if user_id is None:
1324
+ return None
1325
+ return _USER_API_CACHE.get(user_id, None)
@@ -1,5 +1,10 @@
1
- from typing import List
1
+ import hashlib
2
+ import time
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
2
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
+
7
+ import supervisely.io.env as sly_env
3
8
  from supervisely.app.singleton import Singleton
4
9
 
5
10
 
@@ -8,6 +13,9 @@ class WebsocketManager(metaclass=Singleton):
8
13
  self.app = None
9
14
  self.path = path
10
15
  self.active_connections: List[WebSocket] = []
16
+ self._connection_users: Dict[WebSocket, Optional[Union[int, str]]] = {}
17
+ self._cookie_user_map: Dict[str, Tuple[Union[int, str], float]] = {}
18
+ self._cookie_ttl_seconds = 60 * 60
11
19
 
12
20
  def set_app(self, app: FastAPI):
13
21
  if self.app is not None:
@@ -17,17 +25,42 @@ class WebsocketManager(metaclass=Singleton):
17
25
 
18
26
  async def connect(self, websocket: WebSocket):
19
27
  await websocket.accept()
28
+ user_id = self._resolve_user_id(websocket)
20
29
  self.active_connections.append(websocket)
30
+ self._connection_users[websocket] = user_id
21
31
 
22
32
  def disconnect(self, websocket: WebSocket):
23
- self.active_connections.remove(websocket)
24
-
25
- async def broadcast(self, d: dict):
26
- # if self.app is None:
27
- # raise ValueError(
28
- # "WebSocket is not initialized, use Websocket middleware for that"
29
- # )
30
- for connection in self.active_connections:
33
+ if websocket in self.active_connections:
34
+ self.active_connections.remove(websocket)
35
+ self._connection_users.pop(websocket, None)
36
+
37
+ def remember_user_cookie(
38
+ self, cookie_header: Optional[str], user_id: Optional[Union[int, str]]
39
+ ):
40
+ if cookie_header is None or user_id is None:
41
+ return
42
+ fingerprint = self._cookie_fingerprint(cookie_header)
43
+ if fingerprint is None:
44
+ return
45
+ self._purge_cookie_cache()
46
+ self._cookie_user_map[fingerprint] = (user_id, time.monotonic())
47
+
48
+ async def broadcast(self, d: dict, user_id: Optional[Union[int, str]] = None):
49
+ if sly_env.is_multiuser_mode_enabled():
50
+ if user_id is None:
51
+ user_id = sly_env.user_from_multiuser_app()
52
+ if user_id is None:
53
+ targets = list(self.active_connections)
54
+ else:
55
+ targets = [
56
+ connection
57
+ for connection in self.active_connections
58
+ if self._connection_users.get(connection) == user_id
59
+ ]
60
+ else:
61
+ targets = list(self.active_connections)
62
+
63
+ for connection in list(targets):
31
64
  await connection.send_json(d)
32
65
 
33
66
  async def endpoint(self, websocket: WebSocket):
@@ -37,3 +70,38 @@ class WebsocketManager(metaclass=Singleton):
37
70
  data = await websocket.receive_text()
38
71
  except WebSocketDisconnect:
39
72
  self.disconnect(websocket)
73
+
74
+ def _resolve_user_id(self, websocket: WebSocket) -> Optional[int]:
75
+ if not sly_env.is_multiuser_mode_enabled():
76
+ return None
77
+ query_user = websocket.query_params.get("userId")
78
+ if query_user is not None:
79
+ try:
80
+ return int(query_user)
81
+ except ValueError:
82
+ pass
83
+ fingerprint = self._cookie_fingerprint(websocket.headers.get("cookie"))
84
+ if fingerprint is None:
85
+ return None
86
+ cached = self._cookie_user_map.get(fingerprint)
87
+ if cached is None:
88
+ return None
89
+ user_id, ts = cached
90
+ if time.monotonic() - ts > self._cookie_ttl_seconds:
91
+ self._cookie_user_map.pop(fingerprint, None)
92
+ return None
93
+ return user_id
94
+
95
+ @staticmethod
96
+ def _cookie_fingerprint(cookie_header: Optional[str]) -> Optional[str]:
97
+ if not cookie_header:
98
+ return None
99
+ return hashlib.sha256(cookie_header.encode("utf-8")).hexdigest()
100
+
101
+ def _purge_cookie_cache(self) -> None:
102
+ if not self._cookie_user_map:
103
+ return
104
+ cutoff = time.monotonic() - self._cookie_ttl_seconds
105
+ expired = [key for key, (_, ts) in self._cookie_user_map.items() if ts < cutoff]
106
+ for key in expired:
107
+ self._cookie_user_map.pop(key, None)
@@ -1,11 +1,32 @@
1
+ import supervisely.io.env as sly_env
2
+
3
+
1
4
  class Singleton(type):
2
5
  _instances = {}
6
+ _nested_instances = {}
3
7
 
4
8
  def __call__(cls, *args, **kwargs):
5
9
  local = kwargs.pop("__local__", False)
6
10
  if local is False:
7
11
  if cls not in cls._instances:
8
12
  cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
13
+
14
+ if sly_env.is_multiuser_mode_enabled():
15
+ from supervisely.app.content import DataJson, StateJson
16
+ from copy import deepcopy
17
+
18
+ # Initialize nested instances dict once
19
+ nested_instances = cls._nested_instances.setdefault(cls, {})
20
+
21
+ user_id = sly_env.user_from_multiuser_app()
22
+ if user_id is not None and cls in (StateJson, DataJson):
23
+ if user_id not in nested_instances:
24
+ # Create new instance and copy data
25
+ instance = super(Singleton, cls).__call__(*args, **kwargs)
26
+ instance.update(deepcopy(dict(cls._instances[cls])))
27
+ nested_instances[user_id] = instance
28
+
29
+ return nested_instances[user_id]
9
30
  return cls._instances[cls]
10
31
  else:
11
32
  return super(Singleton, cls).__call__(*args, **kwargs)
@@ -1,3 +1,5 @@
1
+ # isort: skip_file
2
+
1
3
  import json
2
4
  import os
3
5
  import time
@@ -12,7 +14,8 @@ import queue
12
14
  import re
13
15
 
14
16
  from supervisely.worker_api.agent_api import AgentAPI
15
- from supervisely.worker_proto import worker_api_pb2 as api_proto
17
+
18
+ # from supervisely.worker_proto import worker_api_pb2 as api_proto # Import moved to methods where needed
16
19
  from supervisely.function_wrapper import function_wrapper
17
20
  from supervisely._utils import take_with_default
18
21
  from supervisely.sly_logger import logger as default_logger
@@ -30,7 +33,6 @@ from supervisely._utils import _remove_sensitive_information
30
33
  from supervisely.worker_api.agent_rpc import send_from_memory_generator
31
34
  from supervisely.io.fs_cache import FileCache
32
35
 
33
-
34
36
  # https://www.roguelynn.com/words/asyncio-we-did-it-wrong/
35
37
 
36
38
 
@@ -390,6 +392,13 @@ class AppService:
390
392
  )
391
393
 
392
394
  def publish_sync(self, initial_events=None):
395
+ try:
396
+ from supervisely.worker_proto import worker_api_pb2 as api_proto
397
+ except Exception as e:
398
+ from supervisely.app.v1.constants import PROTOBUF_REQUIRED_ERROR
399
+
400
+ raise ImportError(PROTOBUF_REQUIRED_ERROR) from e
401
+
393
402
  if initial_events is not None:
394
403
  for event_obj in initial_events:
395
404
  event_obj["api_token"] = os.environ[API_TOKEN]
@@ -507,6 +516,13 @@ class AppService:
507
516
  self._error = error
508
517
 
509
518
  def send_response(self, request_id, data):
519
+ try:
520
+ from supervisely.worker_proto import worker_api_pb2 as api_proto
521
+ except Exception as e:
522
+ from supervisely.app.v1.constants import PROTOBUF_REQUIRED_ERROR
523
+
524
+ raise ImportError(PROTOBUF_REQUIRED_ERROR) from e
525
+
510
526
  out_bytes = json.dumps(data).encode("utf-8")
511
527
  self.api.put_stream_with_data(
512
528
  "SendGeneralEventData",
@@ -7,4 +7,10 @@ SHARED_DATA = '/sessions'
7
7
 
8
8
  STOP_COMMAND = "stop"
9
9
 
10
- IMAGE_ANNOTATION_EVENTS = ["manual_selected_figure_changed"]
10
+ IMAGE_ANNOTATION_EVENTS = ["manual_selected_figure_changed"]
11
+
12
+ # Error message for missing or incompatible protobuf dependencies
13
+ PROTOBUF_REQUIRED_ERROR = (
14
+ "protobuf is required for agent/worker/app_v1 functionality. "
15
+ "Please install supervisely with agent extras: pip install 'supervisely[agent]'"
16
+ )
@@ -157,3 +157,23 @@ class Card(Widget):
157
157
  :rtype: bool
158
158
  """
159
159
  return self._disabled["disabled"]
160
+
161
+ @property
162
+ def description(self) -> Optional[str]:
163
+ """Description of the card.
164
+
165
+ :return: Description of the card.
166
+ :rtype: Optional[str]
167
+ """
168
+ return self._description
169
+
170
+ @description.setter
171
+ def description(self, value: str) -> None:
172
+ """Sets the description of the card.
173
+
174
+ :param value: Description of the card.
175
+ :type value: str
176
+ """
177
+ self._description = value
178
+ StateJson()[self.widget_id]["description"] = self._description
179
+ StateJson().send_changes()
@@ -1,10 +1,10 @@
1
1
  import datetime
2
2
  import tempfile
3
+ import threading
3
4
  from pathlib import Path
4
5
  from typing import Any, Dict, List, Literal
5
6
 
6
7
  import pandas as pd
7
- import yaml
8
8
 
9
9
  from supervisely._utils import logger
10
10
  from supervisely.api.api import Api
@@ -12,8 +12,6 @@ from supervisely.api.app_api import ModuleInfo
12
12
  from supervisely.app.widgets.agent_selector.agent_selector import AgentSelector
13
13
  from supervisely.app.widgets.button.button import Button
14
14
  from supervisely.app.widgets.container.container import Container
15
- from supervisely.app.widgets.card.card import Card
16
- from supervisely.app.widgets.model_info.model_info import ModelInfo
17
15
  from supervisely.app.widgets.ecosystem_model_selector.ecosystem_model_selector import (
18
16
  EcosystemModelSelector,
19
17
  )
@@ -23,7 +21,8 @@ from supervisely.app.widgets.experiment_selector.experiment_selector import (
23
21
  from supervisely.app.widgets.fast_table.fast_table import FastTable
24
22
  from supervisely.app.widgets.field.field import Field
25
23
  from supervisely.app.widgets.flexbox.flexbox import Flexbox
26
- from supervisely.app.widgets.tabs.tabs import Tabs
24
+ from supervisely.app.widgets.model_info.model_info import ModelInfo
25
+ from supervisely.app.widgets.radio_tabs.radio_tabs import RadioTabs
27
26
  from supervisely.app.widgets.text.text import Text
28
27
  from supervisely.app.widgets.widget import Widget
29
28
  from supervisely.io import env
@@ -211,24 +210,31 @@ class DeployModel(Widget):
211
210
  return self._layout
212
211
 
213
212
  def _create_layout(self) -> Container:
214
- frameworks = self.deploy_model.get_frameworks()
215
- experiment_infos = []
216
- for framework_name in frameworks:
217
- experiment_infos.extend(
218
- get_experiment_infos(self.api, self.team_id, framework_name=framework_name)
219
- )
220
213
  self.experiment_table = ExperimentSelector(
221
- experiment_infos=experiment_infos,
222
- team_id=self.team_id,
223
214
  api=self.api,
215
+ team_id=self.team_id,
224
216
  )
225
217
 
226
218
  @self.experiment_table.checkpoint_changed
227
219
  def _checkpoint_changed(row: ExperimentSelector.ModelRow, checkpoint_value: str):
228
220
  print(f"Checkpoint changed for {row._experiment_info.task_id}: {checkpoint_value}")
229
221
 
222
+ threading.Thread(target=self.refresh_experiments, daemon=True).start()
223
+
230
224
  return self.experiment_table
231
225
 
226
+ def refresh_experiments(self):
227
+ self.experiment_table.loading = True
228
+ frameworks = self.deploy_model.get_frameworks()
229
+ experiment_infos = []
230
+ for framework_name in frameworks:
231
+ experiment_infos.extend(
232
+ get_experiment_infos(self.api, self.team_id, framework_name=framework_name)
233
+ )
234
+
235
+ self.experiment_table.set_experiment_infos(experiment_infos)
236
+ self.experiment_table.loading = False
237
+
232
238
  def get_deploy_parameters(self) -> Dict[str, Any]:
233
239
  experiment_info = self.experiment_table.get_selected_experiment_info()
234
240
  return {
@@ -267,8 +273,8 @@ class DeployModel(Widget):
267
273
  MODES = [str(MODE.CONNECT), str(MODE.PRETRAINED), str(MODE.CUSTOM)]
268
274
  MODE_TO_CLASS = {
269
275
  str(MODE.CONNECT): Connect,
270
- str(MODE.PRETRAINED): Pretrained,
271
276
  str(MODE.CUSTOM): Custom,
277
+ str(MODE.PRETRAINED): Pretrained,
272
278
  }
273
279
 
274
280
  def __init__(
@@ -295,6 +301,11 @@ class DeployModel(Widget):
295
301
  self.MODE.PRETRAINED: "Pretrained",
296
302
  self.MODE.CUSTOM: "Custom",
297
303
  }
304
+ self.modes_descriptions = {
305
+ self.MODE.CONNECT: "Connect to an already deployed model",
306
+ self.MODE.PRETRAINED: "Deploy a pretrained model from the ecosystem",
307
+ self.MODE.CUSTOM: "Deploy a custom model from your experiments",
308
+ }
298
309
 
299
310
  # GUI
300
311
  self.layout: Widget = None
@@ -444,31 +455,41 @@ class DeployModel(Widget):
444
455
 
445
456
  self._init_modes(modes)
446
457
  _labels = []
458
+ _descriptions = []
447
459
  _contents = []
460
+ self.statuses_widgets = Container(
461
+ widgets=[
462
+ self.sesson_link,
463
+ self._model_info_container,
464
+ ],
465
+ gap=20,
466
+ )
467
+ self.statuses_widgets.hide()
448
468
  for mode_name, mode in self.modes.items():
449
469
  label = self.modes_labels[mode_name]
470
+ description = self.modes_descriptions[mode_name]
450
471
  if mode_name == str(self.MODE.CONNECT):
451
472
  widgets = [
452
473
  mode.layout,
453
- self._model_info_card,
454
- self.connect_stop_buttons,
455
474
  self.status,
456
- self.sesson_link,
475
+ self.statuses_widgets,
476
+ self.connect_stop_buttons,
457
477
  ]
458
478
  else:
459
479
  widgets = [
460
480
  mode.layout,
461
- self._model_info_card,
462
481
  self.select_agent_field,
463
- self.deploy_stop_buttons,
464
482
  self.status,
465
- self.sesson_link,
483
+ self.statuses_widgets,
484
+ self.deploy_stop_buttons,
466
485
  ]
486
+
467
487
  content = Container(widgets=widgets, gap=20)
468
488
  _labels.append(label)
489
+ _descriptions.append(description)
469
490
  _contents.append(content)
470
491
 
471
- self.tabs = Tabs(labels=_labels, contents=_contents)
492
+ self.tabs = RadioTabs(titles=_labels, descriptions=_descriptions, contents=_contents)
472
493
  if len(self.modes) == 1:
473
494
  self.layout = _contents[0]
474
495
  else:
@@ -490,7 +511,7 @@ class DeployModel(Widget):
490
511
  def _disconnect_button_clicked():
491
512
  self.disconnect()
492
513
 
493
- @self.tabs.click
514
+ @self.tabs.value_changed
494
515
  def _active_tab_changed(tab_name: str):
495
516
  self.set_model_message_by_tab(tab_name)
496
517
 
@@ -573,6 +594,7 @@ class DeployModel(Widget):
573
594
  f"Model {framework}: {model_name} deployed with session ID {model_api.task_id}."
574
595
  )
575
596
  self.model_api = model_api
597
+ self.statuses_widgets.show()
576
598
  self.set_model_status("connected")
577
599
  self.set_session_info(task_info)
578
600
  self.set_model_info(model_api.task_id)
@@ -603,12 +625,14 @@ class DeployModel(Widget):
603
625
  self.set_session_info(task_info)
604
626
  self.set_model_info(model_api.task_id)
605
627
  self.show_stop()
628
+ self.statuses_widgets.show()
606
629
  except Exception as e:
607
630
  logger.error(f"Failed to deploy model: {e}", exc_info=True)
608
631
  self.set_model_status("error", str(e))
609
632
  self.set_session_info(None)
610
633
  self.reset_model_info()
611
634
  self.show_deploy_button()
635
+ self.statuses_widgets.hide()
612
636
  self.enable_modes()
613
637
  else:
614
638
  if str(self.MODE.CONNECT) in self.modes:
@@ -634,6 +658,7 @@ class DeployModel(Widget):
634
658
  self.enable_modes()
635
659
  self.reset_model_info()
636
660
  self.show_deploy_button()
661
+ self.statuses_widgets.hide()
637
662
  if str(self.MODE.CONNECT) in self.modes:
638
663
  self.modes[str(self.MODE.CONNECT)]._update_sessions()
639
664
 
@@ -645,6 +670,7 @@ class DeployModel(Widget):
645
670
  self.set_session_info(None)
646
671
  self.reset_model_info()
647
672
  self.show_deploy_button()
673
+ self.statuses_widgets.hide()
648
674
  self.enable_modes()
649
675
 
650
676
  def load_from_json(self, data: Dict[str, Any]) -> None:
@@ -690,29 +716,24 @@ class DeployModel(Widget):
690
716
  title="Model Info",
691
717
  description="Information about the deployed model",
692
718
  )
693
-
694
- self._model_info_container = Container([self._model_info_widget_field])
695
- self._model_info_container.hide()
696
719
  self._model_info_message = Text("Connect to model to see the session information.")
697
-
698
- self._model_info_card = Card(
699
- title="Session Info",
700
- description="Model parameters and classes",
701
- collapsable=True,
702
- content=Container([self._model_info_container, self._model_info_message]),
720
+ self._model_info_container = Container(
721
+ [self._model_info_widget_field, self._model_info_message], gap=0
703
722
  )
704
- self._model_info_card.collapse()
723
+ self._model_info_widget_field.hide()
724
+
725
+ self._model_info_container.hide()
705
726
 
706
727
  def set_model_info(self, session_id):
707
- self._model_info_widget.set_model_info(session_id)
728
+ self._model_info_widget.set_session_id(session_id)
708
729
 
709
730
  self._model_info_message.hide()
731
+ self._model_info_widget_field.show()
710
732
  self._model_info_container.show()
711
- self._model_info_card.uncollapse()
712
733
 
713
734
  def reset_model_info(self):
714
- self._model_info_card.collapse()
715
735
  self._model_info_container.hide()
736
+ self._model_info_widget_field.hide()
716
737
  self._model_info_message.show()
717
738
 
718
739
  def set_model_message_by_tab(self, tab_name: str):
@@ -724,6 +745,6 @@ class DeployModel(Widget):
724
745
  self._model_info_message.set(
725
746
  "Deploy model to see the session information.", status="text"
726
747
  )
727
- self._model_info_card.collapse()
748
+ self._model_info_widget_field.hide()
728
749
 
729
750
  # ------------------------------------------------------------ #
@@ -28,6 +28,8 @@ class Dialog(Widget):
28
28
  dialog = Dialog(title="Dialog title", content=Input("Input"), size="large")
29
29
  dialog.show()
30
30
  """
31
+ class Routes:
32
+ ON_CLOSE = "close_cb"
31
33
 
32
34
  def __init__(
33
35
  self,
@@ -41,6 +43,16 @@ class Dialog(Widget):
41
43
  self._size = size
42
44
  super().__init__(widget_id=widget_id, file_path=__file__)
43
45
 
46
+ server = self._sly_app.get_server()
47
+ route = self.get_route_path(Dialog.Routes.ON_CLOSE)
48
+ @server.post(route)
49
+ def _on_close():
50
+ # * Change visibility state to False when dialog is closed on client side
51
+ visible = StateJson()[self.widget_id]["visible"]
52
+ if visible is True:
53
+ StateJson()[self.widget_id]["visible"] = False
54
+ # * no need to call send_changes(), as it is already changed on client side
55
+
44
56
  def get_json_data(self) -> Dict[str, str]:
45
57
  """Returns dictionary with widget data, which defines the appearance and behavior of the widget.
46
58