supervisely 6.73.419__py3-none-any.whl → 6.73.421__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. supervisely/api/api.py +10 -5
  2. supervisely/api/app_api.py +71 -4
  3. supervisely/api/module_api.py +4 -0
  4. supervisely/api/nn/deploy_api.py +15 -9
  5. supervisely/api/nn/ecosystem_models_api.py +201 -0
  6. supervisely/api/nn/neural_network_api.py +12 -3
  7. supervisely/api/project_api.py +35 -6
  8. supervisely/api/task_api.py +5 -1
  9. supervisely/app/widgets/__init__.py +8 -1
  10. supervisely/app/widgets/agent_selector/template.html +1 -0
  11. supervisely/app/widgets/deploy_model/__init__.py +0 -0
  12. supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
  13. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  14. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  15. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  16. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  17. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
  18. supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
  19. supervisely/app/widgets/fast_table/fast_table.py +402 -74
  20. supervisely/app/widgets/fast_table/script.js +364 -96
  21. supervisely/app/widgets/fast_table/style.css +24 -0
  22. supervisely/app/widgets/fast_table/template.html +43 -3
  23. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  24. supervisely/app/widgets/select/select.py +6 -4
  25. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
  26. supervisely/app/widgets/tabs/tabs.py +22 -6
  27. supervisely/app/widgets/tabs/template.html +5 -1
  28. supervisely/nn/artifacts/__init__.py +1 -1
  29. supervisely/nn/artifacts/artifacts.py +10 -2
  30. supervisely/nn/artifacts/detectron2.py +1 -0
  31. supervisely/nn/artifacts/hrda.py +1 -0
  32. supervisely/nn/artifacts/mmclassification.py +20 -0
  33. supervisely/nn/artifacts/mmdetection.py +5 -3
  34. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  35. supervisely/nn/artifacts/ritm.py +1 -0
  36. supervisely/nn/artifacts/rtdetr.py +1 -0
  37. supervisely/nn/artifacts/unet.py +1 -0
  38. supervisely/nn/artifacts/utils.py +3 -0
  39. supervisely/nn/artifacts/yolov5.py +2 -0
  40. supervisely/nn/artifacts/yolov8.py +1 -0
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  42. supervisely/nn/experiments.py +9 -0
  43. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  44. supervisely/nn/inference/inference.py +160 -94
  45. supervisely/nn/inference/predict_app/__init__.py +0 -0
  46. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  47. supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
  48. supervisely/nn/inference/predict_app/gui/gui.py +710 -0
  49. supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
  50. supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
  51. supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
  52. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  53. supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
  54. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  55. supervisely/nn/inference/predict_app/gui/utils.py +282 -0
  56. supervisely/nn/inference/predict_app/predict_app.py +184 -0
  57. supervisely/nn/inference/uploader.py +9 -5
  58. supervisely/nn/model/prediction.py +2 -0
  59. supervisely/nn/model/prediction_session.py +20 -3
  60. supervisely/nn/training/gui/gui.py +131 -44
  61. supervisely/nn/training/gui/model_selector.py +8 -6
  62. supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
  63. supervisely/nn/training/gui/training_artifacts.py +0 -5
  64. supervisely/nn/training/train_app.py +161 -44
  65. supervisely/project/project.py +211 -73
  66. supervisely/template/experiment/experiment.html.jinja +74 -17
  67. supervisely/template/experiment/experiment_generator.py +258 -112
  68. supervisely/template/experiment/header.html.jinja +31 -13
  69. supervisely/template/experiment/sly-style.css +7 -2
  70. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/METADATA +3 -1
  71. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/RECORD +75 -57
  72. supervisely/app/widgets/experiment_selector/style.css +0 -27
  73. supervisely/app/widgets/experiment_selector/template.html +0 -61
  74. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/LICENSE +0 -0
  75. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/WHEEL +0 -0
  76. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/entry_points.txt +0 -0
  77. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/top_level.txt +0 -0
@@ -141,6 +141,10 @@ class SelectDatasetTree(Widget):
141
141
  self._select_dataset = None
142
142
  self._width = width
143
143
 
144
+ # Flags
145
+ self._team_is_selectable = team_is_selectable
146
+ self._workspace_is_selectable = workspace_is_selectable
147
+
144
148
  # List of widgets will be used to create a Container.
