supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +136 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/json_geometries_map.py +2 -0
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +9 -9
- supervisely/api/api.py +67 -43
- supervisely/api/app_api.py +72 -5
- supervisely/api/dataset_api.py +108 -33
- supervisely/api/entity_annotation/figure_api.py +113 -49
- supervisely/api/image_api.py +82 -0
- supervisely/api/module_api.py +10 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/pointcloud/pointcloud_api.py +38 -0
- supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
- supervisely/api/project_api.py +213 -6
- supervisely/api/task_api.py +11 -1
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +79 -1
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/api/volume/volume_api.py +38 -0
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +14 -6
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +88 -0
- supervisely/app/fastapi/subapp.py +175 -42
- supervisely/app/fastapi/templating.py +1 -1
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +11 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
- supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
- supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
- supervisely/app/widgets/fast_table/fast_table.py +713 -126
- supervisely/app/widgets/fast_table/script.js +492 -95
- supervisely/app/widgets/fast_table/style.css +54 -0
- supervisely/app/widgets/fast_table/template.html +45 -5
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +523 -0
- supervisely/app/widgets/heatmap/script.js +378 -0
- supervisely/app/widgets/heatmap/style.css +227 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/input_tag/input_tag.py +102 -15
- supervisely/app/widgets/input_tag_list/__init__.py +0 -0
- supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
- supervisely/app/widgets/input_tag_list/template.html +70 -0
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/app/widgets/transfer/style.css +3 -0
- supervisely/app/widgets/transfer/template.html +3 -1
- supervisely/app/widgets/transfer/transfer.py +48 -45
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
- supervisely/convert/video/video_converter.py +2 -2
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/io/env.py +161 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/cache.py +37 -17
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +953 -211
- supervisely/nn/inference/inference_request.py +15 -8
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
- supervisely/nn/inference/predict_app/gui/gui.py +915 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +399 -0
- supervisely/nn/inference/predict_app/predict_app.py +176 -0
- supervisely/nn/inference/session.py +47 -39
- supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +4 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/model_api.py +44 -22
- supervisely/nn/model/prediction.py +15 -1
- supervisely/nn/model/prediction_session.py +70 -14
- supervisely/nn/prediction_dto.py +7 -0
- supervisely/nn/tracker/__init__.py +6 -8
- supervisely/nn/tracker/base_tracker.py +54 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
- supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +273 -0
- supervisely/nn/tracker/calculate_metrics.py +264 -0
- supervisely/nn/tracker/utils.py +273 -0
- supervisely/nn/tracker/visualize.py +520 -0
- supervisely/nn/training/gui/gui.py +152 -49
- supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
- supervisely/nn/training/gui/training_artifacts.py +3 -1
- supervisely/nn/training/train_app.py +225 -46
- supervisely/project/pointcloud_episode_project.py +12 -8
- supervisely/project/pointcloud_project.py +12 -8
- supervisely/project/project.py +221 -75
- supervisely/template/experiment/experiment.html.jinja +105 -55
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- supervisely/versions.json +3 -1
- supervisely/video/sampling.py +42 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
- supervisely_lib/__init__.py +6 -1
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,915 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from supervisely._utils import is_development, logger
|
|
9
|
+
from supervisely.api.api import Api
|
|
10
|
+
from supervisely.api.image_api import ImageInfo
|
|
11
|
+
from supervisely.api.video.video_api import VideoInfo
|
|
12
|
+
from supervisely.app.widgets import Button, Card, Container, Stepper, Widget
|
|
13
|
+
from supervisely.geometry.any_geometry import AnyGeometry
|
|
14
|
+
from supervisely.io import env
|
|
15
|
+
from supervisely.nn.inference.inference import update_meta_and_ann_for_video_annotation
|
|
16
|
+
from supervisely.nn.inference.predict_app.gui.classes_selector import ClassesSelector
|
|
17
|
+
from supervisely.nn.inference.predict_app.gui.input_selector import InputSelector
|
|
18
|
+
from supervisely.nn.inference.predict_app.gui.model_selector import ModelSelector
|
|
19
|
+
from supervisely.nn.inference.predict_app.gui.output_selector import OutputSelector
|
|
20
|
+
from supervisely.nn.inference.predict_app.gui.settings_selector import (
|
|
21
|
+
AddPredictionsMode,
|
|
22
|
+
SettingsSelector,
|
|
23
|
+
)
|
|
24
|
+
from supervisely.nn.inference.predict_app.gui.tags_selector import TagsSelector
|
|
25
|
+
from supervisely.nn.inference.predict_app.gui.utils import (
|
|
26
|
+
copy_items_to_project,
|
|
27
|
+
create_project,
|
|
28
|
+
disable_enable,
|
|
29
|
+
update_custom_button_params,
|
|
30
|
+
video_annotation_from_predictions,
|
|
31
|
+
)
|
|
32
|
+
from supervisely.nn.model.model_api import ModelAPI
|
|
33
|
+
from supervisely.nn.model.prediction import Prediction
|
|
34
|
+
from supervisely.project.project_meta import ProjectMeta
|
|
35
|
+
from supervisely.project.project_type import ProjectType
|
|
36
|
+
from supervisely.video_annotation.key_id_map import KeyIdMap
|
|
37
|
+
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class StepFlow:
|
|
41
|
+
def __init__(self):
|
|
42
|
+
self._stepper = None
|
|
43
|
+
self.steps = {}
|
|
44
|
+
self.steps_sequence = []
|
|
45
|
+
|
|
46
|
+
def add_step(
|
|
47
|
+
self,
|
|
48
|
+
name: str,
|
|
49
|
+
widget: Widget,
|
|
50
|
+
on_select: Optional[Callable] = None,
|
|
51
|
+
on_reactivate: Optional[Callable] = None,
|
|
52
|
+
depends_on: Optional[List[Widget]] = None,
|
|
53
|
+
on_lock: Optional[Callable] = None,
|
|
54
|
+
on_unlock: Optional[Callable] = None,
|
|
55
|
+
button: Optional[Button] = None,
|
|
56
|
+
position: Optional[int] = None,
|
|
57
|
+
):
|
|
58
|
+
if depends_on is None:
|
|
59
|
+
depends_on = []
|
|
60
|
+
self.steps[name] = {
|
|
61
|
+
"widget": widget,
|
|
62
|
+
"on_select": on_select,
|
|
63
|
+
"on_reactivate": on_reactivate,
|
|
64
|
+
"depends_on": depends_on,
|
|
65
|
+
"on_lock": on_lock,
|
|
66
|
+
"on_unlock": on_unlock,
|
|
67
|
+
"button": button,
|
|
68
|
+
"is_selected": False,
|
|
69
|
+
"is_locked": False,
|
|
70
|
+
}
|
|
71
|
+
if button is not None:
|
|
72
|
+
self._wrap_button(button, name)
|
|
73
|
+
if position is not None:
|
|
74
|
+
self.steps_sequence.insert(position, name)
|
|
75
|
+
else:
|
|
76
|
+
self.steps_sequence.append(name)
|
|
77
|
+
self.update_locks()
|
|
78
|
+
|
|
79
|
+
def _create_stepper(self):
|
|
80
|
+
widgets = []
|
|
81
|
+
for step_name in self.steps_sequence:
|
|
82
|
+
step = self.steps[step_name]
|
|
83
|
+
widgets.append(step["widget"])
|
|
84
|
+
self._stepper = Stepper(widgets=widgets)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def stepper(self):
|
|
88
|
+
if self._stepper is None:
|
|
89
|
+
self._create_stepper()
|
|
90
|
+
return self._stepper
|
|
91
|
+
|
|
92
|
+
def update_stepper(self):
|
|
93
|
+
for i, step_name in enumerate(self.steps_sequence):
|
|
94
|
+
step = self.steps[step_name]
|
|
95
|
+
if not step["is_selected"]:
|
|
96
|
+
self.stepper.set_active_step(i + 1)
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
def update_locks(self):
|
|
100
|
+
for step in self.steps.values():
|
|
101
|
+
should_lock = False
|
|
102
|
+
for dep_name in step["depends_on"]:
|
|
103
|
+
dep = self.steps[dep_name]
|
|
104
|
+
if not dep["is_selected"]:
|
|
105
|
+
should_lock = True
|
|
106
|
+
break
|
|
107
|
+
if should_lock and not step["is_locked"]:
|
|
108
|
+
if step["on_lock"] is not None:
|
|
109
|
+
step["on_lock"]()
|
|
110
|
+
step["is_locked"] = True
|
|
111
|
+
if not should_lock and step["is_locked"]:
|
|
112
|
+
if step["on_unlock"]:
|
|
113
|
+
step["on_unlock"]()
|
|
114
|
+
step["is_locked"] = False
|
|
115
|
+
|
|
116
|
+
def _reactivate_dependents(self, step_name: str, visited=None):
|
|
117
|
+
if visited is None:
|
|
118
|
+
visited = set()
|
|
119
|
+
for dep_name, step in self.steps.items():
|
|
120
|
+
if step_name in step["depends_on"] and not dep_name in visited:
|
|
121
|
+
self._reactivate_step(dep_name, visited)
|
|
122
|
+
|
|
123
|
+
def _reactivate_step(self, step_name: str, visited=None):
|
|
124
|
+
step = self.steps[step_name]
|
|
125
|
+
if step["on_reactivate"] is not None:
|
|
126
|
+
step["on_reactivate"]()
|
|
127
|
+
step["is_selected"] = False
|
|
128
|
+
if visited is None:
|
|
129
|
+
visited = set()
|
|
130
|
+
self._reactivate_dependents(step_name, visited)
|
|
131
|
+
|
|
132
|
+
def reactivate_step(self, step_name: str):
|
|
133
|
+
self._reactivate_step(step_name)
|
|
134
|
+
self.update_stepper()
|
|
135
|
+
self.update_locks()
|
|
136
|
+
|
|
137
|
+
def select_step(self, step_name: str):
|
|
138
|
+
step = self.steps[step_name]
|
|
139
|
+
if step["on_select"] is not None:
|
|
140
|
+
step["on_select"]()
|
|
141
|
+
step["is_selected"] = True
|
|
142
|
+
self.update_stepper()
|
|
143
|
+
self.update_locks()
|
|
144
|
+
|
|
145
|
+
def select_or_reactivate(self, step_name: str):
|
|
146
|
+
step = self.steps[step_name]
|
|
147
|
+
if step["is_selected"]:
|
|
148
|
+
self.reactivate_step(step_name)
|
|
149
|
+
else:
|
|
150
|
+
self.select_step(step_name)
|
|
151
|
+
|
|
152
|
+
def _wrap_button(self, button: Button, step_name: str):
|
|
153
|
+
button.click(lambda: self.select_or_reactivate(step_name))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class PredictAppGui:
|
|
157
|
+
def __init__(self, api: Api, static_dir: str = "static"):
|
|
158
|
+
self.api = api
|
|
159
|
+
self.static_dir = static_dir
|
|
160
|
+
|
|
161
|
+
# Environment variables
|
|
162
|
+
self.team_id = env.team_id()
|
|
163
|
+
self.workspace_id = env.workspace_id()
|
|
164
|
+
self.project_id = env.project_id(raise_not_found=False)
|
|
165
|
+
self.project_meta = None
|
|
166
|
+
if self.project_id:
|
|
167
|
+
self.project_meta = ProjectMeta.from_json(self.api.project.get_meta(self.project_id))
|
|
168
|
+
# -------------------------------- #
|
|
169
|
+
|
|
170
|
+
# Flags
|
|
171
|
+
self._stop_flag = False
|
|
172
|
+
self._is_running = False
|
|
173
|
+
# -------------------------------- #
|
|
174
|
+
|
|
175
|
+
# GUI
|
|
176
|
+
# Steps
|
|
177
|
+
self.step_flow = StepFlow()
|
|
178
|
+
select_params = {"icon": None, "plain": False, "text": "Select"}
|
|
179
|
+
reselect_params = {"icon": "zmdi zmdi-refresh", "plain": True, "text": "Reselect"}
|
|
180
|
+
|
|
181
|
+
# 1. Input selector
|
|
182
|
+
self.input_selector = InputSelector(self.workspace_id, self.api)
|
|
183
|
+
|
|
184
|
+
def _on_input_select():
|
|
185
|
+
valid = self.input_selector.validate_step()
|
|
186
|
+
if not valid:
|
|
187
|
+
return
|
|
188
|
+
current_item_type = self.input_selector.radio.get_value()
|
|
189
|
+
self.update_item_type()
|
|
190
|
+
if self.model_api:
|
|
191
|
+
if current_item_type == self.input_selector.radio.get_value():
|
|
192
|
+
inference_settings = self.model_api.get_settings()
|
|
193
|
+
self.settings_selector.set_inference_settings(inference_settings)
|
|
194
|
+
|
|
195
|
+
if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
|
|
196
|
+
try:
|
|
197
|
+
tracking_settings = self.model_api.get_tracking_settings()
|
|
198
|
+
self.settings_selector.set_tracking_settings(tracking_settings)
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.warning(
|
|
201
|
+
"Unable to get tracking settings from the model. Settings defaults"
|
|
202
|
+
)
|
|
203
|
+
self.settings_selector.set_default_tracking_settings()
|
|
204
|
+
self.input_selector.disable()
|
|
205
|
+
|
|
206
|
+
self.project_id = self.input_selector.get_project_id()
|
|
207
|
+
if self.project_id:
|
|
208
|
+
self.project_meta = ProjectMeta.from_json(self.api.project.get_meta(self.project_id))
|
|
209
|
+
update_custom_button_params(self.input_selector.button, reselect_params)
|
|
210
|
+
|
|
211
|
+
def _on_input_reactivate():
|
|
212
|
+
self.input_selector.enable()
|
|
213
|
+
update_custom_button_params(self.input_selector.button, select_params)
|
|
214
|
+
|
|
215
|
+
self.step_flow.add_step(
|
|
216
|
+
name="input_selector",
|
|
217
|
+
widget=self.input_selector.card,
|
|
218
|
+
on_select=_on_input_select,
|
|
219
|
+
on_reactivate=_on_input_reactivate,
|
|
220
|
+
button=self.input_selector.button,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# 2. Model selector
|
|
224
|
+
self.model_selector = ModelSelector(self.api, self.team_id)
|
|
225
|
+
|
|
226
|
+
self.step_flow.add_step(
|
|
227
|
+
name="model_selector",
|
|
228
|
+
widget=self.model_selector.card,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# 3. Classes selector
|
|
232
|
+
self.classes_selector = ClassesSelector()
|
|
233
|
+
|
|
234
|
+
def _on_classes_select():
|
|
235
|
+
valid = self.classes_selector.validate_step()
|
|
236
|
+
if not valid:
|
|
237
|
+
return
|
|
238
|
+
self.classes_selector.classes_table.disable()
|
|
239
|
+
|
|
240
|
+
# Find conflict between project meta and model meta
|
|
241
|
+
selected_classes_names = self.classes_selector.get_selected_classes()
|
|
242
|
+
project_meta = self.project_meta
|
|
243
|
+
model_meta = self.model_api.get_model_meta()
|
|
244
|
+
|
|
245
|
+
has_conflict = False
|
|
246
|
+
for class_name in selected_classes_names:
|
|
247
|
+
project_obj_class = project_meta.get_obj_class(class_name)
|
|
248
|
+
if project_obj_class is None:
|
|
249
|
+
continue
|
|
250
|
+
|
|
251
|
+
model_obj_class = model_meta.get_obj_class(class_name)
|
|
252
|
+
if model_obj_class.geometry_type.name() == AnyGeometry.name():
|
|
253
|
+
continue
|
|
254
|
+
|
|
255
|
+
if project_obj_class.geometry_type.name() == model_obj_class.geometry_type.name():
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
has_conflict = True
|
|
259
|
+
break
|
|
260
|
+
|
|
261
|
+
if has_conflict:
|
|
262
|
+
self.settings_selector.model_prediction_suffix_container.show()
|
|
263
|
+
else:
|
|
264
|
+
self.settings_selector.model_prediction_suffix_container.hide()
|
|
265
|
+
# ------------------------------------------------ #
|
|
266
|
+
|
|
267
|
+
update_custom_button_params(self.classes_selector.button, reselect_params)
|
|
268
|
+
|
|
269
|
+
def _on_classes_reactivate():
|
|
270
|
+
self.classes_selector.classes_table.enable()
|
|
271
|
+
update_custom_button_params(self.classes_selector.button, select_params)
|
|
272
|
+
|
|
273
|
+
self.step_flow.add_step(
|
|
274
|
+
name="classes_selector",
|
|
275
|
+
widget=self.classes_selector.card,
|
|
276
|
+
on_select=_on_classes_select,
|
|
277
|
+
on_reactivate=_on_classes_reactivate,
|
|
278
|
+
depends_on=["input_selector", "model_selector"],
|
|
279
|
+
on_lock=self.classes_selector.lock,
|
|
280
|
+
on_unlock=self.classes_selector.unlock,
|
|
281
|
+
button=self.classes_selector.button,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# 4. Tags selector
|
|
285
|
+
self.tags_selector = None
|
|
286
|
+
if False:
|
|
287
|
+
self.tags_selector = TagsSelector()
|
|
288
|
+
self.step_flow.add_step("tags_selector", self.tags_selector.card)
|
|
289
|
+
|
|
290
|
+
# 5. Settings selector & Preview
|
|
291
|
+
self.settings_selector = SettingsSelector(
|
|
292
|
+
api=self.api,
|
|
293
|
+
static_dir=self.static_dir,
|
|
294
|
+
model_selector=self.model_selector,
|
|
295
|
+
input_selector=self.input_selector,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
def _on_settings_select():
|
|
299
|
+
valid = self.settings_selector.validate_step()
|
|
300
|
+
if not valid:
|
|
301
|
+
return
|
|
302
|
+
self.settings_selector.disable()
|
|
303
|
+
update_custom_button_params(self.settings_selector.button, reselect_params)
|
|
304
|
+
|
|
305
|
+
def _on_settings_reactivate():
|
|
306
|
+
self.settings_selector.enable()
|
|
307
|
+
update_custom_button_params(self.settings_selector.button, select_params)
|
|
308
|
+
|
|
309
|
+
self.step_flow.add_step(
|
|
310
|
+
name="settings_selector",
|
|
311
|
+
widget=self.settings_selector.cards_container,
|
|
312
|
+
on_select=_on_settings_select,
|
|
313
|
+
on_reactivate=_on_settings_reactivate,
|
|
314
|
+
depends_on=["input_selector", "model_selector", "classes_selector"],
|
|
315
|
+
on_lock=self.settings_selector.lock,
|
|
316
|
+
on_unlock=self.settings_selector.unlock,
|
|
317
|
+
button=self.settings_selector.button,
|
|
318
|
+
)
|
|
319
|
+
self.settings_selector.preview.run_button.disable()
|
|
320
|
+
|
|
321
|
+
# 6. Output selector
|
|
322
|
+
self.output_selector = OutputSelector(self.api)
|
|
323
|
+
|
|
324
|
+
self.step_flow.add_step(
|
|
325
|
+
"output_selector",
|
|
326
|
+
self.output_selector.card,
|
|
327
|
+
depends_on=[
|
|
328
|
+
"input_selector",
|
|
329
|
+
"model_selector",
|
|
330
|
+
"classes_selector",
|
|
331
|
+
# "tags_selector",
|
|
332
|
+
"settings_selector",
|
|
333
|
+
],
|
|
334
|
+
on_lock=self.output_selector.lock,
|
|
335
|
+
on_unlock=self.output_selector.unlock,
|
|
336
|
+
)
|
|
337
|
+
# -------------------------------- #
|
|
338
|
+
|
|
339
|
+
# Layout
|
|
340
|
+
self.layout = Container([self.step_flow.stepper])
|
|
341
|
+
# ---------------------------- #
|
|
342
|
+
|
|
343
|
+
def set_entity_meta():
|
|
344
|
+
model_api = self.model_selector.model.model_api
|
|
345
|
+
|
|
346
|
+
model_meta = model_api.get_model_meta()
|
|
347
|
+
if self.classes_selector is not None:
|
|
348
|
+
self.classes_selector.set_project_meta(model_meta)
|
|
349
|
+
self.classes_selector.classes_table.show()
|
|
350
|
+
if self.tags_selector is not None:
|
|
351
|
+
self.tags_selector.tags_table.set_project_meta(model_meta)
|
|
352
|
+
self.tags_selector.tags_table.show()
|
|
353
|
+
|
|
354
|
+
inference_settings = model_api.get_settings()
|
|
355
|
+
self.settings_selector.set_inference_settings(inference_settings)
|
|
356
|
+
|
|
357
|
+
if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
|
|
358
|
+
try:
|
|
359
|
+
tracking_settings = model_api.get_tracking_settings()
|
|
360
|
+
self.settings_selector.set_tracking_settings(tracking_settings)
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.warning(
|
|
363
|
+
"Unable to get tracking settings from the model. Settings defaults"
|
|
364
|
+
)
|
|
365
|
+
self.settings_selector.set_default_tracking_settings()
|
|
366
|
+
|
|
367
|
+
def reset_entity_meta():
|
|
368
|
+
empty_meta = ProjectMeta()
|
|
369
|
+
if self.classes_selector is not None:
|
|
370
|
+
self.classes_selector.set_project_meta(empty_meta)
|
|
371
|
+
self.classes_selector.classes_table.hide()
|
|
372
|
+
if self.tags_selector is not None:
|
|
373
|
+
self.tags_selector.tags_table.set_project_meta(empty_meta)
|
|
374
|
+
self.tags_selector.tags_table.hide()
|
|
375
|
+
|
|
376
|
+
self.settings_selector.set_inference_settings("")
|
|
377
|
+
|
|
378
|
+
def deploy_and_set_step():
|
|
379
|
+
self.model_selector.validator_text.hide()
|
|
380
|
+
model_api = type(self.model_selector.model).deploy(self.model_selector.model)
|
|
381
|
+
if model_api is not None:
|
|
382
|
+
set_entity_meta()
|
|
383
|
+
self.step_flow.select_step("model_selector")
|
|
384
|
+
else:
|
|
385
|
+
reset_entity_meta()
|
|
386
|
+
self.step_flow.reactivate_step("model_selector")
|
|
387
|
+
return model_api
|
|
388
|
+
|
|
389
|
+
def stop_and_reset_step():
|
|
390
|
+
type(self.model_selector.model).stop(self.model_selector.model)
|
|
391
|
+
self.step_flow.reactivate_step("model_selector")
|
|
392
|
+
reset_entity_meta()
|
|
393
|
+
|
|
394
|
+
def disconnect_and_reset_step():
|
|
395
|
+
type(self.model_selector.model).disconnect(self.model_selector.model)
|
|
396
|
+
self.step_flow.reactivate_step("model_selector")
|
|
397
|
+
reset_entity_meta()
|
|
398
|
+
|
|
399
|
+
# Replace deploy methods for DeployModel widget
|
|
400
|
+
self.model_selector.model.deploy = deploy_and_set_step
|
|
401
|
+
self.model_selector.model.stop = stop_and_reset_step
|
|
402
|
+
self.model_selector.model.disconnect = disconnect_and_reset_step
|
|
403
|
+
|
|
404
|
+
# ------------------------------------------------- #
|
|
405
|
+
|
|
406
|
+
@property
|
|
407
|
+
def model_api(self) -> Optional[ModelAPI]:
|
|
408
|
+
return self.model_selector.model.model_api
|
|
409
|
+
|
|
410
|
+
def update_item_type(self):
|
|
411
|
+
item_type = self.input_selector.radio.get_value()
|
|
412
|
+
self.settings_selector.update_item_type(item_type)
|
|
413
|
+
self.output_selector.update_item_type(item_type)
|
|
414
|
+
|
|
415
|
+
def _run_videos(self, run_parameters: Dict[str, Any]) -> List[Prediction]:
|
|
416
|
+
if self.model_api is None:
|
|
417
|
+
self.set_validator_text("Deploying model...", "info")
|
|
418
|
+
self.model_selector.model._deploy()
|
|
419
|
+
if self.model_api is None:
|
|
420
|
+
logger.error("Model Deployed with an error")
|
|
421
|
+
raise RuntimeError("Model Deployed with an error")
|
|
422
|
+
|
|
423
|
+
self.set_validator_text("Preparing settings for prediction...", "info")
|
|
424
|
+
if run_parameters is None:
|
|
425
|
+
run_parameters = self.get_run_parameters()
|
|
426
|
+
|
|
427
|
+
input_parameters = run_parameters["input"]
|
|
428
|
+
input_video_ids = input_parameters["video_ids"]
|
|
429
|
+
if not input_video_ids:
|
|
430
|
+
raise ValueError("No video IDs provided for video prediction.")
|
|
431
|
+
|
|
432
|
+
predict_kwargs = {}
|
|
433
|
+
# Settings
|
|
434
|
+
settings = run_parameters["settings"]
|
|
435
|
+
model_prediction_suffix = settings.pop("model_prediction_suffix", "")
|
|
436
|
+
prediction_mode = settings.pop("predictions_mode")
|
|
437
|
+
tracking = settings.pop("tracking", False)
|
|
438
|
+
predict_kwargs.update(settings)
|
|
439
|
+
|
|
440
|
+
# Classes
|
|
441
|
+
classes = run_parameters["classes"]
|
|
442
|
+
if classes:
|
|
443
|
+
predict_kwargs["classes"] = classes
|
|
444
|
+
|
|
445
|
+
output_parameters = run_parameters["output"]
|
|
446
|
+
project_name = output_parameters.get("project_name", "")
|
|
447
|
+
upload_to_source_project = output_parameters.get("upload_to_source_project", False)
|
|
448
|
+
skip_project_versioning = output_parameters.get("skip_project_versioning", False)
|
|
449
|
+
skip_annotated = output_parameters.get("skip_annotated", False)
|
|
450
|
+
|
|
451
|
+
video_infos_by_project_id: Dict[int, List[VideoInfo]] = {}
|
|
452
|
+
video_infos_by_dataset_id: Dict[int, List[VideoInfo]] = {}
|
|
453
|
+
for info in self.api.video.get_info_by_id_batch(input_video_ids):
|
|
454
|
+
video_infos_by_project_id.setdefault(info.project_id, []).append(info)
|
|
455
|
+
video_infos_by_dataset_id.setdefault(info.dataset_id, []).append(info)
|
|
456
|
+
src_project_metas: Dict[int, ProjectMeta] = {}
|
|
457
|
+
for project_id in video_infos_by_project_id.keys():
|
|
458
|
+
src_project_metas[project_id] = ProjectMeta.from_json(
|
|
459
|
+
self.api.project.get_meta(project_id)
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
video_ids_to_skip = set()
|
|
463
|
+
if skip_annotated:
|
|
464
|
+
self.set_validator_text("Checking for already annotated videos...", "info")
|
|
465
|
+
secondary_pbar = self.output_selector.secondary_progress(
|
|
466
|
+
message="Checking for already annotated videos...", total=len(input_video_ids)
|
|
467
|
+
)
|
|
468
|
+
self.output_selector.secondary_progress.show()
|
|
469
|
+
for dataset_id, video_infos in video_infos_by_dataset_id.items():
|
|
470
|
+
annotations = self.api.video.annotation.download_bulk(
|
|
471
|
+
dataset_id, [info.id for info in video_infos]
|
|
472
|
+
)
|
|
473
|
+
for ann_json, video_info in zip(annotations, video_infos):
|
|
474
|
+
if ann_json:
|
|
475
|
+
project_meta = src_project_metas[video_info.project_id]
|
|
476
|
+
ann = VideoAnnotation.from_json(ann_json, project_meta=project_meta)
|
|
477
|
+
if len(ann.figures) > 0:
|
|
478
|
+
video_ids_to_skip.add(video_info.id)
|
|
479
|
+
secondary_pbar.update()
|
|
480
|
+
self.output_selector.secondary_progress.hide()
|
|
481
|
+
if video_ids_to_skip:
|
|
482
|
+
video_infos_by_project_id = {
|
|
483
|
+
pid: [info for info in infos if info.id not in video_ids_to_skip]
|
|
484
|
+
for pid, infos in video_infos_by_project_id.items()
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
main_pbar_str = "Processing videos..."
|
|
488
|
+
if video_ids_to_skip:
|
|
489
|
+
main_pbar_str += f" (Skipped {len(video_ids_to_skip)} already annotated videos)"
|
|
490
|
+
total_videos = sum(len(v) for v in video_infos_by_project_id.values())
|
|
491
|
+
if total_videos == 0:
|
|
492
|
+
self.set_validator_text(
|
|
493
|
+
f"No videos to process. Skipped {len(video_ids_to_skip)} already annotated videos",
|
|
494
|
+
"warning",
|
|
495
|
+
)
|
|
496
|
+
return []
|
|
497
|
+
main_pbar = self.output_selector.progress(message=main_pbar_str, total=total_videos)
|
|
498
|
+
self.output_selector.progress.show()
|
|
499
|
+
all_predictictions: List[Prediction] = []
|
|
500
|
+
for src_project_id, src_video_infos in video_infos_by_project_id.items():
|
|
501
|
+
if len(src_video_infos) == 0:
|
|
502
|
+
continue
|
|
503
|
+
project_info = self.api.project.get_info_by_id(src_project_id)
|
|
504
|
+
project_validator_text_str = (
|
|
505
|
+
f"Processing project: {project_info.name} [id: {src_project_id}]"
|
|
506
|
+
)
|
|
507
|
+
if upload_to_source_project:
|
|
508
|
+
if not skip_project_versioning and not is_development():
|
|
509
|
+
logger.info("Creating new project version...")
|
|
510
|
+
self.set_validator_text(
|
|
511
|
+
project_validator_text_str + ": Creating project version",
|
|
512
|
+
"info",
|
|
513
|
+
)
|
|
514
|
+
version_id = self.api.project.version.create(
|
|
515
|
+
project_info,
|
|
516
|
+
"Created by Predict App. Task Id: " + str(env.task_id()),
|
|
517
|
+
)
|
|
518
|
+
logger.info("New project version created: " + str(version_id))
|
|
519
|
+
output_project_id = src_project_id
|
|
520
|
+
output_videos: List[VideoInfo] = src_video_infos
|
|
521
|
+
else:
|
|
522
|
+
self.set_validator_text(
|
|
523
|
+
project_validator_text_str + ": Creating project...", "info"
|
|
524
|
+
)
|
|
525
|
+
if not project_name:
|
|
526
|
+
project_name = project_info.name + " [Predictions]"
|
|
527
|
+
logger.warning(
|
|
528
|
+
"Project name is empty, using auto-generated name: " + project_name
|
|
529
|
+
)
|
|
530
|
+
with_annotations = prediction_mode in [
|
|
531
|
+
AddPredictionsMode.APPEND,
|
|
532
|
+
AddPredictionsMode.IOU_MERGE,
|
|
533
|
+
]
|
|
534
|
+
created_project = create_project(
|
|
535
|
+
api=self.api,
|
|
536
|
+
project_id=src_project_id,
|
|
537
|
+
project_name=project_name,
|
|
538
|
+
workspace_id=self.workspace_id,
|
|
539
|
+
copy_meta=with_annotations,
|
|
540
|
+
project_type=ProjectType.VIDEOS,
|
|
541
|
+
)
|
|
542
|
+
output_project_id = created_project.id
|
|
543
|
+
output_videos: List[VideoInfo] = copy_items_to_project(
|
|
544
|
+
api=self.api,
|
|
545
|
+
src_project_id=src_project_id,
|
|
546
|
+
items=src_video_infos,
|
|
547
|
+
dst_project_id=created_project.id,
|
|
548
|
+
with_annotations=with_annotations,
|
|
549
|
+
ds_progress=self.output_selector.secondary_progress,
|
|
550
|
+
project_type=ProjectType.VIDEOS,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
self.set_validator_text(
|
|
554
|
+
project_validator_text_str + ": Merging project meta",
|
|
555
|
+
"info",
|
|
556
|
+
)
|
|
557
|
+
project_meta = src_project_metas[src_project_id]
|
|
558
|
+
for src_video_info, output_video_info in zip(src_video_infos, output_videos):
|
|
559
|
+
video_validator_text_str = (
|
|
560
|
+
project_validator_text_str
|
|
561
|
+
+ f", video: {src_video_info.name} [id: {src_video_info.id}]"
|
|
562
|
+
)
|
|
563
|
+
self.set_validator_text(
|
|
564
|
+
video_validator_text_str + ": Predicting",
|
|
565
|
+
"info",
|
|
566
|
+
)
|
|
567
|
+
frames_predictions: List[Prediction] = []
|
|
568
|
+
with self.model_api.predict_detached(
|
|
569
|
+
video_id=src_video_info.id,
|
|
570
|
+
tqdm=self.output_selector.secondary_progress(),
|
|
571
|
+
tracking=tracking,
|
|
572
|
+
**predict_kwargs,
|
|
573
|
+
) as session:
|
|
574
|
+
self.output_selector.secondary_progress.show()
|
|
575
|
+
for prediction in session:
|
|
576
|
+
if self._stop_flag:
|
|
577
|
+
logger.info("Prediction stopped by user.")
|
|
578
|
+
raise StopIteration("Stopped by user.")
|
|
579
|
+
frames_predictions.append(prediction)
|
|
580
|
+
all_predictictions.extend(frames_predictions)
|
|
581
|
+
if tracking:
|
|
582
|
+
prediction_video_annotation: VideoAnnotation = VideoAnnotation.from_json(
|
|
583
|
+
session.final_result["video_ann"],
|
|
584
|
+
project_meta=self.model_api.get_model_meta(),
|
|
585
|
+
)
|
|
586
|
+
else:
|
|
587
|
+
prediction_video_annotation = video_annotation_from_predictions(
|
|
588
|
+
frames_predictions,
|
|
589
|
+
project_meta,
|
|
590
|
+
frame_size=(src_video_info.frame_height, src_video_info.frame_width),
|
|
591
|
+
)
|
|
592
|
+
if prediction_video_annotation is None:
|
|
593
|
+
logger.warning(
|
|
594
|
+
f"No predictions were made for video {src_video_info.name} [id: {src_video_info.id}]"
|
|
595
|
+
)
|
|
596
|
+
main_pbar.update()
|
|
597
|
+
continue
|
|
598
|
+
self.set_validator_text(
|
|
599
|
+
video_validator_text_str + ": Uploading predictions",
|
|
600
|
+
"info",
|
|
601
|
+
)
|
|
602
|
+
project_meta, prediction_video_annotation, meta_changed = (
|
|
603
|
+
update_meta_and_ann_for_video_annotation(
|
|
604
|
+
meta=project_meta,
|
|
605
|
+
ann=prediction_video_annotation,
|
|
606
|
+
model_prediction_suffix=model_prediction_suffix,
|
|
607
|
+
)
|
|
608
|
+
)
|
|
609
|
+
if meta_changed:
|
|
610
|
+
self.api.project.update_meta(output_project_id, project_meta)
|
|
611
|
+
if upload_to_source_project:
|
|
612
|
+
if prediction_mode in [
|
|
613
|
+
AddPredictionsMode.REPLACE,
|
|
614
|
+
AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS,
|
|
615
|
+
]:
|
|
616
|
+
self.output_selector.secondary_progress.hide()
|
|
617
|
+
with open("/tmp/prediction_video_annotation.json", "w") as f:
|
|
618
|
+
json.dump(prediction_video_annotation.to_json(), f)
|
|
619
|
+
self.api.video.annotation.upload_paths(
|
|
620
|
+
video_ids=[src_video_info.id],
|
|
621
|
+
paths=["/tmp/prediction_video_annotation.json"],
|
|
622
|
+
project_meta=project_meta,
|
|
623
|
+
)
|
|
624
|
+
else:
|
|
625
|
+
secondary_pbar = self.output_selector.secondary_progress(
|
|
626
|
+
message="Uploading annotations...",
|
|
627
|
+
total=len(prediction_video_annotation.figures),
|
|
628
|
+
)
|
|
629
|
+
self.output_selector.secondary_progress.show()
|
|
630
|
+
self.api.video.annotation.append(
|
|
631
|
+
video_id=src_video_info.id,
|
|
632
|
+
ann=prediction_video_annotation,
|
|
633
|
+
key_id_map=KeyIdMap(),
|
|
634
|
+
progress_cb=secondary_pbar.update,
|
|
635
|
+
)
|
|
636
|
+
else:
|
|
637
|
+
secondary_pbar = self.output_selector.secondary_progress(
|
|
638
|
+
message="Uploading annotations...",
|
|
639
|
+
total=len(prediction_video_annotation.figures),
|
|
640
|
+
)
|
|
641
|
+
self.output_selector.secondary_progress.show()
|
|
642
|
+
self.api.video.annotation.append(
|
|
643
|
+
video_id=output_video_info.id,
|
|
644
|
+
ann=prediction_video_annotation,
|
|
645
|
+
key_id_map=KeyIdMap(),
|
|
646
|
+
progress_cb=secondary_pbar.update,
|
|
647
|
+
)
|
|
648
|
+
main_pbar.update()
|
|
649
|
+
self.set_validator_text("Project successfully processed", "success")
|
|
650
|
+
self.output_selector.set_result_thumbnail(output_project_id)
|
|
651
|
+
return all_predictictions
|
|
652
|
+
|
|
653
|
+
def _run_images(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
|
|
654
|
+
if self.model_api is None:
|
|
655
|
+
self.set_validator_text("Deploying model...", "info")
|
|
656
|
+
self.model_selector.model._deploy()
|
|
657
|
+
if self.model_api is None:
|
|
658
|
+
logger.error("Model Deployed with an error")
|
|
659
|
+
raise RuntimeError("Model Deployed with an error")
|
|
660
|
+
|
|
661
|
+
self.set_validator_text("Preparing settings for prediction...", "info")
|
|
662
|
+
if run_parameters is None:
|
|
663
|
+
run_parameters = self.get_run_parameters()
|
|
664
|
+
|
|
665
|
+
predict_kwargs = {}
|
|
666
|
+
# Input
|
|
667
|
+
input_args = {}
|
|
668
|
+
input_parameters = run_parameters["input"]
|
|
669
|
+
input_project_id = input_parameters.get("project_id", None)
|
|
670
|
+
input_dataset_ids = input_parameters.get("dataset_ids", [])
|
|
671
|
+
input_image_ids = input_parameters.get("image_ids", [])
|
|
672
|
+
if input_image_ids:
|
|
673
|
+
input_args["image_ids"] = input_image_ids
|
|
674
|
+
elif input_dataset_ids:
|
|
675
|
+
input_args["dataset_ids"] = input_dataset_ids
|
|
676
|
+
elif input_project_id:
|
|
677
|
+
input_args["project_id"] = input_project_id
|
|
678
|
+
else:
|
|
679
|
+
raise ValueError("No valid input parameters found for prediction.")
|
|
680
|
+
|
|
681
|
+
# Settings
|
|
682
|
+
settings = run_parameters["settings"]
|
|
683
|
+
prediction_mode = settings.pop("predictions_mode")
|
|
684
|
+
upload_mode = None
|
|
685
|
+
with_annotations = None
|
|
686
|
+
if prediction_mode == AddPredictionsMode.REPLACE:
|
|
687
|
+
upload_mode = "replace"
|
|
688
|
+
with_annotations = False
|
|
689
|
+
elif prediction_mode == AddPredictionsMode.APPEND:
|
|
690
|
+
upload_mode = "append"
|
|
691
|
+
with_annotations = True
|
|
692
|
+
elif prediction_mode == AddPredictionsMode.IOU_MERGE:
|
|
693
|
+
upload_mode = "iou_merge"
|
|
694
|
+
with_annotations = True
|
|
695
|
+
elif prediction_mode == AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS:
|
|
696
|
+
upload_mode = "replace"
|
|
697
|
+
with_annotations = True
|
|
698
|
+
predict_kwargs.update(settings)
|
|
699
|
+
predict_kwargs["upload_mode"] = upload_mode
|
|
700
|
+
|
|
701
|
+
# Classes
|
|
702
|
+
classes = run_parameters["classes"]
|
|
703
|
+
if classes:
|
|
704
|
+
predict_kwargs["classes"] = classes
|
|
705
|
+
|
|
706
|
+
# Output
|
|
707
|
+
output_parameters = run_parameters["output"]
|
|
708
|
+
project_name = output_parameters.get("project_name", None)
|
|
709
|
+
upload_to_source_project = output_parameters.get("upload_to_source_project", False)
|
|
710
|
+
skip_project_versioning = output_parameters.get("skip_project_versioning", False)
|
|
711
|
+
skip_annotated = output_parameters.get("skip_annotated", False)
|
|
712
|
+
|
|
713
|
+
image_infos = []
|
|
714
|
+
if input_image_ids:
|
|
715
|
+
image_infos = self.api.image.get_info_by_id_batch(input_image_ids)
|
|
716
|
+
elif input_dataset_ids:
|
|
717
|
+
for dataset_id in input_dataset_ids:
|
|
718
|
+
image_infos.extend(self.api.image.get_list(dataset_id))
|
|
719
|
+
elif input_project_id:
|
|
720
|
+
datasets = self.api.dataset.get_list(input_project_id, recursive=True)
|
|
721
|
+
for dataset in datasets:
|
|
722
|
+
image_infos.extend(self.api.image.get_list(dataset.id))
|
|
723
|
+
if len(image_infos) == 0:
|
|
724
|
+
raise ValueError("No images found for the given input parameters.")
|
|
725
|
+
|
|
726
|
+
to_skip = []
|
|
727
|
+
if skip_annotated:
|
|
728
|
+
to_skip = [image_info.id for image_info in image_infos if image_info.labels_count == 0]
|
|
729
|
+
if to_skip:
|
|
730
|
+
image_infos = [info for info in image_infos if info.id not in to_skip]
|
|
731
|
+
if len(image_infos) == 0:
|
|
732
|
+
self.set_validator_text(
|
|
733
|
+
f"All images are already annotated. Nothing to predict.", "warning"
|
|
734
|
+
)
|
|
735
|
+
return []
|
|
736
|
+
|
|
737
|
+
image_infos_by_project_id: Dict[int, List[ImageInfo]] = {}
|
|
738
|
+
image_infos_by_dataset_id: Dict[int, List[ImageInfo]] = {}
|
|
739
|
+
ds_project_mapping: Dict[int, int] = {}
|
|
740
|
+
for info in image_infos:
|
|
741
|
+
image_infos_by_dataset_id.setdefault(info.dataset_id, []).append(info)
|
|
742
|
+
if info.dataset_id not in ds_project_mapping:
|
|
743
|
+
ds_info = self.api.dataset.get_info_by_id(info.dataset_id)
|
|
744
|
+
ds_project_mapping[info.dataset_id] = ds_info.project_id
|
|
745
|
+
project_id = ds_project_mapping[info.dataset_id]
|
|
746
|
+
image_infos_by_project_id.setdefault(project_id, []).append(info)
|
|
747
|
+
|
|
748
|
+
src_project_metas: Dict[int, ProjectMeta] = {}
|
|
749
|
+
for project_id in image_infos_by_project_id.keys():
|
|
750
|
+
src_project_metas[project_id] = ProjectMeta.from_json(
|
|
751
|
+
self.api.project.get_meta(project_id)
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
self.output_selector.progress.show()
|
|
755
|
+
total_items = sum(len(v) for v in image_infos_by_project_id.values())
|
|
756
|
+
main_pbar = self.output_selector.progress(message=f"Copying images...", total=total_items)
|
|
757
|
+
for src_project_id, infos in image_infos_by_project_id.items():
|
|
758
|
+
if len(infos) == 0:
|
|
759
|
+
continue
|
|
760
|
+
project_info = self.api.project.get_info_by_id(src_project_id)
|
|
761
|
+
project_validator_text_str = (
|
|
762
|
+
f"Processing project: {project_info.name} [id: {src_project_id}]"
|
|
763
|
+
)
|
|
764
|
+
if upload_to_source_project:
|
|
765
|
+
if not skip_project_versioning and not is_development():
|
|
766
|
+
logger.info("Creating new project version...")
|
|
767
|
+
self.set_validator_text(
|
|
768
|
+
project_validator_text_str + ": Creating project version", "info"
|
|
769
|
+
)
|
|
770
|
+
version_id = self.api.project.version.create(
|
|
771
|
+
project_info,
|
|
772
|
+
"Created by Predict App. Task Id: " + str(env.task_id()),
|
|
773
|
+
)
|
|
774
|
+
logger.info("New project version created: " + str(version_id))
|
|
775
|
+
output_project_id = src_project_id
|
|
776
|
+
output_image_infos: List[ImageInfo] = infos
|
|
777
|
+
else:
|
|
778
|
+
self.set_validator_text(
|
|
779
|
+
project_validator_text_str + ": Creating project...", "info"
|
|
780
|
+
)
|
|
781
|
+
if not project_name:
|
|
782
|
+
project_name = project_info.name + " [Predictions]"
|
|
783
|
+
logger.warning(
|
|
784
|
+
"Project name is empty, using auto-generated name: " + project_name
|
|
785
|
+
)
|
|
786
|
+
created_project = create_project(
|
|
787
|
+
api=self.api,
|
|
788
|
+
project_id=src_project_id,
|
|
789
|
+
project_name=project_name,
|
|
790
|
+
workspace_id=self.workspace_id,
|
|
791
|
+
copy_meta=with_annotations,
|
|
792
|
+
project_type=ProjectType.IMAGES,
|
|
793
|
+
)
|
|
794
|
+
output_project_id = created_project.id
|
|
795
|
+
output_image_infos: List[ImageInfo] = copy_items_to_project(
|
|
796
|
+
api=self.api,
|
|
797
|
+
src_project_id=src_project_id,
|
|
798
|
+
items=infos,
|
|
799
|
+
dst_project_id=created_project.id,
|
|
800
|
+
with_annotations=with_annotations,
|
|
801
|
+
ds_progress=self.output_selector.secondary_progress,
|
|
802
|
+
progress_cb=main_pbar.update,
|
|
803
|
+
project_type=ProjectType.IMAGES,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
# Run prediction
|
|
807
|
+
self.set_validator_text("Running prediction...", "info")
|
|
808
|
+
predictions: List[Prediction] = []
|
|
809
|
+
self._is_running = True
|
|
810
|
+
with self.model_api.predict_detached(
|
|
811
|
+
image_ids=[info.id for info in output_image_infos],
|
|
812
|
+
**predict_kwargs,
|
|
813
|
+
tqdm=self.output_selector.progress(),
|
|
814
|
+
) as session:
|
|
815
|
+
for prediction in session:
|
|
816
|
+
if self._stop_flag:
|
|
817
|
+
logger.info("Prediction stopped by user.")
|
|
818
|
+
raise StopIteration("Stopped by user.")
|
|
819
|
+
predictions.append(prediction)
|
|
820
|
+
self.set_validator_text("Project successfully processed", "success")
|
|
821
|
+
self.output_selector.set_result_thumbnail(output_project_id)
|
|
822
|
+
return predictions
|
|
823
|
+
|
|
824
|
+
def run(self, run_parameters: Dict[str, Any] = None) -> List[Prediction]:
|
|
825
|
+
self.show_validator_text()
|
|
826
|
+
if run_parameters is None:
|
|
827
|
+
run_parameters = self.get_run_parameters()
|
|
828
|
+
input_parameters = run_parameters["input"]
|
|
829
|
+
video_ids = input_parameters.get("video_ids", None)
|
|
830
|
+
try:
|
|
831
|
+
if video_ids:
|
|
832
|
+
run_f = self._run_videos
|
|
833
|
+
else:
|
|
834
|
+
run_f = self._run_images
|
|
835
|
+
return run_f(run_parameters)
|
|
836
|
+
except StopIteration:
|
|
837
|
+
logger.info("Prediction stopped by user.")
|
|
838
|
+
self.set_validator_text("Prediction stopped by user.", "warning")
|
|
839
|
+
raise
|
|
840
|
+
except Exception as e:
|
|
841
|
+
logger.error(f"Error during prediction: {str(e)}")
|
|
842
|
+
self.set_validator_text(f"Error during prediction: {str(e)}", "error")
|
|
843
|
+
disable_enable(self.output_selector.widgets_to_disable, False)
|
|
844
|
+
raise
|
|
845
|
+
finally:
|
|
846
|
+
self.output_selector.secondary_progress.hide()
|
|
847
|
+
self.output_selector.progress.hide()
|
|
848
|
+
self._is_running = False
|
|
849
|
+
self._stop_flag = False
|
|
850
|
+
|
|
851
|
+
def stop(self):
|
|
852
|
+
logger.info("Stopping prediction...")
|
|
853
|
+
self._stop_flag = True
|
|
854
|
+
|
|
855
|
+
def wait_for_stop(self, timeout: int = None):
|
|
856
|
+
logger.info(
|
|
857
|
+
"Waiting " + ""
|
|
858
|
+
if timeout is None
|
|
859
|
+
else f"{timeout} seconds " + "for prediction to stop..."
|
|
860
|
+
)
|
|
861
|
+
t = time.monotonic()
|
|
862
|
+
while self._is_running:
|
|
863
|
+
if timeout is not None and time.monotonic() - t > timeout:
|
|
864
|
+
raise TimeoutError("Timeout while waiting for stop.")
|
|
865
|
+
time.sleep(0.1)
|
|
866
|
+
logger.info("Prediction stopped.")
|
|
867
|
+
|
|
868
|
+
def shutdown_model(self):
|
|
869
|
+
self.stop()
|
|
870
|
+
self.wait_for_stop(10)
|
|
871
|
+
self.model_selector.model.stop()
|
|
872
|
+
|
|
873
|
+
def get_run_parameters(self) -> Dict[str, Any]:
|
|
874
|
+
settings = {
|
|
875
|
+
"model": self.model_selector.model.get_deploy_parameters(),
|
|
876
|
+
"settings": self.settings_selector.get_settings(),
|
|
877
|
+
"input": self.input_selector.get_settings(),
|
|
878
|
+
"output": self.output_selector.get_settings(),
|
|
879
|
+
}
|
|
880
|
+
if self.classes_selector is not None:
|
|
881
|
+
settings["classes"] = self.classes_selector.get_selected_classes()
|
|
882
|
+
if self.tags_selector is not None:
|
|
883
|
+
settings["tags"] = self.tags_selector.get_selected_tags()
|
|
884
|
+
return settings
|
|
885
|
+
|
|
886
|
+
def load_from_json(self, data):
|
|
887
|
+
# 1. Input selector
|
|
888
|
+
self.input_selector.load_from_json(data.get("input", {}))
|
|
889
|
+
# self.input_selector_cb()
|
|
890
|
+
|
|
891
|
+
# 2. Model selector
|
|
892
|
+
self.model_selector.model.load_from_json(data.get("model", {}))
|
|
893
|
+
|
|
894
|
+
# 3. Classes selector
|
|
895
|
+
if self.classes_selector is not None:
|
|
896
|
+
self.classes_selector.load_from_json(data.get("classes", {}))
|
|
897
|
+
|
|
898
|
+
# 4. Tags selector
|
|
899
|
+
if self.tags_selector is not None:
|
|
900
|
+
self.tags_selector.load_from_json(data.get("tags", {}))
|
|
901
|
+
|
|
902
|
+
# 5. Settings selector & Preview
|
|
903
|
+
self.settings_selector.load_from_json(data.get("settings", {}))
|
|
904
|
+
|
|
905
|
+
# 6. Output selector
|
|
906
|
+
self.output_selector.load_from_json(data.get("output", {}))
|
|
907
|
+
|
|
908
|
+
def set_validator_text(self, text: str, status: str = "text"):
|
|
909
|
+
self.output_selector.validator_text.set(text=text, status=status)
|
|
910
|
+
|
|
911
|
+
def show_validator_text(self):
|
|
912
|
+
self.output_selector.validator_text.show()
|
|
913
|
+
|
|
914
|
+
def hide_validator_text(self):
|
|
915
|
+
self.output_selector.validator_text.hide()
|