supervisely 6.73.457__py3-none-any.whl → 6.73.458__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +24 -1
- supervisely/api/image_api.py +4 -0
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +41 -1
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +45 -11
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +63 -7
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/inference/inference.py +364 -73
- supervisely/nn/inference/inference_request.py +3 -2
- supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +178 -25
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort_tracker.py +14 -7
- supervisely/nn/tracker/visualize.py +70 -72
- supervisely/video/video.py +15 -1
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/METADATA +3 -2
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/RECORD +40 -40
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/LICENSE +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/WHEEL +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.457.dist-info → supervisely-6.73.458.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,54 @@
|
|
1
|
-
|
1
|
+
import os
|
2
|
+
import random
|
3
|
+
import shutil
|
4
|
+
import subprocess
|
5
|
+
import threading
|
6
|
+
from contextlib import contextmanager, nullcontext
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, Callable, Dict, List
|
2
9
|
|
10
|
+
import cv2
|
3
11
|
import yaml
|
4
12
|
|
13
|
+
from supervisely._utils import logger
|
14
|
+
from supervisely.annotation.annotation import Annotation
|
15
|
+
from supervisely.annotation.label import Label
|
16
|
+
from supervisely.api.api import Api
|
17
|
+
from supervisely.api.video.video_api import VideoInfo
|
5
18
|
from supervisely.app.widgets import (
|
6
19
|
Button,
|
7
20
|
Card,
|
8
|
-
Checkbox,
|
9
21
|
Container,
|
10
22
|
Editor,
|
11
23
|
Field,
|
24
|
+
GridGallery,
|
12
25
|
Input,
|
26
|
+
InputNumber,
|
27
|
+
OneOf,
|
28
|
+
Progress,
|
13
29
|
Select,
|
14
30
|
Text,
|
31
|
+
VideoPlayer,
|
15
32
|
)
|
33
|
+
from supervisely.app.widgets.checkbox.checkbox import Checkbox
|
34
|
+
from supervisely.app.widgets.empty.empty import Empty
|
35
|
+
from supervisely.app.widgets.widget import Widget
|
36
|
+
from supervisely.nn.inference.inference import (
|
37
|
+
_filter_duplicated_predictions_from_ann,
|
38
|
+
update_meta_and_ann,
|
39
|
+
update_meta_and_ann_for_video_annotation,
|
40
|
+
)
|
41
|
+
from supervisely.nn.inference.predict_app.gui.input_selector import InputSelector
|
42
|
+
from supervisely.nn.inference.predict_app.gui.model_selector import ModelSelector
|
43
|
+
from supervisely.nn.inference.predict_app.gui.utils import (
|
44
|
+
video_annotation_from_predictions,
|
45
|
+
)
|
46
|
+
from supervisely.nn.model.model_api import ModelAPI, Prediction
|
47
|
+
from supervisely.nn.tracker import TrackingVisualizer
|
48
|
+
from supervisely.project import ProjectMeta
|
49
|
+
from supervisely.project.project_meta import ProjectType
|
50
|
+
from supervisely.video.video import VideoFrameReader
|
51
|
+
from supervisely.video_annotation.video_annotation import KeyIdMap, VideoAnnotation
|
16
52
|
|
17
53
|
|
18
54
|
class InferenceMode:
|
@@ -21,26 +57,547 @@ class InferenceMode:
|
|
21
57
|
|
22
58
|
|
23
59
|
class AddPredictionsMode:
|
24
|
-
|
25
|
-
|
60
|
+
APPEND = "Merge with existing labels"
|
61
|
+
REPLACE = "Replace existing labels"
|
62
|
+
IOU_MERGE = "Merge by IoU threshold"
|
26
63
|
REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS = "Replace existing labels and save image tags"
|
27
64
|
|
28
65
|
|
66
|
+
class Preview:
|
67
|
+
lock_message = "Select previous step to unlock"
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
api: Api,
|
72
|
+
preview_dir: str,
|
73
|
+
get_model_api_fn: Callable[[], ModelAPI],
|
74
|
+
get_input_settings_fn: Callable[[], Dict[str, Any]],
|
75
|
+
get_settings_fn: Callable[[], Dict[str, Any]],
|
76
|
+
):
|
77
|
+
self.api = api
|
78
|
+
self.preview_dir = preview_dir
|
79
|
+
self.get_model_api_fn = get_model_api_fn
|
80
|
+
self.get_input_settings_fn = get_input_settings_fn
|
81
|
+
self.get_settings_fn = get_settings_fn
|
82
|
+
os.makedirs(self.preview_dir, exist_ok=True)
|
83
|
+
os.makedirs(Path(self.preview_dir, "annotated"), exist_ok=True)
|
84
|
+
self.image_preview_path = None
|
85
|
+
self.image_peview_url = None
|
86
|
+
self.video_preview_path = None
|
87
|
+
self.video_preview_annotated_path = None
|
88
|
+
self.video_peview_url = None
|
89
|
+
|
90
|
+
self.progress_widget = Progress(show_percents=True, hide_on_finish=True)
|
91
|
+
self.download_error = Text("", status="warning")
|
92
|
+
self.download_error.hide()
|
93
|
+
self.progress_container = Container(widgets=[self.download_error, self.progress_widget])
|
94
|
+
self.loading_container = Container(widgets=[self.download_error, Text("Loading...")])
|
95
|
+
|
96
|
+
self.image_gallery = GridGallery(
|
97
|
+
2,
|
98
|
+
sync_views=True,
|
99
|
+
enable_zoom=True,
|
100
|
+
resize_on_zoom=True,
|
101
|
+
empty_message="",
|
102
|
+
)
|
103
|
+
self.image_preview_container = Container(widgets=[self.image_gallery])
|
104
|
+
|
105
|
+
self.video_player = VideoPlayer()
|
106
|
+
self.video_preview_container = Container(widgets=[self.video_player])
|
107
|
+
|
108
|
+
self.locked_text = Text("Select input and model to unlock", status="info")
|
109
|
+
self.empty_text = Text("Click preview to visualize predictions")
|
110
|
+
self.error_text = Text("Failed to generate preview", status="error")
|
111
|
+
|
112
|
+
self.select = Select(
|
113
|
+
items=[
|
114
|
+
Select.Item("locked", content=self.locked_text),
|
115
|
+
Select.Item("empty", content=self.empty_text),
|
116
|
+
Select.Item(ProjectType.IMAGES.value, content=self.image_preview_container),
|
117
|
+
Select.Item(ProjectType.VIDEOS.value, content=self.video_preview_container),
|
118
|
+
Select.Item("error", content=self.error_text),
|
119
|
+
Select.Item("loading", content=self.loading_container),
|
120
|
+
Select.Item("progress", content=self.progress_container),
|
121
|
+
]
|
122
|
+
)
|
123
|
+
self.select.set_value("empty")
|
124
|
+
self.oneof = OneOf(self.select)
|
125
|
+
|
126
|
+
self.run_button = Button("Preview", icon="zmdi zmdi-slideshow")
|
127
|
+
self.run_button.disable()
|
128
|
+
self.card = Card(
|
129
|
+
title="Preview",
|
130
|
+
description="Preview model predictions on a random image or video from the selected input source.",
|
131
|
+
content=self.oneof,
|
132
|
+
content_top_right=self.run_button,
|
133
|
+
lock_message=self.lock_message,
|
134
|
+
)
|
135
|
+
|
136
|
+
@self.run_button.click
|
137
|
+
def _run_preview():
|
138
|
+
self.run_preview()
|
139
|
+
|
140
|
+
def lock(self):
|
141
|
+
self.run_button.disable()
|
142
|
+
self.card.lock(self.lock_message)
|
143
|
+
|
144
|
+
def unlock(self):
|
145
|
+
self.run_button.enable()
|
146
|
+
self.card.unlock()
|
147
|
+
|
148
|
+
@contextmanager
|
149
|
+
def progress(self, message: str, total: int, **kwargs):
|
150
|
+
current_item = self.select.get_value()
|
151
|
+
try:
|
152
|
+
with self.progress_widget(message=message, total=total, **kwargs) as pbar:
|
153
|
+
self.select_item("progress")
|
154
|
+
yield pbar
|
155
|
+
finally:
|
156
|
+
self.select_item(current_item)
|
157
|
+
|
158
|
+
def select_item(self, item: str):
|
159
|
+
self.select.set_value(item)
|
160
|
+
|
161
|
+
def _download_video_by_frames(
|
162
|
+
self, video_info: VideoInfo, save_path: str, frames_number=150, progress_cb=None
|
163
|
+
):
|
164
|
+
if Path(save_path).exists():
|
165
|
+
Path(save_path).unlink()
|
166
|
+
tmp_dir = Path(self.preview_dir, "tmp_frames")
|
167
|
+
if tmp_dir.exists():
|
168
|
+
shutil.rmtree(tmp_dir)
|
169
|
+
os.makedirs(tmp_dir, exist_ok=True)
|
170
|
+
self.api.video.download_frames(
|
171
|
+
video_info.id,
|
172
|
+
frames=list(range(frames_number)),
|
173
|
+
paths=[str(tmp_dir / f"frame_{i}.jpg") for i in range(frames_number)],
|
174
|
+
progress_cb=progress_cb,
|
175
|
+
)
|
176
|
+
fps = int(video_info.frames_count / video_info.duration)
|
177
|
+
fourcc = cv2.VideoWriter.fourcc(*"mp4v") # or 'avc1', 'XVID', 'H264'
|
178
|
+
out = cv2.VideoWriter(
|
179
|
+
save_path, fourcc, fps, (video_info.frame_width, video_info.frame_height)
|
180
|
+
)
|
181
|
+
for i in range(frames_number):
|
182
|
+
frame_path = tmp_dir / f"frame_{i}.jpg"
|
183
|
+
if not frame_path.exists():
|
184
|
+
continue
|
185
|
+
img = cv2.imread(str(frame_path))
|
186
|
+
out.write(img)
|
187
|
+
out.release()
|
188
|
+
shutil.rmtree(tmp_dir)
|
189
|
+
|
190
|
+
def _download_full_video(
|
191
|
+
self, video_id: int, save_path: str, duration: int = 5, progress_cb=None
|
192
|
+
):
|
193
|
+
if Path(save_path).exists():
|
194
|
+
Path(save_path).unlink()
|
195
|
+
temp = Path(self.preview_dir) / f"temp_{video_id}.mp4"
|
196
|
+
if temp.exists():
|
197
|
+
temp.unlink()
|
198
|
+
self.api.video.download_path(video_id, temp, progress_cb=progress_cb)
|
199
|
+
minutes = duration // 60
|
200
|
+
hours = minutes // 60
|
201
|
+
minutes = minutes % 60
|
202
|
+
seconds = duration % 60
|
203
|
+
duration_str = f"{hours:02}:{minutes:02}:{seconds:02}"
|
204
|
+
try:
|
205
|
+
process = subprocess.Popen(
|
206
|
+
[
|
207
|
+
"ffmpeg",
|
208
|
+
"-y",
|
209
|
+
"-i",
|
210
|
+
str(temp),
|
211
|
+
"-c",
|
212
|
+
"copy",
|
213
|
+
"-t",
|
214
|
+
duration_str,
|
215
|
+
save_path,
|
216
|
+
],
|
217
|
+
stderr=subprocess.PIPE,
|
218
|
+
)
|
219
|
+
process.wait()
|
220
|
+
logger.debug("FFmpeg exited with code: " + str(process.returncode))
|
221
|
+
logger.debug(f"FFmpeg stderr: {process.stderr.read().decode()}")
|
222
|
+
if len(VideoFrameReader(save_path).read_frames()) == 0:
|
223
|
+
raise RuntimeError("No frames read from the video")
|
224
|
+
temp.unlink()
|
225
|
+
except Exception as e:
|
226
|
+
if Path(save_path).exists():
|
227
|
+
Path(save_path).unlink()
|
228
|
+
shutil.copy(temp, save_path)
|
229
|
+
temp.unlink()
|
230
|
+
logger.warning(f"FFmpeg trimming failed: {str(e)}", exc_info=True)
|
231
|
+
|
232
|
+
def _download_video_preview(self, video_info: VideoInfo, with_progress=True):
|
233
|
+
video_id = video_info.id
|
234
|
+
duration = 5
|
235
|
+
video_path = Path(self.preview_dir, video_info.name)
|
236
|
+
self.video_preview_path = video_path
|
237
|
+
self.video_preview_annotated_path = Path(self.preview_dir, "annotated") / Path(
|
238
|
+
self.video_preview_path
|
239
|
+
).relative_to(self.preview_dir)
|
240
|
+
success = False
|
241
|
+
try:
|
242
|
+
try:
|
243
|
+
size = int(video_info.file_meta["size"])
|
244
|
+
size = int(size / video_info.duration * duration)
|
245
|
+
except:
|
246
|
+
size = None
|
247
|
+
with (
|
248
|
+
self.progress("Downloading video part:", total=size, unit="B", unit_scale=True)
|
249
|
+
if with_progress and size
|
250
|
+
else nullcontext()
|
251
|
+
) as pbar:
|
252
|
+
success = self._partial_download(
|
253
|
+
video_id, duration, str(self.video_preview_path), progress_cb=pbar.update
|
254
|
+
)
|
255
|
+
except Exception as e:
|
256
|
+
logger.warning(f"Partial download failed: {str(e)}", exc_info=True)
|
257
|
+
success = False
|
258
|
+
if success:
|
259
|
+
return
|
260
|
+
|
261
|
+
video_length_threshold = 120 # seconds
|
262
|
+
if video_info.duration > video_length_threshold:
|
263
|
+
self.download_error.text = (
|
264
|
+
f"Partial download failed. Will Download separate video frames"
|
265
|
+
)
|
266
|
+
self.download_error.show()
|
267
|
+
|
268
|
+
fps = int(video_info.frames_count / video_info.duration)
|
269
|
+
frames_number = min(video_info.frames_count, int(fps * duration))
|
270
|
+
with (
|
271
|
+
self.progress(
|
272
|
+
"Downloading video frames:", total=frames_number, unit="it", unit_scale=False
|
273
|
+
)
|
274
|
+
if with_progress
|
275
|
+
else nullcontext()
|
276
|
+
) as pbar:
|
277
|
+
self._download_video_by_frames(
|
278
|
+
video_info,
|
279
|
+
str(self.video_preview_path),
|
280
|
+
frames_number=frames_number,
|
281
|
+
progress_cb=pbar.update,
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
self.download_error.text = f"Partial download failed. Will Download full video"
|
285
|
+
self.download_error.show()
|
286
|
+
size = int(video_info.file_meta["size"])
|
287
|
+
with (
|
288
|
+
self.progress("Downloading video:", total=size, unit="B", unit_scale=True)
|
289
|
+
if with_progress
|
290
|
+
else nullcontext()
|
291
|
+
) as pbar:
|
292
|
+
self._download_full_video(
|
293
|
+
video_info.id,
|
294
|
+
str(self.video_preview_path),
|
295
|
+
duration=duration,
|
296
|
+
progress_cb=pbar.update,
|
297
|
+
)
|
298
|
+
|
299
|
+
def _partial_download(self, video_id: int, duration: int, save_path: str, progress_cb=None):
|
300
|
+
if Path(save_path).exists():
|
301
|
+
Path(save_path).unlink()
|
302
|
+
duration_minutes = duration // 60
|
303
|
+
duration_hours = duration_minutes // 60
|
304
|
+
duration_minutes = duration_minutes % 60
|
305
|
+
duration_seconds = duration % 60
|
306
|
+
duration_str = f"{duration_hours:02}:{duration_minutes:02}:{duration_seconds:02}"
|
307
|
+
response = self.api.video._download(video_id, is_stream=True)
|
308
|
+
process = subprocess.Popen(
|
309
|
+
[
|
310
|
+
"ffmpeg",
|
311
|
+
"-y",
|
312
|
+
"-t",
|
313
|
+
duration_str,
|
314
|
+
"-probesize",
|
315
|
+
"50M",
|
316
|
+
"-analyzeduration",
|
317
|
+
"50M",
|
318
|
+
"-i",
|
319
|
+
"pipe:0",
|
320
|
+
"-movflags",
|
321
|
+
"frag_keyframe+empty_moov+default_base_moof",
|
322
|
+
"-c",
|
323
|
+
"copy",
|
324
|
+
save_path,
|
325
|
+
],
|
326
|
+
stdin=subprocess.PIPE,
|
327
|
+
stderr=subprocess.PIPE,
|
328
|
+
)
|
329
|
+
|
330
|
+
bytes_written = 0
|
331
|
+
try:
|
332
|
+
for chunk in response.iter_content(chunk_size=8192):
|
333
|
+
process.stdin.write(chunk)
|
334
|
+
bytes_written += len(chunk)
|
335
|
+
if progress_cb:
|
336
|
+
progress_cb(len(chunk))
|
337
|
+
except (BrokenPipeError, IOError):
|
338
|
+
logger.debug("FFmpeg process closed the pipe, stopping download.", exc_info=True)
|
339
|
+
pass
|
340
|
+
finally:
|
341
|
+
process.stdin.close()
|
342
|
+
process.wait()
|
343
|
+
response.close()
|
344
|
+
logger.debug("FFmpeg exited with code: " + str(process.returncode))
|
345
|
+
logger.debug(f"FFmpeg stderr: {process.stderr.read().decode()}")
|
346
|
+
logger.debug(f"Total bytes written: {bytes_written}")
|
347
|
+
try:
|
348
|
+
with VideoFrameReader(save_path) as reader:
|
349
|
+
if len(reader.read_frames()) == 0:
|
350
|
+
return False
|
351
|
+
return True
|
352
|
+
except Exception as e:
|
353
|
+
return False
|
354
|
+
|
355
|
+
def _download_preview_item(self, with_progress: bool = False):
|
356
|
+
input_settings = self.get_input_settings_fn()
|
357
|
+
video_ids = input_settings.get("video_ids", None)
|
358
|
+
if video_ids is None:
|
359
|
+
project_id = input_settings.get("project_id", None)
|
360
|
+
dataset_ids = input_settings.get("dataset_ids", None)
|
361
|
+
if dataset_ids:
|
362
|
+
images = []
|
363
|
+
candidate_ids = list(dataset_ids)
|
364
|
+
random.shuffle(candidate_ids)
|
365
|
+
dataset_id = None
|
366
|
+
for ds_id in candidate_ids:
|
367
|
+
images = self.api.image.get_list(ds_id)
|
368
|
+
if images:
|
369
|
+
dataset_id = ds_id
|
370
|
+
break
|
371
|
+
if not images:
|
372
|
+
raise RuntimeError("No images found in the selected datasets")
|
373
|
+
else:
|
374
|
+
datasets = self.api.dataset.get_list(project_id)
|
375
|
+
total_items = sum(ds.items_count for ds in datasets)
|
376
|
+
if total_items == 0:
|
377
|
+
raise RuntimeError("No images found in the selected datasets")
|
378
|
+
images = []
|
379
|
+
while not images:
|
380
|
+
dataset_id = random.choice(datasets).id
|
381
|
+
images = self.api.image.get_list(dataset_id)
|
382
|
+
image_id = random.choice(images).id
|
383
|
+
image_info = self.api.image.get_info_by_id(image_id)
|
384
|
+
self.image_preview_path = Path(self.preview_dir, image_info.name)
|
385
|
+
self.api.image.download_path(image_id, self.image_preview_path)
|
386
|
+
self._current_item_id = image_id
|
387
|
+
ann_info = self.api.annotation.download(image_id)
|
388
|
+
self._project_meta = ProjectMeta.from_json(
|
389
|
+
self.api.project.get_meta(image_info.project_id)
|
390
|
+
)
|
391
|
+
self._image_annotation = Annotation.from_json(ann_info.annotation, self._project_meta)
|
392
|
+
self.image_peview_url = f"./static/preview/{image_info.name}"
|
393
|
+
elif len(video_ids) == 0:
|
394
|
+
self._current_item_id = None
|
395
|
+
self.video_preview_path = None
|
396
|
+
self.video_peview_url = None
|
397
|
+
self.video_preview_annotated_path = None
|
398
|
+
else:
|
399
|
+
video_id = random.choice(video_ids)
|
400
|
+
video_id = video_ids[0]
|
401
|
+
video_info = self.api.video.get_info_by_id(video_id)
|
402
|
+
self._download_video_preview(video_info, with_progress)
|
403
|
+
self._current_item_id = video_id
|
404
|
+
self.video_peview_url = f"./static/preview/annotated/{video_info.name}"
|
405
|
+
self._project_meta = ProjectMeta.from_json(
|
406
|
+
self.api.project.get_meta(video_info.project_id)
|
407
|
+
)
|
408
|
+
self._video_annotation = VideoAnnotation.from_json(
|
409
|
+
self.api.video.annotation.download(video_id), self._project_meta, KeyIdMap()
|
410
|
+
)
|
411
|
+
|
412
|
+
def set_image_preview(self):
|
413
|
+
|
414
|
+
def _maybe_merge_annotations(
|
415
|
+
source: Annotation,
|
416
|
+
pred: Annotation,
|
417
|
+
predictions_mode: str,
|
418
|
+
model_prediction_suffix: str,
|
419
|
+
iou_threshold: float = None,
|
420
|
+
):
|
421
|
+
project_meta, pred, _ = update_meta_and_ann(
|
422
|
+
self._project_meta, pred, model_prediction_suffix
|
423
|
+
)
|
424
|
+
if predictions_mode == AddPredictionsMode.REPLACE:
|
425
|
+
return pred
|
426
|
+
elif predictions_mode == AddPredictionsMode.IOU_MERGE:
|
427
|
+
iou_threshold = iou_threshold if iou_threshold is not None else 0.9
|
428
|
+
pred = _filter_duplicated_predictions_from_ann(source, pred, iou_threshold)
|
429
|
+
return source.merge(pred)
|
430
|
+
elif predictions_mode in [
|
431
|
+
AddPredictionsMode.APPEND,
|
432
|
+
AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS,
|
433
|
+
]:
|
434
|
+
return source.merge(pred)
|
435
|
+
else:
|
436
|
+
raise RuntimeError(f"Unknown predictions mode: {predictions_mode}")
|
437
|
+
|
438
|
+
self.image_gallery.clean_up()
|
439
|
+
if not self._current_item_id:
|
440
|
+
self._download_preview_item(with_progress=True)
|
441
|
+
image_id = self._current_item_id
|
442
|
+
model_api = self.get_model_api_fn()
|
443
|
+
settings = self.get_settings_fn()
|
444
|
+
inference_settings = settings.get("inference_settings", {})
|
445
|
+
with self.progress("Running Model:", total=1) as pbar:
|
446
|
+
prediction = model_api.predict(
|
447
|
+
image_id=image_id, inference_settings=inference_settings, tqdm=pbar
|
448
|
+
)[0]
|
449
|
+
prediction_annotation = _maybe_merge_annotations(
|
450
|
+
source=self._image_annotation,
|
451
|
+
pred=prediction.annotation,
|
452
|
+
predictions_mode=settings.get("predictions_mode", AddPredictionsMode.APPEND),
|
453
|
+
model_prediction_suffix=settings.get("model_prediction_suffix", ""),
|
454
|
+
iou_threshold=inference_settings.get("existing_objects_iou_thresh"),
|
455
|
+
)
|
456
|
+
self.image_gallery.append(
|
457
|
+
self.image_peview_url, title="Source", annotation=self._image_annotation
|
458
|
+
)
|
459
|
+
self.image_gallery.append(
|
460
|
+
self.image_peview_url, title="Prediction", annotation=prediction_annotation
|
461
|
+
)
|
462
|
+
self.select_item(ProjectType.IMAGES.value)
|
463
|
+
|
464
|
+
def set_video_preview(
|
465
|
+
self,
|
466
|
+
):
|
467
|
+
self.video_player.set_video(None)
|
468
|
+
input_settings = self.get_input_settings_fn()
|
469
|
+
video_ids = input_settings.get("video_ids", None)
|
470
|
+
if not video_ids:
|
471
|
+
raise RuntimeError("No videos selected")
|
472
|
+
if not self._current_item_id:
|
473
|
+
self._download_preview_item(with_progress=True)
|
474
|
+
video_id = self._current_item_id
|
475
|
+
|
476
|
+
frame_start = 0
|
477
|
+
seconds = 5
|
478
|
+
video_info = self.api.video.get_info_by_id(video_id)
|
479
|
+
fps = int(video_info.frames_count / video_info.duration)
|
480
|
+
frames_number = min(video_info.frames_count, int(fps * seconds))
|
481
|
+
model_api = self.get_model_api_fn()
|
482
|
+
project_meta = ProjectMeta.from_json(self.api.project.get_meta(video_info.project_id))
|
483
|
+
|
484
|
+
settings = self.get_settings_fn()
|
485
|
+
inference_settings = settings.get("inference_settings", {})
|
486
|
+
tracking = settings.get("tracking", False)
|
487
|
+
with self.progress("Running model:", total=frames_number) as pbar:
|
488
|
+
with model_api.predict_detached(
|
489
|
+
video_id=video_id,
|
490
|
+
inference_settings=inference_settings,
|
491
|
+
tracking=tracking,
|
492
|
+
start_frame=frame_start,
|
493
|
+
num_frames=frames_number,
|
494
|
+
tqdm=pbar,
|
495
|
+
) as session:
|
496
|
+
predictions: List[Prediction] = list(session)
|
497
|
+
|
498
|
+
if os.path.exists(self.video_preview_annotated_path):
|
499
|
+
os.remove(self.video_preview_annotated_path)
|
500
|
+
if tracking:
|
501
|
+
pred_video_annotation = session.final_result.get("video_ann", {})
|
502
|
+
if pred_video_annotation is None:
|
503
|
+
raise RuntimeError("Model did not return video annotation")
|
504
|
+
pred_video_annotation = VideoAnnotation.from_json(
|
505
|
+
pred_video_annotation, project_meta=project_meta
|
506
|
+
)
|
507
|
+
_, pred_video_annotation, _ = update_meta_and_ann_for_video_annotation(
|
508
|
+
self._project_meta,
|
509
|
+
pred_video_annotation,
|
510
|
+
settings.get("model_prediction_suffix", ""),
|
511
|
+
)
|
512
|
+
visualizer = TrackingVisualizer(
|
513
|
+
output_fps=fps,
|
514
|
+
box_thickness=video_info.frame_height // 110,
|
515
|
+
text_scale=video_info.frame_height / 900,
|
516
|
+
trajectory_thickness=video_info.frame_height // 110,
|
517
|
+
)
|
518
|
+
else:
|
519
|
+
pred_video_annotation = video_annotation_from_predictions(
|
520
|
+
predictions,
|
521
|
+
model_api.get_model_meta(),
|
522
|
+
frame_size=(video_info.frame_height, video_info.frame_width),
|
523
|
+
)
|
524
|
+
visualizer = TrackingVisualizer(
|
525
|
+
output_fps=fps,
|
526
|
+
box_thickness=video_info.frame_height // 110,
|
527
|
+
text_scale=video_info.frame_height / 900,
|
528
|
+
show_trajectories=False,
|
529
|
+
)
|
530
|
+
_, pred_video_annotation, _ = update_meta_and_ann_for_video_annotation(
|
531
|
+
self._project_meta,
|
532
|
+
pred_video_annotation,
|
533
|
+
settings.get("model_prediction_suffix", ""),
|
534
|
+
)
|
535
|
+
visualizer.visualize_video_annotation(
|
536
|
+
pred_video_annotation,
|
537
|
+
source=self.video_preview_path,
|
538
|
+
output_path=self.video_preview_annotated_path,
|
539
|
+
)
|
540
|
+
self.video_player.set_video(self.video_peview_url)
|
541
|
+
self.select_item(ProjectType.VIDEOS.value)
|
542
|
+
|
543
|
+
def set_error(self, text: str):
|
544
|
+
self.error_text.text = text
|
545
|
+
self.select_item("error")
|
546
|
+
|
547
|
+
def run_preview(self):
|
548
|
+
self.download_error.hide()
|
549
|
+
self.select_item("loading")
|
550
|
+
try:
|
551
|
+
input_settings = self.get_input_settings_fn()
|
552
|
+
video_ids = input_settings.get("video_ids", None)
|
553
|
+
if video_ids is None:
|
554
|
+
self.set_image_preview()
|
555
|
+
elif len(video_ids) == 0:
|
556
|
+
self.set_error("No videos selected")
|
557
|
+
else:
|
558
|
+
self.set_video_preview()
|
559
|
+
except Exception as e:
|
560
|
+
logger.error(f"Failed to generate preview: {str(e)}", exc_info=True)
|
561
|
+
self.set_error("Failed to generate preview: " + str(e))
|
562
|
+
|
563
|
+
def _preload_item(self):
|
564
|
+
threading.Thread(
|
565
|
+
target=self._download_preview_item, kwargs={"with_progress": False}, daemon=True
|
566
|
+
).start()
|
567
|
+
|
568
|
+
def update_item_type(self, item_type: str):
|
569
|
+
self.select_item("empty")
|
570
|
+
self._current_item_id = None
|
571
|
+
self.download_error.hide()
|
572
|
+
# self._preload_item() # need to handle race condition with run_preview and multiple clicks
|
573
|
+
|
574
|
+
|
29
575
|
class SettingsSelector:
|
30
|
-
title = "
|
576
|
+
title = "Inference (settings + preview)"
|
31
577
|
description = "Select additional settings for model inference"
|
32
578
|
lock_message = "Select previous step to unlock"
|
33
579
|
|
34
|
-
def __init__(
|
580
|
+
def __init__(
|
581
|
+
self,
|
582
|
+
api: Api,
|
583
|
+
static_dir: str,
|
584
|
+
input_selector: InputSelector,
|
585
|
+
model_selector: ModelSelector,
|
586
|
+
):
|
35
587
|
# Init Step
|
588
|
+
self.api = api
|
589
|
+
self.static_dir = static_dir
|
590
|
+
self.input_selector = input_selector
|
591
|
+
self.model_selector = model_selector
|
36
592
|
self.display_widgets: List[Any] = []
|
37
593
|
# -------------------------------- #
|
38
594
|
|
39
595
|
# Init Base Widgets
|
40
596
|
self.validator_text = None
|
41
597
|
self.button = None
|
598
|
+
self.run_button = None
|
42
599
|
self.container = None
|
43
|
-
self.
|
600
|
+
self.cards = None
|
44
601
|
# -------------------------------- #
|
45
602
|
|
46
603
|
# Init Step Widgets
|
@@ -51,28 +608,72 @@ class SettingsSelector:
|
|
51
608
|
# self.model_prediction_suffix_checkbox = None
|
52
609
|
self.predictions_mode_selector = None
|
53
610
|
self.predictions_mode_field = None
|
54
|
-
self.
|
611
|
+
self.inference_settings_editor = None
|
55
612
|
# -------------------------------- #
|
56
613
|
|
57
|
-
|
58
|
-
self.
|
59
|
-
self.
|
60
|
-
|
614
|
+
self.settings_widgets = []
|
615
|
+
self.image_settings_widgets = []
|
616
|
+
self.video_settings_widgets = []
|
617
|
+
|
618
|
+
# Prediction Mode
|
619
|
+
self.prediction_modes = [
|
620
|
+
AddPredictionsMode.APPEND,
|
621
|
+
AddPredictionsMode.REPLACE,
|
622
|
+
AddPredictionsMode.IOU_MERGE,
|
623
|
+
# AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS, # @TODO: Implement later
|
624
|
+
]
|
625
|
+
self.iou_merge_input = InputNumber(value=0.9, min=0.0, max=1.0, step=0.05, controls=False)
|
626
|
+
self.iou_merge_input_field = Field(
|
627
|
+
content=self.iou_merge_input,
|
628
|
+
title="IoU Threshold",
|
629
|
+
description="IoU threshold for merging predictions with existing labels. Predictions with IoU above this threshold will be considered duplicates and removed.",
|
61
630
|
)
|
62
|
-
self.
|
63
|
-
self.
|
64
|
-
|
65
|
-
|
66
|
-
|
631
|
+
self.prediction_modes_contents = [Empty(), Empty(), self.iou_merge_input_field]
|
632
|
+
self.predictions_mode_selector = Select(
|
633
|
+
items=[
|
634
|
+
Select.Item(mode, content=content)
|
635
|
+
for mode, content in zip(self.prediction_modes, self.prediction_modes_contents)
|
636
|
+
]
|
637
|
+
)
|
638
|
+
self.predictions_mode_selector.set_value(self.prediction_modes[0])
|
639
|
+
self.predicitons_mode_one_of = OneOf(self.predictions_mode_selector)
|
640
|
+
self.predictions_mode_field = Field(
|
641
|
+
content=Container(
|
642
|
+
widgets=[self.predictions_mode_selector, self.predicitons_mode_one_of]
|
643
|
+
),
|
644
|
+
title="Add predictions mode",
|
645
|
+
description="Select how to add predictions to the project: by merging with existing labels or by replacing them.",
|
67
646
|
)
|
68
647
|
# Add widgets to display ------------ #
|
69
|
-
self.
|
648
|
+
self.image_settings_widgets.extend([self.predictions_mode_field])
|
649
|
+
# ----------------------------------- #
|
650
|
+
|
651
|
+
# Tracking
|
652
|
+
self.tracking_checkbox = Checkbox(content="Enable tracking", checked=True)
|
653
|
+
self.tracking_checkbox_field = Field(
|
654
|
+
content=self.tracking_checkbox,
|
655
|
+
title="Tracking",
|
656
|
+
description="Enable tracking for video predictions. The tracking algorithm is BoT-SORT version improved by Supervisely team.",
|
657
|
+
)
|
658
|
+
# Add widgets to display ------------ #
|
659
|
+
self.video_settings_widgets.extend([self.tracking_checkbox_field])
|
660
|
+
self.image_settings_container = Container(widgets=self.image_settings_widgets, gap=15)
|
661
|
+
self.video_settings_container = Container(widgets=self.video_settings_widgets, gap=15)
|
662
|
+
self.image_or_video_container = Container(
|
663
|
+
widgets=[self.image_settings_container, self.video_settings_container], gap=0
|
664
|
+
)
|
665
|
+
self.video_settings_container.hide()
|
666
|
+
self.settings_widgets.extend([self.image_or_video_container])
|
70
667
|
# ----------------------------------- #
|
71
668
|
|
72
669
|
# Class / Tag Suffix
|
73
670
|
self.model_prediction_suffix_input = Input(
|
74
671
|
value="_model", minlength=1, placeholder="Enter suffix e.g: _model"
|
75
672
|
)
|
673
|
+
self.model_meta_has_conflicting_names_text = Text(
|
674
|
+
text="Project and Model metas have conflicting names. This suffix will be added to conflicting class and tag names of model predictions",
|
675
|
+
status="info",
|
676
|
+
)
|
76
677
|
self.model_prediction_suffix_field = Field(
|
77
678
|
content=self.model_prediction_suffix_input,
|
78
679
|
title="Class and tag suffix",
|
@@ -82,38 +683,40 @@ class SettingsSelector:
|
|
82
683
|
"then suffix will be added to the model predictions to avoid conflicts. E.g. 'person_model'."
|
83
684
|
),
|
84
685
|
)
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
# Prediction Mode
|
93
|
-
self.prediction_modes = [
|
94
|
-
AddPredictionsMode.MERGE_WITH_EXISTING_LABELS,
|
95
|
-
AddPredictionsMode.REPLACE_EXISTING_LABELS,
|
96
|
-
# AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS, # @TODO: Implement later
|
97
|
-
]
|
98
|
-
self.predictions_mode_selector = Select(
|
99
|
-
items=[Select.Item(mode) for mode in self.prediction_modes]
|
100
|
-
)
|
101
|
-
self.predictions_mode_selector.set_value(self.prediction_modes[0])
|
102
|
-
self.predictions_mode_field = Field(
|
103
|
-
content=self.predictions_mode_selector,
|
104
|
-
title="Add predictions mode",
|
105
|
-
description="Select how to add predictions to the project: by merging with existing labels or by replacing them.",
|
686
|
+
self.model_prediction_suffix_container = Container(
|
687
|
+
widgets=[
|
688
|
+
self.model_meta_has_conflicting_names_text,
|
689
|
+
self.model_prediction_suffix_field,
|
690
|
+
],
|
691
|
+
gap=5,
|
106
692
|
)
|
693
|
+
self.model_prediction_suffix_container.hide()
|
107
694
|
# Add widgets to display ------------ #
|
108
|
-
self.
|
695
|
+
self.settings_widgets.extend([self.model_prediction_suffix_container])
|
109
696
|
# ----------------------------------- #
|
110
697
|
|
111
698
|
# Inference Settings
|
112
|
-
self.
|
699
|
+
self.inference_settings_editor = Editor("", language_mode="yaml", height_px=300)
|
700
|
+
self.inference_settings_field = Field(
|
701
|
+
content=self.inference_settings_editor,
|
702
|
+
title="Inference and Tracking Settings",
|
703
|
+
)
|
113
704
|
# Add widgets to display ------------ #
|
114
|
-
self.
|
705
|
+
self.settings_widgets.extend([self.inference_settings_field])
|
115
706
|
# ----------------------------------- #
|
116
707
|
|
708
|
+
# Preview
|
709
|
+
self.preview_dir = os.path.join(self.static_dir, "preview")
|
710
|
+
self.preview = Preview(
|
711
|
+
api=self.api,
|
712
|
+
preview_dir=self.preview_dir,
|
713
|
+
get_model_api_fn=lambda: self.model_selector.model.model_api,
|
714
|
+
get_input_settings_fn=self.input_selector.get_settings,
|
715
|
+
get_settings_fn=self.get_settings,
|
716
|
+
)
|
717
|
+
|
718
|
+
self.settings_container = Container(widgets=self.settings_widgets, gap=15)
|
719
|
+
self.display_widgets.extend([self.settings_container])
|
117
720
|
# Base Widgets
|
118
721
|
self.validator_text = Text("")
|
119
722
|
self.validator_text.hide()
|
@@ -124,49 +727,122 @@ class SettingsSelector:
|
|
124
727
|
|
125
728
|
# Card Layout
|
126
729
|
self.container = Container(self.display_widgets)
|
127
|
-
self.
|
730
|
+
self.settings_card = Card(
|
128
731
|
title=self.title,
|
129
732
|
description=self.description,
|
130
733
|
content=self.container,
|
131
734
|
lock_message=self.lock_message,
|
132
735
|
)
|
133
|
-
self.card
|
736
|
+
self.cards = [self.settings_card, self.preview.card]
|
737
|
+
self.cards_container = Container(
|
738
|
+
widgets=self.cards,
|
739
|
+
gap=15,
|
740
|
+
direction="horizontal",
|
741
|
+
fractions=[3, 7],
|
742
|
+
)
|
134
743
|
# ----------------------------------- #
|
135
744
|
|
745
|
+
def lock(self):
|
746
|
+
self.settings_card.lock(self.lock_message)
|
747
|
+
self.preview.lock()
|
748
|
+
|
749
|
+
def unlock(self):
|
750
|
+
self.settings_card.unlock()
|
751
|
+
self.preview.unlock()
|
752
|
+
|
753
|
+
def disable(self):
|
754
|
+
for widget in self.widgets_to_disable:
|
755
|
+
if isinstance(widget, Editor):
|
756
|
+
widget.readonly = True
|
757
|
+
else:
|
758
|
+
widget.disable()
|
759
|
+
|
760
|
+
def enable(self):
|
761
|
+
for widget in self.widgets_to_disable:
|
762
|
+
if isinstance(widget, Editor):
|
763
|
+
widget.readonly = False
|
764
|
+
else:
|
765
|
+
widget.enable()
|
766
|
+
|
136
767
|
@property
|
137
|
-
def widgets_to_disable(self) ->
|
768
|
+
def widgets_to_disable(self) -> List[Widget]:
|
138
769
|
return [
|
139
|
-
self.
|
140
|
-
self.model_prediction_suffix_input,
|
141
|
-
# self.model_prediction_suffix_checkbox,
|
770
|
+
self.tracking_checkbox,
|
142
771
|
self.predictions_mode_selector,
|
143
|
-
self.
|
772
|
+
self.model_prediction_suffix_input,
|
773
|
+
self.inference_settings_editor,
|
144
774
|
]
|
145
775
|
|
146
776
|
def set_inference_settings(self, settings: Dict[str, Any]):
|
777
|
+
settings = "# Inference settings\n" + settings
|
147
778
|
if isinstance(settings, str):
|
148
|
-
self.
|
779
|
+
self.inference_settings_editor.set_text(settings)
|
149
780
|
else:
|
150
|
-
self.
|
781
|
+
self.inference_settings_editor.set_text(yaml.safe_dump(settings))
|
782
|
+
|
783
|
+
def set_tracking_settings(self, settings: Dict[str, Any]):
|
784
|
+
if self.input_selector.radio.get_value() != ProjectType.VIDEOS.value:
|
785
|
+
return
|
786
|
+
|
787
|
+
current_settings = self.inference_settings_editor.get_text()
|
788
|
+
if isinstance(settings, str):
|
789
|
+
all_settings = current_settings + "\n\n# Tracking settings\n" + settings
|
790
|
+
self.inference_settings_editor.set_text(all_settings)
|
791
|
+
else:
|
792
|
+
all_settings = current_settings + "\n\n# Tracking settings\n" + yaml.safe_dump(settings)
|
793
|
+
self.inference_settings_editor.set_text(all_settings)
|
794
|
+
|
795
|
+
def set_default_tracking_settings(self):
|
796
|
+
nn_dir = Path(__file__).parents[3]
|
797
|
+
config_path = nn_dir / "tracker" / "botsort" / "botsort_config.yaml"
|
798
|
+
|
799
|
+
with open(config_path, "r", encoding="utf-8") as file:
|
800
|
+
botsort_config = yaml.safe_load(file)
|
801
|
+
self.set_tracking_settings(botsort_config)
|
151
802
|
|
152
803
|
def get_inference_settings(self) -> Dict:
|
153
|
-
|
154
|
-
|
155
|
-
|
804
|
+
text = self.inference_settings_editor.get_text()
|
805
|
+
inference_settings_text = text.split("# Tracking settings")[0]
|
806
|
+
settings = yaml.safe_load(inference_settings_text)
|
807
|
+
settings = settings if settings is not None else {}
|
808
|
+
if (
|
809
|
+
self.input_selector.radio.get_value() == ProjectType.IMAGES.value
|
810
|
+
and self.predictions_mode_selector.get_value() == AddPredictionsMode.IOU_MERGE
|
811
|
+
):
|
812
|
+
settings["existing_objects_iou_thresh"] = self.iou_merge_input.get_value()
|
813
|
+
return settings
|
814
|
+
|
815
|
+
def get_tracking_settings(self) -> Dict:
|
816
|
+
if self.input_selector.radio.get_value() != ProjectType.VIDEOS.value:
|
817
|
+
return {}
|
818
|
+
|
819
|
+
text = self.inference_settings_editor.get_text()
|
820
|
+
text_parts = text.split("# Tracking settings")
|
821
|
+
if len(text_parts) > 1:
|
822
|
+
tracking_settings_text = text_parts[1]
|
823
|
+
settings = yaml.safe_load(tracking_settings_text)
|
824
|
+
if settings:
|
825
|
+
return settings
|
156
826
|
return {}
|
157
827
|
|
158
828
|
def get_settings(self) -> Dict[str, Any]:
|
159
|
-
|
160
|
-
"inference_mode": self.inference_mode_selector.get_value(),
|
829
|
+
settings = {
|
830
|
+
# "inference_mode": self.inference_mode_selector.get_value(),
|
831
|
+
"inference_mode": InferenceMode.FULL_IMAGE,
|
161
832
|
"model_prediction_suffix": self.model_prediction_suffix_input.get_value(),
|
162
833
|
"predictions_mode": self.predictions_mode_selector.get_value(),
|
163
834
|
"inference_settings": self.get_inference_settings(),
|
164
835
|
}
|
836
|
+
if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
|
837
|
+
settings["tracking_settings"] = self.get_tracking_settings()
|
838
|
+
if self.input_selector.get_settings().get("video_ids", None) is not None:
|
839
|
+
settings["tracking"] = self.tracking_checkbox.is_checked()
|
840
|
+
return settings
|
165
841
|
|
166
842
|
def load_from_json(self, data):
|
167
|
-
inference_mode = data.get("inference_mode", None)
|
168
|
-
if inference_mode:
|
169
|
-
|
843
|
+
# inference_mode = data.get("inference_mode", None)
|
844
|
+
# if inference_mode:
|
845
|
+
# self.inference_mode_selector.set_value(inference_mode)
|
170
846
|
|
171
847
|
model_prediction_suffix = data.get("model_prediction_suffix", None)
|
172
848
|
if model_prediction_suffix is not None:
|
@@ -180,5 +856,26 @@ class SettingsSelector:
|
|
180
856
|
if inference_settings is not None:
|
181
857
|
self.set_inference_settings(inference_settings)
|
182
858
|
|
859
|
+
tracking_settings = data.get("tracking_settings", None)
|
860
|
+
if tracking_settings is not None:
|
861
|
+
self.set_tracking_settings(tracking_settings)
|
862
|
+
|
863
|
+
tracking = data.get("tracking", None)
|
864
|
+
if tracking == True:
|
865
|
+
self.tracking_checkbox.check()
|
866
|
+
elif tracking == False:
|
867
|
+
self.tracking_checkbox.uncheck()
|
868
|
+
|
869
|
+
def update_item_type(self, item_type: str):
|
870
|
+
if item_type == ProjectType.IMAGES.value:
|
871
|
+
self.video_settings_container.hide()
|
872
|
+
self.image_settings_container.show()
|
873
|
+
elif item_type == ProjectType.VIDEOS.value:
|
874
|
+
self.image_settings_container.hide()
|
875
|
+
self.video_settings_container.show()
|
876
|
+
else:
|
877
|
+
raise ValueError(f"Unsupported item type: {item_type}")
|
878
|
+
self.preview.update_item_type(item_type)
|
879
|
+
|
183
880
|
def validate_step(self) -> bool:
|
184
881
|
return True
|