supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (190) hide show
  1. supervisely/__init__.py +136 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/annotation/label.py +80 -3
  5. supervisely/api/annotation_api.py +9 -9
  6. supervisely/api/api.py +67 -43
  7. supervisely/api/app_api.py +72 -5
  8. supervisely/api/dataset_api.py +108 -33
  9. supervisely/api/entity_annotation/figure_api.py +113 -49
  10. supervisely/api/image_api.py +82 -0
  11. supervisely/api/module_api.py +10 -0
  12. supervisely/api/nn/deploy_api.py +15 -9
  13. supervisely/api/nn/ecosystem_models_api.py +201 -0
  14. supervisely/api/nn/neural_network_api.py +12 -3
  15. supervisely/api/pointcloud/pointcloud_api.py +38 -0
  16. supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
  17. supervisely/api/project_api.py +213 -6
  18. supervisely/api/task_api.py +11 -1
  19. supervisely/api/video/video_annotation_api.py +4 -2
  20. supervisely/api/video/video_api.py +79 -1
  21. supervisely/api/video/video_figure_api.py +24 -11
  22. supervisely/api/volume/volume_api.py +38 -0
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +14 -6
  25. supervisely/app/fastapi/__init__.py +1 -0
  26. supervisely/app/fastapi/custom_static_files.py +1 -1
  27. supervisely/app/fastapi/multi_user.py +88 -0
  28. supervisely/app/fastapi/subapp.py +175 -42
  29. supervisely/app/fastapi/templating.py +1 -1
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +11 -1
  35. supervisely/app/widgets/agent_selector/template.html +1 -0
  36. supervisely/app/widgets/card/card.py +20 -0
  37. supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
  38. supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
  39. supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
  40. supervisely/app/widgets/dialog/dialog.py +12 -0
  41. supervisely/app/widgets/dialog/template.html +2 -1
  42. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  43. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  44. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  45. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  46. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
  47. supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
  48. supervisely/app/widgets/fast_table/fast_table.py +713 -126
  49. supervisely/app/widgets/fast_table/script.js +492 -95
  50. supervisely/app/widgets/fast_table/style.css +54 -0
  51. supervisely/app/widgets/fast_table/template.html +45 -5
  52. supervisely/app/widgets/heatmap/__init__.py +0 -0
  53. supervisely/app/widgets/heatmap/heatmap.py +523 -0
  54. supervisely/app/widgets/heatmap/script.js +378 -0
  55. supervisely/app/widgets/heatmap/style.css +227 -0
  56. supervisely/app/widgets/heatmap/template.html +21 -0
  57. supervisely/app/widgets/input_tag/input_tag.py +102 -15
  58. supervisely/app/widgets/input_tag_list/__init__.py +0 -0
  59. supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
  60. supervisely/app/widgets/input_tag_list/template.html +70 -0
  61. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  62. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  63. supervisely/app/widgets/radio_tabs/template.html +1 -0
  64. supervisely/app/widgets/select/select.py +6 -4
  65. supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
  66. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
  67. supervisely/app/widgets/table/table.py +68 -13
  68. supervisely/app/widgets/tabs/tabs.py +22 -6
  69. supervisely/app/widgets/tabs/template.html +5 -1
  70. supervisely/app/widgets/transfer/style.css +3 -0
  71. supervisely/app/widgets/transfer/template.html +3 -1
  72. supervisely/app/widgets/transfer/transfer.py +48 -45
  73. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  74. supervisely/convert/image/csv/csv_converter.py +24 -15
  75. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
  76. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
  77. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
  78. supervisely/convert/video/video_converter.py +2 -2
  79. supervisely/geometry/polyline_3d.py +110 -0
  80. supervisely/io/env.py +161 -1
  81. supervisely/nn/artifacts/__init__.py +1 -1
  82. supervisely/nn/artifacts/artifacts.py +10 -2
  83. supervisely/nn/artifacts/detectron2.py +1 -0
  84. supervisely/nn/artifacts/hrda.py +1 -0
  85. supervisely/nn/artifacts/mmclassification.py +20 -0
  86. supervisely/nn/artifacts/mmdetection.py +5 -3
  87. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  88. supervisely/nn/artifacts/ritm.py +1 -0
  89. supervisely/nn/artifacts/rtdetr.py +1 -0
  90. supervisely/nn/artifacts/unet.py +1 -0
  91. supervisely/nn/artifacts/utils.py +3 -0
  92. supervisely/nn/artifacts/yolov5.py +2 -0
  93. supervisely/nn/artifacts/yolov8.py +1 -0
  94. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  95. supervisely/nn/experiments.py +9 -0
  96. supervisely/nn/inference/cache.py +37 -17
  97. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  98. supervisely/nn/inference/inference.py +953 -211
  99. supervisely/nn/inference/inference_request.py +15 -8
  100. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  101. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  102. supervisely/nn/inference/predict_app/__init__.py +0 -0
  103. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  104. supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
  105. supervisely/nn/inference/predict_app/gui/gui.py +915 -0
  106. supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
  107. supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
  108. supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
  109. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  110. supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
  111. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  112. supervisely/nn/inference/predict_app/gui/utils.py +399 -0
  113. supervisely/nn/inference/predict_app/predict_app.py +176 -0
  114. supervisely/nn/inference/session.py +47 -39
  115. supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
  116. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  117. supervisely/nn/inference/tracking/tracker_interface.py +4 -0
  118. supervisely/nn/inference/uploader.py +9 -5
  119. supervisely/nn/model/model_api.py +44 -22
  120. supervisely/nn/model/prediction.py +15 -1
  121. supervisely/nn/model/prediction_session.py +70 -14
  122. supervisely/nn/prediction_dto.py +7 -0
  123. supervisely/nn/tracker/__init__.py +6 -8
  124. supervisely/nn/tracker/base_tracker.py +54 -0
  125. supervisely/nn/tracker/botsort/__init__.py +1 -0
  126. supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
  127. supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
  128. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  129. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  130. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  131. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  132. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  133. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  134. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  135. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  136. supervisely/nn/tracker/botsort_tracker.py +273 -0
  137. supervisely/nn/tracker/calculate_metrics.py +264 -0
  138. supervisely/nn/tracker/utils.py +273 -0
  139. supervisely/nn/tracker/visualize.py +520 -0
  140. supervisely/nn/training/gui/gui.py +152 -49
  141. supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
  142. supervisely/nn/training/gui/model_selector.py +8 -6
  143. supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
  144. supervisely/nn/training/gui/training_artifacts.py +3 -1
  145. supervisely/nn/training/train_app.py +225 -46
  146. supervisely/project/pointcloud_episode_project.py +12 -8
  147. supervisely/project/pointcloud_project.py +12 -8
  148. supervisely/project/project.py +221 -75
  149. supervisely/template/experiment/experiment.html.jinja +105 -55
  150. supervisely/template/experiment/experiment_generator.py +258 -112
  151. supervisely/template/experiment/header.html.jinja +31 -13
  152. supervisely/template/experiment/sly-style.css +7 -2
  153. supervisely/versions.json +3 -1
  154. supervisely/video/sampling.py +42 -20
  155. supervisely/video/video.py +41 -12
  156. supervisely/video_annotation/video_figure.py +38 -4
  157. supervisely/volume/stl_converter.py +2 -0
  158. supervisely/worker_api/agent_rpc.py +24 -1
  159. supervisely/worker_api/rpc_servicer.py +31 -7
  160. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
  161. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
  162. supervisely_lib/__init__.py +6 -1
  163. supervisely/app/widgets/experiment_selector/style.css +0 -27
  164. supervisely/app/widgets/experiment_selector/template.html +0 -61
  165. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  166. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  167. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  168. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  169. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  170. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  171. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  172. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  173. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  174. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  175. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  176. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  177. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  178. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  179. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  180. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  181. supervisely/nn/tracker/tracker.py +0 -285
  182. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  183. supervisely/nn/tracking/__init__.py +0 -1
  184. supervisely/nn/tracking/boxmot.py +0 -114
  185. supervisely/nn/tracking/tracking.py +0 -24
  186. /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
  187. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
  188. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
  189. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
  190. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,915 @@
1
+ import json
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ import yaml
7
+
8
+ from supervisely._utils import is_development, logger
9
+ from supervisely.api.api import Api
10
+ from supervisely.api.image_api import ImageInfo
11
+ from supervisely.api.video.video_api import VideoInfo
12
+ from supervisely.app.widgets import Button, Card, Container, Stepper, Widget
13
+ from supervisely.geometry.any_geometry import AnyGeometry
14
+ from supervisely.io import env
15
+ from supervisely.nn.inference.inference import update_meta_and_ann_for_video_annotation
16
+ from supervisely.nn.inference.predict_app.gui.classes_selector import ClassesSelector
17
+ from supervisely.nn.inference.predict_app.gui.input_selector import InputSelector
18
+ from supervisely.nn.inference.predict_app.gui.model_selector import ModelSelector
19
+ from supervisely.nn.inference.predict_app.gui.output_selector import OutputSelector
20
+ from supervisely.nn.inference.predict_app.gui.settings_selector import (
21
+ AddPredictionsMode,
22
+ SettingsSelector,
23
+ )
24
+ from supervisely.nn.inference.predict_app.gui.tags_selector import TagsSelector
25
+ from supervisely.nn.inference.predict_app.gui.utils import (
26
+ copy_items_to_project,
27
+ create_project,
28
+ disable_enable,
29
+ update_custom_button_params,
30
+ video_annotation_from_predictions,
31
+ )
32
+ from supervisely.nn.model.model_api import ModelAPI
33
+ from supervisely.nn.model.prediction import Prediction
34
+ from supervisely.project.project_meta import ProjectMeta
35
+ from supervisely.project.project_type import ProjectType
36
+ from supervisely.video_annotation.key_id_map import KeyIdMap
37
+ from supervisely.video_annotation.video_annotation import VideoAnnotation
38
+
39
+
40
+ class StepFlow:
41
+ def __init__(self):
42
+ self._stepper = None
43
+ self.steps = {}
44
+ self.steps_sequence = []
45
+
46
+ def add_step(
47
+ self,
48
+ name: str,
49
+ widget: Widget,
50
+ on_select: Optional[Callable] = None,
51
+ on_reactivate: Optional[Callable] = None,
52
+ depends_on: Optional[List[Widget]] = None,
53
+ on_lock: Optional[Callable] = None,
54
+ on_unlock: Optional[Callable] = None,
55
+ button: Optional[Button] = None,
56
+ position: Optional[int] = None,
57
+ ):
58
+ if depends_on is None:
59
+ depends_on = []
60
+ self.steps[name] = {
61
+ "widget": widget,
62
+ "on_select": on_select,
63
+ "on_reactivate": on_reactivate,
64
+ "depends_on": depends_on,
65
+ "on_lock": on_lock,
66
+ "on_unlock": on_unlock,
67
+ "button": button,
68
+ "is_selected": False,
69
+ "is_locked": False,
70
+ }
71
+ if button is not None:
72
+ self._wrap_button(button, name)
73
+ if position is not None:
74
+ self.steps_sequence.insert(position, name)
75
+ else:
76
+ self.steps_sequence.append(name)
77
+ self.update_locks()
78
+
79
+ def _create_stepper(self):
80
+ widgets = []
81
+ for step_name in self.steps_sequence:
82
+ step = self.steps[step_name]
83
+ widgets.append(step["widget"])
84
+ self._stepper = Stepper(widgets=widgets)
85
+
86
+ @property
87
+ def stepper(self):
88
+ if self._stepper is None:
89
+ self._create_stepper()
90
+ return self._stepper
91
+
92
+ def update_stepper(self):
93
+ for i, step_name in enumerate(self.steps_sequence):
94
+ step = self.steps[step_name]
95
+ if not step["is_selected"]:
96
+ self.stepper.set_active_step(i + 1)
97
+ return
98
+
99
+ def update_locks(self):
100
+ for step in self.steps.values():
101
+ should_lock = False
102
+ for dep_name in step["depends_on"]:
103
+ dep = self.steps[dep_name]
104
+ if not dep["is_selected"]:
105
+ should_lock = True
106
+ break
107
+ if should_lock and not step["is_locked"]:
108
+ if step["on_lock"] is not None:
109
+ step["on_lock"]()
110
+ step["is_locked"] = True
111
+ if not should_lock and step["is_locked"]:
112
+ if step["on_unlock"]:
113
+ step["on_unlock"]()
114
+ step["is_locked"] = False
115
+
116
+ def _reactivate_dependents(self, step_name: str, visited=None):
117
+ if visited is None:
118
+ visited = set()
119
+ for dep_name, step in self.steps.items():
120
+ if step_name in step["depends_on"] and not dep_name in visited:
121
+ self._reactivate_step(dep_name, visited)
122
+
123
+ def _reactivate_step(self, step_name: str, visited=None):
124
+ step = self.steps[step_name]
125
+ if step["on_reactivate"] is not None:
126
+ step["on_reactivate"]()
127
+ step["is_selected"] = False
128
+ if visited is None:
129
+ visited = set()
130
+ self._reactivate_dependents(step_name, visited)
131
+
132
+ def reactivate_step(self, step_name: str):
133
+ self._reactivate_step(step_name)
134
+ self.update_stepper()
135
+ self.update_locks()
136
+
137
+ def select_step(self, step_name: str):
138
+ step = self.steps[step_name]
139
+ if step["on_select"] is not None:
140
+ step["on_select"]()
141
+ step["is_selected"] = True
142
+ self.update_stepper()
143
+ self.update_locks()
144
+
145
+ def select_or_reactivate(self, step_name: str):
146
+ step = self.steps[step_name]
147
+ if step["is_selected"]:
148
+ self.reactivate_step(step_name)
149
+ else:
150
+ self.select_step(step_name)
151
+
152
+ def _wrap_button(self, button: Button, step_name: str):
153
+ button.click(lambda: self.select_or_reactivate(step_name))
154
+
155
+
156
+ class PredictAppGui:
157
+ def __init__(self, api: Api, static_dir: str = "static"):
158
+ self.api = api
159
+ self.static_dir = static_dir
160
+
161
+ # Environment variables
162
+ self.team_id = env.team_id()
163
+ self.workspace_id = env.workspace_id()
164
+ self.project_id = env.project_id(raise_not_found=False)
165
+ self.project_meta = None
166
+ if self.project_id:
167
+ self.project_meta = ProjectMeta.from_json(self.api.project.get_meta(self.project_id))
168
+ # -------------------------------- #
169
+
170
+ # Flags
171
+ self._stop_flag = False
172
+ self._is_running = False
173
+ # -------------------------------- #
174
+
175
+ # GUI
176
+ # Steps
177
+ self.step_flow = StepFlow()
178
+ select_params = {"icon": None, "plain": False, "text": "Select"}
179
+ reselect_params = {"icon": "zmdi zmdi-refresh", "plain": True, "text": "Reselect"}
180
+
181
+ # 1. Input selector
182
+ self.input_selector = InputSelector(self.workspace_id, self.api)
183
+
184
+ def _on_input_select():
185
+ valid = self.input_selector.validate_step()
186
+ if not valid:
187
+ return
188
+ current_item_type = self.input_selector.radio.get_value()
189
+ self.update_item_type()
190
+ if self.model_api:
191
+ if current_item_type == self.input_selector.radio.get_value():
192
+ inference_settings = self.model_api.get_settings()
193
+ self.settings_selector.set_inference_settings(inference_settings)
194
+
195
+ if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
196
+ try:
197
+ tracking_settings = self.model_api.get_tracking_settings()
198
+ self.settings_selector.set_tracking_settings(tracking_settings)
199
+ except Exception as e:
200
+ logger.warning(
201
+ "Unable to get tracking settings from the model. Settings defaults"
202
+ )
203
+ self.settings_selector.set_default_tracking_settings()
204
+ self.input_selector.disable()
205
+
206
+ self.project_id = self.input_selector.get_project_id()
207
+ if self.project_id:
208
+ self.project_meta = ProjectMeta.from_json(self.api.project.get_meta(self.project_id))
209
+ update_custom_button_params(self.input_selector.button, reselect_params)
210
+
211
+ def _on_input_reactivate():
212
+ self.input_selector.enable()
213
+ update_custom_button_params(self.input_selector.button, select_params)
214
+
215
+ self.step_flow.add_step(
216
+ name="input_selector",
217
+ widget=self.input_selector.card,
218
+ on_select=_on_input_select,
219
+ on_reactivate=_on_input_reactivate,
220
+ button=self.input_selector.button,
221
+ )
222
+
223
+ # 2. Model selector
224
+ self.model_selector = ModelSelector(self.api, self.team_id)
225
+
226
+ self.step_flow.add_step(
227
+ name="model_selector",
228
+ widget=self.model_selector.card,
229
+ )
230
+
231
+ # 3. Classes selector
232
+ self.classes_selector = ClassesSelector()
233
+
234
+ def _on_classes_select():
235
+ valid = self.classes_selector.validate_step()
236
+ if not valid:
237
+ return
238
+ self.classes_selector.classes_table.disable()
239
+
240
+ # Find conflict between project meta and model meta
241
+ selected_classes_names = self.classes_selector.get_selected_classes()
242
+ project_meta = self.project_meta
243
+ model_meta = self.model_api.get_model_meta()
244
+
245
+ has_conflict = False
246
+ for class_name in selected_classes_names:
247
+ project_obj_class = project_meta.get_obj_class(class_name)
248
+ if project_obj_class is None:
249
+ continue
250
+
251
+ model_obj_class = model_meta.get_obj_class(class_name)
252
+ if model_obj_class.geometry_type.name() == AnyGeometry.name():
253
+ continue
254
+
255
+ if project_obj_class.geometry_type.name() == model_obj_class.geometry_type.name():
256
+ continue
257
+
258
+ has_conflict = True
259
+ break
260
+
261
+ if has_conflict:
262
+ self.settings_selector.model_prediction_suffix_container.show()
263
+ else:
264
+ self.settings_selector.model_prediction_suffix_container.hide()
265
+ # ------------------------------------------------ #
266
+
267
+ update_custom_button_params(self.classes_selector.button, reselect_params)
268
+
269
+ def _on_classes_reactivate():
270
+ self.classes_selector.classes_table.enable()
271
+ update_custom_button_params(self.classes_selector.button, select_params)
272
+
273
+ self.step_flow.add_step(
274
+ name="classes_selector",
275
+ widget=self.classes_selector.card,
276
+ on_select=_on_classes_select,
277
+ on_reactivate=_on_classes_reactivate,
278
+ depends_on=["input_selector", "model_selector"],
279
+ on_lock=self.classes_selector.lock,
280
+ on_unlock=self.classes_selector.unlock,
281
+ button=self.classes_selector.button,
282
+ )
283
+
284
+ # 4. Tags selector
285
+ self.tags_selector = None
286
+ if False:
287
+ self.tags_selector = TagsSelector()
288
+ self.step_flow.add_step("tags_selector", self.tags_selector.card)
289
+
290
+ # 5. Settings selector & Preview
291
+ self.settings_selector = SettingsSelector(
292
+ api=self.api,
293
+ static_dir=self.static_dir,
294
+ model_selector=self.model_selector,
295
+ input_selector=self.input_selector,
296
+ )
297
+
298
+ def _on_settings_select():
299
+ valid = self.settings_selector.validate_step()
300
+ if not valid:
301
+ return
302
+ self.settings_selector.disable()
303
+ update_custom_button_params(self.settings_selector.button, reselect_params)
304
+
305
+ def _on_settings_reactivate():
306
+ self.settings_selector.enable()
307
+ update_custom_button_params(self.settings_selector.button, select_params)
308
+
309
+ self.step_flow.add_step(
310
+ name="settings_selector",
311
+ widget=self.settings_selector.cards_container,
312
+ on_select=_on_settings_select,
313
+ on_reactivate=_on_settings_reactivate,
314
+ depends_on=["input_selector", "model_selector", "classes_selector"],
315
+ on_lock=self.settings_selector.lock,
316
+ on_unlock=self.settings_selector.unlock,
317
+ button=self.settings_selector.button,
318
+ )
319
+ self.settings_selector.preview.run_button.disable()
320
+
321
+ # 6. Output selector
322
+ self.output_selector = OutputSelector(self.api)
323
+
324
+ self.step_flow.add_step(
325
+ "output_selector",
326
+ self.output_selector.card,
327
+ depends_on=[
328
+ "input_selector",
329
+ "model_selector",
330
+ "classes_selector",
331
+ # "tags_selector",
332
+ "settings_selector",
333
+ ],
334
+ on_lock=self.output_selector.lock,
335
+ on_unlock=self.output_selector.unlock,
336
+ )
337
+ # -------------------------------- #
338
+
339
+ # Layout
340
+ self.layout = Container([self.step_flow.stepper])
341
+ # ---------------------------- #
342
+
343
+ def set_entity_meta():
344
+ model_api = self.model_selector.model.model_api
345
+
346
+ model_meta = model_api.get_model_meta()
347
+ if self.classes_selector is not None:
348
+ self.classes_selector.set_project_meta(model_meta)
349
+ self.classes_selector.classes_table.show()
350
+ if self.tags_selector is not None:
351
+ self.tags_selector.tags_table.set_project_meta(model_meta)
352
+ self.tags_selector.tags_table.show()
353
+
354
+ inference_settings = model_api.get_settings()
355
+ self.settings_selector.set_inference_settings(inference_settings)
356
+
357
+ if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
358
+ try:
359
+ tracking_settings = model_api.get_tracking_settings()
360
+ self.settings_selector.set_tracking_settings(tracking_settings)
361
+ except Exception as e:
362
+ logger.warning(
363
+ "Unable to get tracking settings from the model. Settings defaults"
364
+ )
365
+ self.settings_selector.set_default_tracking_settings()
366
+
367
+ def reset_entity_meta():
368
+ empty_meta = ProjectMeta()
369
+ if self.classes_selector is not None:
370
+ self.classes_selector.set_project_meta(empty_meta)
371
+ self.classes_selector.classes_table.hide()
372
+ if self.tags_selector is not None:
373
+ self.tags_selector.tags_table.set_project_meta(empty_meta)
374
+ self.tags_selector.tags_table.hide()
375
+
376
+ self.settings_selector.set_inference_settings("")
377
+
378
+ def deploy_and_set_step():
379
+ self.model_selector.validator_text.hide()
380
+ model_api = type(self.model_selector.model).deploy(self.model_selector.model)
381
+ if model_api is not None:
382
+ set_entity_meta()
383
+ self.step_flow.select_step("model_selector")
384
+ else:
385
+ reset_entity_meta()
386
+ self.step_flow.reactivate_step("model_selector")
387
+ return model_api
388
+
389
+ def stop_and_reset_step():
390
+ type(self.model_selector.model).stop(self.model_selector.model)
391
+ self.step_flow.reactivate_step("model_selector")
392
+ reset_entity_meta()
393
+
394
+ def disconnect_and_reset_step():
395
+ type(self.model_selector.model).disconnect(self.model_selector.model)
396
+ self.step_flow.reactivate_step("model_selector")
397
+ reset_entity_meta()
398
+
399
+ # Replace deploy methods for DeployModel widget
400
+ self.model_selector.model.deploy = deploy_and_set_step
401
+ self.model_selector.model.stop = stop_and_reset_step
402
+ self.model_selector.model.disconnect = disconnect_and_reset_step
403
+
404
+ # ------------------------------------------------- #
405
+
406
+ @property
407
+ def model_api(self) -> Optional[ModelAPI]:
408
+ return self.model_selector.model.model_api
409
+
410
+ def update_item_type(self):
411
+ item_type = self.input_selector.radio.get_value()
412
+ self.settings_selector.update_item_type(item_type)
413
+ self.output_selector.update_item_type(item_type)
414
+
415
+ def _run_videos(self, run_parameters: Dict[str, Any]) -> List[Prediction]:
416
+ if self.model_api is None:
417
+ self.set_validator_text("Deploying model...", "info")
418
+ self.model_selector.model._deploy()
419
+ if self.model_api is None:
420
+ logger.error("Model Deployed with an error")
421
+ raise RuntimeError("Model Deployed with an error")
422
+
423
+ self.set_validator_text("Preparing settings for prediction...", "info")
424
+ if run_parameters is None:
425
+ run_parameters = self.get_run_parameters()
426
+
427
+ input_parameters = run_parameters["input"]
428
+ input_video_ids = input_parameters["video_ids"]
429
+ if not input_video_ids:
430
+ raise ValueError("No video IDs provided for video prediction.")
431
+
432
+ predict_kwargs = {}
433
+ # Settings
434
+ settings = run_parameters["settings"]
435
+ model_prediction_suffix = settings.pop("model_prediction_suffix", "")
436
+ prediction_mode = settings.pop("predictions_mode")
437
+ tracking = settings.pop("tracking", False)
438
+ predict_kwargs.update(settings)
439
+
440
+ # Classes
441
+ classes = run_parameters["classes"]
442
+ if classes:
443
+ predict_kwargs["classes"] = classes
444
+
445
+ output_parameters = run_parameters["output"]
446
+ project_name = output_parameters.get("project_name", "")
447
+ upload_to_source_project = output_parameters.get("upload_to_source_project", False)
448
+ skip_project_versioning = output_parameters.get("skip_project_versioning", False)
449
+ skip_annotated = output_parameters.get("skip_annotated", False)
450
+
451
+ video_infos_by_project_id: Dict[int, List[VideoInfo]] = {}
452
+ video_infos_by_dataset_id: Dict[int, List[VideoInfo]] = {}
453
+ for info in self.api.video.get_info_by_id_batch(input_video_ids):
454
+ video_infos_by_project_id.setdefault(info.project_id, []).append(info)
455
+ video_infos_by_dataset_id.setdefault(info.dataset_id, []).append(info)
456
+ src_project_metas: Dict[int, ProjectMeta] = {}
457
+ for project_id in video_infos_by_project_id.keys():
458
+ src_project_metas[project_id] = ProjectMeta.from_json(
459
+ self.api.project.get_meta(project_id)
460
+ )
461
+
462
+ video_ids_to_skip = set()
463
+ if skip_annotated:
464
+ self.set_validator_text("Checking for already annotated videos...", "info")
465
+ secondary_pbar = self.output_selector.secondary_progress(
466
+ message="Checking for already annotated videos...", total=len(input_video_ids)
467
+ )
468
+ self.output_selector.secondary_progress.show()
469
+ for dataset_id, video_infos in video_infos_by_dataset_id.items():
470
+ annotations = self.api.video.annotation.download_bulk(
471
+ dataset_id, [info.id for info in video_infos]
472
+ )
473
+ for ann_json, video_info in zip(annotations, video_infos):
474
+ if ann_json:
475
+ project_meta = src_project_metas[video_info.project_id]
476
+ ann = VideoAnnotation.from_json(ann_json, project_meta=project_meta)
477
+ if len(ann.figures) > 0:
478
+ video_ids_to_skip.add(video_info.id)
479
+ secondary_pbar.update()
480
+ self.output_selector.secondary_progress.hide()
481
+ if video_ids_to_skip:
482
+ video_infos_by_project_id = {
483
+ pid: [info for info in infos if info.id not in video_ids_to_skip]
484
+ for pid, infos in video_infos_by_project_id.items()
485
+ }
486
+
487
+ main_pbar_str = "Processing videos..."
488
+ if video_ids_to_skip:
489
+ main_pbar_str += f" (Skipped {len(video_ids_to_skip)} already annotated videos)"
490
+ total_videos = sum(len(v) for v in video_infos_by_project_id.values())
491
+ if total_videos == 0:
492
+ self.set_validator_text(
493
+ f"No videos to process. Skipped {len(video_ids_to_skip)} already annotated videos",
494
+ "warning",
495
+ )
496
+ return []
497
+ main_pbar = self.output_selector.progress(message=main_pbar_str, total=total_videos)
498
+ self.output_selector.progress.show()
499
+ all_predictictions: List[Prediction] = []
500
+ for src_project_id, src_video_infos in video_infos_by_project_id.items():
501
+ if len(src_video_infos) == 0:
502
+ continue
503
+ project_info = self.api.project.get_info_by_id(src_project_id)
504
+ project_validator_text_str = (
505
+ f"Processing project: {project_info.name} [id: {src_project_id}]"
506
+ )
507
+ if upload_to_source_project:
508
+ if not skip_project_versioning and not is_development():
509
+ logger.info("Creating new project version...")
510
+ self.set_validator_text(
511
+ project_validator_text_str + ": Creating project version",
512
+ "info",
513
+ )
514
+ version_id = self.api.project.version.create(
515
+ project_info,
516
+ "Created by Predict App. Task Id: " + str(env.task_id()),
517
+ )
518
+ logger.info("New project version created: " + str(version_id))
519
+ output_project_id = src_project_id
520
+ output_videos: List[VideoInfo] = src_video_infos
521
+ else:
522
+ self.set_validator_text(
523
+ project_validator_text_str + ": Creating project...", "info"
524
+ )
525
+ if not project_name:
526
+ project_name = project_info.name + " [Predictions]"
527
+ logger.warning(
528
+ "Project name is empty, using auto-generated name: " + project_name
529
+ )
530
+ with_annotations = prediction_mode in [
531
+ AddPredictionsMode.APPEND,
532
+ AddPredictionsMode.IOU_MERGE,
533
+ ]
534
+ created_project = create_project(
535
+ api=self.api,
536
+ project_id=src_project_id,
537
+ project_name=project_name,
538
+ workspace_id=self.workspace_id,
539
+ copy_meta=with_annotations,
540
+ project_type=ProjectType.VIDEOS,
541
+ )
542
+ output_project_id = created_project.id
543
+ output_videos: List[VideoInfo] = copy_items_to_project(
544
+ api=self.api,
545
+ src_project_id=src_project_id,
546
+ items=src_video_infos,
547
+ dst_project_id=created_project.id,
548
+ with_annotations=with_annotations,
549
+ ds_progress=self.output_selector.secondary_progress,
550
+ project_type=ProjectType.VIDEOS,
551
+ )
552
+
553
+ self.set_validator_text(
554
+ project_validator_text_str + ": Merging project meta",
555
+ "info",
556
+ )
557
+ project_meta = src_project_metas[src_project_id]
558
+ for src_video_info, output_video_info in zip(src_video_infos, output_videos):
559
+ video_validator_text_str = (
560
+ project_validator_text_str
561
+ + f", video: {src_video_info.name} [id: {src_video_info.id}]"
562
+ )
563
+ self.set_validator_text(
564
+ video_validator_text_str + ": Predicting",
565
+ "info",
566
+ )
567
+ frames_predictions: List[Prediction] = []
568
+ with self.model_api.predict_detached(
569
+ video_id=src_video_info.id,
570
+ tqdm=self.output_selector.secondary_progress(),
571
+ tracking=tracking,
572
+ **predict_kwargs,
573
+ ) as session:
574
+ self.output_selector.secondary_progress.show()
575
+ for prediction in session:
576
+ if self._stop_flag:
577
+ logger.info("Prediction stopped by user.")
578
+ raise StopIteration("Stopped by user.")
579
+ frames_predictions.append(prediction)
580
+ all_predictictions.extend(frames_predictions)
581
+ if tracking:
582
+ prediction_video_annotation: VideoAnnotation = VideoAnnotation.from_json(
583
+ session.final_result["video_ann"],
584
+ project_meta=self.model_api.get_model_meta(),
585
+ )
586
+ else:
587
+ prediction_video_annotation = video_annotation_from_predictions(
588
+ frames_predictions,
589
+ project_meta,
590
+ frame_size=(src_video_info.frame_height, src_video_info.frame_width),
591
+ )
592
+ if prediction_video_annotation is None:
593
+ logger.warning(
594
+ f"No predictions were made for video {src_video_info.name} [id: {src_video_info.id}]"
595
+ )
596
+ main_pbar.update()
597
+ continue
598
+ self.set_validator_text(
599
+ video_validator_text_str + ": Uploading predictions",
600
+ "info",
601
+ )
602
+ project_meta, prediction_video_annotation, meta_changed = (
603
+ update_meta_and_ann_for_video_annotation(
604
+ meta=project_meta,
605
+ ann=prediction_video_annotation,
606
+ model_prediction_suffix=model_prediction_suffix,
607
+ )
608
+ )
609
+ if meta_changed:
610
+ self.api.project.update_meta(output_project_id, project_meta)
611
+ if upload_to_source_project:
612
+ if prediction_mode in [
613
+ AddPredictionsMode.REPLACE,
614
+ AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS,
615
+ ]:
616
+ self.output_selector.secondary_progress.hide()
617
+ with open("/tmp/prediction_video_annotation.json", "w") as f:
618
+ json.dump(prediction_video_annotation.to_json(), f)
619
+ self.api.video.annotation.upload_paths(
620
+ video_ids=[src_video_info.id],
621
+ paths=["/tmp/prediction_video_annotation.json"],
622
+ project_meta=project_meta,
623
+ )
624
+ else:
625
+ secondary_pbar = self.output_selector.secondary_progress(
626
+ message="Uploading annotations...",
627
+ total=len(prediction_video_annotation.figures),
628
+ )
629
+ self.output_selector.secondary_progress.show()
630
+ self.api.video.annotation.append(
631
+ video_id=src_video_info.id,
632
+ ann=prediction_video_annotation,
633
+ key_id_map=KeyIdMap(),
634
+ progress_cb=secondary_pbar.update,
635
+ )
636
+ else:
637
+ secondary_pbar = self.output_selector.secondary_progress(
638
+ message="Uploading annotations...",
639
+ total=len(prediction_video_annotation.figures),
640
+ )
641
+ self.output_selector.secondary_progress.show()
642
+ self.api.video.annotation.append(
643
+ video_id=output_video_info.id,
644
+ ann=prediction_video_annotation,
645
+ key_id_map=KeyIdMap(),
646
+ progress_cb=secondary_pbar.update,
647
+ )
648
+ main_pbar.update()
649
+ self.set_validator_text("Project successfully processed", "success")
650
+ self.output_selector.set_result_thumbnail(output_project_id)
651
+ return all_predictictions
652
+
653
+ def _run_images(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
654
+ if self.model_api is None:
655
+ self.set_validator_text("Deploying model...", "info")
656
+ self.model_selector.model._deploy()
657
+ if self.model_api is None:
658
+ logger.error("Model Deployed with an error")
659
+ raise RuntimeError("Model Deployed with an error")
660
+
661
+ self.set_validator_text("Preparing settings for prediction...", "info")
662
+ if run_parameters is None:
663
+ run_parameters = self.get_run_parameters()
664
+
665
+ predict_kwargs = {}
666
+ # Input
667
+ input_args = {}
668
+ input_parameters = run_parameters["input"]
669
+ input_project_id = input_parameters.get("project_id", None)
670
+ input_dataset_ids = input_parameters.get("dataset_ids", [])
671
+ input_image_ids = input_parameters.get("image_ids", [])
672
+ if input_image_ids:
673
+ input_args["image_ids"] = input_image_ids
674
+ elif input_dataset_ids:
675
+ input_args["dataset_ids"] = input_dataset_ids
676
+ elif input_project_id:
677
+ input_args["project_id"] = input_project_id
678
+ else:
679
+ raise ValueError("No valid input parameters found for prediction.")
680
+
681
+ # Settings
682
+ settings = run_parameters["settings"]
683
+ prediction_mode = settings.pop("predictions_mode")
684
+ upload_mode = None
685
+ with_annotations = None
686
+ if prediction_mode == AddPredictionsMode.REPLACE:
687
+ upload_mode = "replace"
688
+ with_annotations = False
689
+ elif prediction_mode == AddPredictionsMode.APPEND:
690
+ upload_mode = "append"
691
+ with_annotations = True
692
+ elif prediction_mode == AddPredictionsMode.IOU_MERGE:
693
+ upload_mode = "iou_merge"
694
+ with_annotations = True
695
+ elif prediction_mode == AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS:
696
+ upload_mode = "replace"
697
+ with_annotations = True
698
+ predict_kwargs.update(settings)
699
+ predict_kwargs["upload_mode"] = upload_mode
700
+
701
+ # Classes
702
+ classes = run_parameters["classes"]
703
+ if classes:
704
+ predict_kwargs["classes"] = classes
705
+
706
+ # Output
707
+ output_parameters = run_parameters["output"]
708
+ project_name = output_parameters.get("project_name", None)
709
+ upload_to_source_project = output_parameters.get("upload_to_source_project", False)
710
+ skip_project_versioning = output_parameters.get("skip_project_versioning", False)
711
+ skip_annotated = output_parameters.get("skip_annotated", False)
712
+
713
+ image_infos = []
714
+ if input_image_ids:
715
+ image_infos = self.api.image.get_info_by_id_batch(input_image_ids)
716
+ elif input_dataset_ids:
717
+ for dataset_id in input_dataset_ids:
718
+ image_infos.extend(self.api.image.get_list(dataset_id))
719
+ elif input_project_id:
720
+ datasets = self.api.dataset.get_list(input_project_id, recursive=True)
721
+ for dataset in datasets:
722
+ image_infos.extend(self.api.image.get_list(dataset.id))
723
+ if len(image_infos) == 0:
724
+ raise ValueError("No images found for the given input parameters.")
725
+
726
+ to_skip = []
727
+ if skip_annotated:
728
+ to_skip = [image_info.id for image_info in image_infos if image_info.labels_count == 0]
729
+ if to_skip:
730
+ image_infos = [info for info in image_infos if info.id not in to_skip]
731
+ if len(image_infos) == 0:
732
+ self.set_validator_text(
733
+ f"All images are already annotated. Nothing to predict.", "warning"
734
+ )
735
+ return []
736
+
737
+ image_infos_by_project_id: Dict[int, List[ImageInfo]] = {}
738
+ image_infos_by_dataset_id: Dict[int, List[ImageInfo]] = {}
739
+ ds_project_mapping: Dict[int, int] = {}
740
+ for info in image_infos:
741
+ image_infos_by_dataset_id.setdefault(info.dataset_id, []).append(info)
742
+ if info.dataset_id not in ds_project_mapping:
743
+ ds_info = self.api.dataset.get_info_by_id(info.dataset_id)
744
+ ds_project_mapping[info.dataset_id] = ds_info.project_id
745
+ project_id = ds_project_mapping[info.dataset_id]
746
+ image_infos_by_project_id.setdefault(project_id, []).append(info)
747
+
748
+ src_project_metas: Dict[int, ProjectMeta] = {}
749
+ for project_id in image_infos_by_project_id.keys():
750
+ src_project_metas[project_id] = ProjectMeta.from_json(
751
+ self.api.project.get_meta(project_id)
752
+ )
753
+
754
+ self.output_selector.progress.show()
755
+ total_items = sum(len(v) for v in image_infos_by_project_id.values())
756
+ main_pbar = self.output_selector.progress(message=f"Copying images...", total=total_items)
757
+ for src_project_id, infos in image_infos_by_project_id.items():
758
+ if len(infos) == 0:
759
+ continue
760
+ project_info = self.api.project.get_info_by_id(src_project_id)
761
+ project_validator_text_str = (
762
+ f"Processing project: {project_info.name} [id: {src_project_id}]"
763
+ )
764
+ if upload_to_source_project:
765
+ if not skip_project_versioning and not is_development():
766
+ logger.info("Creating new project version...")
767
+ self.set_validator_text(
768
+ project_validator_text_str + ": Creating project version", "info"
769
+ )
770
+ version_id = self.api.project.version.create(
771
+ project_info,
772
+ "Created by Predict App. Task Id: " + str(env.task_id()),
773
+ )
774
+ logger.info("New project version created: " + str(version_id))
775
+ output_project_id = src_project_id
776
+ output_image_infos: List[ImageInfo] = infos
777
+ else:
778
+ self.set_validator_text(
779
+ project_validator_text_str + ": Creating project...", "info"
780
+ )
781
+ if not project_name:
782
+ project_name = project_info.name + " [Predictions]"
783
+ logger.warning(
784
+ "Project name is empty, using auto-generated name: " + project_name
785
+ )
786
+ created_project = create_project(
787
+ api=self.api,
788
+ project_id=src_project_id,
789
+ project_name=project_name,
790
+ workspace_id=self.workspace_id,
791
+ copy_meta=with_annotations,
792
+ project_type=ProjectType.IMAGES,
793
+ )
794
+ output_project_id = created_project.id
795
+ output_image_infos: List[ImageInfo] = copy_items_to_project(
796
+ api=self.api,
797
+ src_project_id=src_project_id,
798
+ items=infos,
799
+ dst_project_id=created_project.id,
800
+ with_annotations=with_annotations,
801
+ ds_progress=self.output_selector.secondary_progress,
802
+ progress_cb=main_pbar.update,
803
+ project_type=ProjectType.IMAGES,
804
+ )
805
+
806
+ # Run prediction
807
+ self.set_validator_text("Running prediction...", "info")
808
+ predictions: List[Prediction] = []
809
+ self._is_running = True
810
+ with self.model_api.predict_detached(
811
+ image_ids=[info.id for info in output_image_infos],
812
+ **predict_kwargs,
813
+ tqdm=self.output_selector.progress(),
814
+ ) as session:
815
+ for prediction in session:
816
+ if self._stop_flag:
817
+ logger.info("Prediction stopped by user.")
818
+ raise StopIteration("Stopped by user.")
819
+ predictions.append(prediction)
820
+ self.set_validator_text("Project successfully processed", "success")
821
+ self.output_selector.set_result_thumbnail(output_project_id)
822
+ return predictions
823
+
824
+ def run(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
825
+ self.show_validator_text()
826
+ if run_parameters is None:
827
+ run_parameters = self.get_run_parameters()
828
+ input_parameters = run_parameters["input"]
829
+ video_ids = input_parameters.get("video_ids", None)
830
+ try:
831
+ if video_ids:
832
+ run_f = self._run_videos
833
+ else:
834
+ run_f = self._run_images
835
+ return run_f(run_parameters)
836
+ except StopIteration:
837
+ logger.info("Prediction stopped by user.")
838
+ self.set_validator_text("Prediction stopped by user.", "warning")
839
+ raise
840
+ except Exception as e:
841
+ logger.error(f"Error during prediction: {str(e)}")
842
+ self.set_validator_text(f"Error during prediction: {str(e)}", "error")
843
+ disable_enable(self.output_selector.widgets_to_disable, False)
844
+ raise
845
+ finally:
846
+ self.output_selector.secondary_progress.hide()
847
+ self.output_selector.progress.hide()
848
+ self._is_running = False
849
+ self._stop_flag = False
850
+
851
+ def stop(self):
852
+ logger.info("Stopping prediction...")
853
+ self._stop_flag = True
854
+
855
+ def wait_for_stop(self, timeout: int = None):
856
+ logger.info(
857
+ "Waiting " + ""
858
+ if timeout is None
859
+ else f"{timeout} seconds " + "for prediction to stop..."
860
+ )
861
+ t = time.monotonic()
862
+ while self._is_running:
863
+ if timeout is not None and time.monotonic() - t > timeout:
864
+ raise TimeoutError("Timeout while waiting for stop.")
865
+ time.sleep(0.1)
866
+ logger.info("Prediction stopped.")
867
+
868
+ def shutdown_model(self):
869
+ self.stop()
870
+ self.wait_for_stop(10)
871
+ self.model_selector.model.stop()
872
+
873
+ def get_run_parameters(self) -> Dict[str, Any]:
874
+ settings = {
875
+ "model": self.model_selector.model.get_deploy_parameters(),
876
+ "settings": self.settings_selector.get_settings(),
877
+ "input": self.input_selector.get_settings(),
878
+ "output": self.output_selector.get_settings(),
879
+ }
880
+ if self.classes_selector is not None:
881
+ settings["classes"] = self.classes_selector.get_selected_classes()
882
+ if self.tags_selector is not None:
883
+ settings["tags"] = self.tags_selector.get_selected_tags()
884
+ return settings
885
+
886
+ def load_from_json(self, data):
887
+ # 1. Input selector
888
+ self.input_selector.load_from_json(data.get("input", {}))
889
+ # self.input_selector_cb()
890
+
891
+ # 2. Model selector
892
+ self.model_selector.model.load_from_json(data.get("model", {}))
893
+
894
+ # 3. Classes selector
895
+ if self.classes_selector is not None:
896
+ self.classes_selector.load_from_json(data.get("classes", {}))
897
+
898
+ # 4. Tags selector
899
+ if self.tags_selector is not None:
900
+ self.tags_selector.load_from_json(data.get("tags", {}))
901
+
902
+ # 5. Settings selector & Preview
903
+ self.settings_selector.load_from_json(data.get("settings", {}))
904
+
905
+ # 6. Output selector
906
+ self.output_selector.load_from_json(data.get("output", {}))
907
+
908
+ def set_validator_text(self, text: str, status: str = "text"):
909
+ self.output_selector.validator_text.set(text=text, status=status)
910
+
911
+ def show_validator_text(self):
912
+ self.output_selector.validator_text.show()
913
+
914
+ def hide_validator_text(self):
915
+ self.output_selector.validator_text.hide()