supervisely 6.73.456__py3-none-any.whl → 6.73.458__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. supervisely/__init__.py +24 -1
  2. supervisely/api/image_api.py +4 -0
  3. supervisely/api/video/video_annotation_api.py +4 -2
  4. supervisely/api/video/video_api.py +41 -1
  5. supervisely/app/v1/app_service.py +18 -2
  6. supervisely/app/v1/constants.py +7 -1
  7. supervisely/app/widgets/card/card.py +20 -0
  8. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  9. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  10. supervisely/app/widgets/fast_table/fast_table.py +45 -11
  11. supervisely/app/widgets/fast_table/template.html +1 -1
  12. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  13. supervisely/app/widgets/radio_tabs/template.html +1 -0
  14. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +63 -7
  15. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  16. supervisely/nn/inference/cache.py +2 -2
  17. supervisely/nn/inference/inference.py +364 -73
  18. supervisely/nn/inference/inference_request.py +3 -2
  19. supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
  20. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  21. supervisely/nn/inference/predict_app/gui/input_selector.py +178 -25
  22. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  23. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  24. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  25. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  26. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  27. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  28. supervisely/nn/model/model_api.py +9 -0
  29. supervisely/nn/tracker/base_tracker.py +11 -1
  30. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  31. supervisely/nn/tracker/botsort_tracker.py +14 -7
  32. supervisely/nn/tracker/visualize.py +70 -72
  33. supervisely/video/video.py +15 -1
  34. supervisely/worker_api/agent_rpc.py +24 -1
  35. supervisely/worker_api/rpc_servicer.py +31 -7
  36. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/METADATA +3 -2
  37. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/RECORD +41 -41
  38. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/LICENSE +0 -0
  39. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/WHEEL +0 -0
  40. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/entry_points.txt +0 -0
  41. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/top_level.txt +0 -0
@@ -1,163 +1,156 @@
1
- import random
1
+ import json
2
2
  import time
3
+ from pathlib import Path
3
4
  from typing import Any, Callable, Dict, List, Optional
4
5
 
5
6
  import yaml
6
7
 
7
8
  from supervisely._utils import is_development, logger
8
- from supervisely.annotation.annotation import Annotation
9
- from supervisely.annotation.label import Label
10
9
  from supervisely.api.api import Api
10
+ from supervisely.api.image_api import ImageInfo
11
11
  from supervisely.api.video.video_api import VideoInfo
12
12
  from supervisely.app.widgets import Button, Card, Container, Stepper, Widget
13
+ from supervisely.geometry.any_geometry import AnyGeometry
13
14
  from supervisely.io import env
15
+ from supervisely.nn.inference.inference import update_meta_and_ann_for_video_annotation
14
16
  from supervisely.nn.inference.predict_app.gui.classes_selector import ClassesSelector
15
17
  from supervisely.nn.inference.predict_app.gui.input_selector import InputSelector
16
18
  from supervisely.nn.inference.predict_app.gui.model_selector import ModelSelector
17
19
  from supervisely.nn.inference.predict_app.gui.output_selector import OutputSelector
18
- from supervisely.nn.inference.predict_app.gui.preview import Preview
19
20
  from supervisely.nn.inference.predict_app.gui.settings_selector import (
20
21
  AddPredictionsMode,
21
22
  SettingsSelector,
22
23
  )
23
24
  from supervisely.nn.inference.predict_app.gui.tags_selector import TagsSelector
24
25
  from supervisely.nn.inference.predict_app.gui.utils import (
25
- copy_project,
26
+ copy_items_to_project,
27
+ create_project,
26
28
  disable_enable,
27
- set_stepper_step,
28
- wrap_button_click,
29
+ update_custom_button_params,
30
+ video_annotation_from_predictions,
29
31
  )
30
32
  from supervisely.nn.model.model_api import ModelAPI
31
33
  from supervisely.nn.model.prediction import Prediction
32
34
  from supervisely.project.project_meta import ProjectMeta
35
+ from supervisely.project.project_type import ProjectType
33
36
  from supervisely.video_annotation.key_id_map import KeyIdMap
34
37
  from supervisely.video_annotation.video_annotation import VideoAnnotation
35
38
 
36
39
 
37
40
  class StepFlow:
38
-
39
- def __init__(self, stepper: Stepper):
40
- self.stepper = stepper
41
+ def __init__(self):
42
+ self._stepper = None
41
43
  self.steps = {}
42
- self.step_sequence = []
44
+ self.steps_sequence = []
43
45
 
44
- def register_step(
46
+ def add_step(
45
47
  self,
46
48
  name: str,
47
- card: Card,
49
+ widget: Widget,
50
+ on_select: Optional[Callable] = None,
51
+ on_reactivate: Optional[Callable] = None,
52
+ depends_on: Optional[List[Widget]] = None,
53
+ on_lock: Optional[Callable] = None,
54
+ on_unlock: Optional[Callable] = None,
48
55
  button: Optional[Button] = None,
49
- widgets_to_disable: Optional[List[Widget]] = None,
50
- validation_text: Optional[Widget] = None,
51
- validation_func: Optional[Callable] = None,
52
56
  position: Optional[int] = None,
53
- ) -> "StepFlow":
57
+ ):
58
+ if depends_on is None:
59
+ depends_on = []
54
60
  self.steps[name] = {
55
- "card": card,
61
+ "widget": widget,
62
+ "on_select": on_select,
63
+ "on_reactivate": on_reactivate,
64
+ "depends_on": depends_on,
65
+ "on_lock": on_lock,
66
+ "on_unlock": on_unlock,
56
67
  "button": button,
57
- "widgets_to_disable": widgets_to_disable or [],
58
- "validation_text": validation_text,
59
- "validation_func": validation_func,
60
- "position": position,
61
- "next_steps": [],
62
- "on_select_click": [],
63
- "on_reselect_click": [],
64
- "wrapper": None,
65
- "has_button": button is not None,
68
+ "is_selected": False,
69
+ "is_locked": False,
66
70
  }
67
-
71
+ if button is not None:
72
+ self._wrap_button(button, name)
68
73
  if position is not None:
69
- while len(self.step_sequence) <= position:
70
- self.step_sequence.append(None)
71
- self.step_sequence[position] = name
72
-
73
- return self
74
-
75
- def set_next_steps(self, step_name: str, next_steps: List[str]) -> "StepFlow":
76
- if step_name in self.steps:
77
- self.steps[step_name]["next_steps"] = next_steps
78
- return self
79
-
80
- def add_on_select_actions(
81
- self, step_name: str, actions: List[Callable], is_reselect: bool = False
82
- ) -> "StepFlow":
83
- if step_name in self.steps:
84
- key = "on_reselect_click" if is_reselect else "on_select_click"
85
- self.steps[step_name][key].extend(actions)
86
- return self
87
-
88
- def build_wrappers(self) -> Dict[str, Callable]:
89
- valid_sequence = [s for s in self.step_sequence if s is not None and s in self.steps]
74
+ self.steps_sequence.insert(position, name)
75
+ else:
76
+ self.steps_sequence.append(name)
77
+ self.update_locks()
90
78
 