145
149
  self._widgets = []
146
150
 
@@ -165,11 +169,25 @@ class SelectDatasetTree(Widget):
165
169
  for widget in self._widgets:
166
170
  widget.disable()
167
171
 
172
+ if hasattr(self, "_select_team"):
173
+ if not self._team_is_selectable:
174
+ self._select_team.disable()
175
+ if hasattr(self, "_select_workspace"):
176
+ if not self._workspace_is_selectable:
177
+ self._select_workspace.disable()
178
+
168
179
  def enable(self) -> None:
169
180
  """Enable the widget in the UI."""
170
181
  for widget in self._widgets:
171
182
  widget.enable()
172
183
 
184
+ if hasattr(self, "_select_team"):
185
+ if not self._team_is_selectable:
186
+ self._select_team.disable()
187
+ if hasattr(self, "_select_workspace"):
188
+ if not self._workspace_is_selectable:
189
+ self._select_workspace.disable()
190
+
173
191
  @property
174
192
  def team_id(self) -> int:
175
193
  """The ID of the team selected in the widget.
@@ -1,9 +1,10 @@
1
- from typing import List, Optional, Dict
2
- from supervisely.app import StateJson
3
- from supervisely.app.widgets import Widget
4
1
  import traceback
5
- from supervisely import logger
2
+ from typing import Dict, List, Optional
6
3
 
4
+ from supervisely import logger
5
+ from supervisely.app import StateJson
6
+ from supervisely.app.content import DataJson
7
+ from supervisely.app.widgets import Widget
7
8
 
8
9
  try:
9
10
  from typing import Literal
@@ -34,7 +35,7 @@ class Tabs(Widget):
34
35
  raise ValueError("You can specify up to 10 tabs.")
35
36
  if len(set(labels)) != len(labels):
36
37
  raise ValueError("All of tab labels should be unique.")
37
- self._items = []
38
+ self._items: List[Tabs.TabPane] = []
38
39
  for label, widget in zip(labels, contents):
39
40
  self._items.append(Tabs.TabPane(label=label, content=widget))
40
41
  self._value = labels[0]
@@ -43,7 +44,10 @@ class Tabs(Widget):
43
44
  super().__init__(widget_id=widget_id, file_path=__file__)
44
45
 
45
46
  def get_json_data(self) -> Dict:
46
- return {"type": self._type}
47
+ return {
48
+ "type": self._type,
49
+ "tabsOptions": {item.name: {"disabled": False} for item in self._items},
50
+ }
47
51
 
48
52
  def get_json_state(self) -> Dict:
49
53
  return {"value": self._value}
@@ -56,6 +60,18 @@ class Tabs(Widget):
56
60
  def get_active_tab(self) -> str:
57
61
  return StateJson()[self.widget_id]["value"]
58
62
 
63
+ def disable_tab(self, tab_name: str):
64
+ if tab_name not in [item.name for item in self._items]:
65
+ raise ValueError(f"Tab with name '{tab_name}' does not exist.")
66
+ DataJson()[self.widget_id]["tabsOptions"][tab_name]["disabled"] = True
67
+ DataJson().send_changes()
68
+
69
+ def enable_tab(self, tab_name: str):
70
+ if tab_name not in [item.name for item in self._items]:
71
+ raise ValueError(f"Tab with name '{tab_name}' does not exist.")
72
+ DataJson()[self.widget_id]["tabsOptions"][tab_name]["disabled"] = False
73
+ DataJson().send_changes()
74
+
59
75
  def click(self, func):
60
76
  route_path = self.get_route_path(Tabs.Routes.CLICK)
61
77
  server = self._sly_app.get_server()
@@ -11,7 +11,11 @@
11
11
  %}
12
12
  >
13
13
  {% for tab_pane in widget._items %}
14
- <el-tab-pane label="{{{tab_pane.label}}}" name="{{{tab_pane.name}}}">
14
+ <el-tab-pane
15
+ label="{{{tab_pane.label}}}"
16
+ name="{{{tab_pane.name}}}"
17
+ :disabled="data.{{{widget.widget_id}}}.tabsOptions['{{{tab_pane.name}}}'].disabled"
18
+ >
15
19
  {{{ tab_pane.content }}}
16
20
  </el-tab-pane>
17
21
  {% endfor %}
@@ -1,6 +1,6 @@
1
1
  from supervisely.nn.artifacts.detectron2 import Detectron2
2
2
  from supervisely.nn.artifacts.hrda import HRDA
3
- from supervisely.nn.artifacts.mmclassification import MMClassification
3
+ from supervisely.nn.artifacts.mmclassification import MMClassification, MMPretrain
4
4
  from supervisely.nn.artifacts.mmdetection import MMDetection, MMDetection3
5
5
  from supervisely.nn.artifacts.mmsegmentation import MMSegmentation
6
6
  from supervisely.nn.artifacts.ritm import RITM
@@ -68,6 +68,7 @@ class BaseTrainArtifacts:
68
68
  self._pattern: str = None
69
69
  self._available_task_types: List[str] = []
70
70
  self._require_runtime = False
71
+ self._has_benchmark_evaluation = False
71
72
 
72
73
  @property
73
74
  def team_id(self) -> int:
@@ -209,6 +210,13 @@ class BaseTrainArtifacts:
209
210
  """
