supervisely 6.73.419__py3-none-any.whl → 6.73.421__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. supervisely/api/api.py +10 -5
  2. supervisely/api/app_api.py +71 -4
  3. supervisely/api/module_api.py +4 -0
  4. supervisely/api/nn/deploy_api.py +15 -9
  5. supervisely/api/nn/ecosystem_models_api.py +201 -0
  6. supervisely/api/nn/neural_network_api.py +12 -3
  7. supervisely/api/project_api.py +35 -6
  8. supervisely/api/task_api.py +5 -1
  9. supervisely/app/widgets/__init__.py +8 -1
  10. supervisely/app/widgets/agent_selector/template.html +1 -0
  11. supervisely/app/widgets/deploy_model/__init__.py +0 -0
  12. supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
  13. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  14. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  15. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  16. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  17. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
  18. supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
  19. supervisely/app/widgets/fast_table/fast_table.py +402 -74
  20. supervisely/app/widgets/fast_table/script.js +364 -96
  21. supervisely/app/widgets/fast_table/style.css +24 -0
  22. supervisely/app/widgets/fast_table/template.html +43 -3
  23. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  24. supervisely/app/widgets/select/select.py +6 -4
  25. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
  26. supervisely/app/widgets/tabs/tabs.py +22 -6
  27. supervisely/app/widgets/tabs/template.html +5 -1
  28. supervisely/nn/artifacts/__init__.py +1 -1
  29. supervisely/nn/artifacts/artifacts.py +10 -2
  30. supervisely/nn/artifacts/detectron2.py +1 -0
  31. supervisely/nn/artifacts/hrda.py +1 -0
  32. supervisely/nn/artifacts/mmclassification.py +20 -0
  33. supervisely/nn/artifacts/mmdetection.py +5 -3
  34. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  35. supervisely/nn/artifacts/ritm.py +1 -0
  36. supervisely/nn/artifacts/rtdetr.py +1 -0
  37. supervisely/nn/artifacts/unet.py +1 -0
  38. supervisely/nn/artifacts/utils.py +3 -0
  39. supervisely/nn/artifacts/yolov5.py +2 -0
  40. supervisely/nn/artifacts/yolov8.py +1 -0
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  42. supervisely/nn/experiments.py +9 -0
  43. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  44. supervisely/nn/inference/inference.py +160 -94
  45. supervisely/nn/inference/predict_app/__init__.py +0 -0
  46. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  47. supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
  48. supervisely/nn/inference/predict_app/gui/gui.py +710 -0
  49. supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
  50. supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
  51. supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
  52. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  53. supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
  54. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  55. supervisely/nn/inference/predict_app/gui/utils.py +282 -0
  56. supervisely/nn/inference/predict_app/predict_app.py +184 -0
  57. supervisely/nn/inference/uploader.py +9 -5
  58. supervisely/nn/model/prediction.py +2 -0
  59. supervisely/nn/model/prediction_session.py +20 -3
  60. supervisely/nn/training/gui/gui.py +131 -44
  61. supervisely/nn/training/gui/model_selector.py +8 -6
  62. supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
  63. supervisely/nn/training/gui/training_artifacts.py +0 -5
  64. supervisely/nn/training/train_app.py +161 -44
  65. supervisely/project/project.py +211 -73
  66. supervisely/template/experiment/experiment.html.jinja +74 -17
  67. supervisely/template/experiment/experiment_generator.py +258 -112
  68. supervisely/template/experiment/header.html.jinja +31 -13
  69. supervisely/template/experiment/sly-style.css +7 -2
  70. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/METADATA +3 -1
  71. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/RECORD +75 -57
  72. supervisely/app/widgets/experiment_selector/style.css +0 -27
  73. supervisely/app/widgets/experiment_selector/template.html +0 -61
  74. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/LICENSE +0 -0
  75. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/WHEEL +0 -0
  76. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/entry_points.txt +0 -0
  77. {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,710 @@
1
+ import random
2
+ import time
3
+ from typing import Any, Callable, Dict, List, Optional
4
+
5
+ import yaml
6
+
7
+ from supervisely._utils import is_development, logger
8
+ from supervisely.annotation.annotation import Annotation
9
+ from supervisely.annotation.label import Label
10
+ from supervisely.api.api import Api
11
+ from supervisely.api.video.video_api import VideoInfo
12
+ from supervisely.app.widgets import Button, Card, Container, Stepper, Widget
13
+ from supervisely.io import env
14
+ from supervisely.nn.inference.predict_app.gui.classes_selector import ClassesSelector
15
+ from supervisely.nn.inference.predict_app.gui.input_selector import InputSelector
16
+ from supervisely.nn.inference.predict_app.gui.model_selector import ModelSelector
17
+ from supervisely.nn.inference.predict_app.gui.output_selector import OutputSelector
18
+ from supervisely.nn.inference.predict_app.gui.preview import Preview
19
+ from supervisely.nn.inference.predict_app.gui.settings_selector import (
20
+ AddPredictionsMode,
21
+ SettingsSelector,
22
+ )
23
+ from supervisely.nn.inference.predict_app.gui.tags_selector import TagsSelector
24
+ from supervisely.nn.inference.predict_app.gui.utils import (
25
+ copy_project,
26
+ disable_enable,
27
+ set_stepper_step,
28
+ wrap_button_click,
29
+ )
30
+ from supervisely.nn.model.model_api import ModelAPI
31
+ from supervisely.nn.model.prediction import Prediction
32
+ from supervisely.project.project_meta import ProjectMeta
33
+ from supervisely.video_annotation.key_id_map import KeyIdMap
34
+ from supervisely.video_annotation.video_annotation import VideoAnnotation
35
+
36
+
37
+ class StepFlow:
38
+
39
+ def __init__(self, stepper: Stepper):
40
+ self.stepper = stepper
41
+ self.steps = {}
42
+ self.step_sequence = []
43
+
44
+ def register_step(
45
+ self,
46
+ name: str,
47
+ card: Card,
48
+ button: Optional[Button] = None,
49
+ widgets_to_disable: Optional[List[Widget]] = None,
50
+ validation_text: Optional[Widget] = None,
51
+ validation_func: Optional[Callable] = None,
52
+ position: Optional[int] = None,
53
+ ) -> "StepFlow":
54
+ self.steps[name] = {
55
+ "card": card,
56
+ "button": button,
57
+ "widgets_to_disable": widgets_to_disable or [],
58
+ "validation_text": validation_text,
59
+ "validation_func": validation_func,
60
+ "position": position,
61
+ "next_steps": [],
62
+ "on_select_click": [],
63
+ "on_reselect_click": [],
64
+ "wrapper": None,
65
+ "has_button": button is not None,
66
+ }
67
+
68
+ if position is not None:
69
+ while len(self.step_sequence) <= position:
70
+ self.step_sequence.append(None)
71
+ self.step_sequence[position] = name
72
+
73
+ return self
74
+
75
+ def set_next_steps(self, step_name: str, next_steps: List[str]) -> "StepFlow":
76
+ if step_name in self.steps:
77
+ self.steps[step_name]["next_steps"] = next_steps
78
+ return self
79
+
80
+ def add_on_select_actions(
81
+ self, step_name: str, actions: List[Callable], is_reselect: bool = False
82
+ ) -> "StepFlow":
83
+ if step_name in self.steps:
84
+ key = "on_reselect_click" if is_reselect else "on_select_click"
85
+ self.steps[step_name][key].extend(actions)
86
+ return self
87
+
88
+ def build_wrappers(self) -> Dict[str, Callable]:
89
+ valid_sequence = [s for s in self.step_sequence if s is not None and s in self.steps]
90
+
91
+ for step_name in reversed(valid_sequence):
92
+ step = self.steps[step_name]
93
+
94
+ cards_to_unlock = []
95
+ for next_step_name in step["next_steps"]:
96
+ if next_step_name in self.steps:
97
+ cards_to_unlock.append(self.steps[next_step_name]["card"])
98
+
99
+ callback = None
100
+ if step["next_steps"] and step["has_button"]:
101
+ for next_step_name in step["next_steps"]:
102
+ if (
103
+ next_step_name in self.steps
104
+ and self.steps[next_step_name].get("wrapper")
105
+ and self.steps[next_step_name]["has_button"]
106
+ ):
107
+ callback = self.steps[next_step_name]["wrapper"]
108
+ break
109
+
110
+ if step["has_button"]:
111
+ wrapper = wrap_button_click(
112
+ button=step["button"],
113
+ cards_to_unlock=cards_to_unlock,
114
+ widgets_to_disable=step["widgets_to_disable"],
115
+ callback=callback,
116
+ validation_text=step["validation_text"],
117
+ validation_func=step["validation_func"],
118
+ on_select_click=step["on_select_click"],
119
+ on_reselect_click=step["on_reselect_click"],
120
+ collapse_card=None,
121
+ )
122
+
123
+ step["wrapper"] = wrapper
124
+
125
+ return {
126
+ name: self.steps[name]["wrapper"]
127
+ for name in self.steps
128
+ if self.steps[name].get("wrapper") and self.steps[name]["has_button"]
129
+ }
130
+
131
+ def setup_button_handlers(self) -> None:
132
+ positions = {}
133
+ pos = 1
134
+
135
+ for i, step_name in enumerate(self.step_sequence):
136
+ if step_name is not None and step_name in self.steps:
137
+ positions[step_name] = pos
138
+ pos += 1
139
+
140
+ for step_name, step in self.steps.items():
141
+ if step_name in positions and step.get("wrapper") and step["has_button"]:
142
+
143
+ button = step["button"]
144
+ wrapper = step["wrapper"]
145
+ position = positions[step_name]
146
+ next_position = position + 1
147
+
148
+ def create_handler(btn, cb, next_pos):
149
+ def handler():
150
+ cb()
151
+ set_stepper_step(self.stepper, btn, next_pos=next_pos)
152
+
153
+ return handler
154
+
155
+ button.click(create_handler(button, wrapper, next_position))
156
+
157
+ def build(self) -> Dict[str, Callable]:
158
+ wrappers = self.build_wrappers()
159
+ self.setup_button_handlers()
160
+ return wrappers
161
+
162
+
163
+ class PredictAppGui:
164
+ def __init__(self, api: Api, static_dir: str = "static"):
165
+ self.api = api
166
+ self.static_dir = static_dir
167
+
168
+ # Environment variables
169
+ self.team_id = env.team_id()
170
+ self.workspace_id = env.workspace_id()
171
+ self.project_id = env.project_id(raise_not_found=False)
172
+ # -------------------------------- #
173
+
174
+ # Flags
175
+ self._stop_flag = False
176
+ self._is_running = False
177
+ # -------------------------------- #
178
+
179
+ # GUI
180
+ # Steps
181
+ self.steps = []
182
+
183
+ # 1. Input selector
184
+ self.input_selector = InputSelector(self.workspace_id)
185
+ self.steps.append(self.input_selector.card)
186
+
187
+ # 2. Model selector
188
+ self.model_selector = ModelSelector(self.api, self.team_id)
189
+ self.steps.append(self.model_selector.card)
190
+
191
+ # 3. Classes selector
192
+ self.classes_selector = None
193
+ if True:
194
+ self.classes_selector = ClassesSelector()
195
+ self.steps.append(self.classes_selector.card)
196
+
197
+ # 4. Tags selector
198
+ self.tags_selector = None
199
+ if False:
200
+ self.tags_selector = TagsSelector()
201
+ self.steps.append(self.tags_selector.card)
202
+
203
+ # 5. Settings selector
204
+ self.settings_selector = SettingsSelector()
205
+ self.steps.append(self.settings_selector.card)
206
+
207
+ # 6. Preview
208
+ self.preview = None
209
+ if False:
210
+ self.preview = Preview(api, static_dir)
211
+ self.steps.append(self.preview.card)
212
+
213
+ # 7. Output selector
214
+ self.output_selector = OutputSelector(self.api)
215
+ self.steps.append(self.output_selector.card)
216
+ # -------------------------------- #
217
+
218
+ # Stepper
219
+ self.stepper = Stepper(widgets=self.steps)
220
+ # ---------------------------- #
221
+
222
+ # Layout
223
+ self.layout = Container([self.stepper])
224
+ # ---------------------------- #
225
+
226
+ # Button Utils
227
+ def deploy_model() -> ModelAPI:
228
+ self.model_selector.validator_text.hide()
229
+ model_api = None
230
+ try:
231
+ model_api = type(self.model_selector.model).deploy(self.model_selector.model)
232
+ except:
233
+ self.output_selector.start_button.disable()
234
+ raise
235
+ else:
236
+ self.output_selector.start_button.enable()
237
+ return model_api
238
+
239
+ # Reimplement deploy method for DeployModel widget
240
+ self.model_selector.model.deploy = deploy_model
241
+
242
+ def set_entity_meta():
243
+ model_api = self.model_selector.model.model_api
244
+
245
+ model_meta = model_api.get_model_meta()
246
+ if self.classes_selector is not None:
247
+ self.classes_selector.classes_table.set_project_meta(model_meta)
248
+ self.classes_selector.classes_table.show()
249
+ if self.tags_selector is not None:
250
+ self.tags_selector.tags_table.set_project_meta(model_meta)
251
+ self.tags_selector.tags_table.show()
252
+
253
+ inference_settings = model_api.get_settings()
254
+ self.settings_selector.set_inference_settings(inference_settings)
255
+
256
+ if self.preview is not None:
257
+ self.preview.inference_settings = inference_settings
258
+
259
+ def reset_entity_meta():
260
+ empty_meta = ProjectMeta()
261
+ if self.classes_selector is not None:
262
+ self.classes_selector.classes_table.set_project_meta(empty_meta)
263
+ self.classes_selector.classes_table.hide()
264
+ if self.tags_selector is not None:
265
+ self.tags_selector.tags_table.set_project_meta(empty_meta)
266
+ self.tags_selector.tags_table.hide()
267
+
268
+ self.settings_selector.set_inference_settings("")
269
+
270
+ if self.preview is not None:
271
+ self.preview.inference_settings = None
272
+
273
+ def disable_settings_editor():
274
+ if self.settings_selector.inference_settings.readonly:
275
+ self.settings_selector.inference_settings.readonly = False
276
+ else:
277
+ self.settings_selector.inference_settings.readonly = True
278
+
279
+ def generate_preview():
280
+ def _get_frame_annotation(
281
+ video_info: VideoInfo, frame_index: int, project_meta: ProjectMeta
282
+ ) -> Annotation:
283
+ video_annotation = VideoAnnotation.from_json(
284
+ self.api.video.annotation.download(video_info.id, frame_index),
285
+ project_meta=project_meta,
286
+ key_id_map=KeyIdMap(),
287
+ )
288
+ frame = video_annotation.frames.get(frame_index)
289
+ img_size = (video_info.frame_height, video_info.frame_width)
290
+ if frame is None:
291
+ return Annotation(img_size)
292
+ labels = []
293
+ for figure in frame.figures:
294
+ labels.append(Label(figure.geometry, figure.video_object.obj_class))
295
+ ann = Annotation(img_size, labels=labels)
296
+ return ann
297
+
298
+ if self.preview is None:
299
+ return
300
+
301
+ self.preview.validator_text.hide()
302
+ self.preview.gallery.clean_up()
303
+ self.preview.gallery.show()
304
+ self.preview.gallery.loading = True
305
+ try:
306
+ items_settings = self.input_selector.get_settings()
307
+ if "video_id" in items_settings:
308
+ video_id = items_settings["video_id"]
309
+ video_info = self.api.video.get_info_by_id(video_id)
310
+ video_frame = random.randint(0, video_info.frames_count - 1)
311
+ self.api.video.frame.download_path(
312
+ video_info.id, video_frame, self.preview.preview_path
313
+ )
314
+ img_url = self.preview.peview_url
315
+ project_meta = ProjectMeta.from_json(
316
+ self.api.project.get_meta(video_info.project_id)
317
+ )
318
+ input_ann = _get_frame_annotation(video_info, video_frame, project_meta)
319
+ prediction = self.model_selector.model.model_api.predict(
320
+ input=self.preview.preview_path, **self.settings_selector.get_settings()
321
+ )[0]
322
+ output_ann = prediction.annotation
323
+ else:
324
+ if "project_id" in items_settings:
325
+ project_id = items_settings["project_id"]
326
+ dataset_infos = self.api.dataset.get_list(project_id, recursive=True)
327
+ dataset_infos = [ds for ds in dataset_infos if ds.items_count > 0]
328
+ if not dataset_infos:
329
+ raise ValueError("No datasets with items found in the project.")
330
+ dataset_info = random.choice(dataset_infos)
331
+ elif "dataset_ids" in items_settings:
332
+ dataset_ids = items_settings["dataset_ids"]
333
+ dataset_infos = [
334
+ self.api.dataset.get_info_by_id(dataset_id)
335
+ for dataset_id in dataset_ids
336
+ ]
337
+ dataset_infos = [ds for ds in dataset_infos if ds.items_count > 0]
338
+ if not dataset_infos:
339
+ raise ValueError("No items in selected datasets.")
340
+ dataset_info = random.choice(dataset_infos)
341
+ else:
342
+ raise ValueError("No valid item settings found for preview.")
343
+ images = self.api.image.get_list(dataset_info.id)
344
+ image_info = random.choice(images)
345
+ img_url = image_info.preview_url
346
+
347
+ project_meta = ProjectMeta.from_json(
348
+ self.api.project.get_meta(dataset_info.project_id)
349
+ )
350
+ input_ann = Annotation.from_json(
351
+ self.api.annotation.download(image_info.id).annotation,
352
+ project_meta=project_meta,
353
+ )
354
+ prediction = self.model_selector.model.model_api.predict(
355
+ image_id=image_info.id, **self.settings_selector.get_settings()
356
+ )[0]
357
+ output_ann = prediction.annotation
358
+
359
+ self.preview.gallery.append(img_url, input_ann, "Input")
360
+ self.preview.gallery.append(img_url, output_ann, "Output")
361
+ self.preview.validator_text.hide()
362
+ self.preview.gallery.show()
363
+ return prediction
364
+ except Exception as e:
365
+ self.preview.gallery.hide()
366
+ self.preview.validator_text.set(
367
+ text=f"Error during preview: {str(e)}", status="error"
368
+ )
369
+ self.preview.validator_text.show()
370
+ self.preview.gallery.clean_up()
371
+ finally:
372
+ self.preview.gallery.loading = False
373
+
374
+ # ---------------------------- #
375
+
376
+ # StepFlow callbacks and wiring
377
+ self.step_flow = StepFlow(self.stepper)
378
+ position = 0
379
+
380
+ # 1. Input selector
381
+ self.step_flow.register_step(
382
+ "input_selector",
383
+ self.input_selector.card,
384
+ self.input_selector.button,
385
+ self.input_selector.widgets_to_disable,
386
+ self.input_selector.validator_text,
387
+ self.input_selector.validate_step,
388
+ position=position,
389
+ )
390
+ position += 1
391
+
392
+ # 2. Model selector
393
+ self.step_flow.register_step(
394
+ "model_selector",
395
+ self.model_selector.card,
396
+ self.model_selector.button,
397
+ self.model_selector.widgets_to_disable,
398
+ self.model_selector.validator_text,
399
+ self.model_selector.validate_step,
400
+ position=position,
401
+ )
402
+ self.step_flow.add_on_select_actions("model_selector", [set_entity_meta])
403
+ self.step_flow.add_on_select_actions(
404
+ "model_selector", [reset_entity_meta], is_reselect=True
405
+ )
406
+ position += 1
407
+
408
+ # 3. Classes selector
409
+ if self.classes_selector is not None:
410
+ self.step_flow.register_step(
411
+ "classes_selector",
412
+ self.classes_selector.card,
413
+ self.classes_selector.button,
414
+ self.classes_selector.widgets_to_disable,
415
+ self.classes_selector.validator_text,
416
+ self.classes_selector.validate_step,
417
+ position=position,
418
+ )
419
+ position += 1
420
+
421
+ # 4. Tags selector
422
+ if self.tags_selector is not None:
423
+ self.step_flow.register_step(
424
+ "tags_selector",
425
+ self.tags_selector.card,
426
+ self.tags_selector.button,
427
+ self.tags_selector.widgets_to_disable,
428
+ self.tags_selector.validator_text,
429
+ self.tags_selector.validate_step,
430
+ position=position,
431
+ )
432
+ position += 1
433
+
434
+ # 5. Settings selector
435
+ self.step_flow.register_step(
436
+ "settings_selector",
437
+ self.settings_selector.card,
438
+ self.settings_selector.button,
439
+ self.settings_selector.widgets_to_disable,
440
+ self.settings_selector.validator_text,
441
+ self.settings_selector.validate_step,
442
+ position=position,
443
+ )
444
+ self.step_flow.add_on_select_actions("settings_selector", [disable_settings_editor])
445
+ self.step_flow.add_on_select_actions("settings_selector", [disable_settings_editor], True)
446
+ position += 1
447
+
448
+ # 6. Preview
449
+ if self.preview is not None:
450
+ self.step_flow.register_step(
451
+ "preview",
452
+ self.preview.card,
453
+ self.preview.button,
454
+ self.preview.widgets_to_disable,
455
+ self.preview.validator_text,
456
+ self.preview.validate_step,
457
+ position=position,
458
+ ).add_on_select_actions("preview", [generate_preview])
459
+ position += 1
460
+
461
+ # 7. Output selector
462
+ self.step_flow.register_step(
463
+ "output_selector",
464
+ self.output_selector.card,
465
+ None,
466
+ self.output_selector.widgets_to_disable,
467
+ self.output_selector.validator_text,
468
+ self.output_selector.validate_step,
469
+ position=position,
470
+ )
471
+
472
+ # Dependencies Chain
473
+ has_model_selector = self.model_selector is not None
474
+ has_classes_selector = self.classes_selector is not None
475
+ has_tags_selector = self.tags_selector is not None
476
+ has_preview = self.preview is not None
477
+
478
+ # Step 1 -> Step 2
479
+ prev_step = "input_selector"
480
+ if has_model_selector:
481
+ self.step_flow.set_next_steps(prev_step, ["model_selector"])
482
+ prev_step = "model_selector"
483
+ # Step 2 -> Step 3
484
+ if has_classes_selector:
485
+ self.step_flow.set_next_steps(prev_step, ["classes_selector"])
486
+ prev_step = "classes_selector"
487
+ # Step 3 -> Step 4
488
+ if has_tags_selector:
489
+ self.step_flow.set_next_steps(prev_step, ["tags_selector"])
490
+ prev_step = "tags_selector"
491
+ # Step 4 -> Step 5
492
+ self.step_flow.set_next_steps(prev_step, ["settings_selector"])
493
+ prev_step = "settings_selector"
494
+ # Step 5 -> Step 6
495
+ if has_preview:
496
+ self.step_flow.set_next_steps(prev_step, ["preview"])
497
+ prev_step = "preview"
498
+ # Step 6 -> Step 7
499
+ self.step_flow.set_next_steps(prev_step, ["output_selector"])
500
+
501
+ # Create all wrappers and set button handlers
502
+ wrappers = self.step_flow.build()
503
+
504
+ self.input_selector_cb = wrappers.get("input_selector")
505
+ self.classes_selector_cb = wrappers.get("classes_selector")
506
+ self.tags_selector_cb = wrappers.get("tags_selector")
507
+ self.model_selector_cb = wrappers.get("model_selector")
508
+ self.settings_selector_cb = wrappers.get("settings_selector")
509
+ self.preview_cb = wrappers.get("preview")
510
+ self.output_selector_cb = wrappers.get("output_selector")
511
+ # ------------------------------------------------- #
512
+
513
+ # Other Handlers
514
+ @self.input_selector.radio.value_changed
515
+ def input_selector_type_changed(value: str):
516
+ self.input_selector.validator_text.hide()
517
+
518
+ @self.input_selector.select_dataset_for_video.value_changed
519
+ def dataset_for_video_changed(dataset_id: int):
520
+ self.input_selector.select_video.loading = True
521
+ if dataset_id is None:
522
+ rows = []
523
+ else:
524
+ dataset_info = self.api.dataset.get_info_by_id(dataset_id)
525
+ videos = self.api.video.get_list(dataset_id)
526
+ rows = [[video.id, video.name, dataset_info.name] for video in videos]
527
+ self.input_selector.select_video.rows = rows
528
+ self.input_selector.select_video.loading = False
529
+
530
+ # ------------------------------------------------- #
531
+
532
+ def run(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
533
+ self.show_validator_text()
534
+ self.set_validator_text("Preparing settings for prediction...", "info")
535
+ if run_parameters is None:
536
+ run_parameters = self.get_run_parameters()
537
+
538
+ if self.model_selector.model.model_api is None:
539
+ self.model_selector.model._deploy()
540
+
541
+ model_api = self.model_selector.model.model_api
542
+ if model_api is None:
543
+ logger.error("Model Deployed with an error")
544
+ self.set_validator_text("Model Deployed with an error", "error")
545
+ return
546
+
547
+ kwargs = {}
548
+
549
+ # Input
550
+ # Input would be newely created project
551
+ input_parameters = run_parameters["input"]
552
+ input_project_id = input_parameters.get("project_id", None)
553
+ if input_project_id is None:
554
+ raise ValueError("Input project ID is required for prediction.")
555
+ input_dataset_ids = input_parameters.get("dataset_ids", [])
556
+ if not input_dataset_ids:
557
+ raise ValueError("At least one dataset must be selected for prediction.")
558
+
559
+ # Settings
560
+ settings = run_parameters["settings"]
561
+ prediction_mode = settings.pop("predictions_mode")
562
+ upload_mode = None
563
+ with_annotations = None
564
+ if prediction_mode == AddPredictionsMode.REPLACE_EXISTING_LABELS:
565
+ upload_mode = "replace"
566
+ with_annotations = False
567
+ elif prediction_mode == AddPredictionsMode.MERGE_WITH_EXISTING_LABELS:
568
+ upload_mode = "append"
569
+ with_annotations = True
570
+ elif prediction_mode == AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS:
571
+ upload_mode = "replace"
572
+ with_annotations = True
573
+ kwargs.update(settings)
574
+ kwargs["upload_mode"] = upload_mode
575
+
576
+ # Classes
577
+ classes = run_parameters["classes"]
578
+ if classes:
579
+ kwargs["classes"] = classes
580
+
581
+ # Output
582
+ # Always create new project
583
+ # But the actual inference will happen inplace
584
+ output_parameters = run_parameters["output"]
585
+ project_name = output_parameters["project_name"]
586
+ if not project_name:
587
+ input_project_info = self.api.project.get_info_by_id(input_project_id)
588
+ project_name = input_project_info.name + " [Predictions]"
589
+ logger.warning("Project name is empty, using auto-generated name: " + project_name)
590
+
591
+ # Copy project
592
+ self.set_validator_text("Copying project...", "info")
593
+ created_project = copy_project(
594
+ self.api,
595
+ project_name,
596
+ self.workspace_id,
597
+ input_project_id,
598
+ input_dataset_ids,
599
+ with_annotations,
600
+ self.output_selector.progress,
601
+ )
602
+ # ------------------------ #
603
+
604
+ # Run prediction
605
+ self.set_validator_text("Running prediction...", "info")
606
+ predictions = []
607
+ self._is_running = True
608
+ try:
609
+ with model_api.predict_detached(
610
+ project_id=created_project.id,
611
+ tqdm=self.output_selector.progress(),
612
+ **kwargs,
613
+ ) as session:
614
+ self.output_selector.progress.show()
615
+ i = 0
616
+ for prediction in session:
617
+ predictions.append(prediction)
618
+ i += 1
619
+ if self._stop_flag:
620
+ logger.info("Prediction stopped by user.")
621
+ break
622
+ self.output_selector.progress.hide()
623
+ except Exception as e:
624
+ self.output_selector.progress.hide()
625
+ logger.error(f"Error during prediction: {str(e)}")
626
+ self.set_validator_text(f"Error during prediction: {str(e)}", "error")
627
+ disable_enable(self.output_selector.widgets_to_disable, False)
628
+ self._is_running = False
629
+ self._stop_flag = False
630
+ raise e
631
+ finally:
632
+ self._is_running = False
633
+ self._stop_flag = False
634
+ # ------------------------ #
635
+
636
+ # Set result thumbnail
637
+ self.set_validator_text("Project successfully processed", "success")
638
+ self.output_selector.set_result_thumbnail(created_project.id)
639
+ # ------------------------ #
640
+ return predictions
641
+
642
+ def stop(self):
643
+ logger.info("Stopping prediction...")
644
+ self._stop_flag = True
645
+
646
+ def wait_for_stop(self, timeout: int = None):
647
+ logger.info(
648
+ "Waiting " + ""
649
+ if timeout is None
650
+ else f"{timeout} seconds " + "for prediction to stop..."
651
+ )
652
+ t = time.monotonic()
653
+ while self._is_running:
654
+ if timeout is not None and time.monotonic() - t > timeout:
655
+ raise TimeoutError("Timeout while waiting for stop.")
656
+ time.sleep(0.1)
657
+ logger.info("Prediction stopped.")
658
+
659
+ def shutdown_model(self):
660
+ self.stop()
661
+ self.wait_for_stop(10)
662
+ self.model_selector.model.stop()
663
+
664
+ def get_run_parameters(self) -> Dict[str, Any]:
665
+ settings = {
666
+ "model": self.model_selector.model.get_deploy_parameters(),
667
+ "settings": self.settings_selector.get_settings(),
668
+ "input": self.input_selector.get_settings(),
669
+ "output": self.output_selector.get_settings(),
670
+ }
671
+ if self.classes_selector is not None:
672
+ settings["classes"] = self.classes_selector.get_selected_classes()
673
+ if self.tags_selector is not None:
674
+ settings["tags"] = self.tags_selector.get_selected_tags()
675
+ return settings
676
+
677
+ def load_from_json(self, data):
678
+ # 1. Input selector
679
+ self.input_selector.load_from_json(data.get("input", {}))
680
+ # self.input_selector_cb()
681
+
682
+ # 2. Model selector
683
+ self.model_selector.model.load_from_json(data.get("model", {}))
684
+
685
+ # 3. Classes selector
686
+ if self.classes_selector is not None:
687
+ self.classes_selector.load_from_json(data.get("classes", {}))
688
+
689
+ # 4. Tags selector
690
+ if self.tags_selector is not None:
691
+ self.tags_selector.load_from_json(data.get("tags", {}))
692
+
693
+ # 5. Settings selector
694
+ self.settings_selector.load_from_json(data.get("settings", {}))
695
+
696
+ # 6. Preview (No need?)
697
+ if self.preview is not None:
698
+ self.preview.load_from_json(data.get("preview", {}))
699
+
700
+ # 7. Output selector
701
+ self.output_selector.load_from_json(data.get("output", {}))
702
+
703
+ def set_validator_text(self, text: str, status: str = "text"):
704
+ self.output_selector.validator_text.set(text=text, status=status)
705
+
706
+ def show_validator_text(self):
707
+ self.output_selector.validator_text.show()
708
+
709
+ def hide_validator_text(self):
710
+ self.output_selector.validator_text.hide()