91
- for step_name in reversed(valid_sequence):
79
+ def _create_stepper(self):
80
+ widgets = []
81
+ for step_name in self.steps_sequence:
92
82
  step = self.steps[step_name]
83
+ widgets.append(step["widget"])
84
+ self._stepper = Stepper(widgets=widgets)
93
85
 
94
- cards_to_unlock = []
95
- for next_step_name in step["next_steps"]:
96
- if next_step_name in self.steps:
97
- cards_to_unlock.append(self.steps[next_step_name]["card"])
98
-
99
- callback = None
100
- if step["next_steps"] and step["has_button"]:
101
- for next_step_name in step["next_steps"]:
102
- if (
103
- next_step_name in self.steps
104
- and self.steps[next_step_name].get("wrapper")
105
- and self.steps[next_step_name]["has_button"]
106
- ):
107
- callback = self.steps[next_step_name]["wrapper"]
108
- break
109
-
110
- if step["has_button"]:
111
- wrapper = wrap_button_click(
112
- button=step["button"],
113
- cards_to_unlock=cards_to_unlock,
114
- widgets_to_disable=step["widgets_to_disable"],
115
- callback=callback,
116
- validation_text=step["validation_text"],
117
- validation_func=step["validation_func"],
118
- on_select_click=step["on_select_click"],
119
- on_reselect_click=step["on_reselect_click"],
120
- collapse_card=None,
121
- )
122
-
123
- step["wrapper"] = wrapper
124
-
125
- return {
126
- name: self.steps[name]["wrapper"]
127
- for name in self.steps
128
- if self.steps[name].get("wrapper") and self.steps[name]["has_button"]
129
- }
130
-
131
- def setup_button_handlers(self) -> None:
132
- positions = {}
133
- pos = 1
134
-
135
- for i, step_name in enumerate(self.step_sequence):
136
- if step_name is not None and step_name in self.steps:
137
- positions[step_name] = pos
138
- pos += 1
139
-
140
- for step_name, step in self.steps.items():
141
- if step_name in positions and step.get("wrapper") and step["has_button"]:
86
+ @property
87
+ def stepper(self):
88
+ if self._stepper is None:
89
+ self._create_stepper()
90
+ return self._stepper
142
91
 
143
- button = step["button"]
144
- wrapper = step["wrapper"]
145
- position = positions[step_name]
146
- next_position = position + 1
147
-
148
- def create_handler(btn, cb, next_pos):
149
- def handler():
150
- cb()
151
- set_stepper_step(self.stepper, btn, next_pos=next_pos)
152
-
153
- return handler
92
+ def update_stepper(self):
93
+ for i, step_name in enumerate(self.steps_sequence):
94
+ step = self.steps[step_name]
95
+ if not step["is_selected"]:
96
+ self.stepper.set_active_step(i + 1)
97
+ return
154
98
 
155
- button.click(create_handler(button, wrapper, next_position))
99
+ def update_locks(self):
100
+ for step in self.steps.values():
101
+ should_lock = False
102
+ for dep_name in step["depends_on"]:
103
+ dep = self.steps[dep_name]
104
+ if not dep["is_selected"]:
105
+ should_lock = True
106
+ break
107
+ if should_lock and not step["is_locked"]:
108
+ if step["on_lock"] is not None:
109
+ step["on_lock"]()
110
+ step["is_locked"] = True
111
+ if not should_lock and step["is_locked"]:
112
+ if step["on_unlock"]:
113
+ step["on_unlock"]()
114
+ step["is_locked"] = False
115
+
116
+ def _reactivate_dependents(self, step_name: str, visited=None):
117
+ if visited is None:
118
+ visited = set()
119
+ for dep_name, step in self.steps.items():
120
+ if step_name in step["depends_on"] and not dep_name in visited:
121
+ self._reactivate_step(dep_name, visited)
122
+
123
+ def _reactivate_step(self, step_name: str, visited=None):
124
+ step = self.steps[step_name]
125
+ if step["on_reactivate"] is not None:
126
+ step["on_reactivate"]()
127
+ step["is_selected"] = False
128
+ if visited is None:
129
+ visited = set()
130
+ self._reactivate_dependents(step_name, visited)
131
+
132
+ def reactivate_step(self, step_name: str):
133
+ self._reactivate_step(step_name)
134
+ self.update_stepper()
135
+ self.update_locks()
136
+
137
+ def select_step(self, step_name: str):
138
+ step = self.steps[step_name]
139
+ if step["on_select"] is not None:
140
+ step["on_select"]()
141
+ step["is_selected"] = True
142
+ self.update_stepper()
143
+ self.update_locks()
144
+
145
+ def select_or_reactivate(self, step_name: str):
146
+ step = self.steps[step_name]
147
+ if step["is_selected"]:
148
+ self.reactivate_step(step_name)
149
+ else:
150
+ self.select_step(step_name)
156
151
 
157
- def build(self) -> Dict[str, Callable]:
158
- wrappers = self.build_wrappers()
159
- self.setup_button_handlers()
160
- return wrappers
152
+ def _wrap_button(self, button: Button, step_name: str):
153
+ button.click(lambda: self.select_or_reactivate(step_name))
161
154
 
162
155
 
163
156
  class PredictAppGui:
@@ -169,6 +162,9 @@ class PredictAppGui:
169
162
  self.team_id = env.team_id()
170
163
  self.workspace_id = env.workspace_id()
171
164
  self.project_id = env.project_id(raise_not_found=False)
165
+ self.project_meta = None
166
+ if self.project_id:
167
+ self.project_meta = ProjectMeta.from_json(self.api.project.get_meta(self.project_id))
172
168
  # -------------------------------- #
173
169
 
174
170
  # Flags
@@ -178,73 +174,178 @@ class PredictAppGui:
178
174
 
179
175
  # GUI
180
176
  # Steps
181
- self.steps = []
177
+ self.step_flow = StepFlow()
178
+ select_params = {"icon": None, "plain": False, "text": "Select"}
179
+ reselect_params = {"icon": "zmdi zmdi-refresh", "plain": True, "text": "Reselect"}
182
180
 
183
181
  # 1. Input selector
184
- self.input_selector = InputSelector(self.workspace_id)
185
- self.steps.append(self.input_selector.card)
182
+ self.input_selector = InputSelector(self.workspace_id, self.api)
183
+
184
+ def _on_input_select():
185
+ valid = self.input_selector.validate_step()
186
+ if not valid:
187
+ return
188
+ current_item_type = self.input_selector.radio.get_value()
189
+ self.update_item_type()
190
+ if self.model_api:
191
+ if current_item_type == self.input_selector.radio.get_value():
192
+ inference_settings = self.model_api.get_settings()
193
+ self.settings_selector.set_inference_settings(inference_settings)
194
+
195
+ if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
196
+ try:
197
+ tracking_settings = self.model_api.get_tracking_settings()
198
+ self.settings_selector.set_tracking_settings(tracking_settings)
199
+ except Exception as e:
200
+ logger.warning(
201
+ "Unable to get tracking settings from the model. Settings defaults"
202
+ )
203
+ self.settings_selector.set_default_tracking_settings()
204
+ self.input_selector.disable()
205
+
206
+ self.project_id = self.input_selector.get_project_id()
207
+ if self.project_id:
208
+ self.project_meta = ProjectMeta.from_json(self.api.project.get_meta(self.project_id))
209
+ update_custom_button_params(self.input_selector.button, reselect_params)
210
+
211
+ def _on_input_reactivate():
212
+ self.input_selector.enable()
213
+ update_custom_button_params(self.input_selector.button, select_params)
214
+
215
+ self.step_flow.add_step(
216
+ name="input_selector",
217
+ widget=self.input_selector.card,
218
+ on_select=_on_input_select,
219
+ on_reactivate=_on_input_reactivate,
220
+ button=self.input_selector.button,
221
+ )
186
222
 