210
211
  return self._require_runtime
211
212
 
213
+ @property
214
+ def has_benchmark_evaluation(self):
215
+ """
216
+ Whether the framework has integrated benchmark evaluation.
217
+ """
218
+ return self._has_benchmark_evaluation
219
+
212
220
  def is_valid_artifacts_path(self, path):
213
221
  """
214
222
  Check if the provided path is valid and follows specified session path pattern.
@@ -610,9 +618,9 @@ class BaseTrainArtifacts:
610
618
  date_time = parsed_datetime.strftime("%Y-%m-%d %H:%M:%S")
611
619
 
612
620
  experiment_info_data = {
613
- "experiment_name": f"Unknown {self.framework_name} experiment",
621
+ "experiment_name": f"{self.framework_name} experiment",
614
622
  "framework_name": self.framework_name,
615
- "model_name": f"Unknown {self.framework_name} model",
623
+ "model_name": f"{self.framework_name} model",
616
624
  "task_type": train_info.task_type,
617
625
  "project_id": project_id,
618
626
  "task_id": train_info.task_id,
@@ -25,6 +25,7 @@ class Detectron2(BaseTrainArtifacts):
25
25
  self._pattern = re_compile(r"^/detectron2/\d+_[^/]+/?$")
26
26
  self._available_task_types: List[str] = ["instance segmentation"]
27
27
  self._require_runtime = False
28
+ self._has_benchmark_evaluation = False
28
29
 
29
30
  def get_task_id(self, artifacts_folder: str) -> str:
30
31
  parts = artifacts_folder.split("/")
@@ -20,6 +20,7 @@ class HRDA(BaseTrainArtifacts):
20
20
  # self._config_file = "config.py"
21
21
  # self._available_task_types: List[str] = ["semantic segmentation"]
22
22
  # self._require_runtime = False
23
+ # self._has_benchmark_evaluation = False
23
24
 
24
25
  def get_task_id(self, artifacts_folder: str) -> str:
25
26
  raise NotImplementedError
@@ -21,6 +21,7 @@ class MMClassification(BaseTrainArtifacts):
21
21
  self._pattern = re_compile(r"^/mmclassification/\d+_[^/]+/?$")
22
22
  self._available_task_types: List[str] = ["classification"]
23
23
  self._require_runtime = False
24
+ self._has_benchmark_evaluation = False
24
25
 
25
26
  def get_task_id(self, artifacts_folder: str) -> str:
26
27
  parts = artifacts_folder.split("/")
@@ -44,3 +45,22 @@ class MMClassification(BaseTrainArtifacts):
44
45
 
45
46
  def get_config_path(self, artifacts_folder: str) -> str:
46
47
  return None
48
+
49
+
50
+ class MMPretrain(MMClassification):
51
+ def __init__(self, team_id: int):
52
+ super().__init__(team_id)
53
+
54
+ self._app_name = "Train MMPretrain"
55
+ self._slug = "supervisely-ecosystem/mmpretrain/supervisely/train"
56
+ self._serve_app_name = "Serve MMPretrain"
57
+ self._serve_slug = "supervisely-ecosystem/mmpretrain/supervisely/serve"
58
+ self._framework_name = "MMPretrain"
59
+ self._framework_folder = "/mmclassification-v2"
60
+ self._weights_folder = "checkpoints"
61
+ self._task_type = "classification"
62
+ self._weights_ext = ".pth"
63
+ self._pattern = re_compile(r"^/mmclassification-v2/\d+_[^/]+/?$")
64
+ self._available_task_types: List[str] = ["classification"]
65
+ self._require_runtime = False
66
+ self._has_benchmark_evaluation = False
@@ -26,6 +26,7 @@ class MMDetection(BaseTrainArtifacts):
26
26
  self._pattern = re_compile(r"^/mmdetection/\d+_[^/]+/?$")
27
27
  self._available_task_types: List[str] = ["object detection", "instance segmentation"]
28
28
  self._require_runtime = False
29
+ self._has_benchmark_evaluation = False
29
30
 
30
31
  def get_task_id(self, artifacts_folder: str) -> str:
31
32
  parts = artifacts_folder.split("/")
@@ -63,8 +64,8 @@ class MMDetection3(BaseTrainArtifacts):
63
64
  super().__init__(team_id)
64
65
 
65
66
  self._app_name = "Train MMDetection 3.0"
66
- self._slug = "Serve MMDetection 3.0"
67
- self._serve_app_name = "supervisely-ecosystem/train-mmdetection-v3"
67
+ self._slug = "supervisely-ecosystem/train-mmdetection-v3"
68
+ self._serve_app_name = "Serve MMDetection 3.0"
68
69
  self._serve_slug = "supervisely-ecosystem/serve-mmdetection-v3"
69
70
  self._framework_name = "MMDetection 3.0"
70
71
  self._framework_folder = "/mmdetection-3"
@@ -75,7 +76,8 @@ class MMDetection3(BaseTrainArtifacts):
75
76
  self._pattern = re_compile(r"^/mmdetection-3/\d+_[^/]+/?$")
76
77
  self._available_task_types: List[str] = ["object detection", "instance segmentation"]
77
78
  self._require_runtime = False
78
-
79
+ self._has_benchmark_evaluation = True
80
+
79
81
  def get_task_id(self, artifacts_folder: str) -> str:
80
82
  parts = artifacts_folder.split("/")
81
83
  if len(parts) < 3:
@@ -22,6 +22,7 @@ class MMSegmentation(BaseTrainArtifacts):
22
22
  self._pattern = re_compile(r"^/mmsegmentation/\d+_[^/]+/?$")
23
23
  self._available_task_types: List[str] = ["instance segmentation"]
24
24
  self._require_runtime = False
25
+ self._has_benchmark_evaluation = True
25
26
 
26
27
  def get_task_id(self, artifacts_folder: str) -> str:
27
28
  return artifacts_folder.split("/")[2].split("_")[0]
@@ -22,6 +22,7 @@ class RITM(BaseTrainArtifacts):
22
22
  self._pattern = re_compile(r"^/RITM_training/\d+_[^/]+/?$")
23
23
  self._available_task_types: List[str] = ["interactive segmentation"]
24
24
  self._require_runtime = False
25
+ self._has_benchmark_evaluation = False
25
26
 
26
27
  def get_task_id(self, artifacts_folder: str) -> str:
27
28
  parts = artifacts_folder.split("/")
@@ -22,6 +22,7 @@ class RTDETR(BaseTrainArtifacts):
22
22
  self._pattern = re_compile(r"^/RT-DETR/[^/]+/\d+/?$")
23
23
  self._available_task_types: List[str] = ["object detection"]
24
24
  self._require_runtime = False
25
+ self._has_benchmark_evaluation = True
25
26
 
26
27
  def get_task_id(self, artifacts_folder: str) -> str:
27
28
  return artifacts_folder.split("/")[-1]
@@ -22,6 +22,7 @@ class UNet(BaseTrainArtifacts):
22
22
  self._pattern = re_compile(r"^/unet/\d+_[^/]+/?$")
23
23
  self._available_task_types: List[str] = ["semantic segmentation"]
24
24
  self._require_runtime = False
25
+ self._has_benchmark_evaluation = True
25
26
 
26
27
  def get_task_id(self, artifacts_folder: str) -> str:
27
28
  parts = artifacts_folder.split("/")
@@ -4,6 +4,7 @@ from supervisely.nn.artifacts import (
4
4
  YOLOv5v2,
5
5
  YOLOv8,
6
6
  MMClassification,
7
+ MMPretrain,
7
8
  MMSegmentation,
8
9
  MMDetection,
9
10
  MMDetection3,
@@ -19,6 +20,7 @@ class FrameworkName:
19
20
  YOLOV5V2 = "YOLOv5 2.0"
20
21
  YOLOV8 = "YOLOv8+"
21
22
  MMCLASSIFICATION = "MMClassification"
23
+ MMPRETRAIN = "MMPretrain"
22
24
  MMSEGMENTATION = "MMSegmentation"
23
25
  MMDETECTION = "MMDetection"
24
26
  MMDETECTION3 = "MMDetection 3.0"
@@ -34,6 +36,7 @@ class FrameworkMapper:
34
36
  FrameworkName.YOLOV5V2: YOLOv5v2,
35
37
  FrameworkName.YOLOV8: YOLOv8,
36
38
  FrameworkName.MMCLASSIFICATION: MMClassification,
39
+ FrameworkName.MMPRETRAIN: MMPretrain,
37
40
  FrameworkName.MMSEGMENTATION: MMSegmentation,
38
41
  FrameworkName.MMDETECTION: MMDetection,
39
42
  FrameworkName.MMDETECTION3: MMDetection3,
@@ -22,6 +22,7 @@ class YOLOv5(BaseTrainArtifacts):
22
22
  self._pattern = re_compile(r"^/yolov5_train/[^/]+/\d+/?$")
23
23
  self._available_task_types: List[str] = ["object detection"]
24
24
  self._require_runtime = False
25
+ self._has_benchmark_evaluation = False
25
26
 
26
27
  def get_task_id(self, artifacts_folder: str) -> str:
27
28
  return artifacts_folder.split("/")[-1]
@@ -55,3 +56,4 @@ class YOLOv5v2(YOLOv5):
55
56
  self._config_file = None
56
57
  self._pattern = re_compile(r"^/yolov5_2.0_train/[^/]+/\d+/?$")
57
58
  self._available_task_types: List[str] = ["object detection"]
59
+ self._has_benchmark_evaluation = False
@@ -28,6 +28,7 @@ class YOLOv8(BaseTrainArtifacts):
28
28
  "pose estimation",
29
29
  ]
30
30
  self._require_runtime = True
31
+ self._has_benchmark_evaluation = True
31
32
 
32
33
  def get_task_id(self, artifacts_folder: str) -> str:
33
34
  parts = artifacts_folder.split("/")
@@ -70,24 +70,24 @@ class MetricProvider:
70
70
 
71
71
  def json_metrics(self):
72
72
  return {
73
- "mIoU": self.eval_data.loc["mean"]["IoU"] * 100,
74
- "mE_boundary_oU": self.eval_data.loc["mean"]["E_boundary_oU"] * 100,
75
- "mFP_boundary_oU": self.eval_data.loc["mean"]["FP_boundary_oU"] * 100,
76
- "mFN_boundary_oU": self.eval_data.loc["mean"]["FN_boundary_oU"] * 100,
77
- "mE_boundary_oU_renormed": self.eval_data.loc["mean"]["E_boundary_oU_renormed"] * 100,
78
- "mE_extent_oU": self.eval_data.loc["mean"]["E_extent_oU"] * 100,
79
- "mFP_extent_oU": self.eval_data.loc["mean"]["FP_extent_oU"] * 100,
80
- "mFN_extent_oU": self.eval_data.loc["mean"]["FN_extent_oU"] * 100,
81
- "mE_extent_oU_renormed": self.eval_data.loc["mean"]["E_extent_oU_renormed"] * 100,
82
- "mE_segment_oU": self.eval_data.loc["mean"]["E_segment_oU"] * 100,
83
- "mFP_segment_oU": self.eval_data.loc["mean"]["FP_segment_oU"] * 100,
84
- "mFN_segment_oU": self.eval_data.loc["mean"]["FN_segment_oU"] * 100,
85
- "mE_segment_oU_renormed": self.eval_data.loc["mean"]["E_segment_oU_renormed"] * 100,
86
- "mPrecision": self.eval_data.loc["mean"]["precision"] * 100,
87
- "mRecall": self.eval_data.loc["mean"]["recall"] * 100,
88
- "mF1_score": self.eval_data.loc["mean"]["F1_score"] * 100,
89
- "PixelAcc": self.pixel_accuracy * 100,
90
- "mBoundaryIoU": self.eval_data.loc["mean"]["boundary_IoU"] * 100,
73
+ "mIoU": self.eval_data.loc["mean"]["IoU"],
74
+ "mE_boundary_oU": self.eval_data.loc["mean"]["E_boundary_oU"],
75
+ "mFP_boundary_oU": self.eval_data.loc["mean"]["FP_boundary_oU"],
76
+ "mFN_boundary_oU": self.eval_data.loc["mean"]["FN_boundary_oU"],
77
+ "mE_boundary_oU_renormed": self.eval_data.loc["mean"]["E_boundary_oU_renormed"],
78
+ "mE_extent_oU": self.eval_data.loc["mean"]["E_extent_oU"],
79
+ "mFP_extent_oU": self.eval_data.loc["mean"]["FP_extent_oU"],
80
+ "mFN_extent_oU": self.eval_data.loc["mean"]["FN_extent_oU"],
81
+ "mE_extent_oU_renormed": self.eval_data.loc["mean"]["E_extent_oU_renormed"],
82
+ "mE_segment_oU": self.eval_data.loc["mean"]["E_segment_oU"],
83
+ "mFP_segment_oU": self.eval_data.loc["mean"]["FP_segment_oU"],
84
+ "mFN_segment_oU": self.eval_data.loc["mean"]["FN_segment_oU"],
85
+ "mE_segment_oU_renormed": self.eval_data.loc["mean"]["E_segment_oU_renormed"],
86
+ "mPrecision": self.eval_data.loc["mean"]["precision"],
87
+ "mRecall": self.eval_data.loc["mean"]["recall"],
88
+ "mF1_score": self.eval_data.loc["mean"]["F1_score"],
89
+ "PixelAcc": self.pixel_accuracy,
90
+ "mBoundaryIoU": self.eval_data.loc["mean"]["boundary_IoU"],
91
91
  }
92
92
 
93
93
  def metric_table(self):
@@ -54,6 +54,8 @@ class ExperimentInfo:
54
54
  """Number of images in the validation set"""
55
55
  datetime: Optional[str] = None
56
56
  """Date and time when the experiment was started"""
57
+ experiment_report_id: Optional[int] = None
58
+ """ID of the experiment report"""
57
59
  evaluation_report_id: Optional[int] = None
58
60
  """ID of the model benchmark evaluation report"""
59
61
  evaluation_report_link: Optional[str] = None
@@ -62,6 +64,12 @@ class ExperimentInfo:
62
64
  """Evaluation metrics"""
63
65
  logs: Optional[dict] = None
64
66
  """Dictionary with link and type of logger"""
67
+ train_collection_id: Optional[int] = None
68
+ """ID of the collection with train images"""
69
+ val_collection_id: Optional[int] = None
70
+ """ID of the collection with validation images"""
71
+ project_version: Optional[int] = None
72
+ """Version of the project"""
65
73
 
66
74
  def __init__(self, **kwargs):
67
75
  required_fieds = {
@@ -82,6 +90,7 @@ class ExperimentInfo:
82
90
  for field in fields(self.__class__):
83
91
  value = getattr(self, field.name)
84
92
  data[field.name] = value
93
+ return data
85
94
 
86
95
 
87
96
  def get_experiment_infos(api: Api, team_id: int, framework_name: str) -> List[ExperimentInfo]:
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from os.path import join
2
3
  from typing import Any, Dict, List, Optional, Union
3
4
 
@@ -7,14 +8,25 @@ import supervisely.io.env as sly_env
7
8
  import supervisely.io.fs as sly_fs
8
9
  import supervisely.io.json as sly_json
9
10
  from supervisely import Api
10
- from supervisely.app.widgets import Card, Container, Field, RadioTabs, SelectString, Text, Widget
11
- from supervisely.app.widgets.experiment_selector.experiment_selector import ExperimentSelector
11
+ from supervisely.app.widgets import (
12
+ Card,
13
+ Container,
14
+ Field,
15
+ RadioTabs,
16
+ SelectString,
17
+ Text,
18
+ Widget,
19
+ )
20
+ from supervisely.app.widgets.experiment_selector.experiment_selector import (
21
+ ExperimentSelector,
22
+ )
12
23
  from supervisely.app.widgets.pretrained_models_selector.pretrained_models_selector import (
13
24
  PretrainedModelsSelector,
14
25
  )
15
26
  from supervisely.nn.experiments import get_experiment_infos
16
27
  from supervisely.nn.inference.gui.serving_gui import ServingGUI
17
28
  from supervisely.nn.utils import ModelSource, RuntimeType, _get_model_name
29
+ from supervisely.nn.experiments import ExperimentInfo
18
30
 
19
31
 
20
32
  class ServingGUITemplate(ServingGUI):
@@ -67,7 +79,7 @@ class ServingGUITemplate(ServingGUI):
67
79
  # Custom models
68
80
  if use_custom_models:
69
81
  experiments = get_experiment_infos(self.api, self.team_id, self.framework_name)
70
- self.experiment_selector = ExperimentSelector(self.team_id, experiments)
82
+ self.experiment_selector = ExperimentSelector(self.api, self.team_id, experiments)
71
83
  else:
72
84
  self.experiment_selector = None
73
85
 
@@ -132,11 +144,13 @@ class ServingGUITemplate(ServingGUI):
132
144
 
133
145
  if self.runtime_select is not None:
134
146
  self.runtime_select.value_changed(lambda _: self._update_export_message())
147
+
135
148
  if self.experiment_selector is not None:
136
- self.experiment_selector.value_changed(lambda _: self._update_export_message())
137
- for task_type in self.experiment_selector.rows:
138
- for row in self.experiment_selector.rows[task_type]:
139
- row.checkpoints_selector.value_changed(lambda _: self._update_export_message())
149
+ self.experiment_selector.selection_changed(lambda _: self._update_export_message())
150
+ self.experiment_selector.checkpoint_changed(
151
+ lambda row, _: self._update_export_message()
152
+ )
153
+
140
154
  if self.pretrained_models_table is not None:
141
155
  self.pretrained_models_table.model_changed(lambda _: self._update_export_message())
142
156
 
@@ -156,7 +170,12 @@ class ServingGUITemplate(ServingGUI):
156
170
 
157
171
  @property
158
172
  def model_info(self) -> Dict[str, Any]:
159
- return self._get_selected_row()
173
+ model_info = self._get_selected_row()
174
+ if isinstance(model_info, ExperimentInfo):
175
+ # model info requires json format
176
+ # to match types of pretrained and custom model info
177
+ model_info = model_info.to_json()
178
+ return model_info
160
179
 
161
180
  @property
162
181
  def model_name(self) -> Optional[str]:
@@ -171,7 +190,14 @@ class ServingGUITemplate(ServingGUI):
171
190
  model_meta = self.model_info.get("meta", {})
172
191
  return model_meta.get("model_files", {})
173
192
  else:
174
- return self.experiment_selector.get_model_files()
193
+ experiment_info = self.experiment_selector.get_selected_experiment_info()
194
+ artifacts_dir = experiment_info.artifacts_dir
195
+ model_files = experiment_info.model_files
196
+ full_model_files = {
197
+ name: os.path.join(artifacts_dir, file) for name, file in model_files.items()
198
+ }
199
+ full_model_files["checkpoint"] = self.experiment_selector.get_selected_checkpoint_path()
200
+ return full_model_files
175
201
 
176
202
  @property
177
203
  def runtime(self) -> str:
@@ -235,10 +261,10 @@ class ServingGUITemplate(ServingGUI):
235
261
 
236
262
  checkpoint_name = None
237
263
  if self.model_source == ModelSource.CUSTOM and self.experiment_selector is not None:
238
- selected_row = self.experiment_selector.get_selected_row()
239
- if selected_row is None:
264
+ selected_experiment_info = self.experiment_selector.get_selected_experiment_info()
265
+ if selected_experiment_info is None:
240
266
  return
241
- checkpoint_name = selected_row.get_selected_checkpoint_name()
267
+ checkpoint_name = self.experiment_selector.get_selected_checkpoint_name()
242
268
  if checkpoint_name is None:
243
269
  return
244
270
 
@@ -250,7 +276,7 @@ class ServingGUITemplate(ServingGUI):
250
276
  if key.lower().startswith(runtime.lower()):
251
277
  available = True
252
278
  break
253
- if checkpoint_name != selected_row.best_checkpoint:
279
+ if checkpoint_name != selected_experiment_info.best_checkpoint:
254
280
  available = False
255
281
 
256
282
  if available: