supervisely 6.73.456__py3-none-any.whl → 6.73.458__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. supervisely/__init__.py +24 -1
  2. supervisely/api/image_api.py +4 -0
  3. supervisely/api/video/video_annotation_api.py +4 -2
  4. supervisely/api/video/video_api.py +41 -1
  5. supervisely/app/v1/app_service.py +18 -2
  6. supervisely/app/v1/constants.py +7 -1
  7. supervisely/app/widgets/card/card.py +20 -0
  8. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  9. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  10. supervisely/app/widgets/fast_table/fast_table.py +45 -11
  11. supervisely/app/widgets/fast_table/template.html +1 -1
  12. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  13. supervisely/app/widgets/radio_tabs/template.html +1 -0
  14. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +63 -7
  15. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  16. supervisely/nn/inference/cache.py +2 -2
  17. supervisely/nn/inference/inference.py +364 -73
  18. supervisely/nn/inference/inference_request.py +3 -2
  19. supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
  20. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  21. supervisely/nn/inference/predict_app/gui/input_selector.py +178 -25
  22. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  23. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  24. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  25. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  26. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  27. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  28. supervisely/nn/model/model_api.py +9 -0
  29. supervisely/nn/tracker/base_tracker.py +11 -1
  30. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  31. supervisely/nn/tracker/botsort_tracker.py +14 -7
  32. supervisely/nn/tracker/visualize.py +70 -72
  33. supervisely/video/video.py +15 -1
  34. supervisely/worker_api/agent_rpc.py +24 -1
  35. supervisely/worker_api/rpc_servicer.py +31 -7
  36. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/METADATA +3 -2
  37. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/RECORD +41 -41
  38. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/LICENSE +0 -0
  39. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/WHEEL +0 -0
  40. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/entry_points.txt +0 -0
  41. {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,54 @@
1
- from typing import Any, Dict, List
1
+ import os
2
+ import random
3
+ import shutil
4
+ import subprocess
5
+ import threading
6
+ from contextlib import contextmanager, nullcontext
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, List
2
9
 
10
+ import cv2
3
11
  import yaml
4
12
 
13
+ from supervisely._utils import logger
14
+ from supervisely.annotation.annotation import Annotation
15
+ from supervisely.annotation.label import Label
16
+ from supervisely.api.api import Api
17
+ from supervisely.api.video.video_api import VideoInfo
5
18
  from supervisely.app.widgets import (
6
19
  Button,
7
20
  Card,
8
- Checkbox,
9
21
  Container,
10
22
  Editor,
11
23
  Field,
24
+ GridGallery,
12
25
  Input,
26
+ InputNumber,
27
+ OneOf,
28
+ Progress,
13
29
  Select,
14
30
  Text,
31
+ VideoPlayer,
15
32
  )
33
+ from supervisely.app.widgets.checkbox.checkbox import Checkbox
34
+ from supervisely.app.widgets.empty.empty import Empty
35
+ from supervisely.app.widgets.widget import Widget
36
+ from supervisely.nn.inference.inference import (
37
+ _filter_duplicated_predictions_from_ann,
38
+ update_meta_and_ann,
39
+ update_meta_and_ann_for_video_annotation,
40
+ )
41
+ from supervisely.nn.inference.predict_app.gui.input_selector import InputSelector
42
+ from supervisely.nn.inference.predict_app.gui.model_selector import ModelSelector
43
+ from supervisely.nn.inference.predict_app.gui.utils import (
44
+ video_annotation_from_predictions,
45
+ )
46
+ from supervisely.nn.model.model_api import ModelAPI, Prediction
47
+ from supervisely.nn.tracker import TrackingVisualizer
48
+ from supervisely.project import ProjectMeta
49
+ from supervisely.project.project_meta import ProjectType
50
+ from supervisely.video.video import VideoFrameReader
51
+ from supervisely.video_annotation.video_annotation import KeyIdMap, VideoAnnotation
16
52
 
17
53
 
18
54
  class InferenceMode:
@@ -21,26 +57,547 @@ class InferenceMode:
21
57
 
22
58
 
23
59
  class AddPredictionsMode:
24
- MERGE_WITH_EXISTING_LABELS = "Merge with existing labels"
25
- REPLACE_EXISTING_LABELS = "Replace existing labels"
60
+ APPEND = "Merge with existing labels"
61
+ REPLACE = "Replace existing labels"
62
+ IOU_MERGE = "Merge by IoU threshold"
26
63
  REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS = "Replace existing labels and save image tags"
27
64
 
28
65
 
66
+ class Preview:
67
+ lock_message = "Select previous step to unlock"
68
+
69
+ def __init__(
70
+ self,
71
+ api: Api,
72
+ preview_dir: str,
73
+ get_model_api_fn: Callable[[], ModelAPI],
74
+ get_input_settings_fn: Callable[[], Dict[str, Any]],
75
+ get_settings_fn: Callable[[], Dict[str, Any]],
76
+ ):
77
+ self.api = api
78
+ self.preview_dir = preview_dir
79
+ self.get_model_api_fn = get_model_api_fn
80
+ self.get_input_settings_fn = get_input_settings_fn
81
+ self.get_settings_fn = get_settings_fn
82
+ os.makedirs(self.preview_dir, exist_ok=True)
83
+ os.makedirs(Path(self.preview_dir, "annotated"), exist_ok=True)
84
+ self.image_preview_path = None
85
+ self.image_peview_url = None
86
+ self.video_preview_path = None
87
+ self.video_preview_annotated_path = None
88
+ self.video_peview_url = None
89
+
90
+ self.progress_widget = Progress(show_percents=True, hide_on_finish=True)
91
+ self.download_error = Text("", status="warning")
92
+ self.download_error.hide()
93
+ self.progress_container = Container(widgets=[self.download_error, self.progress_widget])
94
+ self.loading_container = Container(widgets=[self.download_error, Text("Loading...")])
95
+
96
+ self.image_gallery = GridGallery(
97
+ 2,
98
+ sync_views=True,
99
+ enable_zoom=True,
100
+ resize_on_zoom=True,
101
+ empty_message="",
102
+ )
103
+ self.image_preview_container = Container(widgets=[self.image_gallery])
104
+
105
+ self.video_player = VideoPlayer()
106
+ self.video_preview_container = Container(widgets=[self.video_player])
107
+
108
+ self.locked_text = Text("Select input and model to unlock", status="info")
109
+ self.empty_text = Text("Click preview to visualize predictions")
110
+ self.error_text = Text("Failed to generate preview", status="error")
111
+
112
+ self.select = Select(
113
+ items=[
114
+ Select.Item("locked", content=self.locked_text),
115
+ Select.Item("empty", content=self.empty_text),
116
+ Select.Item(ProjectType.IMAGES.value, content=self.image_preview_container),
117
+ Select.Item(ProjectType.VIDEOS.value, content=self.video_preview_container),
118
+ Select.Item("error", content=self.error_text),
119
+ Select.Item("loading", content=self.loading_container),
120
+ Select.Item("progress", content=self.progress_container),
121
+ ]
122
+ )
123
+ self.select.set_value("empty")
124
+ self.oneof = OneOf(self.select)
125
+
126
+ self.run_button = Button("Preview", icon="zmdi zmdi-slideshow")
127
+ self.run_button.disable()
128
+ self.card = Card(
129
+ title="Preview",
130
+ description="Preview model predictions on a random image or video from the selected input source.",
131
+ content=self.oneof,
132
+ content_top_right=self.run_button,
133
+ lock_message=self.lock_message,
134
+ )
135
+
136
+ @self.run_button.click
137
+ def _run_preview():
138
+ self.run_preview()
139
+
140
+ def lock(self):
141
+ self.run_button.disable()
142
+ self.card.lock(self.lock_message)
143
+
144
+ def unlock(self):
145
+ self.run_button.enable()
146
+ self.card.unlock()
147
+
148
+ @contextmanager
149
+ def progress(self, message: str, total: int, **kwargs):
150
+ current_item = self.select.get_value()
151
+ try:
152
+ with self.progress_widget(message=message, total=total, **kwargs) as pbar:
153
+ self.select_item("progress")
154
+ yield pbar
155
+ finally:
156
+ self.select_item(current_item)
157
+
158
+ def select_item(self, item: str):
159
+ self.select.set_value(item)
160
+
161
+ def _download_video_by_frames(
162
+ self, video_info: VideoInfo, save_path: str, frames_number=150, progress_cb=None
163
+ ):
164
+ if Path(save_path).exists():
165
+ Path(save_path).unlink()
166
+ tmp_dir = Path(self.preview_dir, "tmp_frames")
167
+ if tmp_dir.exists():
168
+ shutil.rmtree(tmp_dir)
169
+ os.makedirs(tmp_dir, exist_ok=True)
170
+ self.api.video.download_frames(
171
+ video_info.id,
172
+ frames=list(range(frames_number)),
173
+ paths=[str(tmp_dir / f"frame_{i}.jpg") for i in range(frames_number)],
174
+ progress_cb=progress_cb,
175
+ )
176
+ fps = int(video_info.frames_count / video_info.duration)
177
+ fourcc = cv2.VideoWriter.fourcc(*"mp4v") # or 'avc1', 'XVID', 'H264'
178
+ out = cv2.VideoWriter(
179
+ save_path, fourcc, fps, (video_info.frame_width, video_info.frame_height)
180
+ )
181
+ for i in range(frames_number):
182
+ frame_path = tmp_dir / f"frame_{i}.jpg"
183
+ if not frame_path.exists():
184
+ continue
185
+ img = cv2.imread(str(frame_path))
186
+ out.write(img)
187
+ out.release()
188
+ shutil.rmtree(tmp_dir)
189
+
190
+ def _download_full_video(
191
+ self, video_id: int, save_path: str, duration: int = 5, progress_cb=None
192
+ ):
193
+ if Path(save_path).exists():
194
+ Path(save_path).unlink()
195
+ temp = Path(self.preview_dir) / f"temp_{video_id}.mp4"
196
+ if temp.exists():
197
+ temp.unlink()
198
+ self.api.video.download_path(video_id, temp, progress_cb=progress_cb)
199
+ minutes = duration // 60
200
+ hours = minutes // 60
201
+ minutes = minutes % 60
202
+ seconds = duration % 60
203
+ duration_str = f"{hours:02}:{minutes:02}:{seconds:02}"
204
+ try:
205
+ process = subprocess.Popen(
206
+ [
207
+ "ffmpeg",
208
+ "-y",
209
+ "-i",
210
+ str(temp),
211
+ "-c",
212
+ "copy",
213
+ "-t",
214
+ duration_str,
215
+ save_path,
216
+ ],
217
+ stderr=subprocess.PIPE,
218
+ )
219
+ process.wait()
220
+ logger.debug("FFmpeg exited with code: " + str(process.returncode))
221
+ logger.debug(f"FFmpeg stderr: {process.stderr.read().decode()}")
222
+ if len(VideoFrameReader(save_path).read_frames()) == 0:
223
+ raise RuntimeError("No frames read from the video")
224
+ temp.unlink()
225
+ except Exception as e:
226
+ if Path(save_path).exists():
227
+ Path(save_path).unlink()
228
+ shutil.copy(temp, save_path)
229
+ temp.unlink()
230
+ logger.warning(f"FFmpeg trimming failed: {str(e)}", exc_info=True)
231
+
232
+ def _download_video_preview(self, video_info: VideoInfo, with_progress=True):
233
+ video_id = video_info.id
234
+ duration = 5
235
+ video_path = Path(self.preview_dir, video_info.name)
236
+ self.video_preview_path = video_path
237
+ self.video_preview_annotated_path = Path(self.preview_dir, "annotated") / Path(
238
+ self.video_preview_path
239
+ ).relative_to(self.preview_dir)
240
+ success = False
241
+ try:
242
+ try:
243
+ size = int(video_info.file_meta["size"])
244
+ size = int(size / video_info.duration * duration)
245
+ except:
246
+ size = None
247
+ with (
248
+ self.progress("Downloading video part:", total=size, unit="B", unit_scale=True)
249
+ if with_progress and size
250
+ else nullcontext()
251
+ ) as pbar:
252
+ success = self._partial_download(
253
+ video_id, duration, str(self.video_preview_path), progress_cb=pbar.update
254
+ )
255
+ except Exception as e:
256
+ logger.warning(f"Partial download failed: {str(e)}", exc_info=True)
257
+ success = False
258
+ if success:
259
+ return
260
+
261
+ video_length_threshold = 120 # seconds
262
+ if video_info.duration > video_length_threshold:
263
+ self.download_error.text = (
264
+ f"Partial download failed. Will Download separate video frames"
265
+ )
266
+ self.download_error.show()
267
+
268
+ fps = int(video_info.frames_count / video_info.duration)
269
+ frames_number = min(video_info.frames_count, int(fps * duration))
270
+ with (
271
+ self.progress(
272
+ "Downloading video frames:", total=frames_number, unit="it", unit_scale=False
273
+ )
274
+ if with_progress
275
+ else nullcontext()
276
+ ) as pbar:
277
+ self._download_video_by_frames(
278
+ video_info,
279
+ str(self.video_preview_path),
280
+ frames_number=frames_number,
281
+ progress_cb=pbar.update,
282
+ )
283
+ else:
284
+ self.download_error.text = f"Partial download failed. Will Download full video"
285
+ self.download_error.show()
286
+ size = int(video_info.file_meta["size"])
287
+ with (
288
+ self.progress("Downloading video:", total=size, unit="B", unit_scale=True)
289
+ if with_progress
290
+ else nullcontext()
291
+ ) as pbar:
292
+ self._download_full_video(
293
+ video_info.id,
294
+ str(self.video_preview_path),
295
+ duration=duration,
296
+ progress_cb=pbar.update,
297
+ )
298
+
299
+ def _partial_download(self, video_id: int, duration: int, save_path: str, progress_cb=None):
300
+ if Path(save_path).exists():
301
+ Path(save_path).unlink()
302
+ duration_minutes = duration // 60
303
+ duration_hours = duration_minutes // 60
304
+ duration_minutes = duration_minutes % 60
305
+ duration_seconds = duration % 60
306
+ duration_str = f"{duration_hours:02}:{duration_minutes:02}:{duration_seconds:02}"
307
+ response = self.api.video._download(video_id, is_stream=True)
308
+ process = subprocess.Popen(
309
+ [
310
+ "ffmpeg",
311
+ "-y",
312
+ "-t",
313
+ duration_str,
314
+ "-probesize",
315
+ "50M",
316
+ "-analyzeduration",
317
+ "50M",
318
+ "-i",
319
+ "pipe:0",
320
+ "-movflags",
321
+ "frag_keyframe+empty_moov+default_base_moof",
322
+ "-c",
323
+ "copy",
324
+ save_path,
325
+ ],
326
+ stdin=subprocess.PIPE,
327
+ stderr=subprocess.PIPE,
328
+ )
329
+
330
+ bytes_written = 0
331
+ try:
332
+ for chunk in response.iter_content(chunk_size=8192):
333
+ process.stdin.write(chunk)
334
+ bytes_written += len(chunk)
335
+ if progress_cb:
336
+ progress_cb(len(chunk))
337
+ except (BrokenPipeError, IOError):
338
+ logger.debug("FFmpeg process closed the pipe, stopping download.", exc_info=True)
339
+ pass
340
+ finally:
341
+ process.stdin.close()
342
+ process.wait()
343
+ response.close()
344
+ logger.debug("FFmpeg exited with code: " + str(process.returncode))
345
+ logger.debug(f"FFmpeg stderr: {process.stderr.read().decode()}")
346
+ logger.debug(f"Total bytes written: {bytes_written}")
347
+ try:
348
+ with VideoFrameReader(save_path) as reader:
349
+ if len(reader.read_frames()) == 0:
350
+ return False
351
+ return True
352
+ except Exception as e:
353
+ return False
354
+
355
+ def _download_preview_item(self, with_progress: bool = False):
356
+ input_settings = self.get_input_settings_fn()
357
+ video_ids = input_settings.get("video_ids", None)
358
+ if video_ids is None:
359
+ project_id = input_settings.get("project_id", None)
360
+ dataset_ids = input_settings.get("dataset_ids", None)
361
+ if dataset_ids:
362
+ images = []
363
+ candidate_ids = list(dataset_ids)
364
+ random.shuffle(candidate_ids)
365
+ dataset_id = None
366
+ for ds_id in candidate_ids:
367
+ images = self.api.image.get_list(ds_id)
368
+ if images:
369
+ dataset_id = ds_id
370
+ break
371
+ if not images:
372
+ raise RuntimeError("No images found in the selected datasets")
373
+ else:
374
+ datasets = self.api.dataset.get_list(project_id)
375
+ total_items = sum(ds.items_count for ds in datasets)
376
+ if total_items == 0:
377
+ raise RuntimeError("No images found in the selected datasets")
378
+ images = []
379
+ while not images:
380
+ dataset_id = random.choice(datasets).id
381
+ images = self.api.image.get_list(dataset_id)
382
+ image_id = random.choice(images).id
383
+ image_info = self.api.image.get_info_by_id(image_id)
384
+ self.image_preview_path = Path(self.preview_dir, image_info.name)
385
+ self.api.image.download_path(image_id, self.image_preview_path)
386
+ self._current_item_id = image_id
387
+ ann_info = self.api.annotation.download(image_id)
388
+ self._project_meta = ProjectMeta.from_json(
389
+ self.api.project.get_meta(image_info.project_id)
390
+ )
391
+ self._image_annotation = Annotation.from_json(ann_info.annotation, self._project_meta)
392
+ self.image_peview_url = f"./static/preview/{image_info.name}"
393
+ elif len(video_ids) == 0:
394
+ self._current_item_id = None
395
+ self.video_preview_path = None
396
+ self.video_peview_url = None
397
+ self.video_preview_annotated_path = None
398
+ else:
399
+ video_id = random.choice(video_ids)
400
+ video_id = video_ids[0]
401
+ video_info = self.api.video.get_info_by_id(video_id)
402
+ self._download_video_preview(video_info, with_progress)
403
+ self._current_item_id = video_id
404
+ self.video_peview_url = f"./static/preview/annotated/{video_info.name}"
405
+ self._project_meta = ProjectMeta.from_json(
406
+ self.api.project.get_meta(video_info.project_id)
407
+ )
408
+ self._video_annotation = VideoAnnotation.from_json(
409
+ self.api.video.annotation.download(video_id), self._project_meta, KeyIdMap()
410
+ )
411
+
412
+ def set_image_preview(self):
413
+
414
+ def _maybe_merge_annotations(
415
+ source: Annotation,
416
+ pred: Annotation,
417
+ predictions_mode: str,
418
+ model_prediction_suffix: str,
419
+ iou_threshold: float = None,
420
+ ):
421
+ project_meta, pred, _ = update_meta_and_ann(
422
+ self._project_meta, pred, model_prediction_suffix
423
+ )
424
+ if predictions_mode == AddPredictionsMode.REPLACE:
425
+ return pred
426
+ elif predictions_mode == AddPredictionsMode.IOU_MERGE:
427
+ iou_threshold = iou_threshold if iou_threshold is not None else 0.9
428
+ pred = _filter_duplicated_predictions_from_ann(source, pred, iou_threshold)
429
+ return source.merge(pred)
430
+ elif predictions_mode in [
431
+ AddPredictionsMode.APPEND,
432
+ AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS,
433
+ ]:
434
+ return source.merge(pred)
435
+ else:
436
+ raise RuntimeError(f"Unknown predictions mode: {predictions_mode}")
437
+
438
+ self.image_gallery.clean_up()
439
+ if not self._current_item_id:
440
+ self._download_preview_item(with_progress=True)
441
+ image_id = self._current_item_id
442
+ model_api = self.get_model_api_fn()
443
+ settings = self.get_settings_fn()
444
+ inference_settings = settings.get("inference_settings", {})
445
+ with self.progress("Running Model:", total=1) as pbar:
446
+ prediction = model_api.predict(
447
+ image_id=image_id, inference_settings=inference_settings, tqdm=pbar
448
+ )[0]
449
+ prediction_annotation = _maybe_merge_annotations(
450
+ source=self._image_annotation,
451
+ pred=prediction.annotation,
452
+ predictions_mode=settings.get("predictions_mode", AddPredictionsMode.APPEND),
453
+ model_prediction_suffix=settings.get("model_prediction_suffix", ""),
454
+ iou_threshold=inference_settings.get("existing_objects_iou_thresh"),
455
+ )
456
+ self.image_gallery.append(
457
+ self.image_peview_url, title="Source", annotation=self._image_annotation
458
+ )
459
+ self.image_gallery.append(
460
+ self.image_peview_url, title="Prediction", annotation=prediction_annotation
461
+ )
462
+ self.select_item(ProjectType.IMAGES.value)
463
+
464
+ def set_video_preview(
465
+ self,
466
+ ):
467
+ self.video_player.set_video(None)
468
+ input_settings = self.get_input_settings_fn()
469
+ video_ids = input_settings.get("video_ids", None)
470
+ if not video_ids:
471
+ raise RuntimeError("No videos selected")
472
+ if not self._current_item_id:
473
+ self._download_preview_item(with_progress=True)
474
+ video_id = self._current_item_id
475
+
476
+ frame_start = 0
477
+ seconds = 5
478
+ video_info = self.api.video.get_info_by_id(video_id)
479
+ fps = int(video_info.frames_count / video_info.duration)
480
+ frames_number = min(video_info.frames_count, int(fps * seconds))
481
+ model_api = self.get_model_api_fn()
482
+ project_meta = ProjectMeta.from_json(self.api.project.get_meta(video_info.project_id))
483
+
484
+ settings = self.get_settings_fn()
485
+ inference_settings = settings.get("inference_settings", {})
486
+ tracking = settings.get("tracking", False)
487
+ with self.progress("Running model:", total=frames_number) as pbar:
488
+ with model_api.predict_detached(
489
+ video_id=video_id,
490
+ inference_settings=inference_settings,
491
+ tracking=tracking,
492
+ start_frame=frame_start,
493
+ num_frames=frames_number,
494
+ tqdm=pbar,
495
+ ) as session:
496
+ predictions: List[Prediction] = list(session)
497
+
498
+ if os.path.exists(self.video_preview_annotated_path):
499
+ os.remove(self.video_preview_annotated_path)
500
+ if tracking:
501
+ pred_video_annotation = session.final_result.get("video_ann", {})
502
+ if pred_video_annotation is None:
503
+ raise RuntimeError("Model did not return video annotation")
504
+ pred_video_annotation = VideoAnnotation.from_json(
505
+ pred_video_annotation, project_meta=project_meta
506
+ )
507
+ _, pred_video_annotation, _ = update_meta_and_ann_for_video_annotation(
508
+ self._project_meta,
509
+ pred_video_annotation,
510
+ settings.get("model_prediction_suffix", ""),
511
+ )
512
+ visualizer = TrackingVisualizer(
513
+ output_fps=fps,
514
+ box_thickness=video_info.frame_height // 110,
515
+ text_scale=video_info.frame_height / 900,
516
+ trajectory_thickness=video_info.frame_height // 110,
517
+ )
518
+ else:
519
+ pred_video_annotation = video_annotation_from_predictions(
520
+ predictions,
521
+ model_api.get_model_meta(),
522
+ frame_size=(video_info.frame_height, video_info.frame_width),
523
+ )
524
+ visualizer = TrackingVisualizer(
525
+ output_fps=fps,
526
+ box_thickness=video_info.frame_height // 110,
527
+ text_scale=video_info.frame_height / 900,
528
+ show_trajectories=False,
529
+ )
530
+ _, pred_video_annotation, _ = update_meta_and_ann_for_video_annotation(
531
+ self._project_meta,
532
+ pred_video_annotation,
533
+ settings.get("model_prediction_suffix", ""),
534
+ )
535
+ visualizer.visualize_video_annotation(
536
+ pred_video_annotation,
537
+ source=self.video_preview_path,
538
+ output_path=self.video_preview_annotated_path,
539
+ )
540
+ self.video_player.set_video(self.video_peview_url)
541
+ self.select_item(ProjectType.VIDEOS.value)
542
+
543
+ def set_error(self, text: str):
544
+ self.error_text.text = text
545
+ self.select_item("error")
546
+
547
+ def run_preview(self):
548
+ self.download_error.hide()
549
+ self.select_item("loading")
550
+ try:
551
+ input_settings = self.get_input_settings_fn()
552
+ video_ids = input_settings.get("video_ids", None)
553
+ if video_ids is None:
554
+ self.set_image_preview()
555
+ elif len(video_ids) == 0:
556
+ self.set_error("No videos selected")
557
+ else:
558
+ self.set_video_preview()
559
+ except Exception as e:
560
+ logger.error(f"Failed to generate preview: {str(e)}", exc_info=True)
561
+ self.set_error("Failed to generate preview: " + str(e))
562
+
563
+ def _preload_item(self):
564
+ threading.Thread(
565
+ target=self._download_preview_item, kwargs={"with_progress": False}, daemon=True
566
+ ).start()
567
+
568
+ def update_item_type(self, item_type: str):
569
+ self.select_item("empty")
570
+ self._current_item_id = None
571
+ self.download_error.hide()
572
+ # self._preload_item() # need to handle race condition with run_preview and multiple clicks
573
+
574
+
29
575
  class SettingsSelector:
30
- title = "Settings Selector"
576
+ title = "Inference (settings + preview)"
31
577
  description = "Select additional settings for model inference"
32
578
  lock_message = "Select previous step to unlock"
33
579
 
34
- def __init__(self):
580
+ def __init__(
581
+ self,
582
+ api: Api,
583
+ static_dir: str,
584
+ input_selector: InputSelector,
585
+ model_selector: ModelSelector,
586
+ ):
35
587
  # Init Step
588
+ self.api = api
589
+ self.static_dir = static_dir
590
+ self.input_selector = input_selector
591
+ self.model_selector = model_selector
36
592
  self.display_widgets: List[Any] = []
37
593
  # -------------------------------- #
38
594
 
39
595
  # Init Base Widgets
40
596
  self.validator_text = None
41
597
  self.button = None
598
+ self.run_button = None
42
599
  self.container = None
43
- self.card = None
600
+ self.cards = None
44
601
  # -------------------------------- #
45
602
 
46
603
  # Init Step Widgets
@@ -51,28 +608,72 @@ class SettingsSelector:
51
608
  # self.model_prediction_suffix_checkbox = None
52
609
  self.predictions_mode_selector = None
53
610
  self.predictions_mode_field = None
54
- self.inference_settings = None
611
+ self.inference_settings_editor = None
55
612
  # -------------------------------- #
56
613
 
57
- # Inference Mode
58
- self.inference_modes = [InferenceMode.FULL_IMAGE, InferenceMode.SLIDING_WINDOW]
59
- self.inference_mode_selector = Select(
60
- items=[Select.Item(mode) for mode in self.inference_modes]
614
+ self.settings_widgets = []
615
+ self.image_settings_widgets = []
616
+ self.video_settings_widgets = []
617
+
618
+ # Prediction Mode
619
+ self.prediction_modes = [
620
+ AddPredictionsMode.APPEND,
621
+ AddPredictionsMode.REPLACE,
622
+ AddPredictionsMode.IOU_MERGE,
623
+ # AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS, # @TODO: Implement later
624
+ ]
625
+ self.iou_merge_input = InputNumber(value=0.9, min=0.0, max=1.0, step=0.05, controls=False)
626
+ self.iou_merge_input_field = Field(
627
+ content=self.iou_merge_input,
628
+ title="IoU Threshold",
629
+ description="IoU threshold for merging predictions with existing labels. Predictions with IoU above this threshold will be considered duplicates and removed.",
61
630
  )
62
- self.inference_mode_selector.set_value(self.inference_modes[0])
63
- self.inference_mode_field = Field(
64
- content=self.inference_mode_selector,
65
- title="Inference mode",
66
- description="Select how to process images: full images or using sliding window.",
631
+ self.prediction_modes_contents = [Empty(), Empty(), self.iou_merge_input_field]
632
+ self.predictions_mode_selector = Select(
633
+ items=[
634
+ Select.Item(mode, content=content)
635
+ for mode, content in zip(self.prediction_modes, self.prediction_modes_contents)
636
+ ]
637
+ )
638
+ self.predictions_mode_selector.set_value(self.prediction_modes[0])
639
+ self.predicitons_mode_one_of = OneOf(self.predictions_mode_selector)
640
+ self.predictions_mode_field = Field(
641
+ content=Container(
642
+ widgets=[self.predictions_mode_selector, self.predicitons_mode_one_of]
643
+ ),
644
+ title="Add predictions mode",
645
+ description="Select how to add predictions to the project: by merging with existing labels or by replacing them.",
67
646
  )
68
647
  # Add widgets to display ------------ #
69
- self.display_widgets.extend([self.inference_mode_field])
648
+ self.image_settings_widgets.extend([self.predictions_mode_field])
649
+ # ----------------------------------- #
650
+
651
+ # Tracking
652
+ self.tracking_checkbox = Checkbox(content="Enable tracking", checked=True)
653
+ self.tracking_checkbox_field = Field(
654
+ content=self.tracking_checkbox,
655
+ title="Tracking",
656
+ description="Enable tracking for video predictions. The tracking algorithm is BoT-SORT version improved by Supervisely team.",
657
+ )
658
+ # Add widgets to display ------------ #
659
+ self.video_settings_widgets.extend([self.tracking_checkbox_field])
660
+ self.image_settings_container = Container(widgets=self.image_settings_widgets, gap=15)
661
+ self.video_settings_container = Container(widgets=self.video_settings_widgets, gap=15)
662
+ self.image_or_video_container = Container(
663
+ widgets=[self.image_settings_container, self.video_settings_container], gap=0
664
+ )
665
+ self.video_settings_container.hide()
666
+ self.settings_widgets.extend([self.image_or_video_container])
70
667
  # ----------------------------------- #
71
668
 
72
669
  # Class / Tag Suffix
73
670
  self.model_prediction_suffix_input = Input(
74
671
  value="_model", minlength=1, placeholder="Enter suffix e.g: _model"
75
672
  )
673
+ self.model_meta_has_conflicting_names_text = Text(
674
+ text="Project and Model metas have conflicting names. This suffix will be added to conflicting class and tag names of model predictions",
675
+ status="info",
676
+ )
76
677
  self.model_prediction_suffix_field = Field(
77
678
  content=self.model_prediction_suffix_input,
78
679
  title="Class and tag suffix",
@@ -82,38 +683,40 @@ class SettingsSelector:
82
683
  "then suffix will be added to the model predictions to avoid conflicts. E.g. 'person_model'."
83
684
  ),
84
685
  )
85
- # self.model_prediction_suffix_checkbox = Checkbox("Always add suffix to model predictions")
86
- # Add widgets to display ------------ #
87
- self.display_widgets.extend(
88
- [self.model_prediction_suffix_field] # , self.model_prediction_suffix_checkbox]
89
- )
90
- # ----------------------------------- #
91
-
92
- # Prediction Mode
93
- self.prediction_modes = [
94
- AddPredictionsMode.MERGE_WITH_EXISTING_LABELS,
95
- AddPredictionsMode.REPLACE_EXISTING_LABELS,
96
- # AddPredictionsMode.REPLACE_EXISTING_LABELS_AND_SAVE_IMAGE_TAGS, # @TODO: Implement later
97
- ]
98
- self.predictions_mode_selector = Select(
99
- items=[Select.Item(mode) for mode in self.prediction_modes]
100
- )
101
- self.predictions_mode_selector.set_value(self.prediction_modes[0])
102
- self.predictions_mode_field = Field(
103
- content=self.predictions_mode_selector,
104
- title="Add predictions mode",
105
- description="Select how to add predictions to the project: by merging with existing labels or by replacing them.",
686
+ self.model_prediction_suffix_container = Container(
687
+ widgets=[
688
+ self.model_meta_has_conflicting_names_text,
689
+ self.model_prediction_suffix_field,
690
+ ],
691
+ gap=5,
106
692
  )
693
+ self.model_prediction_suffix_container.hide()
107
694
  # Add widgets to display ------------ #
108
- self.display_widgets.extend([self.predictions_mode_field])
695
+ self.settings_widgets.extend([self.model_prediction_suffix_container])
109
696
  # ----------------------------------- #
110
697
 
111
698
  # Inference Settings
112
- self.inference_settings = Editor("", language_mode="yaml", height_px=300)
699
+ self.inference_settings_editor = Editor("", language_mode="yaml", height_px=300)
700
+ self.inference_settings_field = Field(
701
+ content=self.inference_settings_editor,
702
+ title="Inference and Tracking Settings",
703
+ )
113
704
  # Add widgets to display ------------ #
114
- self.display_widgets.extend([self.inference_settings])
705
+ self.settings_widgets.extend([self.inference_settings_field])
115
706
  # ----------------------------------- #
116
707
 
708
+ # Preview
709
+ self.preview_dir = os.path.join(self.static_dir, "preview")
710
+ self.preview = Preview(
711
+ api=self.api,
712
+ preview_dir=self.preview_dir,
713
+ get_model_api_fn=lambda: self.model_selector.model.model_api,
714
+ get_input_settings_fn=self.input_selector.get_settings,
715
+ get_settings_fn=self.get_settings,
716
+ )
717
+
718
+ self.settings_container = Container(widgets=self.settings_widgets, gap=15)
719
+ self.display_widgets.extend([self.settings_container])
117
720
  # Base Widgets
118
721
  self.validator_text = Text("")
119
722
  self.validator_text.hide()
@@ -124,49 +727,122 @@ class SettingsSelector:
124
727
 
125
728
  # Card Layout
126
729
  self.container = Container(self.display_widgets)
127
- self.card = Card(
730
+ self.settings_card = Card(
128
731
  title=self.title,
129
732
  description=self.description,
130
733
  content=self.container,
131
734
  lock_message=self.lock_message,
132
735
  )
133
- self.card.lock()
736
+ self.cards = [self.settings_card, self.preview.card]
737
+ self.cards_container = Container(
738
+ widgets=self.cards,
739
+ gap=15,
740
+ direction="horizontal",
741
+ fractions=[3, 7],
742
+ )
134
743
  # ----------------------------------- #
135
744
 
745
+ def lock(self):
746
+ self.settings_card.lock(self.lock_message)
747
+ self.preview.lock()
748
+
749
+ def unlock(self):
750
+ self.settings_card.unlock()
751
+ self.preview.unlock()
752
+
753
+ def disable(self):
754
+ for widget in self.widgets_to_disable:
755
+ if isinstance(widget, Editor):
756
+ widget.readonly = True
757
+ else:
758
+ widget.disable()
759
+
760
+ def enable(self):
761
+ for widget in self.widgets_to_disable:
762
+ if isinstance(widget, Editor):
763
+ widget.readonly = False
764
+ else:
765
+ widget.enable()
766
+
136
767
  @property
137
- def widgets_to_disable(self) -> list:
768
+ def widgets_to_disable(self) -> List[Widget]:
138
769
  return [
139
- self.inference_mode_selector,
140
- self.model_prediction_suffix_input,
141
- # self.model_prediction_suffix_checkbox,
770
+ self.tracking_checkbox,
142
771
  self.predictions_mode_selector,
143
- self.inference_settings,
772
+ self.model_prediction_suffix_input,
773
+ self.inference_settings_editor,
144
774
  ]
145
775
 
146
776
  def set_inference_settings(self, settings: Dict[str, Any]):
777
+ settings = "# Inference settings\n" + settings
147
778
  if isinstance(settings, str):
148
- self.inference_settings.set_text(settings)
779
+ self.inference_settings_editor.set_text(settings)
149
780
  else:
150
- self.inference_settings.set_text(yaml.safe_dump(settings))
781
+ self.inference_settings_editor.set_text(yaml.safe_dump(settings))
782
+
783
+ def set_tracking_settings(self, settings: Dict[str, Any]):
784
+ if self.input_selector.radio.get_value() != ProjectType.VIDEOS.value:
785
+ return
786
+
787
+ current_settings = self.inference_settings_editor.get_text()
788
+ if isinstance(settings, str):
789
+ all_settings = current_settings + "\n\n# Tracking settings\n" + settings
790
+ self.inference_settings_editor.set_text(all_settings)
791
+ else:
792
+ all_settings = current_settings + "\n\n# Tracking settings\n" + yaml.safe_dump(settings)
793
+ self.inference_settings_editor.set_text(all_settings)
794
+
795
+ def set_default_tracking_settings(self):
796
+ nn_dir = Path(__file__).parents[3]
797
+ config_path = nn_dir / "tracker" / "botsort" / "botsort_config.yaml"
798
+
799
+ with open(config_path, "r", encoding="utf-8") as file:
800
+ botsort_config = yaml.safe_load(file)
801
+ self.set_tracking_settings(botsort_config)
151
802
 
152
803
  def get_inference_settings(self) -> Dict:
153
- settings = yaml.safe_load(self.inference_settings.get_text())
154
- if settings:
155
- return settings
804
+ text = self.inference_settings_editor.get_text()
805
+ inference_settings_text = text.split("# Tracking settings")[0]
806
+ settings = yaml.safe_load(inference_settings_text)
807
+ settings = settings if settings is not None else {}
808
+ if (
809
+ self.input_selector.radio.get_value() == ProjectType.IMAGES.value
810
+ and self.predictions_mode_selector.get_value() == AddPredictionsMode.IOU_MERGE
811
+ ):
812
+ settings["existing_objects_iou_thresh"] = self.iou_merge_input.get_value()
813
+ return settings
814
+
815
+ def get_tracking_settings(self) -> Dict:
816
+ if self.input_selector.radio.get_value() != ProjectType.VIDEOS.value:
817
+ return {}
818
+
819
+ text = self.inference_settings_editor.get_text()
820
+ text_parts = text.split("# Tracking settings")
821
+ if len(text_parts) > 1:
822
+ tracking_settings_text = text_parts[1]
823
+ settings = yaml.safe_load(tracking_settings_text)
824
+ if settings:
825
+ return settings
156
826
  return {}
157
827
 
158
828
  def get_settings(self) -> Dict[str, Any]:
159
- return {
160
- "inference_mode": self.inference_mode_selector.get_value(),
829
+ settings = {
830
+ # "inference_mode": self.inference_mode_selector.get_value(),
831
+ "inference_mode": InferenceMode.FULL_IMAGE,
161
832
  "model_prediction_suffix": self.model_prediction_suffix_input.get_value(),
162
833
  "predictions_mode": self.predictions_mode_selector.get_value(),
163
834
  "inference_settings": self.get_inference_settings(),
164
835
  }
836
+ if self.input_selector.radio.get_value() == ProjectType.VIDEOS.value:
837
+ settings["tracking_settings"] = self.get_tracking_settings()
838
+ if self.input_selector.get_settings().get("video_ids", None) is not None:
839
+ settings["tracking"] = self.tracking_checkbox.is_checked()
840
+ return settings
165
841
 
166
842
  def load_from_json(self, data):
167
- inference_mode = data.get("inference_mode", None)
168
- if inference_mode:
169
- self.inference_mode_selector.set_value(inference_mode)
843
+ # inference_mode = data.get("inference_mode", None)
844
+ # if inference_mode:
845
+ # self.inference_mode_selector.set_value(inference_mode)
170
846
 
171
847
  model_prediction_suffix = data.get("model_prediction_suffix", None)
172
848
  if model_prediction_suffix is not None:
@@ -180,5 +856,26 @@ class SettingsSelector:
180
856
  if inference_settings is not None:
181
857
  self.set_inference_settings(inference_settings)
182
858
 
859
+ tracking_settings = data.get("tracking_settings", None)
860
+ if tracking_settings is not None:
861
+ self.set_tracking_settings(tracking_settings)
862
+
863
+ tracking = data.get("tracking", None)
864
+ if tracking == True:
865
+ self.tracking_checkbox.check()
866
+ elif tracking == False:
867
+ self.tracking_checkbox.uncheck()
868
+
869
+ def update_item_type(self, item_type: str):
870
+ if item_type == ProjectType.IMAGES.value:
871
+ self.video_settings_container.hide()
872
+ self.image_settings_container.show()
873
+ elif item_type == ProjectType.VIDEOS.value:
874
+ self.image_settings_container.hide()
875
+ self.video_settings_container.show()
876
+ else:
877
+ raise ValueError(f"Unsupported item type: {item_type}")
878
+ self.preview.update_item_type(item_type)
879
+
183
880
  def validate_step(self) -> bool:
184
881
  return True