187
223
  # 2. Model selector
188
224
  self.model_selector = ModelSelector(self.api, self.team_id)
189
- self.steps.append(self.model_selector.card)
225
+
226
+ self.step_flow.add_step(
227
+ name="model_selector",
228
+ widget=self.model_selector.card,
229
+ )
190
230
 
191
231
  # 3. Classes selector
192
- self.classes_selector = None
193
- if True:
194
- self.classes_selector = ClassesSelector()
195
- self.steps.append(self.classes_selector.card)
232
+ self.classes_selector = ClassesSelector()
233
+
234
+ def _on_classes_select():
235
+ valid = self.classes_selector.validate_step()
236
+ if not valid:
237
+ return
238
+ self.classes_selector.classes_table.disable()
239
+
240
+ # Find conflict between project meta and model meta
241
+ selected_classes_names = self.classes_selector.get_selected_classes()
242
+ project_meta = self.project_meta
243
+ model_meta = self.model_api.get_model_meta()
244
+
245
+ has_conflict = False
246
+ for class_name in selected_classes_names:
247
+ project_obj_class = project_meta.get_obj_class(class_name)
248
+ if project_obj_class is None:
249
+ continue
250
+
251
+ model_obj_class = model_meta.get_obj_class(class_name)
252
+ if model_obj_class.geometry_type.name() == AnyGeometry.name():
253
+ continue
254
+
255
+ if project_obj_class.geometry_type.name() == model_obj_class.geometry_type.name():
256
+ continue
257
+
258
+ has_conflict = True
259
+ break
260
+
261
+ if has_conflict:
262
+ self.settings_selector.model_prediction_suffix_container.show()
263
+ else:
264
+ self.settings_selector.model_prediction_suffix_container.hide()
265
+ # ------------------------------------------------ #
266
+
267
+ update_custom_button_params(self.classes_selector.button, reselect_params)
268
+
269
+ def _on_classes_reactivate():
270
+ self.classes_selector.classes_table.enable()
271
+ update_custom_button_params(self.classes_selector.button, select_params)
272
+
273
+ self.step_flow.add_step(
274
+ name="classes_selector",
275
+ widget=self.classes_selector.card,
276
+ on_select=_on_classes_select,
277
+ on_reactivate=_on_classes_reactivate,
278
+ depends_on=["input_selector", "model_selector"],
279
+ on_lock=self.classes_selector.lock,
280
+ on_unlock=self.classes_selector.unlock,
281
+ button=self.classes_selector.button,
282
+ )
196
283
 
197
284
  # 4. Tags selector
198
285
  self.tags_selector = None
199
286
  if False:
200
287
  self.tags_selector = TagsSelector()
201
- self.steps.append(self.tags_selector.card)
202
-
203
- # 5. Settings selector
204
- self.settings_selector = SettingsSelector()
205
- self.steps.append(self.settings_selector.card)
288
+ self.step_flow.add_step("tags_selector", self.tags_selector.card)
289
+
290
+ # 5. Settings selector & Preview
291
+ self.settings_selector = SettingsSelector(
292
+ api=self.api,
293
+ static_dir=self.static_dir,
294
+ model_selector=self.model_selector,
295
+ input_selector=self.input_selector,
296
+ )
206
297
 
207
- # 6. Preview
208
- self.preview = None
209
- if False:
210
- self.preview = Preview(api, static_dir)
211
- self.steps.append(self.preview.card)
298
+ def _on_settings_select():
299
+ valid = self.settings_selector.validate_step()
300
+ if not valid:
301
+ return
302
+ self.settings_selector.disable()
303
+ update_custom_button_params(self.settings_selector.button, reselect_params)
304
+
305
+ def _on_settings_reactivate():
306
+ self.settings_selector.enable()
307
+ update_custom_button_params(self.settings_selector.button, select_params)
308
+
309
+ self.step_flow.add_step(
310
+ name="settings_selector",
311
+ widget=self.settings_selector.cards_container,
312
+ on_select=_on_settings_select,
313
+ on_reactivate=_on_settings_reactivate,
314
+ depends_on=["input_selector", "model_selector", "classes_selector"],
315
+ on_lock=self.settings_selector.lock,
316
+ on_unlock=self.settings_selector.unlock,
317
+ button=self.settings_selector.button,
318
+ )
319
+ self.settings_selector.preview.run_button.disable()
212
320
 
213
- # 7. Output selector
321
+ # 6. Output selector
214
322
  self.output_selector = OutputSelector(self.api)
215
- self.steps.append(self.output_selector.card)
216
- # -------------------------------- #
217
323
 
218
- # Stepper
219
- self.stepper = Stepper(widgets=self.steps)
220
- # ---------------------------- #
324
+ self.step_flow.add_step(
325
+ "output_selector",
326
+ self.output_selector.card,
327
+ depends_on=[
328
+ "input_selector",
329
+ "model_selector",
330
+ "classes_selector",
331
+ # "tags_selector",
332
+ "settings_selector",
333
+ ],
334
+ on_lock=self.output_selector.lock,
335
+ on_unlock=self.output_selector.unlock,
336
+ )
337
+ # -------------------------------- #
221
338
 
222
339
  # Layout
223
- self.layout = Container([self.stepper])
340
+ self.layout = Container([self.step_flow.stepper])
224
341
  # ---------------------------- #
225
342
 
226
- # Button Utils
227
- def deploy_model() -> ModelAPI:
228
- self.model_selector.validator_text.hide()
229
- model_api = None
230
- try:
231
- model_api = type(self.model_selector.model).deploy(self.model_selector.model)
232
- except:
233
- self.output_selector.start_button.disable()
234
- raise
235
- else:
236
- self.output_selector.start_button.enable()
237
- return model_api
238
-
239
- # Reimplement deploy method for DeployModel widget
240
- self.model_selector.model.deploy = deploy_model
241
-
242
343
  def set_entity_meta():
243
344
  model_api = self.model_selector.model.model_api
244
345
 
245
346
  model_meta = model_api.get_model_meta()
246
347
  if self.classes_selector is not None:
247
- self.classes_selector.classes_table.set_project_meta(model_meta)
348
+ self.classes_selector.set_project_meta(model_meta)
248
349
  self.classes_selector.classes_table.show()
249
350
  if self.tags_selector is not None:
250
351
  self.tags_selector.tags_table.set_project_meta(model_meta)
@@ -253,13 +354,20 @@ class PredictAppGui:
253
354
  inference_settings = model_api.get_settings()
254
355
  self.settings_selector.set_inference_settings(inference_settings)
