supervisely 6.73.457__py3-none-any.whl → 6.73.458__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +24 -1
- supervisely/api/image_api.py +4 -0
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +41 -1
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +45 -11
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +63 -7
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/inference/inference.py +364 -73
- supervisely/nn/inference/inference_request.py +3 -2
- supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +178 -25
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort_tracker.py +14 -7
- supervisely/nn/tracker/visualize.py +70 -72
- supervisely/video/video.py +15 -1
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/METADATA +3 -2
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/RECORD +40 -40
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/LICENSE +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/WHEEL +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/top_level.txt +0 -0
@@ -1,163 +1,156 @@
|
|
1
|
-
import
|
1
|
+
import json
|
2
2
|
import time
|
3
|
+
from pathlib import Path
|
3
4
|
from typing import Any, Callable, Dict, List, Optional
|
4
5
|
|
5
6
|
import yaml
|
6
7
|
|
7
8
|
from supervisely._utils import is_development, logger
|
8
|
-
from supervisely.annotation.annotation import Annotation
|
9
|
-
from supervisely.annotation.label import Label
|
10
9
|
from supervisely.api.api import Api
|
10
|
+
from supervisely.api.image_api import ImageInfo
|
11
11
|
from supervisely.api.video.video_api import VideoInfo
|
12
12
|
from supervisely.app.widgets import Button, Card, Container, Stepper, Widget
|
13
|
+
from supervisely.geometry.any_geometry import AnyGeometry
|
13
14
|
from supervisely.io import env
|
15
|
+
from supervisely.nn.inference.inference import update_meta_and_ann_for_video_annotation
|
14
16
|
from supervisely.nn.inference.predict_app.gui.classes_selector import ClassesSelector
|
15
17
|
from supervisely.nn.inference.predict_app.gui.input_selector import InputSelector
|
16
18
|
from supervisely.nn.inference.predict_app.gui.model_selector import ModelSelector
|
17
19
|
from supervisely.nn.inference.predict_app.gui.output_selector import OutputSelector
|
18
|
-
from supervisely.nn.inference.predict_app.gui.preview import Preview
|
19
20
|
from supervisely.nn.inference.predict_app.gui.settings_selector import (
|
20
21
|
AddPredictionsMode,
|
21
22
|
SettingsSelector,
|
22
23
|
)
|
23
24
|
from supervisely.nn.inference.predict_app.gui.tags_selector import TagsSelector
|
24
25
|
from supervisely.nn.inference.predict_app.gui.utils import (
|
25
|
-
|
26
|
+
copy_items_to_project,
|
27
|
+
create_project,
|
26
28
|
disable_enable,
|
27
|
-
|
28
|
-
|
29
|
+
update_custom_button_params,
|
30
|
+
video_annotation_from_predictions,
|
29
31
|
)
|
30
32
|
from supervisely.nn.model.model_api import ModelAPI
|
31
33
|
from supervisely.nn.model.prediction import Prediction
|
32
34
|
from supervisely.project.project_meta import ProjectMeta
|
35
|
+
from supervisely.project.project_type import ProjectType
|
33
36
|
from supervisely.video_annotation.key_id_map import KeyIdMap
|
34
37
|
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
35
38
|
|
36
39
|
|
37
40
|
class StepFlow:
|
38
|
-
|
39
|
-
|
40
|
-
self.stepper = stepper
|
41
|
+
def __init__(self):
|
42
|
+
self._stepper = None
|
41
43
|
self.steps = {}
|
42
|
-
self.
|
44
|
+
self.steps_sequence = []
|
43
45
|
|
44
|
-
def
|
46
|
+
def add_step(
|
45
47
|
self,
|
46
48
|
name: str,
|
47
|
-
|
49
|
+
widget: Widget,
|
50
|
+
on_select: Optional[Callable] = None,
|
51
|
+
on_reactivate: Optional[Callable] = None,
|
52
|
+
depends_on: Optional[List[Widget]] = None,
|
53
|
+
on_lock: Optional[Callable] = None,
|
54
|
+
on_unlock: Optional[Callable] = None,
|
48
55
|
button: Optional[Button] = None,
|
49
|
-
widgets_to_disable: Optional[List[Widget]] = None,
|
50
|
-
validation_text: Optional[Widget] = None,
|
51
|
-
validation_func: Optional[Callable] = None,
|
52
56
|
position: Optional[int] = None,
|
53
|
-
)
|
57
|
+
):
|
58
|
+
if depends_on is None:
|
59
|
+
depends_on = []
|
54
60
|
self.steps[name] = {
|
55
|
-
"
|
61
|
+
"widget": widget,
|
62
|
+
"on_select": on_select,
|
63
|
+
"on_reactivate": on_reactivate,
|
64
|
+
"depends_on": depends_on,
|
65
|
+
"on_lock": on_lock,
|
66
|
+
"on_unlock": on_unlock,
|
56
67
|
"button": button,
|
57
|
-
"
|
58
|
-
"
|
59
|
-
"validation_func": validation_func,
|
60
|
-
"position": position,
|
61
|
-
"next_steps": [],
|
62
|
-
"on_select_click": [],
|
63
|
-
"on_reselect_click": [],
|
64
|
-
"wrapper": None,
|
65
|
-
"has_button": button is not None,
|
68
|
+
"is_selected": False,
|
69
|
+
"is_locked": False,
|
66
70
|
}
|
67
|
-
|
71
|
+
if button is not None:
|
72
|
+
self._wrap_button(button, name)
|
68
73
|
if position is not None:
|
69
|
-
|
70
|
-
|
71
|
-
self.
|
72
|
-
|
73
|
-
return self
|
74
|
-
|
75
|
-
def set_next_steps(self, step_name: str, next_steps: List[str]) -> "StepFlow":
|
76
|
-
if step_name in self.steps:
|
77
|
-
self.steps[step_name]["next_steps"] = next_steps
|
78
|
-
return self
|
79
|
-
|
80
|
-
def add_on_select_actions(
|
81
|
-
self, step_name: str, actions: List[Callable], is_reselect: bool = False
|
82
|
-
) -> "StepFlow":
|
83
|
-
if step_name in self.steps:
|
84
|
-
key = "on_reselect_click" if is_reselect else "on_select_click"
|
85
|
-
self.steps[step_name][key].extend(actions)
|
86
|
-
return self
|
87
|
-
|
88
|
-
def build_wrappers(self) -> Dict[str, Callable]:
|
89
|
-
valid_sequence = [s for s in self.step_sequence if s is not None and s in self.steps]
|
74
|
+
self.steps_sequence.insert(position, name)
|
75
|
+
else:
|
76
|
+
self.steps_sequence.append(name)
|
77
|
+
self.update_locks()
|
90
78
|
|
91
|
-
|
79
|
+
def _create_stepper(self):
|
80
|
+
widgets = []
|
81
|
+
for step_name in self.steps_sequence:
|
92
82
|
step = self.steps[step_name]
|
83
|
+
widgets.append(step["widget"])
|
84
|
+
self._stepper = Stepper(widgets=widgets)
|
93
85
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
callback = None
|
100
|
-
if step["next_steps"] and step["has_button"]:
|
101
|
-
for next_step_name in step["next_steps"]:
|
102
|
-
if (
|
103
|
-
next_step_name in self.steps
|
104
|
-
and self.steps[next_step_name].get("wrapper")
|
105
|
-
and self.steps[next_step_name]["has_button"]
|
106
|
-
):
|
107
|
-
callback = self.steps[next_step_name]["wrapper"]
|
108
|
-
break
|
109
|
-
|
110
|
-
if step["has_button"]:
|
111
|
-
wrapper = wrap_button_click(
|
112
|
-
button=step["button"],
|
113
|
-
cards_to_unlock=cards_to_unlock,
|
114
|
-
widgets_to_disable=step["widgets_to_disable"],
|
115
|
-
callback=callback,
|
116
|
-
validation_text=step["validation_text"],
|
117
|
-
validation_func=step["validation_func"],
|
118
|
-
on_select_click=step["on_select_click"],
|
119
|
-
on_reselect_click=step["on_reselect_click"],
|
120
|
-
collapse_card=None,
|
121
|
-
)
|
122
|
-
|
123
|
-
step["wrapper"] = wrapper
|
124
|
-
|
125
|
-
return {
|
126
|
-
name: self.steps[name]["wrapper"]
|
127
|
-
for name in self.steps
|
128
|
-
if self.steps[name].get("wrapper") and self.steps[name]["has_button"]
|
129
|
-
}
|
130
|
-
|
131
|
-
def setup_button_handlers(self) -> None:
|
132
|
-
positions = {}
|
133
|
-
pos = 1
|
134
|
-
|
135
|
-
for i, step_name in enumerate(self.step_sequence):
|
136
|
-
if step_name is not None and step_name in self.steps:
|
137
|
-
positions[step_name] = pos
|
138
|
-
pos += 1
|
139
|
-
|
140
|
-
for step_name, step in self.steps.items():
|
141
|
-
if step_name in positions and step.get("wrapper") and step["has_button"]:
|
86
|
+
@property
|
87
|
+
def stepper(self):
|
88
|
+
if self._stepper is None:
|
89
|
+
self._create_stepper()
|
90
|
+
return self._stepper
|
142
91
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
def handler():
|
150
|
-
cb()
|
151
|
-
set_stepper_step(self.stepper, btn, next_pos=next_pos)
|
152
|
-
|
153
|
-
return handler
|
92
|
+
def update_stepper(self):
|
93
|
+
for i, step_name in enumerate(self.steps_sequence):
|
94
|
+
step = self.steps[step_name]
|
95
|
+
if not step["is_selected"]:
|
96
|
+
self.stepper.set_active_step(i + 1)
|
97
|
+
return
|
154
98
|
|
155
|
-
|
99
|
+
def update_locks(self):
|
100
|
+
for step in self.steps.values():
|
101
|
+
should_lock = False
|
102
|
+
for dep_name in step["depends_on"]:
|
103
|
+
dep = self.steps[dep_name]
|
104
|
+
if not dep["is_selected"]:
|
105
|
+
should_lock = True
|
106
|
+
break
|
107
|
+
if should_lock and not step["is_locked"]:
|
108
|
+
if step["on_lock"] is not None:
|
109
|
+
step["on_lock"]()
|
110
|
+
step["is_locked"] = True
|
111
|
+
if not should_lock and step["is_locked"]:
|
112
|
+
if step["on_unlock"]:
|
113
|
+
step["on_unlock"]()
|
114
|
+
step["is_locked"] = False
|
115
|
+
|
116
|
+
def _reactivate_dependents(self, step_name: str, visited=None):
|
117
|
+
if visited is None:
|
118
|
+
visited = set()
|
119
|
+
for dep_name, step in self.steps.items():
|
120
|
+
if step_name in step["depends_on"] and not dep_name in visited:
|
121
|
+
self._reactivate_step(dep_name, visited)
|
122
|
+
|
123
|
+
def _reactivate_step(self, step_name: str, visited=None):
|
124
|
+
step = self.steps[step_name]
|
125
|
+
if step["on_reactivate"] is not None:
|
126
|
+
step["on_reactivate"]()
|
127
|
+
step["is_selected"] = False
|
128
|
+
if visited is None:
|
129
|
+
visited = set()
|
130
|
+
self._reactivate_dependents(step_name, visited)
|
131
|
+
|
132
|
+
def reactivate_step(self, step_name: str):
|
133
|
+
self._reactivate_step(step_name)
|
134
|
+
self.update_stepper()
|
135
|
+
self.update_locks()
|
136
|
+
|
137
|
+
def select_step(self, step_name: str):
|
138
|
+
step = self.steps[step_name]
|
139
|
+
if step["on_select"] is not None:
|
140
|
+
step["on_select"]()
|
141
|
+
step["is_selected"] = True
|
142
|
+
self.update_stepper()
|
143
|
+
self.update_locks()
|
144
|
+
|
145
|
+
def select_or_reactivate(self, step_name: str):
|
146
|
+
step = self.steps[step_name]
|
147
|
+
if step["is_selected"]:
|
148
|
+
self.reactivate_step(step_name)
|
149
|
+
else:
|
150
|
+
self.select_step(step_name)
|
156
151
|
|
157
|
-
def
|
158
|
-
|
159
|
-
self.setup_button_handlers()
|
160
|
-
return wrappers
|
152
|
+
def _wrap_button(self, button: Button, step_name: str):
|
153
|
+
button.click(lambda: self.select_or_reactivate(step_name))
|
161
154
|
|
162
155
|
|
163
156
|
class PredictAppGui:
|
@@ -169,6 +162,9 @@ class PredictAppGui:
|
|
169
162
|
self.team_id = env.team_id()
|
170
163
|
self.workspace_id = env.workspace_id()
|
171
164
|
self.project_id = env.project_id(raise_not_found=False)
|
165
|
+
self.project_meta = None
|
166
|
+
if self.project_id:
|
167
|
+
self.project_meta = ProjectMeta.from_json(self.api.project.get_meta(self.project_id))
|
172
168
|
# -------------------------------- #
|
173
169
|
|
174
170
|
# Flags
|
@@ -178,73 +174,178 @@ class PredictAppGui:
|
|
178
174
|
|
179
175
|
# GUI
|
180
176
|
# Steps
|
181
|
-
self.
|
177
|
+
self.step_flow = StepFlow()
|
178
|
+
select_params = {"icon": None, "plain": False, "text": "Select"}
|
179
|
+
reselect_params = {"icon": "zmdi zmdi-refresh", "plain": True, "text": "Reselect"}
|
182
180
|
|
183
181
|
# 1. Input selector
|
184
|
-
self.input_selector = InputSelector(self.workspace_id)
|
185
|
-
|
182
|
+
self.input_selector = InputSelector(self.workspace_id, self.api)
|
183
|
+
|
184
|
+
def _on_input_select():
|
185
|
+
valid = self.input_selector.validate_step()
|
186
|
+
if not valid:
|
187
|
+
return
|
188
|
+
current_item_type = self.input_selector.radio.get_value()
|
189
|
+
self.update_item_type()
|
190
|
+
if self.model_api:
|
191
|
+
if current_item_type == self.input_selector.radio.get_value():
|
192
|
+
inference_settings = self.model_api.get_settings()
|
193
|
+
self.settings_selector.set_inference_settings(inference_settings)
|
194
|
+
|
195
|
+
if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
|
196
|
+
try:
|
197
|
+
tracking_settings = self.model_api.get_tracking_settings()
|
198
|
+
self.settings_selector.set_tracking_settings(tracking_settings)
|
199
|
+
except Exception as e:
|
200
|
+
logger.warning(
|
201
|
+
"Unable to get tracking settings from the model. Settings defaults"
|
202
|
+
)
|
203
|
+
self.settings_selector.set_default_tracking_settings()
|
204
|
+
self.input_selector.disable()
|
205
|
+
|
206
|
+
self.project_id = self.input_selector.get_project_id()
|
207
|
+
if self.project_id:
|
208
|
+
self.project_meta = ProjectMeta.from_json(self.api.project.get_meta(self.project_id))
|
209
|
+
update_custom_button_params(self.input_selector.button, reselect_params)
|
210
|
+
|
211
|
+
def _on_input_reactivate():
|
212
|
+
self.input_selector.enable()
|
213
|
+
update_custom_button_params(self.input_selector.button, select_params)
|
214
|
+
|
215
|
+
self.step_flow.add_step(
|
216
|
+
name="input_selector",
|
217
|
+
widget=self.input_selector.card,
|
218
|
+
on_select=_on_input_select,
|
219
|
+
on_reactivate=_on_input_reactivate,
|
220
|
+
button=self.input_selector.button,
|
221
|
+
)
|
186
222
|
|
187
223
|
# 2. Model selector
|
188
224
|
self.model_selector = ModelSelector(self.api, self.team_id)
|
189
|
-
|
225
|
+
|
226
|
+
self.step_flow.add_step(
|
227
|
+
name="model_selector",
|
228
|
+
widget=self.model_selector.card,
|
229
|
+
)
|
190
230
|
|
191
231
|
# 3. Classes selector
|
192
|
-
self.classes_selector =
|
193
|
-
|
194
|
-
|
195
|
-
self.
|
232
|
+
self.classes_selector = ClassesSelector()
|
233
|
+
|
234
|
+
def _on_classes_select():
|
235
|
+
valid = self.classes_selector.validate_step()
|
236
|
+
if not valid:
|
237
|
+
return
|
238
|
+
self.classes_selector.classes_table.disable()
|
239
|
+
|
240
|
+
# Find conflict between project meta and model meta
|
241
|
+
selected_classes_names = self.classes_selector.get_selected_classes()
|
242
|
+
project_meta = self.project_meta
|
243
|
+
model_meta = self.model_api.get_model_meta()
|
244
|
+
|
245
|
+
has_conflict = False
|
246
|
+
for class_name in selected_classes_names:
|
247
|
+
project_obj_class = project_meta.get_obj_class(class_name)
|
248
|
+
if project_obj_class is None:
|
249
|
+
continue
|
250
|
+
|
251
|
+
model_obj_class = model_meta.get_obj_class(class_name)
|
252
|
+
if model_obj_class.geometry_type.name() == AnyGeometry.name():
|
253
|
+
continue
|
254
|
+
|
255
|
+
if project_obj_class.geometry_type.name() == model_obj_class.geometry_type.name():
|
256
|
+
continue
|
257
|
+
|
258
|
+
has_conflict = True
|
259
|
+
break
|
260
|
+
|
261
|
+
if has_conflict:
|
262
|
+
self.settings_selector.model_prediction_suffix_container.show()
|
263
|
+
else:
|
264
|
+
self.settings_selector.model_prediction_suffix_container.hide()
|
265
|
+
# ------------------------------------------------ #
|
266
|
+
|
267
|
+
update_custom_button_params(self.classes_selector.button, reselect_params)
|
268
|
+
|
269
|
+
def _on_classes_reactivate():
|
270
|
+
self.classes_selector.classes_table.enable()
|
271
|
+
update_custom_button_params(self.classes_selector.button, select_params)
|
272
|
+
|
273
|
+
self.step_flow.add_step(
|
274
|
+
name="classes_selector",
|
275
|
+
widget=self.classes_selector.card,
|
276
|
+
on_select=_on_classes_select,
|
277
|
+
on_reactivate=_on_classes_reactivate,
|
278
|
+
depends_on=["input_selector", "model_selector"],
|
279
|
+
on_lock=self.classes_selector.lock,
|
280
|
+
on_unlock=self.classes_selector.unlock,
|
281
|
+
button=self.classes_selector.button,
|
282
|
+
)
|
196
283
|
|
197
284
|
# 4. Tags selector
|
198
285
|
self.tags_selector = None
|
199
286
|
if False:
|
200
287
|
self.tags_selector = TagsSelector()
|
201
|
-
self.
|
202
|
-
|
203
|
-
# 5. Settings selector
|
204
|
-
self.settings_selector = SettingsSelector(
|
205
|
-
|
288
|
+
self.step_flow.add_step("tags_selector", self.tags_selector.card)
|
289
|
+
|
290
|
+
# 5. Settings selector & Preview
|
291
|
+
self.settings_selector = SettingsSelector(
|
292
|
+
api=self.api,
|
293
|
+
static_dir=self.static_dir,
|
294
|
+
model_selector=self.model_selector,
|
295
|
+
input_selector=self.input_selector,
|
296
|
+
)
|
206
297
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
self.
|
298
|
+
def _on_settings_select():
|
299
|
+
valid = self.settings_selector.validate_step()
|
300
|
+
if not valid:
|
301
|
+
return
|
302
|
+
self.settings_selector.disable()
|
303
|
+
update_custom_button_params(self.settings_selector.button, reselect_params)
|
304
|
+
|
305
|
+
def _on_settings_reactivate():
|
306
|
+
self.settings_selector.enable()
|
307
|
+
update_custom_button_params(self.settings_selector.button, select_params)
|
308
|
+
|
309
|
+
self.step_flow.add_step(
|
310
|
+
name="settings_selector",
|
311
|
+
widget=self.settings_selector.cards_container,
|
312
|
+
on_select=_on_settings_select,
|
313
|
+
on_reactivate=_on_settings_reactivate,
|
314
|
+
depends_on=["input_selector", "model_selector", "classes_selector"],
|
315
|
+
on_lock=self.settings_selector.lock,
|
316
|
+
on_unlock=self.settings_selector.unlock,
|
317
|
+
button=self.settings_selector.button,
|
318
|
+
)
|
319
|
+
self.settings_selector.preview.run_button.disable()
|
212
320
|
|
213
|
-
#
|
321
|
+
# 6. Output selector
|
214
322
|
self.output_selector = OutputSelector(self.api)
|
215
|
-
self.steps.append(self.output_selector.card)
|
216
|
-
# -------------------------------- #
|
217
323
|
|
218
|
-
|
219
|
-
|
220
|
-
|
324
|
+
self.step_flow.add_step(
|
325
|
+
"output_selector",
|
326
|
+
self.output_selector.card,
|
327
|
+
depends_on=[
|
328
|
+
"input_selector",
|
329
|
+
"model_selector",
|
330
|
+
"classes_selector",
|
331
|
+
# "tags_selector",
|
332
|
+
"settings_selector",
|
333
|
+
],
|
334
|
+
on_lock=self.output_selector.lock,
|
335
|
+
on_unlock=self.output_selector.unlock,
|
336
|
+
)
|
337
|
+
# -------------------------------- #
|
221
338
|
|
222
339
|
# Layout
|
223
|
-
self.layout = Container([self.stepper])
|
340
|
+
self.layout = Container([self.step_flow.stepper])
|
224
341
|
# ---------------------------- #
|
225
342
|
|
226
|
-
# Button Utils
|
227
|
-
def deploy_model() -> ModelAPI:
|
228
|
-
self.model_selector.validator_text.hide()
|
229
|
-
model_api = None
|
230
|
-
try:
|
231
|
-
model_api = type(self.model_selector.model).deploy(self.model_selector.model)
|
232
|
-
except:
|
233
|
-
self.output_selector.start_button.disable()
|
234
|
-
raise
|
235
|
-
else:
|
236
|
-
self.output_selector.start_button.enable()
|
237
|
-
return model_api
|
238
|
-
|
239
|
-
# Reimplement deploy method for DeployModel widget
|
240
|
-
self.model_selector.model.deploy = deploy_model
|
241
|
-
|
242
343
|
def set_entity_meta():
|
243
344
|
model_api = self.model_selector.model.model_api
|
244
345
|
|
245
346
|
model_meta = model_api.get_model_meta()
|
246
347
|
if self.classes_selector is not None:
|
247
|
-
self.classes_selector.
|
348
|
+
self.classes_selector.set_project_meta(model_meta)
|
248
349
|
self.classes_selector.classes_table.show()
|
249
350
|
if self.tags_selector is not None:
|
250
351
|
self.tags_selector.tags_table.set_project_meta(model_meta)
|
@@ -253,13 +354,20 @@ class PredictAppGui:
|
|
253
354
|
inference_settings = model_api.get_settings()
|
254
355
|
self.settings_selector.set_inference_settings(inference_settings)
|
255
356
|
|
256
|
-
if self.
|
257
|
-
|
357
|
+
if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
|
358
|
+
try:
|
359
|
+
tracking_settings = model_api.get_tracking_settings()
|
360
|
+
self.settings_selector.set_tracking_settings(tracking_settings)
|
361
|
+
except Exception as e:
|
362
|
+
logger.warning(
|
363
|
+
"Unable to get tracking settings from the model. Settings defaults"
|
364
|
+
)
|
365
|
+
self.settings_selector.set_default_tracking_settings()
|
258
366
|
|
259
367
|
def reset_entity_meta():
|
260
368
|
empty_meta = ProjectMeta()
|
261
369
|
if self.classes_selector is not None:
|
262
|
-
self.classes_selector.
|
370
|
+
self.classes_selector.set_project_meta(empty_meta)
|
263
371
|
self.classes_selector.classes_table.hide()
|
264
372
|
if self.tags_selector is not None:
|
265
373
|
self.tags_selector.tags_table.set_project_meta(empty_meta)
|
@@ -267,394 +375,478 @@ class PredictAppGui:
|
|
267
375
|
|
268
376
|
self.settings_selector.set_inference_settings("")
|
269
377
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
self.
|
378
|
+
def deploy_and_set_step():
|
379
|
+
self.model_selector.validator_text.hide()
|
380
|
+
model_api = type(self.model_selector.model).deploy(self.model_selector.model)
|
381
|
+
if model_api is not None:
|
382
|
+
set_entity_meta()
|
383
|
+
self.step_flow.select_step("model_selector")
|
276
384
|
else:
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
def _get_frame_annotation(
|
281
|
-
video_info: VideoInfo, frame_index: int, project_meta: ProjectMeta
|
282
|
-
) -> Annotation:
|
283
|
-
video_annotation = VideoAnnotation.from_json(
|
284
|
-
self.api.video.annotation.download(video_info.id, frame_index),
|
285
|
-
project_meta=project_meta,
|
286
|
-
key_id_map=KeyIdMap(),
|
287
|
-
)
|
288
|
-
frame = video_annotation.frames.get(frame_index)
|
289
|
-
img_size = (video_info.frame_height, video_info.frame_width)
|
290
|
-
if frame is None:
|
291
|
-
return Annotation(img_size)
|
292
|
-
labels = []
|
293
|
-
for figure in frame.figures:
|
294
|
-
labels.append(Label(figure.geometry, figure.video_object.obj_class))
|
295
|
-
ann = Annotation(img_size, labels=labels)
|
296
|
-
return ann
|
297
|
-
|
298
|
-
if self.preview is None:
|
299
|
-
return
|
385
|
+
reset_entity_meta()
|
386
|
+
self.step_flow.reactivate_step("model_selector")
|
387
|
+
return model_api
|
300
388
|
|
301
|
-
|
302
|
-
self.
|
303
|
-
self.
|
304
|
-
|
305
|
-
try:
|
306
|
-
items_settings = self.input_selector.get_settings()
|
307
|
-
if "video_id" in items_settings:
|
308
|
-
video_id = items_settings["video_id"]
|
309
|
-
video_info = self.api.video.get_info_by_id(video_id)
|
310
|
-
video_frame = random.randint(0, video_info.frames_count - 1)
|
311
|
-
self.api.video.frame.download_path(
|
312
|
-
video_info.id, video_frame, self.preview.preview_path
|
313
|
-
)
|
314
|
-
img_url = self.preview.peview_url
|
315
|
-
project_meta = ProjectMeta.from_json(
|
316
|
-
self.api.project.get_meta(video_info.project_id)
|
317
|
-
)
|
318
|
-
input_ann = _get_frame_annotation(video_info, video_frame, project_meta)
|
319
|
-
prediction = self.model_selector.model.model_api.predict(
|
320
|
-
input=self.preview.preview_path, **self.settings_selector.get_settings()
|
321
|
-
)[0]
|
322
|
-
output_ann = prediction.annotation
|
323
|
-
else:
|
324
|
-
if "project_id" in items_settings:
|
325
|
-
project_id = items_settings["project_id"]
|
326
|
-
dataset_infos = self.api.dataset.get_list(project_id, recursive=True)
|
327
|
-
dataset_infos = [ds for ds in dataset_infos if ds.items_count > 0]
|
328
|
-
if not dataset_infos:
|
329
|
-
raise ValueError("No datasets with items found in the project.")
|
330
|
-
dataset_info = random.choice(dataset_infos)
|
331
|
-
elif "dataset_ids" in items_settings:
|
332
|
-
dataset_ids = items_settings["dataset_ids"]
|
333
|
-
dataset_infos = [
|
334
|
-
self.api.dataset.get_info_by_id(dataset_id)
|
335
|
-
for dataset_id in dataset_ids
|
336
|
-
]
|
337
|
-
dataset_infos = [ds for ds in dataset_infos if ds.items_count > 0]
|
338
|
-
if not dataset_infos:
|
339
|
-
raise ValueError("No items in selected datasets.")
|
340
|
-
dataset_info = random.choice(dataset_infos)
|
341
|
-
else:
|
342
|
-
raise ValueError("No valid item settings found for preview.")
|
343
|
-
images = self.api.image.get_list(dataset_info.id)
|
344
|
-
image_info = random.choice(images)
|
345
|
-
img_url = image_info.preview_url
|
389
|
+
def stop_and_reset_step():
|
390
|
+
type(self.model_selector.model).stop(self.model_selector.model)
|
391
|
+
self.step_flow.reactivate_step("model_selector")
|
392
|
+
reset_entity_meta()
|
346
393
|
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
self.api.annotation.download(image_info.id).annotation,
|
352
|
-
project_meta=project_meta,
|
353
|
-
)
|
354
|
-
prediction = self.model_selector.model.model_api.predict(
|
355
|
-
image_id=image_info.id, **self.settings_selector.get_settings()
|
356
|
-
)[0]
|
357
|
-
output_ann = prediction.annotation
|
358
|
-
|
359
|
-
self.preview.gallery.append(img_url, input_ann, "Input")
|
360
|
-
self.preview.gallery.append(img_url, output_ann, "Output")
|
361
|
-
self.preview.validator_text.hide()
|
362
|
-
self.preview.gallery.show()
|
363
|
-
return prediction
|
364
|
-
except Exception as e:
|
365
|
-
self.preview.gallery.hide()
|
366
|
-
self.preview.validator_text.set(
|
367
|
-
text=f"Error during preview: {str(e)}", status="error"
|
368
|
-
)
|
369
|
-
self.preview.validator_text.show()
|
370
|
-
self.preview.gallery.clean_up()
|
371
|
-
finally:
|
372
|
-
self.preview.gallery.loading = False
|
394
|
+
def disconnect_and_reset_step():
|
395
|
+
type(self.model_selector.model).disconnect(self.model_selector.model)
|
396
|
+
self.step_flow.reactivate_step("model_selector")
|
397
|
+
reset_entity_meta()
|
373
398
|
|
374
|
-
#
|
399
|
+
# Replace deploy methods for DeployModel widget
|
400
|
+
self.model_selector.model.deploy = deploy_and_set_step
|
401
|
+
self.model_selector.model.stop = stop_and_reset_step
|
402
|
+
self.model_selector.model.disconnect = disconnect_and_reset_step
|
375
403
|
|
376
|
-
#
|
377
|
-
self.step_flow = StepFlow(self.stepper)
|
378
|
-
position = 0
|
404
|
+
# ------------------------------------------------- #
|
379
405
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
self.input_selector.card,
|
384
|
-
self.input_selector.button,
|
385
|
-
self.input_selector.widgets_to_disable,
|
386
|
-
self.input_selector.validator_text,
|
387
|
-
self.input_selector.validate_step,
|
388
|
-
position=position,
|
389
|
-
)
|
390
|
-
position += 1
|
406
|
+
@property
|
407
|
+
def model_api(self) -> Optional[ModelAPI]:
|
408
|
+
return self.model_selector.model.model_api
|
391
409
|
|
392
|
-
|
393
|
-
self.
|
394
|
-
|
395
|
-
|
396
|
-
self.model_selector.button,
|
397
|
-
self.model_selector.widgets_to_disable,
|
398
|
-
self.model_selector.validator_text,
|
399
|
-
self.model_selector.validate_step,
|
400
|
-
position=position,
|
401
|
-
)
|
402
|
-
self.step_flow.add_on_select_actions("model_selector", [set_entity_meta])
|
403
|
-
self.step_flow.add_on_select_actions(
|
404
|
-
"model_selector", [reset_entity_meta], is_reselect=True
|
405
|
-
)
|
406
|
-
position += 1
|
410
|
+
def update_item_type(self):
|
411
|
+
item_type = self.input_selector.radio.get_value()
|
412
|
+
self.settings_selector.update_item_type(item_type)
|
413
|
+
self.output_selector.update_item_type(item_type)
|
407
414
|
|
408
|
-
|
409
|
-
if self.
|
410
|
-
self.
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
self.classes_selector.validator_text,
|
416
|
-
self.classes_selector.validate_step,
|
417
|
-
position=position,
|
418
|
-
)
|
419
|
-
position += 1
|
415
|
+
def _run_videos(self, run_parameters: Dict[str, Any]) -> List[Prediction]:
|
416
|
+
if self.model_api is None:
|
417
|
+
self.set_validator_text("Deploying model...", "info")
|
418
|
+
self.model_selector.model._deploy()
|
419
|
+
if self.model_api is None:
|
420
|
+
logger.error("Model Deployed with an error")
|
421
|
+
raise RuntimeError("Model Deployed with an error")
|
420
422
|
|
421
|
-
|
422
|
-
if
|
423
|
-
self.
|
424
|
-
"tags_selector",
|
425
|
-
self.tags_selector.card,
|
426
|
-
self.tags_selector.button,
|
427
|
-
self.tags_selector.widgets_to_disable,
|
428
|
-
self.tags_selector.validator_text,
|
429
|
-
self.tags_selector.validate_step,
|
430
|
-
position=position,
|
431
|
-
)
|
432
|
-
position += 1
|
433
|
-
|
434
|
-
# 5. Settings selector
|
435
|
-
self.step_flow.register_step(
|
436
|
-
"settings_selector",
|
437
|
-
self.settings_selector.card,
|
438
|
-
self.settings_selector.button,
|
439
|
-
self.settings_selector.widgets_to_disable,
|
440
|
-
self.settings_selector.validator_text,
|
441
|
-
self.settings_selector.validate_step,
|
442
|
-
position=position,
|
443
|
-
)
|
444
|
-
self.step_flow.add_on_select_actions("settings_selector", [disable_settings_editor])
|
445
|
-
self.step_flow.add_on_select_actions("settings_selector", [disable_settings_editor], True)
|
446
|
-
position += 1
|
447
|
-
|
448
|
-
# 6. Preview
|
449
|
-
if self.preview is not None:
|
450
|
-
self.step_flow.register_step(
|
451
|
-
"preview",
|
452
|
-
self.preview.card,
|
453
|
-
self.preview.button,
|
454
|
-
self.preview.widgets_to_disable,
|
455
|
-
self.preview.validator_text,
|
456
|
-
self.preview.validate_step,
|
457
|
-
position=position,
|
458
|
-
).add_on_select_actions("preview", [generate_preview])
|
459
|
-
position += 1
|
460
|
-
|
461
|
-
# 7. Output selector
|
462
|
-
self.step_flow.register_step(
|
463
|
-
"output_selector",
|
464
|
-
self.output_selector.card,
|
465
|
-
None,
|
466
|
-
self.output_selector.widgets_to_disable,
|
467
|
-
self.output_selector.validator_text,
|
468
|
-
self.output_selector.validate_step,
|
469
|
-
position=position,
|
470
|
-
)
|
423
|
+
self.set_validator_text("Preparing settings for prediction...", "info")
|
424
|
+
if run_parameters is None:
|
425
|
+
run_parameters = self.get_run_parameters()
|
471
426
|
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
has_preview = self.preview is not None
|
477
|
-
|
478
|
-
# Step 1 -> Step 2
|
479
|
-
prev_step = "input_selector"
|
480
|
-
if has_model_selector:
|
481
|
-
self.step_flow.set_next_steps(prev_step, ["model_selector"])
|
482
|
-
prev_step = "model_selector"
|
483
|
-
# Step 2 -> Step 3
|
484
|
-
if has_classes_selector:
|
485
|
-
self.step_flow.set_next_steps(prev_step, ["classes_selector"])
|
486
|
-
prev_step = "classes_selector"
|
487
|
-
# Step 3 -> Step 4
|
488
|
-
if has_tags_selector:
|
489
|
-
self.step_flow.set_next_steps(prev_step, ["tags_selector"])
|
490
|
-
prev_step = "tags_selector"
|
491
|
-
# Step 4 -> Step 5
|
492
|
-
self.step_flow.set_next_steps(prev_step, ["settings_selector"])
|
493
|
-
prev_step = "settings_selector"
|
494
|
-
# Step 5 -> Step 6
|
495
|
-
if has_preview:
|
496
|
-
self.step_flow.set_next_steps(prev_step, ["preview"])
|
497
|
-
prev_step = "preview"
|
498
|
-
# Step 6 -> Step 7
|
499
|
-
self.step_flow.set_next_steps(prev_step, ["output_selector"])
|
500
|
-
|
501
|
-
# Create all wrappers and set button handlers
|
502
|
-
wrappers = self.step_flow.build()
|
503
|
-
|
504
|
-
self.input_selector_cb = wrappers.get("input_selector")
|
505
|
-
self.classes_selector_cb = wrappers.get("classes_selector")
|
506
|
-
self.tags_selector_cb = wrappers.get("tags_selector")
|
507
|
-
self.model_selector_cb = wrappers.get("model_selector")
|
508
|
-
self.settings_selector_cb = wrappers.get("settings_selector")
|
509
|
-
self.preview_cb = wrappers.get("preview")
|
510
|
-
self.output_selector_cb = wrappers.get("output_selector")
|
511
|
-
# ------------------------------------------------- #
|
427
|
+
input_parameters = run_parameters["input"]
|
428
|
+
input_video_ids = input_parameters["video_ids"]
|
429
|
+
if not input_video_ids:
|
430
|
+
raise ValueError("No video IDs provided for video prediction.")
|
512
431
|
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
432
|
+
predict_kwargs = {}
|
433
|
+
# Settings
|
434
|
+
settings = run_parameters["settings"]
|
435
|
+
model_prediction_suffix = settings.pop("model_prediction_suffix", "")
|
436
|
+
prediction_mode = settings.pop("predictions_mode")
|
437
|
+
tracking = settings.pop("tracking", False)
|
438
|
+
predict_kwargs.update(settings)
|
517
439
|
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
rows = []
|
523
|
-
else:
|
524
|
-
dataset_info = self.api.dataset.get_info_by_id(dataset_id)
|
525
|
-
videos = self.api.video.get_list(dataset_id)
|
526
|
-
rows = [[video.id, video.name, dataset_info.name] for video in videos]
|
527
|
-
self.input_selector.select_video.rows = rows
|
528
|
-
self.input_selector.select_video.loading = False
|
440
|
+
# Classes
|
441
|
+
classes = run_parameters["classes"]
|
442
|
+
if classes:
|
443
|
+
predict_kwargs["classes"] = classes
|
529
444
|
|
530
|
-
|
445
|
+
output_parameters = run_parameters["output"]
|
446
|
+
project_name = output_parameters.get("project_name", "")
|
447
|
+
upload_to_source_project = output_parameters.get("upload_to_source_project", False)
|
448
|
+
skip_project_versioning = output_parameters.get("skip_project_versioning", False)
|
449
|
+
skip_annotated = output_parameters.get("skip_annotated", False)
|
450
|
+
|
451
|
+
video_infos_by_project_id: Dict[int, List[VideoInfo]] = {}
|
452
|
+
video_infos_by_dataset_id: Dict[int, List[VideoInfo]] = {}
|
453
|
+
for info in self.api.video.get_info_by_id_batch(input_video_ids):
|
454
|
+
video_infos_by_project_id.setdefault(info.project_id, []).append(info)
|
455
|
+
video_infos_by_dataset_id.setdefault(info.dataset_id, []).append(info)
|
456
|
+
src_project_metas: Dict[int, ProjectMeta] = {}
|
457
|
+
for project_id in video_infos_by_project_id.keys():
|
458
|
+
src_project_metas[project_id] = ProjectMeta.from_json(
|
459
|
+
self.api.project.get_meta(project_id)
|
460
|
+
)
|
531
461
|
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
462
|
+
video_ids_to_skip = set()
|
463
|
+
if skip_annotated:
|
464
|
+
self.set_validator_text("Checking for already annotated videos...", "info")
|
465
|
+
secondary_pbar = self.output_selector.secondary_progress(
|
466
|
+
message="Checking for already annotated videos...", total=len(input_video_ids)
|
467
|
+
)
|
468
|
+
self.output_selector.secondary_progress.show()
|
469
|
+
for dataset_id, video_infos in video_infos_by_dataset_id.items():
|
470
|
+
annotations = self.api.video.annotation.download_bulk(
|
471
|
+
dataset_id, [info.id for info in video_infos]
|
472
|
+
)
|
473
|
+
for ann_json, video_info in zip(annotations, video_infos):
|
474
|
+
if ann_json:
|
475
|
+
project_meta = src_project_metas[video_info.project_id]
|
476
|
+
ann = VideoAnnotation.from_json(ann_json, project_meta=project_meta)
|
477
|
+
if len(ann.figures) > 0:
|
478
|
+
video_ids_to_skip.add(video_info.id)
|
479
|
+
secondary_pbar.update()
|
480
|
+
self.output_selector.secondary_progress.hide()
|
481
|
+
if video_ids_to_skip:
|
482
|
+
video_infos_by_project_id = {
|
483
|
+
pid: [info for info in infos if info.id not in video_ids_to_skip]
|
484
|
+
for pid, infos in video_infos_by_project_id.items()
|
485
|
+
}
|
486
|
+
|
487
|
+
main_pbar_str = "Processing videos..."
|
488
|
+
if video_ids_to_skip:
|
489
|
+
main_pbar_str += f" (Skipped {len(video_ids_to_skip)} already annotated videos)"
|
490
|
+
total_videos = sum(len(v) for v in video_infos_by_project_id.values())
|
491
|
+
if total_videos == 0:
|
492
|
+
self.set_validator_text(
|
493
|
+
f"No videos to process. Skipped {len(video_ids_to_skip)} already annotated videos",
|
494
|
+
"warning",
|
495
|
+
)
|
496
|
+
return []
|
497
|
+
main_pbar = self.output_selector.progress(message=main_pbar_str, total=total_videos)
|
498
|
+
self.output_selector.progress.show()
|
499
|
+
all_predictictions: List[Prediction] = []
|
500
|
+
for src_project_id, src_video_infos in video_infos_by_project_id.items():
|
501
|
+
if len(src_video_infos) == 0:
|
502
|
+
continue
|
503
|
+
project_info = self.api.project.get_info_by_id(src_project_id)
|
504
|
+
project_validator_text_str = (
|
505
|
+
f"Processing project: {project_info.name} [id: {src_project_id}]"
|
506
|
+
)
|
507
|
+
if upload_to_source_project:
|
508
|
+
if not skip_project_versioning and not is_development():
|
509
|
+
logger.info("Creating new project version...")
|
510
|
+
self.set_validator_text(
|
511
|
+
project_validator_text_str + ": Creating project version",
|
512
|
+
"info",
|
513
|
+
)
|
514
|
+
version_id = self.api.project.version.create(
|
515
|
+
project_info,
|
516
|
+
"Created by Predict App. Task Id: " + str(env.task_id()),
|
517
|
+
)
|
518
|
+
logger.info("New project version created: " + str(version_id))
|
519
|
+
output_project_id = src_project_id
|
520
|
+
output_videos: List[VideoInfo] = src_video_infos
|
521
|
+
else:
|
522
|
+
self.set_validator_text(
|
523
|
+
project_validator_text_str + ": Creating project...", "info"
|
524
|
+
)
|
525
|
+
if not project_name:
|
526
|
+
project_name = project_info.name + " [Predictions]"
|
527
|
+
logger.warning(
|
528
|
+
"Project name is empty, using auto-generated name: " + project_name
|
529
|
+
)
|
530
|
+
with_annotations = prediction_mode in [
|
531
|
+
AddPredictionsMode.APPEND,
|
532
|
+
AddPredictionsMode.IOU_MERGE,
|
533
|
+
]
|
534
|
+
created_project = create_project(
|
535
|
+
api=self.api,
|
536
|
+
project_id=src_project_id,
|
537
|
+
project_name=project_name,
|
538
|
+
workspace_id=self.workspace_id,
|
539
|
+
copy_meta=with_annotations,
|
540
|
+
project_type=ProjectType.VIDEOS,
|
541
|
+
)
|
542
|
+
output_project_id = created_project.id
|
543
|
+
output_videos: List[VideoInfo] = copy_items_to_project(
|
544
|
+
api=self.api,
|
545
|
+
src_project_id=src_project_id,
|
546
|
+
items=src_video_infos,
|
547
|
+
dst_project_id=created_project.id,
|
548
|
+
with_annotations=with_annotations,
|
549
|
+
ds_progress=self.output_selector.secondary_progress,
|
550
|
+
project_type=ProjectType.VIDEOS,
|
551
|
+
)
|
537
552
|
|
538
|
-
|
539
|
-
|
553
|
+
self.set_validator_text(
|
554
|
+
project_validator_text_str + ": Merging project meta",
|
555
|
+
"info",
|
556
|
+
)
|
557
|
+
project_meta = src_project_metas[src_project_id]
|
558
|
+
for src_video_info, output_video_info in zip(src_video_infos, output_videos):
|
559
|
+
video_validator_text_str = (
|
560
|
+
project_validator_text_str
|
561
|
+
+ f", video: {src_video_info.name} [id: {src_video_info.id}]"
|
562
|
+
)
|
563
|
+
self.set_validator_text(
|
564
|
+
video_validator_text_str + ": Predicting",
|
565
|
+
"info",
|
566
|
+
)
|
567
|
+
frames_predictions: List[Prediction] = []
|
568
|
+
with self.model_api.predict_detached(
|
569
|
+
video_id=src_video_info.id,
|
570
|
+
tqdm=self.output_selector.secondary_progress(),
|
571
|
+
tracking=tracking,
|
572
|
+
**predict_kwargs,
|
573
|
+
) as session:
|
574
|
+
self.output_selector.secondary_progress.show()
|
575
|
+
for prediction in session:
|
576
|
+
if self._stop_flag:
|
577
|
+
logger.info("Prediction stopped by user.")
|
578
|
+
raise StopIteration("Stopped by user.")
|
579
|
+
frames_predictions.append(prediction)
|
580
|
+
all_predictictions.extend(frames_predictions)
|
581
|
+
if tracking:
|
582
|
+
prediction_video_annotation: VideoAnnotation = VideoAnnotation.from_json(
|
583
|
+
session.final_result["video_ann"],
|
584
|
+
project_meta=self.model_api.get_model_meta(),
|
585
|
+
)
|
586
|
+
else:
|
587
|
+
prediction_video_annotation = video_annotation_from_predictions(
|
588
|
+
frames_predictions,
|
589
|
+
project_meta,
|
590
|
+
frame_size=(src_video_info.frame_height, src_video_info.frame_width),
|
591
|
+
)
|
592
|
+
if prediction_video_annotation is None:
|
593
|
+
logger.warning(
|
594
|
+
f"No predictions were made for video {src_video_info.name} [id: {src_video_info.id}]"
|
595
|
+
)
|
596
|
+
main_pbar.update()
|
597
|
+
continue
|
598
|
+
self.set_validator_text(
|
599
|
+
video_validator_text_str + ": Uploading predictions",
|
600
|
+
"info",
|
601
|
+
)
|
602
|
+
project_meta, prediction_video_annotation, meta_changed = (
|
603
|
+
update_meta_and_ann_for_video_annotation(
|
604
|
+
meta=project_meta,
|
605
|
+
ann=prediction_video_annotation,
|
606
|
+
model_prediction_suffix=model_prediction_suffix,
|
607
|
+
)
|
608
|
+
)
|
609
|
+
if meta_changed:
|
610
|
+
self.api.project.update_meta(output_project_id, project_meta)
|
611
|
+
if upload_to_source_project:
|
612
|
+
if prediction_mode in [
|
613
|
+
AddPredictionsMode.REPLACE,
|
614
|
+
AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS,
|
615
|
+
]:
|
616
|
+
self.output_selector.secondary_progress.hide()
|
617
|
+
with open("/tmp/prediction_video_annotation.json", "w") as f:
|
618
|
+
json.dump(prediction_video_annotation.to_json(), f)
|
619
|
+
self.api.video.annotation.upload_paths(
|
620
|
+
video_ids=[src_video_info.id],
|
621
|
+
paths=["/tmp/prediction_video_annotation.json"],
|
622
|
+
project_meta=project_meta,
|
623
|
+
)
|
624
|
+
else:
|
625
|
+
secondary_pbar = self.output_selector.secondary_progress(
|
626
|
+
message="Uploading annotations...",
|
627
|
+
total=len(prediction_video_annotation.figures),
|
628
|
+
)
|
629
|
+
self.output_selector.secondary_progress.show()
|
630
|
+
self.api.video.annotation.append(
|
631
|
+
video_id=src_video_info.id,
|
632
|
+
ann=prediction_video_annotation,
|
633
|
+
key_id_map=KeyIdMap(),
|
634
|
+
progress_cb=secondary_pbar.update,
|
635
|
+
)
|
636
|
+
else:
|
637
|
+
secondary_pbar = self.output_selector.secondary_progress(
|
638
|
+
message="Uploading annotations...",
|
639
|
+
total=len(prediction_video_annotation.figures),
|
640
|
+
)
|
641
|
+
self.output_selector.secondary_progress.show()
|
642
|
+
self.api.video.annotation.append(
|
643
|
+
video_id=output_video_info.id,
|
644
|
+
ann=prediction_video_annotation,
|
645
|
+
key_id_map=KeyIdMap(),
|
646
|
+
progress_cb=secondary_pbar.update,
|
647
|
+
)
|
648
|
+
main_pbar.update()
|
649
|
+
self.set_validator_text("Project successfully processed", "success")
|
650
|
+
self.output_selector.set_result_thumbnail(output_project_id)
|
651
|
+
return all_predictictions
|
540
652
|
|
541
|
-
|
542
|
-
if model_api is None:
|
653
|
+
def _run_images(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
|
654
|
+
if self.model_api is None:
|
655
|
+
self.set_validator_text("Deploying model...", "info")
|
656
|
+
self.model_selector.model._deploy()
|
657
|
+
if self.model_api is None:
|
543
658
|
logger.error("Model Deployed with an error")
|
544
|
-
|
545
|
-
return
|
659
|
+
raise RuntimeError("Model Deployed with an error")
|
546
660
|
|
547
|
-
|
661
|
+
self.set_validator_text("Preparing settings for prediction...", "info")
|
662
|
+
if run_parameters is None:
|
663
|
+
run_parameters = self.get_run_parameters()
|
548
664
|
|
665
|
+
predict_kwargs = {}
|
549
666
|
# Input
|
550
|
-
# Input would be newely created project
|
551
667
|
input_args = {}
|
552
668
|
input_parameters = run_parameters["input"]
|
553
669
|
input_project_id = input_parameters.get("project_id", None)
|
554
|
-
if input_project_id is None:
|
555
|
-
raise ValueError("Input project ID is required for prediction.")
|
556
670
|
input_dataset_ids = input_parameters.get("dataset_ids", [])
|
557
671
|
input_image_ids = input_parameters.get("image_ids", [])
|
558
|
-
if not (input_dataset_ids or input_image_ids):
|
559
|
-
raise ValueError("At least one dataset must be selected for prediction.")
|
560
672
|
if input_image_ids:
|
561
673
|
input_args["image_ids"] = input_image_ids
|
562
674
|
elif input_dataset_ids:
|
563
675
|
input_args["dataset_ids"] = input_dataset_ids
|
564
|
-
|
676
|
+
elif input_project_id:
|
565
677
|
input_args["project_id"] = input_project_id
|
678
|
+
else:
|
679
|
+
raise ValueError("No valid input parameters found for prediction.")
|
566
680
|
|
567
681
|
# Settings
|
568
682
|
settings = run_parameters["settings"]
|
569
683
|
prediction_mode = settings.pop("predictions_mode")
|
570
684
|
upload_mode = None
|
571
685
|
with_annotations = None
|
572
|
-
if prediction_mode == AddPredictionsMode.
|
686
|
+
if prediction_mode == AddPredictionsMode.REPLACE:
|
573
687
|
upload_mode = "replace"
|
574
688
|
with_annotations = False
|
575
|
-
elif prediction_mode == AddPredictionsMode.
|
689
|
+
elif prediction_mode == AddPredictionsMode.APPEND:
|
576
690
|
upload_mode = "append"
|
577
691
|
with_annotations = True
|
692
|
+
elif prediction_mode == AddPredictionsMode.IOU_MERGE:
|
693
|
+
upload_mode = "iou_merge"
|
694
|
+
with_annotations = True
|
578
695
|
elif prediction_mode == AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS:
|
579
696
|
upload_mode = "replace"
|
580
697
|
with_annotations = True
|
581
|
-
|
582
|
-
|
698
|
+
predict_kwargs.update(settings)
|
699
|
+
predict_kwargs["upload_mode"] = upload_mode
|
583
700
|
|
584
701
|
# Classes
|
585
702
|
classes = run_parameters["classes"]
|
586
703
|
if classes:
|
587
|
-
|
704
|
+
predict_kwargs["classes"] = classes
|
588
705
|
|
589
706
|
# Output
|
590
|
-
# Always create new project
|
591
|
-
# But the actual inference will happen inplace
|
592
707
|
output_parameters = run_parameters["output"]
|
593
|
-
project_name = output_parameters.get("project_name",
|
708
|
+
project_name = output_parameters.get("project_name", None)
|
594
709
|
upload_to_source_project = output_parameters.get("upload_to_source_project", False)
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
710
|
+
skip_project_versioning = output_parameters.get("skip_project_versioning", False)
|
711
|
+
skip_annotated = output_parameters.get("skip_annotated", False)
|
712
|
+
|
713
|
+
image_infos = []
|
714
|
+
if input_image_ids:
|
715
|
+
image_infos = self.api.image.get_info_by_id_batch(input_image_ids)
|
716
|
+
elif input_dataset_ids:
|
717
|
+
for dataset_id in input_dataset_ids:
|
718
|
+
image_infos.extend(self.api.image.get_list(dataset_id))
|
719
|
+
elif input_project_id:
|
720
|
+
datasets = self.api.dataset.get_list(input_project_id, recursive=True)
|
721
|
+
for dataset in datasets:
|
722
|
+
image_infos.extend(self.api.image.get_list(dataset.id))
|
723
|
+
if len(image_infos) == 0:
|
724
|
+
raise ValueError("No images found for the given input parameters.")
|
725
|
+
|
726
|
+
to_skip = []
|
727
|
+
if skip_annotated:
|
728
|
+
to_skip = [image_info.id for image_info in image_infos if image_info.labels_count == 0]
|
729
|
+
if to_skip:
|
730
|
+
image_infos = [info for info in image_infos if info.id not in to_skip]
|
731
|
+
if len(image_infos) == 0:
|
732
|
+
self.set_validator_text(
|
733
|
+
f"All images are already annotated. Nothing to predict.", "warning"
|
734
|
+
)
|
735
|
+
return []
|
736
|
+
|
737
|
+
image_infos_by_project_id: Dict[int, List[ImageInfo]] = {}
|
738
|
+
image_infos_by_dataset_id: Dict[int, List[ImageInfo]] = {}
|
739
|
+
ds_project_mapping: Dict[int, int] = {}
|
740
|
+
for info in image_infos:
|
741
|
+
image_infos_by_dataset_id.setdefault(info.dataset_id, []).append(info)
|
742
|
+
if info.dataset_id not in ds_project_mapping:
|
743
|
+
ds_info = self.api.dataset.get_info_by_id(info.dataset_id)
|
744
|
+
ds_project_mapping[info.dataset_id] = ds_info.project_id
|
745
|
+
project_id = ds_project_mapping[info.dataset_id]
|
746
|
+
image_infos_by_project_id.setdefault(project_id, []).append(info)
|
747
|
+
|
748
|
+
src_project_metas: Dict[int, ProjectMeta] = {}
|
749
|
+
for project_id in image_infos_by_project_id.keys():
|
750
|
+
src_project_metas[project_id] = ProjectMeta.from_json(
|
751
|
+
self.api.project.get_meta(project_id)
|
613
752
|
)
|
614
|
-
output_project_id = created_project.id
|
615
|
-
input_args = {
|
616
|
-
"project_id": output_project_id,
|
617
|
-
}
|
618
753
|
|
619
|
-
|
754
|
+
self.output_selector.progress.show()
|
755
|
+
total_items = sum(len(v) for v in image_infos_by_project_id.values())
|
756
|
+
main_pbar = self.output_selector.progress(message=f"Copying images...", total=total_items)
|
757
|
+
for src_project_id, infos in image_infos_by_project_id.items():
|
758
|
+
if len(infos) == 0:
|
759
|
+
continue
|
760
|
+
project_info = self.api.project.get_info_by_id(src_project_id)
|
761
|
+
project_validator_text_str = (
|
762
|
+
f"Processing project: {project_info.name} [id: {src_project_id}]"
|
763
|
+
)
|
764
|
+
if upload_to_source_project:
|
765
|
+
if not skip_project_versioning and not is_development():
|
766
|
+
logger.info("Creating new project version...")
|
767
|
+
self.set_validator_text(
|
768
|
+
project_validator_text_str + ": Creating project version", "info"
|
769
|
+
)
|
770
|
+
version_id = self.api.project.version.create(
|
771
|
+
project_info,
|
772
|
+
"Created by Predict App. Task Id: " + str(env.task_id()),
|
773
|
+
)
|
774
|
+
logger.info("New project version created: " + str(version_id))
|
775
|
+
output_project_id = src_project_id
|
776
|
+
output_image_infos: List[ImageInfo] = infos
|
777
|
+
else:
|
778
|
+
self.set_validator_text(
|
779
|
+
project_validator_text_str + ": Creating project...", "info"
|
780
|
+
)
|
781
|
+
if not project_name:
|
782
|
+
project_name = project_info.name + " [Predictions]"
|
783
|
+
logger.warning(
|
784
|
+
"Project name is empty, using auto-generated name: " + project_name
|
785
|
+
)
|
786
|
+
created_project = create_project(
|
787
|
+
api=self.api,
|
788
|
+
project_id=src_project_id,
|
789
|
+
project_name=project_name,
|
790
|
+
workspace_id=self.workspace_id,
|
791
|
+
copy_meta=with_annotations,
|
792
|
+
project_type=ProjectType.IMAGES,
|
793
|
+
)
|
794
|
+
output_project_id = created_project.id
|
795
|
+
output_image_infos: List[ImageInfo] = copy_items_to_project(
|
796
|
+
api=self.api,
|
797
|
+
src_project_id=src_project_id,
|
798
|
+
items=infos,
|
799
|
+
dst_project_id=created_project.id,
|
800
|
+
with_annotations=with_annotations,
|
801
|
+
ds_progress=self.output_selector.secondary_progress,
|
802
|
+
progress_cb=main_pbar.update,
|
803
|
+
project_type=ProjectType.IMAGES,
|
804
|
+
)
|
620
805
|
|
621
806
|
# Run prediction
|
622
807
|
self.set_validator_text("Running prediction...", "info")
|
623
|
-
predictions = []
|
808
|
+
predictions: List[Prediction] = []
|
624
809
|
self._is_running = True
|
810
|
+
with self.model_api.predict_detached(
|
811
|
+
image_ids=[info.id for info in output_image_infos],
|
812
|
+
**predict_kwargs,
|
813
|
+
tqdm=self.output_selector.progress(),
|
814
|
+
) as session:
|
815
|
+
for prediction in session:
|
816
|
+
if self._stop_flag:
|
817
|
+
logger.info("Prediction stopped by user.")
|
818
|
+
raise StopIteration("Stopped by user.")
|
819
|
+
predictions.append(prediction)
|
820
|
+
self.set_validator_text("Project successfully processed", "success")
|
821
|
+
self.output_selector.set_result_thumbnail(output_project_id)
|
822
|
+
return predictions
|
823
|
+
|
824
|
+
def run(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
|
825
|
+
self.show_validator_text()
|
826
|
+
if run_parameters is None:
|
827
|
+
run_parameters = self.get_run_parameters()
|
828
|
+
input_parameters = run_parameters["input"]
|
829
|
+
video_ids = input_parameters.get("video_ids", None)
|
625
830
|
try:
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
)
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
i += 1
|
636
|
-
if self._stop_flag:
|
637
|
-
logger.info("Prediction stopped by user.")
|
638
|
-
break
|
639
|
-
self.output_selector.progress.hide()
|
831
|
+
if video_ids:
|
832
|
+
run_f = self._run_videos
|
833
|
+
else:
|
834
|
+
run_f = self._run_images
|
835
|
+
return run_f(run_parameters)
|
836
|
+
except StopIteration:
|
837
|
+
logger.info("Prediction stopped by user.")
|
838
|
+
self.set_validator_text("Prediction stopped by user.", "warning")
|
839
|
+
raise
|
640
840
|
except Exception as e:
|
641
|
-
self.output_selector.progress.hide()
|
642
841
|
logger.error(f"Error during prediction: {str(e)}")
|
643
842
|
self.set_validator_text(f"Error during prediction: {str(e)}", "error")
|
644
843
|
disable_enable(self.output_selector.widgets_to_disable, False)
|
645
|
-
|
646
|
-
self._stop_flag = False
|
647
|
-
raise e
|
844
|
+
raise
|
648
845
|
finally:
|
846
|
+
self.output_selector.secondary_progress.hide()
|
847
|
+
self.output_selector.progress.hide()
|
649
848
|
self._is_running = False
|
650
849
|
self._stop_flag = False
|
651
|
-
# ------------------------ #
|
652
|
-
|
653
|
-
# Set result thumbnail
|
654
|
-
self.set_validator_text("Project successfully processed", "success")
|
655
|
-
self.output_selector.set_result_thumbnail(output_project_id)
|
656
|
-
# ------------------------ #
|
657
|
-
return predictions
|
658
850
|
|
659
851
|
def stop(self):
|
660
852
|
logger.info("Stopping prediction...")
|
@@ -707,14 +899,10 @@ class PredictAppGui:
|
|
707
899
|
if self.tags_selector is not None:
|
708
900
|
self.tags_selector.load_from_json(data.get("tags", {}))
|
709
901
|
|
710
|
-
# 5. Settings selector
|
902
|
+
# 5. Settings selector & Preview
|
711
903
|
self.settings_selector.load_from_json(data.get("settings", {}))
|
712
904
|
|
713
|
-
# 6.
|
714
|
-
if self.preview is not None:
|
715
|
-
self.preview.load_from_json(data.get("preview", {}))
|
716
|
-
|
717
|
-
# 7. Output selector
|
905
|
+
# 6. Output selector
|
718
906
|
self.output_selector.load_from_json(data.get("output", {}))
|
719
907
|
|
720
908
|
def set_validator_text(self, text: str, status: str = "text"):
|