255
356
 
256
- if self.preview is not None:
257
- self.preview.inference_settings = inference_settings
357
+ if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
358
+ try:
359
+ tracking_settings = model_api.get_tracking_settings()
360
+ self.settings_selector.set_tracking_settings(tracking_settings)
361
+ except Exception as e:
362
+ logger.warning(
363
+ "Unable to get tracking settings from the model. Settings defaults"
364
+ )
365
+ self.settings_selector.set_default_tracking_settings()
258
366
 
259
367
  def reset_entity_meta():
260
368
  empty_meta = ProjectMeta()
261
369
  if self.classes_selector is not None:
262
- self.classes_selector.classes_table.set_project_meta(empty_meta)
370
+ self.classes_selector.set_project_meta(empty_meta)
263
371
  self.classes_selector.classes_table.hide()
264
372
  if self.tags_selector is not None:
265
373
  self.tags_selector.tags_table.set_project_meta(empty_meta)
@@ -267,394 +375,478 @@ class PredictAppGui:
267
375
 
268
376
  self.settings_selector.set_inference_settings("")
269
377
 
270
- if self.preview is not None:
271
- self.preview.inference_settings = None
272
-
273
- def disable_settings_editor():
274
- if self.settings_selector.inference_settings.readonly:
275
- self.settings_selector.inference_settings.readonly = False
378
+ def deploy_and_set_step():
379
+ self.model_selector.validator_text.hide()
380
+ model_api = type(self.model_selector.model).deploy(self.model_selector.model)
381
+ if model_api is not None:
382
+ set_entity_meta()
383
+ self.step_flow.select_step("model_selector")
276
384
  else:
277
- self.settings_selector.inference_settings.readonly = True
278
-
279
- def generate_preview():
280
- def _get_frame_annotation(
281
- video_info: VideoInfo, frame_index: int, project_meta: ProjectMeta
282
- ) -> Annotation:
283
- video_annotation = VideoAnnotation.from_json(
284
- self.api.video.annotation.download(video_info.id, frame_index),
285
- project_meta=project_meta,
286
- key_id_map=KeyIdMap(),
287
- )
288
- frame = video_annotation.frames.get(frame_index)
289
- img_size = (video_info.frame_height, video_info.frame_width)
290
- if frame is None:
291
- return Annotation(img_size)
292
- labels = []
293
- for figure in frame.figures:
294
- labels.append(Label(figure.geometry, figure.video_object.obj_class))
295
- ann = Annotation(img_size, labels=labels)
296
- return ann
297
-
298
- if self.preview is None:
299
- return
385
+ reset_entity_meta()
386
+ self.step_flow.reactivate_step("model_selector")
387
+ return model_api
300
388
 
301
- self.preview.validator_text.hide()
302
- self.preview.gallery.clean_up()
303
- self.preview.gallery.show()
304
- self.preview.gallery.loading = True
305
- try:
306
- items_settings = self.input_selector.get_settings()
307
- if "video_id" in items_settings:
308
- video_id = items_settings["video_id"]
309
- video_info = self.api.video.get_info_by_id(video_id)
310
- video_frame = random.randint(0, video_info.frames_count - 1)
311
- self.api.video.frame.download_path(
312
- video_info.id, video_frame, self.preview.preview_path
313
- )
314
- img_url = self.preview.peview_url
315
- project_meta = ProjectMeta.from_json(
316
- self.api.project.get_meta(video_info.project_id)
317
- )
318
- input_ann = _get_frame_annotation(video_info, video_frame, project_meta)
319
- prediction = self.model_selector.model.model_api.predict(
320
- input=self.preview.preview_path, **self.settings_selector.get_settings()
321
- )[0]
322
- output_ann = prediction.annotation
323
- else:
324
- if "project_id" in items_settings:
325
- project_id = items_settings["project_id"]
326
- dataset_infos = self.api.dataset.get_list(project_id, recursive=True)
327
- dataset_infos = [ds for ds in dataset_infos if ds.items_count > 0]
328
- if not dataset_infos:
329
- raise ValueError("No datasets with items found in the project.")
330
- dataset_info = random.choice(dataset_infos)
331
- elif "dataset_ids" in items_settings:
332
- dataset_ids = items_settings["dataset_ids"]
333
- dataset_infos = [
334
- self.api.dataset.get_info_by_id(dataset_id)
335
- for dataset_id in dataset_ids
336
- ]
337
- dataset_infos = [ds for ds in dataset_infos if ds.items_count > 0]
338
- if not dataset_infos:
339
- raise ValueError("No items in selected datasets.")
340
- dataset_info = random.choice(dataset_infos)
341
- else:
342
- raise ValueError("No valid item settings found for preview.")
343
- images = self.api.image.get_list(dataset_info.id)
344
- image_info = random.choice(images)
345
- img_url = image_info.preview_url
389
+ def stop_and_reset_step():
390
+ type(self.model_selector.model).stop(self.model_selector.model)
391
+ self.step_flow.reactivate_step("model_selector")
392
+ reset_entity_meta()
346
393
 
347
- project_meta = ProjectMeta.from_json(
348
- self.api.project.get_meta(dataset_info.project_id)
349
- )
350
- input_ann = Annotation.from_json(
351
- self.api.annotation.download(image_info.id).annotation,
352
- project_meta=project_meta,
353
- )
354
- prediction = self.model_selector.model.model_api.predict(
355
- image_id=image_info.id, **self.settings_selector.get_settings()
356
- )[0]
357
- output_ann = prediction.annotation
358
-
359
- self.preview.gallery.append(img_url, input_ann, "Input")
360
- self.preview.gallery.append(img_url, output_ann, "Output")
361
- self.preview.validator_text.hide()
362
- self.preview.gallery.show()
363
- return prediction
364
- except Exception as e:
365
- self.preview.gallery.hide()
366
- self.preview.validator_text.set(
367
- text=f"Error during preview: {str(e)}", status="error"
368
- )
369
- self.preview.validator_text.show()
370
- self.preview.gallery.clean_up()
371
- finally:
372
- self.preview.gallery.loading = False
394
+ def disconnect_and_reset_step():
395
+ type(self.model_selector.model).disconnect(self.model_selector.model)
396
+ self.step_flow.reactivate_step("model_selector")
397
+ reset_entity_meta()
373
398
 
374
- # ---------------------------- #
399
+ # Replace deploy methods for DeployModel widget
400
+ self.model_selector.model.deploy = deploy_and_set_step
401
+ self.model_selector.model.stop = stop_and_reset_step
402
+ self.model_selector.model.disconnect = disconnect_and_reset_step
375
403
 
376
- # StepFlow callbacks and wiring
377
- self.step_flow = StepFlow(self.stepper)
378
- position = 0
404
+ # ------------------------------------------------- #
379
405
 
380
- # 1. Input selector
381
- self.step_flow.register_step(
382
- "input_selector",
383
- self.input_selector.card,
384
- self.input_selector.button,
385
- self.input_selector.widgets_to_disable,
386
- self.input_selector.validator_text,
387
- self.input_selector.validate_step,
388
- position=position,
389
- )
390
- position += 1
406
+ @property
407
+ def model_api(self) -> Optional[ModelAPI]:
408
+ return self.model_selector.model.model_api
391
409
 
392
- # 2. Model selector
393
- self.step_flow.register_step(
394
- "model_selector",
395
- self.model_selector.card,
396
- self.model_selector.button,
397
- self.model_selector.widgets_to_disable,
398
- self.model_selector.validator_text,
399
- self.model_selector.validate_step,
400
- position=position,
401
- )
402
- self.step_flow.add_on_select_actions("model_selector", [set_entity_meta])
403
- self.step_flow.add_on_select_actions(
404
- "model_selector", [reset_entity_meta], is_reselect=True
405
- )
406
- position += 1
410
+ def update_item_type(self):
411
+ item_type = self.input_selector.radio.get_value()
412
+ self.settings_selector.update_item_type(item_type)
413
+ self.output_selector.update_item_type(item_type)
407
414
 
408
- # 3. Classes selector
409
- if self.classes_selector is not None:
410
- self.step_flow.register_step(
411
- "classes_selector",
412
- self.classes_selector.card,
413
- self.classes_selector.button,
414
- self.classes_selector.widgets_to_disable,
415
- self.classes_selector.validator_text,
416
- self.classes_selector.validate_step,
417
- position=position,
418
- )
419
- position += 1
415
+ def _run_videos(self, run_parameters: Dict[str, Any]) -> List[Prediction]:
416
+ if self.model_api is None:
417
+ self.set_validator_text("Deploying model...", "info")
418
+ self.model_selector.model._deploy()
419
+ if self.model_api is None:
420
+ logger.error("Model Deployed with an error")
421
+ raise RuntimeError("Model Deployed with an error")
420
422
 
421
- # 4. Tags selector
422
- if self.tags_selector is not None:
423
- self.step_flow.register_step(
424
- "tags_selector",
425
- self.tags_selector.card,
426
- self.tags_selector.button,
427
- self.tags_selector.widgets_to_disable,
428
- self.tags_selector.validator_text,
429
- self.tags_selector.validate_step,
430
- position=position,
431
- )
432
- position += 1
433
-
434
- # 5. Settings selector
435
- self.step_flow.register_step(
436
- "settings_selector",
437
- self.settings_selector.card,
438
- self.settings_selector.button,
439
- self.settings_selector.widgets_to_disable,
440
- self.settings_selector.validator_text,
441
- self.settings_selector.validate_step,
442
- position=position,
443
- )
444
- self.step_flow.add_on_select_actions("settings_selector", [disable_settings_editor])
445
- self.step_flow.add_on_select_actions("settings_selector", [disable_settings_editor], True)
446
- position += 1
447
-
448
- # 6. Preview
449
- if self.preview is not None:
450
- self.step_flow.register_step(
451
- "preview",
452
- self.preview.card,
453
- self.preview.button,
454
- self.preview.widgets_to_disable,
455
- self.preview.validator_text,
456
- self.preview.validate_step,
457
- position=position,
458
- ).add_on_select_actions("preview", [generate_preview])
459
- position += 1
460
-
461
- # 7. Output selector
462
- self.step_flow.register_step(
463
- "output_selector",
464
- self.output_selector.card,
465
- None,
466
- self.output_selector.widgets_to_disable,
467
- self.output_selector.validator_text,
468
- self.output_selector.validate_step,
469
- position=position,
470
- )
423
+ self.set_validator_text("Preparing settings for prediction...", "info")
424
+ if run_parameters is None:
425
+ run_parameters = self.get_run_parameters()
471
426
 
472
- # Dependencies Chain
473
- has_model_selector = self.model_selector is not None
474
- has_classes_selector = self.classes_selector is not None
475
- has_tags_selector = self.tags_selector is not None
476
- has_preview = self.preview is not None
477
-
478
- # Step 1 -> Step 2
479
- prev_step = "input_selector"
480
- if has_model_selector:
481
- self.step_flow.set_next_steps(prev_step, ["model_selector"])
482
- prev_step = "model_selector"
483
- # Step 2 -> Step 3
484
- if has_classes_selector:
485
- self.step_flow.set_next_steps(prev_step, ["classes_selector"])
486
- prev_step = "classes_selector"
487
- # Step 3 -> Step 4
488
- if has_tags_selector:
489
- self.step_flow.set_next_steps(prev_step, ["tags_selector"])
490
- prev_step = "tags_selector"
491
- # Step 4 -> Step 5
492
- self.step_flow.set_next_steps(prev_step, ["settings_selector"])
493
- prev_step = "settings_selector"
494
- # Step 5 -> Step 6
495
- if has_preview:
496
- self.step_flow.set_next_steps(prev_step, ["preview"])
497
- prev_step = "preview"
498
- # Step 6 -> Step 7
499
- self.step_flow.set_next_steps(prev_step, ["output_selector"])
500
-
501
- # Create all wrappers and set button handlers
502
- wrappers = self.step_flow.build()
503
-
504
- self.input_selector_cb = wrappers.get("input_selector")
505
- self.classes_selector_cb = wrappers.get("classes_selector")
506
- self.tags_selector_cb = wrappers.get("tags_selector")
507
- self.model_selector_cb = wrappers.get("model_selector")
508
- self.settings_selector_cb = wrappers.get("settings_selector")
509
- self.preview_cb = wrappers.get("preview")
510
- self.output_selector_cb = wrappers.get("output_selector")
511
- # ------------------------------------------------- #
427
+ input_parameters = run_parameters["input"]
428
+ input_video_ids = input_parameters["video_ids"]
429
+ if not input_video_ids:
430
+ raise ValueError("No video IDs provided for video prediction.")
512
431
 
513
- # Other Handlers
514
- @self.input_selector.radio.value_changed
515
- def input_selector_type_changed(value: str):
516
- self.input_selector.validator_text.hide()
432
+ predict_kwargs = {}
433
+ # Settings
434
+ settings = run_parameters["settings"]
435
+ model_prediction_suffix = settings.pop("model_prediction_suffix", "")
436
+ prediction_mode = settings.pop("predictions_mode")
437
+ tracking = settings.pop("tracking", False)
438
+ predict_kwargs.update(settings)
517
439
 
518
- @self.input_selector.select_dataset_for_video.value_changed
519
- def dataset_for_video_changed(dataset_id: int):
520
- self.input_selector.select_video.loading = True
521
- if dataset_id is None:
522
- rows = []
523
- else:
524
- dataset_info = self.api.dataset.get_info_by_id(dataset_id)
525
- videos = self.api.video.get_list(dataset_id)
526
- rows = [[video.id, video.name, dataset_info.name] for video in videos]
527
- self.input_selector.select_video.rows = rows
528
- self.input_selector.select_video.loading = False
440
+ # Classes
441
+ classes = run_parameters["classes"]
442
+ if classes:
443
+ predict_kwargs["classes"] = classes
529
444
 
530
- # ------------------------------------------------- #
445
+ output_parameters = run_parameters["output"]
446
+ project_name = output_parameters.get("project_name", "")
447
+ upload_to_source_project = output_parameters.get("upload_to_source_project", False)
448
+ skip_project_versioning = output_parameters.get("skip_project_versioning", False)
449
+ skip_annotated = output_parameters.get("skip_annotated", False)
450
+
451
+ video_infos_by_project_id: Dict[int, List[VideoInfo]] = {}
452
+ video_infos_by_dataset_id: Dict[int, List[VideoInfo]] = {}
453
+ for info in self.api.video.get_info_by_id_batch(input_video_ids):
454
+ video_infos_by_project_id.setdefault(info.project_id, []).append(info)
455
+ video_infos_by_dataset_id.setdefault(info.dataset_id, []).append(info)
456
+ src_project_metas: Dict[int, ProjectMeta] = {}
457
+ for project_id in video_infos_by_project_id.keys():
458
+ src_project_metas[project_id] = ProjectMeta.from_json(
459
+ self.api.project.get_meta(project_id)
460
+ )
531
461
 
532
- def run(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
533
- self.show_validator_text()
534
- self.set_validator_text("Preparing settings for prediction...", "info")
535
- if run_parameters is None:
536
- run_parameters = self.get_run_parameters()
462
+ video_ids_to_skip = set()
463
+ if skip_annotated:
464
+ self.set_validator_text("Checking for already annotated videos...", "info")
465
+ secondary_pbar = self.output_selector.secondary_progress(
466
+ message="Checking for already annotated videos...", total=len(input_video_ids)
467
+ )
468
+ self.output_selector.secondary_progress.show()
469
+ for dataset_id, video_infos in video_infos_by_dataset_id.items():
470
+ annotations = self.api.video.annotation.download_bulk(
471
+ dataset_id, [info.id for info in video_infos]
472
+ )
473
+ for ann_json, video_info in zip(annotations, video_infos):
474
+ if ann_json:
475
+ project_meta = src_project_metas[video_info.project_id]
476
+ ann = VideoAnnotation.from_json(ann_json, project_meta=project_meta)
477
+ if len(ann.figures) > 0:
478
+ video_ids_to_skip.add(video_info.id)
479
+ secondary_pbar.update()
480
+ self.output_selector.secondary_progress.hide()
481
+ if video_ids_to_skip:
482
+ video_infos_by_project_id = {
483
+ pid: [info for info in infos if info.id not in video_ids_to_skip]
484
+ for pid, infos in video_infos_by_project_id.items()
485
+ }
486
+
487
+ main_pbar_str = "Processing videos..."
488
+ if video_ids_to_skip:
489
+ main_pbar_str += f" (Skipped {len(video_ids_to_skip)} already annotated videos)"
490
+ total_videos = sum(len(v) for v in video_infos_by_project_id.values())
491
+ if total_videos == 0:
492
+ self.set_validator_text(
493
+ f"No videos to process. Skipped {len(video_ids_to_skip)} already annotated videos",
494
+ "warning",
495
+ )
496
+ return []
497
+ main_pbar = self.output_selector.progress(message=main_pbar_str, total=total_videos)
498
+ self.output_selector.progress.show()
499
+ all_predictictions: List[Prediction] = []
500
+ for src_project_id, src_video_infos in video_infos_by_project_id.items():
501
+ if len(src_video_infos) == 0:
502
+ continue
503
+ project_info = self.api.project.get_info_by_id(src_project_id)
504
+ project_validator_text_str = (
505
+ f"Processing project: {project_info.name} [id: {src_project_id}]"
506
+ )
507
+ if upload_to_source_project:
508
+ if not skip_project_versioning and not is_development():
509
+ logger.info("Creating new project version...")
510
+ self.set_validator_text(
511
+ project_validator_text_str + ": Creating project version",
512
+ "info",
513
+ )
514
+ version_id = self.api.project.version.create(
515
+ project_info,
516
+ "Created by Predict App. Task Id: " + str(env.task_id()),
517
+ )
518
+ logger.info("New project version created: " + str(version_id))
519
+ output_project_id = src_project_id
520
+ output_videos: List[VideoInfo] = src_video_infos
521
+ else:
522
+ self.set_validator_text(
523
+ project_validator_text_str + ": Creating project...", "info"
524
+ )
525
+ if not project_name:
526
+ project_name = project_info.name + " [Predictions]"
527
+ logger.warning(
528
+ "Project name is empty, using auto-generated name: " + project_name
529
+ )
530
+ with_annotations = prediction_mode in [
531
+ AddPredictionsMode.APPEND,
532
+ AddPredictionsMode.IOU_MERGE,
533
+ ]
534
+ created_project = create_project(
535
+ api=self.api,
536
+ project_id=src_project_id,
537
+ project_name=project_name,
538
+ workspace_id=self.workspace_id,
539
+ copy_meta=with_annotations,
540
+ project_type=ProjectType.VIDEOS,
541
+ )
542
+ output_project_id = created_project.id
543
+ output_videos: List[VideoInfo] = copy_items_to_project(
544
+ api=self.api,
545
+ src_project_id=src_project_id,
546
+ items=src_video_infos,
547
+ dst_project_id=created_project.id,
548
+ with_annotations=with_annotations,
549
+ ds_progress=self.output_selector.secondary_progress,
550
+ project_type=ProjectType.VIDEOS,
551
+ )
537
552
 
538
- if self.model_selector.model.model_api is None:
539
- self.model_selector.model._deploy()
553
+ self.set_validator_text(
554
+ project_validator_text_str + ": Merging project meta",
555
+ "info",
556
+ )
557
+ project_meta = src_project_metas[src_project_id]
558
+ for src_video_info, output_video_info in zip(src_video_infos, output_videos):
559
+ video_validator_text_str = (
560
+ project_validator_text_str
561
+ + f", video: {src_video_info.name} [id: {src_video_info.id}]"
562
+ )
563
+ self.set_validator_text(
564
+ video_validator_text_str + ": Predicting",
565
+ "info",
566
+ )
567
+ frames_predictions: List[Prediction] = []
568
+ with self.model_api.predict_detached(
569
+ video_id=src_video_info.id,
570
+ tqdm=self.output_selector.secondary_progress(),
571
+ tracking=tracking,
572
+ **predict_kwargs,
573
+ ) as session:
574
+ self.output_selector.secondary_progress.show()
575
+ for prediction in session:
576
+ if self._stop_flag:
577
+ logger.info("Prediction stopped by user.")
578
+ raise StopIteration("Stopped by user.")
579
+ frames_predictions.append(prediction)
580
+ all_predictictions.extend(frames_predictions)
581
+ if tracking:
582
+ prediction_video_annotation: VideoAnnotation = VideoAnnotation.from_json(
583
+ session.final_result["video_ann"],
584
+ project_meta=self.model_api.get_model_meta(),
585
+ )
586
+ else:
587
+ prediction_video_annotation = video_annotation_from_predictions(
588
+ frames_predictions,
589
+ project_meta,
590
+ frame_size=(src_video_info.frame_height, src_video_info.frame_width),
591
+ )
592
+ if prediction_video_annotation is None:
593
+ logger.warning(
594
+ f"No predictions were made for video {src_video_info.name} [id: {src_video_info.id}]"
595
+ )
596
+ main_pbar.update()
597
+ continue
598
+ self.set_validator_text(
599
+ video_validator_text_str + ": Uploading predictions",
600
+ "info",
601
+ )
602
+ project_meta, prediction_video_annotation, meta_changed = (
603
+ update_meta_and_ann_for_video_annotation(
604
+ meta=project_meta,
605
+ ann=prediction_video_annotation,
606
+ model_prediction_suffix=model_prediction_suffix,
607
+ )
608
+ )
609
+ if meta_changed:
610
+ self.api.project.update_meta(output_project_id, project_meta)
611
+ if upload_to_source_project:
612
+ if prediction_mode in [
613
+ AddPredictionsMode.REPLACE,
614
+ AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS,
615
+ ]:
616
+ self.output_selector.secondary_progress.hide()
617
+ with open("/tmp/prediction_video_annotation.json", "w") as f:
618
+ json.dump(prediction_video_annotation.to_json(), f)
619
+ self.api.video.annotation.upload_paths(
620
+ video_ids=[src_video_info.id],
621
+ paths=["/tmp/prediction_video_annotation.json"],
622
+ project_meta=project_meta,
623
+ )
624
+ else:
625
+ secondary_pbar = self.output_selector.secondary_progress(
626
+ message="Uploading annotations...",
627
+ total=len(prediction_video_annotation.figures),
628
+ )
629
+ self.output_selector.secondary_progress.show()
630
+ self.api.video.annotation.append(
631
+ video_id=src_video_info.id,
632
+ ann=prediction_video_annotation,
633
+ key_id_map=KeyIdMap(),
634
+ progress_cb=secondary_pbar.update,
635
+ )
636
+ else:
637
+ secondary_pbar = self.output_selector.secondary_progress(
638
+ message="Uploading annotations...",
639
+ total=len(prediction_video_annotation.figures),
640
+ )
641
+ self.output_selector.secondary_progress.show()
642
+ self.api.video.annotation.append(
643
+ video_id=output_video_info.id,
644
+ ann=prediction_video_annotation,
645
+ key_id_map=KeyIdMap(),
646
+ progress_cb=secondary_pbar.update,
647
+ )
648
+ main_pbar.update()
649
+ self.set_validator_text("Project successfully processed", "success")
650
+ self.output_selector.set_result_thumbnail(output_project_id)
651
+ return all_predictictions
540
652
 
541
- model_api = self.model_selector.model.model_api
542
- if model_api is None:
653
+ def _run_images(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
654
+ if self.model_api is None:
655
+ self.set_validator_text("Deploying model...", "info")
656
+ self.model_selector.model._deploy()
657
+ if self.model_api is None:
543
658
  logger.error("Model Deployed with an error")
544
- self.set_validator_text("Model Deployed with an error", "error")
545
- return
659
+ raise RuntimeError("Model Deployed with an error")
546
660
 
547
- kwargs = {}
661
+ self.set_validator_text("Preparing settings for prediction...", "info")
662
+ if run_parameters is None:
663
+ run_parameters = self.get_run_parameters()
548
664
 
665
+ predict_kwargs = {}
549
666
  # Input
550
- # Input would be newely created project
551
667
  input_args = {}
552
668
  input_parameters = run_parameters["input"]
553
669
  input_project_id = input_parameters.get("project_id", None)
554
- if input_project_id is None:
555
- raise ValueError("Input project ID is required for prediction.")
556
670
  input_dataset_ids = input_parameters.get("dataset_ids", [])
557
671
  input_image_ids = input_parameters.get("image_ids", [])
558
- if not (input_dataset_ids or input_image_ids):
559
- raise ValueError("At least one dataset must be selected for prediction.")
560
672
  if input_image_ids:
561
673
  input_args["image_ids"] = input_image_ids
562
674
  elif input_dataset_ids:
563
675
  input_args["dataset_ids"] = input_dataset_ids
564
- else:
676
+ elif input_project_id:
565
677
  input_args["project_id"] = input_project_id
678
+ else:
679
+ raise ValueError("No valid input parameters found for prediction.")
566
680
 
567
681
  # Settings
568
682
  settings = run_parameters["settings"]
569
683
  prediction_mode = settings.pop("predictions_mode")
570
684
  upload_mode = None
571
685
  with_annotations = None
572
- if prediction_mode == AddPredictionsMode.REPLACE_EXISTING_LABELS:
686
+ if prediction_mode == AddPredictionsMode.REPLACE:
573
687
  upload_mode = "replace"
574
688
  with_annotations = False
575
- elif prediction_mode == AddPredictionsMode.MERGE_WITH_EXISTING_LABELS:
689
+ elif prediction_mode == AddPredictionsMode.APPEND:
576
690
  upload_mode = "append"
577
691
  with_annotations = True
692
+ elif prediction_mode == AddPredictionsMode.IOU_MERGE:
693
+ upload_mode = "iou_merge"
694
+ with_annotations = True
578
695
  elif prediction_mode == AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS:
579
696
  upload_mode = "replace"
580
697
  with_annotations = True
581
- kwargs.update(settings)
582
- kwargs["upload_mode"] = upload_mode
698
+ predict_kwargs.update(settings)
699
+ predict_kwargs["upload_mode"] = upload_mode
583
700
 
584
701
  # Classes
585
702
  classes = run_parameters["classes"]
586
703
  if classes:
587
- kwargs["classes"] = classes
704
+ predict_kwargs["classes"] = classes
588
705
 
589
706
  # Output
590
- # Always create new project
591
- # But the actual inference will happen inplace
592
707
  output_parameters = run_parameters["output"]
593
- project_name = output_parameters.get("project_name", "")
708
+ project_name = output_parameters.get("project_name", None)
594
709
  upload_to_source_project = output_parameters.get("upload_to_source_project", False)
595
- if upload_to_source_project:
596
- output_project_id = input_project_id
597
- else:
598
- if not project_name:
599
- input_project_info = self.api.project.get_info_by_id(input_project_id)
600
- project_name = input_project_info.name + " [Predictions]"
601
- logger.warning("Project name is empty, using auto-generated name: " + project_name)
602
-
603
- # Copy project
604
- self.set_validator_text("Copying project...", "info")
605
- created_project = copy_project(
606
- self.api,
607
- project_name,
608
- self.workspace_id,
609
- input_project_id,
610
- input_dataset_ids,
611
- with_annotations,
612
- self.output_selector.progress,
710
+ skip_project_versioning = output_parameters.get("skip_project_versioning", False)
711
+ skip_annotated = output_parameters.get("skip_annotated", False)
712
+
713
+ image_infos = []
714
+ if input_image_ids:
715
+ image_infos = self.api.image.get_info_by_id_batch(input_image_ids)
716
+ elif input_dataset_ids:
717
+ for dataset_id in input_dataset_ids:
718
+ image_infos.extend(self.api.image.get_list(dataset_id))
719
+ elif input_project_id:
720
+ datasets = self.api.dataset.get_list(input_project_id, recursive=True)
721
+ for dataset in datasets:
722
+ image_infos.extend(self.api.image.get_list(dataset.id))
723
+ if len(image_infos) == 0:
724
+ raise ValueError("No images found for the given input parameters.")
725
+
726
+ to_skip = []
727
+ if skip_annotated:
728
+ to_skip = [image_info.id for image_info in image_infos if image_info.labels_count == 0]
729
+ if to_skip:
730
+ image_infos = [info for info in image_infos if info.id not in to_skip]
731
+ if len(image_infos) == 0:
732
+ self.set_validator_text(
733
+ f"All images are already annotated. Nothing to predict.", "warning"
734
+ )
735
+ return []
736
+
737
+ image_infos_by_project_id: Dict[int, List[ImageInfo]] = {}
738
+ image_infos_by_dataset_id: Dict[int, List[ImageInfo]] = {}
739
+ ds_project_mapping: Dict[int, int] = {}
740
+ for info in image_infos:
741
+ image_infos_by_dataset_id.setdefault(info.dataset_id, []).append(info)
742
+ if info.dataset_id not in ds_project_mapping:
743
+ ds_info = self.api.dataset.get_info_by_id(info.dataset_id)
744
+ ds_project_mapping[info.dataset_id] = ds_info.project_id
745
+ project_id = ds_project_mapping[info.dataset_id]
746
+ image_infos_by_project_id.setdefault(project_id, []).append(info)
747
+
748
+ src_project_metas: Dict[int, ProjectMeta] = {}
749
+ for project_id in image_infos_by_project_id.keys():
750
+ src_project_metas[project_id] = ProjectMeta.from_json(
751
+ self.api.project.get_meta(project_id)
613
752
  )
614
- output_project_id = created_project.id
615
- input_args = {
616
- "project_id": output_project_id,
617
- }
618
753
 
619
- # ------------------------ #
754
+ self.output_selector.progress.show()
755
+ total_items = sum(len(v) for v in image_infos_by_project_id.values())
756
+ main_pbar = self.output_selector.progress(message=f"Copying images...", total=total_items)
757
+ for src_project_id, infos in image_infos_by_project_id.items():
758
+ if len(infos) == 0:
759
+ continue
760
+ project_info = self.api.project.get_info_by_id(src_project_id)
761
+ project_validator_text_str = (
762
+ f"Processing project: {project_info.name} [id: {src_project_id}]"
763
+ )
764
+ if upload_to_source_project:
765
+ if not skip_project_versioning and not is_development():
766
+ logger.info("Creating new project version...")
767
+ self.set_validator_text(
768
+ project_validator_text_str + ": Creating project version", "info"
769
+ )
770
+ version_id = self.api.project.version.create(
771
+ project_info,
772
+ "Created by Predict App. Task Id: " + str(env.task_id()),
773
+ )
774
+ logger.info("New project version created: " + str(version_id))
775
+ output_project_id = src_project_id
776
+ output_image_infos: List[ImageInfo] = infos
777
+ else:
778
+ self.set_validator_text(
779
+ project_validator_text_str + ": Creating project...", "info"
780
+ )
781
+ if not project_name:
782
+ project_name = project_info.name + " [Predictions]"
783
+ logger.warning(
784
+ "Project name is empty, using auto-generated name: " + project_name
785
+ )
786
+ created_project = create_project(
787
+ api=self.api,
788
+ project_id=src_project_id,
789
+ project_name=project_name,
790
+ workspace_id=self.workspace_id,
791
+ copy_meta=with_annotations,
792
+ project_type=ProjectType.IMAGES,
793
+ )
794
+ output_project_id = created_project.id
795
+ output_image_infos: List[ImageInfo] = copy_items_to_project(
796
+ api=self.api,
797
+ src_project_id=src_project_id,
798
+ items=infos,
799
+ dst_project_id=created_project.id,
800
+ with_annotations=with_annotations,
801
+ ds_progress=self.output_selector.secondary_progress,
802
+ progress_cb=main_pbar.update,
803
+ project_type=ProjectType.IMAGES,
804
+ )
620
805
 
621
806
  # Run prediction
622
807
  self.set_validator_text("Running prediction...", "info")
623
- predictions = []
808
+ predictions: List[Prediction] = []
624
809
  self._is_running = True
810
+ with self.model_api.predict_detached(
811
+ image_ids=[info.id for info in output_image_infos],
812
+ **predict_kwargs,
813
+ tqdm=self.output_selector.progress(),
814
+ ) as session:
815
+ for prediction in session:
816
+ if self._stop_flag:
817
+ logger.info("Prediction stopped by user.")
818
+ raise StopIteration("Stopped by user.")
819
+ predictions.append(prediction)
820
+ self.set_validator_text("Project successfully processed", "success")
821
+ self.output_selector.set_result_thumbnail(output_project_id)
822
+ return predictions
823
+
824
+ def run(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
825
+ self.show_validator_text()
826
+ if run_parameters is None:
827
+ run_parameters = self.get_run_parameters()
828
+ input_parameters = run_parameters["input"]
829
+ video_ids = input_parameters.get("video_ids", None)
625
830
  try:
626
- with model_api.predict_detached(
627
- **input_args,
628
- tqdm=self.output_selector.progress(),
629
- **kwargs,
630
- ) as session:
631
- self.output_selector.progress.show()
632
- i = 0
633
- for prediction in session:
634
- predictions.append(prediction)
635
- i += 1
636
- if self._stop_flag:
637
- logger.info("Prediction stopped by user.")
638
- break
639
- self.output_selector.progress.hide()
831
+ if video_ids:
832
+ run_f = self._run_videos
833
+ else:
834
+ run_f = self._run_images
835
+ return run_f(run_parameters)
836
+ except StopIteration:
837
+ logger.info("Prediction stopped by user.")
838
+ self.set_validator_text("Prediction stopped by user.", "warning")
839
+ raise
640
840
  except Exception as e:
641
- self.output_selector.progress.hide()
642
841
  logger.error(f"Error during prediction: {str(e)}")
643
842
  self.set_validator_text(f"Error during prediction: {str(e)}", "error")
644
843
  disable_enable(self.output_selector.widgets_to_disable, False)
645
- self._is_running = False
646
- self._stop_flag = False
647
- raise e
844
+ raise
648
845
  finally:
846
+ self.output_selector.secondary_progress.hide()
847
+ self.output_selector.progress.hide()
649
848
  self._is_running = False
650
849
  self._stop_flag = False
651
- # ------------------------ #
652
-
653
- # Set result thumbnail
654
- self.set_validator_text("Project successfully processed", "success")
655
- self.output_selector.set_result_thumbnail(output_project_id)
656
- # ------------------------ #
657
- return predictions
658
850
 
659
851
  def stop(self):
660
852
  logger.info("Stopping prediction...")
@@ -707,14 +899,10 @@ class PredictAppGui:
707
899
  if self.tags_selector is not None:
708
900
  self.tags_selector.load_from_json(data.get("tags", {}))
709
901
 
710
- # 5. Settings selector
902
+ # 5. Settings selector & Preview
711
903
  self.settings_selector.load_from_json(data.get("settings", {}))
712
904
 
713
- # 6. Preview (No need?)
714
- if self.preview is not None:
715
- self.preview.load_from_json(data.get("preview", {}))
716
-
717
- # 7. Output selector
905
+ # 6. Output selector
718
906
  self.output_selector.load_from_json(data.get("output", {}))
719
907
 
720
908
  def set_validator_text(self, text: str, status: str = "text"):