supervisely 6.73.459__py3-none-any.whl → 6.73.468__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

@@ -11,6 +11,7 @@ import subprocess
11
11
  import tempfile
12
12
  import threading
13
13
  import time
14
+ import uuid
14
15
  from collections import OrderedDict, defaultdict
15
16
  from concurrent.futures import ThreadPoolExecutor
16
17
  from dataclasses import asdict, dataclass
@@ -52,6 +53,7 @@ from supervisely.annotation.tag_meta import TagMeta, TagValueType
52
53
  from supervisely.api.api import Api, ApiField
53
54
  from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
54
55
  from supervisely.api.image_api import ImageInfo
56
+ from supervisely.api.video.video_api import VideoInfo
55
57
  from supervisely.app.content import get_data_dir
56
58
  from supervisely.app.fastapi.subapp import (
57
59
  Application,
@@ -102,6 +104,11 @@ from supervisely.video_annotation.video_figure import VideoFigure
102
104
  from supervisely.video_annotation.video_object import VideoObject
103
105
  from supervisely.video_annotation.video_object_collection import VideoObjectCollection
104
106
  from supervisely.video_annotation.video_tag_collection import VideoTagCollection
107
+ from supervisely.video_annotation.key_id_map import KeyIdMap
108
+ from supervisely.video_annotation.video_object_collection import (
109
+ VideoObject,
110
+ VideoObjectCollection,
111
+ )
105
112
 
106
113
  try:
107
114
  from typing import Literal
@@ -435,7 +442,7 @@ class Inference:
435
442
 
436
443
  device = "cuda" if torch.cuda.is_available() else "cpu"
437
444
  except Exception as e:
438
- logger.warn(
445
+ logger.warning(
439
446
  f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
440
447
  )
441
448
  device = "cpu"
@@ -1367,6 +1374,7 @@ class Inference:
1367
1374
 
1368
1375
  if tracker == "botsort":
1369
1376
  from supervisely.nn.tracker import BotSortTracker
1377
+
1370
1378
  device = tracker_settings.get("device", self.device)
1371
1379
  logger.debug(f"Initializing BotSort tracker with device: {device}")
1372
1380
  return BotSortTracker(settings=tracker_settings, device=device)
@@ -1383,15 +1391,15 @@ class Inference:
1383
1391
  if classes is not None:
1384
1392
  num_classes = len(classes)
1385
1393
  except NotImplementedError:
1386
- logger.warn(f"get_classes() function not implemented for {type(self)} object.")
1394
+ logger.warning(f"get_classes() function not implemented for {type(self)} object.")
1387
1395
  except AttributeError:
1388
- logger.warn("Probably, get_classes() function not working without model deploy.")
1396
+ logger.warning("Probably, get_classes() function not working without model deploy.")
1389
1397
  except Exception as exc:
1390
- logger.warn("Unknown exception. Please, contact support")
1398
+ logger.warning("Unknown exception. Please, contact support")
1391
1399
  logger.exception(exc)
1392
1400
 
1393
1401
  if num_classes is None:
1394
- logger.warn(f"get_classes() function return {classes}; skip classes processing.")
1402
+ logger.warning(f"get_classes() function return {classes}; skip classes processing.")
1395
1403
 
1396
1404
  return {
1397
1405
  "app_name": get_name_from_env(default="Neural Network Serving"),
@@ -1412,7 +1420,7 @@ class Inference:
1412
1420
  def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
1413
1421
  """
1414
1422
  Get default parameters for all available tracking algorithms.
1415
-
1423
+
1416
1424
  Returns:
1417
1425
  {"botsort": {"track_high_thresh": 0.6, ...}}
1418
1426
  Empty dict if tracking not supported.
@@ -1430,6 +1438,7 @@ class Inference:
1430
1438
  try:
1431
1439
  if tracker_name == "botsort":
1432
1440
  from supervisely.nn.tracker import BotSortTracker
1441
+
1433
1442
  trackers_params[tracker_name] = BotSortTracker.get_default_params()
1434
1443
  # Add other trackers here as elif blocks
1435
1444
  else:
@@ -1441,7 +1450,7 @@ class Inference:
1441
1450
  for tracker_name, params in trackers_params.items():
1442
1451
  trackers_params[tracker_name] = {
1443
1452
  k: v for k, v in params.items() if k not in INTERNAL_FIELDS
1444
- }
1453
+ }
1445
1454
  return trackers_params
1446
1455
 
1447
1456
  def get_human_readable_info(self, replace_none_with: Optional[str] = None):
@@ -2270,8 +2279,8 @@ class Inference:
2270
2279
  frame_index=frame_index,
2271
2280
  video_id=video_info.id,
2272
2281
  dataset_id=video_info.dataset_id,
2273
- project_id=video_info.project_id,
2274
- )
2282
+ project_id=video_info.project_id,
2283
+ )
2275
2284
  for ann, frame_index in zip(anns, batch)
2276
2285
  ]
2277
2286
  for pred, this_slides_data in zip(predictions, slides_data):
@@ -2289,6 +2298,162 @@ class Inference:
2289
2298
  inference_request.final_result = {"video_ann": video_ann_json}
2290
2299
  return video_ann_json
2291
2300
 
2301
+ def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
2302
+ logger.debug("Inferring video_id...", extra={"state": state})
2303
+ inference_settings = self._get_inference_settings(state)
2304
+ logger.debug(f"Inference settings:", extra=inference_settings)
2305
+ batch_size = self._get_batch_size_from_state(state)
2306
+ video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
2307
+ if video_id is None:
2308
+ raise ValueError("Video id is not provided")
2309
+ video_info = api.video.get_info_by_id(video_id)
2310
+ start_frame_index = get_value_for_keys(
2311
+ state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
2312
+ )
2313
+ if start_frame_index is None:
2314
+ start_frame_index = 0
2315
+ step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
2316
+ if step is None:
2317
+ step = 1
2318
+ end_frame_index = get_value_for_keys(
2319
+ state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
2320
+ )
2321
+ duration = state.get("duration", None)
2322
+ frames_count = get_value_for_keys(
2323
+ state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
2324
+ )
2325
+ tracking = state.get("tracker", None)
2326
+ direction = state.get("direction", "forward")
2327
+ direction = 1 if direction == "forward" else -1
2328
+ track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
2329
+
2330
+ if frames_count is not None:
2331
+ n_frames = frames_count
2332
+ elif end_frame_index is not None:
2333
+ n_frames = end_frame_index - start_frame_index
2334
+ elif duration is not None:
2335
+ fps = video_info.frames_count / video_info.duration
2336
+ n_frames = int(duration * fps)
2337
+ else:
2338
+ n_frames = video_info.frames_count
2339
+
2340
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2341
+
2342
+ logger.debug(
2343
+ f"Video info:",
2344
+ extra=dict(
2345
+ w=video_info.frame_width,
2346
+ h=video_info.frame_height,
2347
+ start_frame_index=start_frame_index,
2348
+ n_frames=n_frames,
2349
+ ),
2350
+ )
2351
+
2352
+ # start downloading video in background
2353
+ self.cache.run_cache_task_manually(api, None, video_id=video_id)
2354
+
2355
+ progress_total = (n_frames + step - 1) // step
2356
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
2357
+
2358
+ _upload_f = partial(
2359
+ self.upload_predictions_to_video,
2360
+ api=api,
2361
+ video_info=video_info,
2362
+ track_id=track_id,
2363
+ context=inference_request.context,
2364
+ progress_cb=inference_request.done,
2365
+ inference_request=inference_request,
2366
+ )
2367
+
2368
+ _range = (start_frame_index, start_frame_index + direction * n_frames)
2369
+ if _range[0] > _range[1]:
2370
+ _range = (_range[1], _range[0])
2371
+
2372
+ def _notify_f(predictions: List[Prediction]):
2373
+ logger.debug(
2374
+ "Notifying tracking progress...",
2375
+ extra={
2376
+ "track_id": track_id,
2377
+ "range": _range,
2378
+ "current": inference_request.progress.current,
2379
+ "total": inference_request.progress.total,
2380
+ },
2381
+ )
2382
+ stopped = self.api.video.notify_progress(
2383
+ track_id=track_id,
2384
+ video_id=video_info.id,
2385
+ frame_start=_range[0],
2386
+ frame_end=_range[1],
2387
+ current=inference_request.progress.current,
2388
+ total=inference_request.progress.total,
2389
+ )
2390
+ if stopped:
2391
+ inference_request.stop()
2392
+ logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
2393
+
2394
+ def _exception_handler(e: Exception):
2395
+ self.api.video.notify_tracking_error(
2396
+ track_id=track_id,
2397
+ error=str(type(e)),
2398
+ message=str(e),
2399
+ )
2400
+ raise e
2401
+
2402
+ with Uploader(
2403
+ upload_f=_upload_f,
2404
+ notify_f=_notify_f,
2405
+ exception_handler=_exception_handler,
2406
+ logger=logger,
2407
+ ) as uploader:
2408
+ for batch in batched(
2409
+ range(
2410
+ start_frame_index, start_frame_index + direction * n_frames, direction * step
2411
+ ),
2412
+ batch_size,
2413
+ ):
2414
+ if inference_request.is_stopped():
2415
+ logger.debug(
2416
+ f"Cancelling inference video...",
2417
+ extra={"inference_request_uuid": inference_request.uuid},
2418
+ )
2419
+ break
2420
+ logger.debug(
2421
+ f"Inferring frames {batch[0]}-{batch[-1]}:",
2422
+ )
2423
+ frames = self.cache.download_frames(
2424
+ api, video_info.id, batch, redownload_video=True
2425
+ )
2426
+ anns, slides_data = self._inference_auto(
2427
+ source=frames,
2428
+ settings=inference_settings,
2429
+ )
2430
+
2431
+ if inference_request.tracker is not None:
2432
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2433
+
2434
+ predictions = [
2435
+ Prediction(
2436
+ ann,
2437
+ model_meta=self.model_meta,
2438
+ frame_index=frame_index,
2439
+ video_id=video_info.id,
2440
+ dataset_id=video_info.dataset_id,
2441
+ project_id=video_info.project_id,
2442
+ )
2443
+ for ann, frame_index in zip(anns, batch)
2444
+ ]
2445
+ for pred, this_slides_data in zip(predictions, slides_data):
2446
+ pred.extra_data["slides_data"] = this_slides_data
2447
+ uploader.put(predictions)
2448
+ video_ann_json = None
2449
+ if inference_request.tracker is not None:
2450
+ inference_request.set_stage("Postprocess...", 0, 1)
2451
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
2452
+ inference_request.done()
2453
+ inference_request.final_result = {"video_ann": video_ann_json}
2454
+ return video_ann_json
2455
+
2456
+
2292
2457
  def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
2293
2458
  """Inference project images.
2294
2459
  If "output_project_id" in state, upload images and annotations to the output project.
@@ -2955,6 +3120,83 @@ class Inference:
2955
3120
  inference_request.add_results(results)
2956
3121
  inference_request.done(len(results))
2957
3122
 
3123
+ def upload_predictions_to_video(
3124
+ self,
3125
+ predictions: List[Prediction],
3126
+ api: Api,
3127
+ video_info: VideoInfo,
3128
+ track_id: str,
3129
+ context: Dict,
3130
+ progress_cb=None,
3131
+ inference_request: InferenceRequest = None,
3132
+ ):
3133
+ key_id_map = KeyIdMap()
3134
+ project_meta = context.get("project_meta", None)
3135
+ if project_meta is None:
3136
+ project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
3137
+ context["project_meta"] = project_meta
3138
+ meta_changed = False
3139
+ for prediction in predictions:
3140
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
3141
+ project_meta, prediction.annotation, None
3142
+ )
3143
+ prediction.annotation = ann
3144
+ meta_changed = meta_changed or meta_changed_
3145
+ if meta_changed:
3146
+ project_meta = api.project.update_meta(video_info.project_id, project_meta)
3147
+ context["project_meta"] = project_meta
3148
+
3149
+ figure_data_by_object_id = defaultdict(list)
3150
+
3151
+ tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
3152
+ new_tracks: Dict[int, VideoObject] = {}
3153
+ for prediction in predictions:
3154
+ annotation = prediction.annotation
3155
+ tracks = annotation.custom_data
3156
+ for track, label in zip(tracks, annotation.labels):
3157
+ if track not in tracks_to_object_ids and track not in new_tracks:
3158
+ video_object = VideoObject(obj_class=label.obj_class)
3159
+ new_tracks[track] = video_object
3160
+ if new_tracks:
3161
+ tracks, video_objects = zip(*new_tracks.items())
3162
+ added_object_ids = api.video.object.append_bulk(
3163
+ video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
3164
+ )
3165
+ for track, object_id in zip(tracks, added_object_ids):
3166
+ tracks_to_object_ids[track] = object_id
3167
+ for prediction in predictions:
3168
+ annotation = prediction.annotation
3169
+ tracks = annotation.custom_data
3170
+ for track, label in zip(tracks, annotation.labels):
3171
+ object_id = tracks_to_object_ids[track]
3172
+ figure_data_by_object_id[object_id].append(
3173
+ {
3174
+ ApiField.OBJECT_ID: object_id,
3175
+ ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
3176
+ ApiField.GEOMETRY: label.geometry.to_json(),
3177
+ ApiField.META: {ApiField.FRAME: prediction.frame_index},
3178
+ ApiField.TRACK_ID: track_id,
3179
+ }
3180
+ )
3181
+
3182
+ for object_id, figures_data in figure_data_by_object_id.items():
3183
+ figures_keys = [uuid.uuid4() for _ in figures_data]
3184
+ api.video.figure._append_bulk(
3185
+ entity_id=video_info.id,
3186
+ figures_json=figures_data,
3187
+ figures_keys=figures_keys,
3188
+ key_id_map=key_id_map,
3189
+ )
3190
+ logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
3191
+ if progress_cb:
3192
+ progress_cb(len(predictions))
3193
+ if inference_request is not None:
3194
+ results = self._format_output(predictions)
3195
+ for result in results:
3196
+ result["annotation"] = None
3197
+ result["data"] = None
3198
+ inference_request.add_results(results)
3199
+
2958
3200
  def serve(self):
2959
3201
  if not self._use_gui and not self._is_cli_deploy:
2960
3202
  Progress("Deploying model ...", 1)
@@ -3352,6 +3594,22 @@ class Inference:
3352
3594
  "inference_request_uuid": inference_request.uuid,
3353
3595
  }
3354
3596
 
3597
+ @server.post("/tracking_by_detection")
3598
+ def tracking_by_detection(response: Response, request: Request):
3599
+ state = request.state.state
3600
+ context = request.state.context
3601
+ state.update(context)
3602
+ if state.get("tracker") is None:
3603
+ state["tracker"] = "botsort"
3604
+
3605
+ logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
3606
+ self.validate_inference_state(state)
3607
+ api = self.api_from_request(request)
3608
+ inference_request, future = self.inference_requests_manager.schedule_task(
3609
+ self._tracking_by_detection, api, state
3610
+ )
3611
+ return {"message": "Track task started."}
3612
+
3355
3613
  @server.post("/inference_project_id_async")
3356
3614
  def inference_project_id_async(response: Response, request: Request):
3357
3615
  state = request.state.state
@@ -3415,10 +3673,7 @@ class Inference:
3415
3673
  data = {**inference_request.to_json(), **log_extra}
3416
3674
  if inference_request.stage != InferenceRequest.Stage.INFERENCE:
3417
3675
  data["progress"] = {"current": 0, "total": 1}
3418
- logger.debug(
3419
- f"Sending inference progress with uuid:",
3420
- extra=data,
3421
- )
3676
+ logger.debug(f"Sending inference progress with uuid:", extra=data)
3422
3677
  return data
3423
3678
 
3424
3679
  @server.post(f"/pop_inference_results")
@@ -4411,6 +4666,7 @@ def _filter_duplicated_predictions_from_ann_cpu(
4411
4666
 
4412
4667
  return pred_ann.clone(labels=new_labels)
4413
4668
 
4669
+
4414
4670
  def _filter_duplicated_predictions_from_ann(
4415
4671
  gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
4416
4672
  ) -> Annotation:
@@ -4644,7 +4900,7 @@ def get_gpu_count():
4644
4900
  gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
4645
4901
  return gpu_count
4646
4902
  except (subprocess.CalledProcessError, FileNotFoundError) as exc:
4647
- logger.warn("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4903
+ logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4648
4904
  return 0
4649
4905
 
4650
4906
 
@@ -5111,7 +5367,8 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
5111
5367
  return data[key]
5112
5368
  return None
5113
5369
 
5114
- def torch_load_safe(checkpoint_path: str, device:str = "cpu"):
5370
+
5371
+ def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
5115
5372
  import torch # pylint: disable=import-error
5116
5373
 
5117
5374
  # TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
@@ -14,13 +14,6 @@ from supervisely.sly_logger import logger
14
14
  from supervisely.task.progress import Progress
15
15
 
16
16
 
17
- def generate_uuid(self) -> str:
18
- """
19
- Generates a unique UUID for the inference request.
20
- """
21
- return uuid.uuid5(namespace=uuid.NAMESPACE_URL, name=f"{time.time()}-{rand_str(10)}").hex
22
-
23
-
24
17
  class InferenceRequest:
25
18
  class Stage:
26
19
  PREPARING = "Preparing model for inference..."
@@ -59,7 +52,7 @@ class InferenceRequest:
59
52
  self._created_at = time.monotonic()
60
53
  self._updated_at = self._created_at
61
54
  self._finished = False
62
-
55
+
63
56
  self.tracker = None
64
57
 
65
58
  self.global_progress = None
@@ -252,7 +245,8 @@ class InferenceRequest:
252
245
  status_data.pop(key, None)
253
246
  status_data.update(self.get_usage())
254
247
  return status_data
255
-
248
+
249
+
256
250
  class GlobalProgress:
257
251
  def __init__(self):
258
252
  self.progress = Progress(message="Ready", total_cnt=1)
@@ -2,6 +2,9 @@ import threading
2
2
  from typing import Any, Dict, List
3
3
 
4
4
  from supervisely.api.api import Api
5
+ from supervisely.api.dataset_api import DatasetInfo
6
+ from supervisely.api.project_api import ProjectInfo
7
+ from supervisely.api.video.video_api import VideoInfo
5
8
  from supervisely.app.widgets import (
6
9
  Button,
7
10
  Card,
@@ -202,6 +205,53 @@ class InputSelector:
202
205
 
203
206
  self.select_video.add_rows(rows)
204
207
 
208
+ def select_project(self, project_id: int, project_info: ProjectInfo = None):
209
+ if project_info is None:
210
+ project_info = self.api.project.get_info_by_id(project_id)
211
+ if project_info.type == ProjectType.IMAGES.value:
212
+ self.select_dataset_for_images.set_project_id(project_id)
213
+ self.select_dataset_for_images.select_all()
214
+ self.radio.set_value(ProjectType.IMAGES.value)
215
+ elif project_info.type == ProjectType.VIDEOS.value:
216
+ self.select_dataset_for_video.set_project_id(project_id)
217
+ self.select_dataset_for_video.select_all()
218
+ self._refresh_video_table()
219
+ self.select_video.select_rows(list(range(len(self.select_video._rows_total))))
220
+ self.radio.set_value(ProjectType.VIDEOS.value)
221
+ else:
222
+ raise ValueError(f"Project of type {project_info.type} is not supported.")
223
+
224
+ def select_datasets(self, dataset_ids: List[int], dataset_infos: List[DatasetInfo] = None):
225
+ if dataset_infos is None:
226
+ dataset_infos = [self.api.dataset.get_info_by_id(ds_id) for ds_id in dataset_ids]
227
+ project_ids = set(ds.project_id for ds in dataset_infos)
228
+ if len(project_ids) > 1:
229
+ raise ValueError("Cannot select datasets from different projects")
230
+ project_id = project_ids.pop()
231
+ project_info = self.api.project.get_info_by_id(project_id)
232
+ if project_info.type == ProjectType.IMAGES.value:
233
+ self.select_dataset_for_images.set_project_id(project_id)
234
+ self.select_dataset_for_images.set_dataset_ids(dataset_ids)
235
+ self.radio.set_value(ProjectType.IMAGES.value)
236
+ elif project_info.type == ProjectType.VIDEOS.value:
237
+ self.select_dataset_for_video.set_project_id(project_id)
238
+ self.select_dataset_for_video.set_dataset_ids(dataset_ids)
239
+ self._refresh_video_table()
240
+ self.select_video.select_rows(list(range(self.select_video._rows_total)))
241
+ self.radio.set_value(ProjectType.VIDEOS.value)
242
+ else:
243
+ raise ValueError(f"Project of type {project_info.type} is not supported.")
244
+
245
+ def select_videos(self, video_ids: List[int], video_infos: List[VideoInfo] = None):
246
+ if video_infos is None:
247
+ video_infos = self.api.video.get_info_by_id_batch(video_ids)
248
+ project_id = video_infos[0].project_id
249
+ self.select_dataset_for_video.set_project_id(project_id)
250
+ self.select_dataset_for_video.select_all()
251
+ self._refresh_video_table()
252
+ self.select_video.select_row_by_value("id", video_ids)
253
+ self.radio.set_value(ProjectType.VIDEOS.value)
254
+
205
255
  def disable(self):
206
256
  for widget in self.widgets_to_disable:
207
257
  widget.disable()
@@ -249,37 +299,13 @@ class InputSelector:
249
299
  video_infos = self.api.video.get_info_by_id_batch(video_ids)
250
300
  if not video_infos:
251
301
  raise ValueError(f"Videos with video ids {video_ids} are not found")
252
- project_id = video_infos[0].project_id
253
- self.select_dataset_for_video.set_project_id(project_id)
254
- self.select_dataset_for_video.select_all()
255
- self.select_video.select_row_by_value("id", data["video_ids"])
256
- self.radio.set_value(ProjectType.VIDEOS.value)
302
+ self.select_videos(video_ids, video_infos)
257
303
  elif "dataset_ids" in data:
258
304
  dataset_ids = data["dataset_ids"]
259
- if len(dataset_ids) == 0:
260
- raise ValueError("Dataset ids cannot be empty")
261
- dataset_id = dataset_ids[0]
262
- dataset_info = self.api.dataset.get_info_by_id(dataset_id)
263
- project_info = self.api.project.get_info_by_id(dataset_info.project_id)
264
- if project_info.type == ProjectType.VIDEOS:
265
- self.select_dataset_for_video.set_project_id(project_info.id)
266
- self.select_dataset_for_video.set_dataset_ids(dataset_ids)
267
- self.radio.set_value(ProjectType.VIDEOS.value)
268
- else:
269
- self.select_dataset_for_images.set_project_id(project_info.id)
270
- self.select_dataset_for_images.set_dataset_ids(dataset_ids)
271
- self.radio.set_value(ProjectType.IMAGES.value)
305
+ self.select_datasets(dataset_ids)
272
306
  elif "project_id" in data:
273
307
  project_id = data["project_id"]
274
- project_info = self.api.project.get_info_by_id(project_id)
275
- if project_info.type == ProjectType.VIDEOS:
276
- self.select_dataset_for_video.set_project_id(project_id)
277
- self.select_dataset_for_video.select_all()
278
- self.radio.set_value(ProjectType.VIDEOS.value)
279
- else:
280
- self.select_dataset_for_images.set_project_id(project_id)
281
- self.select_dataset_for_images.select_all()
282
- self.radio.set_value(ProjectType.IMAGES.value)
308
+ self.select_project(project_id)
283
309
 
284
310
  def get_project_id(self) -> int:
285
311
  if self.radio.get_value() == ProjectType.IMAGES.value:
@@ -441,46 +441,52 @@ class SessionJSON:
441
441
  prev_current = 0
442
442
  if preparing_cb:
443
443
  # wait for inference status
444
- resp = self._get_preparing_progress()
445
- awaiting_preparing_progress = 0
446
- break_flag = False
447
- while resp.get("status") is None:
448
- time.sleep(1)
449
- awaiting_preparing_progress += 1
450
- if awaiting_preparing_progress > 30:
451
- break_flag = True
444
+ try:
452
445
  resp = self._get_preparing_progress()
453
- if break_flag:
454
- logger.warning(
455
- "Unable to get preparing progress. Continue without prepaing progress status."
456
- )
457
- if not break_flag:
458
- if resp["status"] == "download_info":
446
+ for i in range(30):
447
+ logger.info(
448
+ f"Waiting for preparing progress... {30 - i} seconds left until timeout"
449
+ )
450
+ resp = self._get_preparing_progress()
451
+ if resp.get("status") is not None:
452
+ break
453
+ time.sleep(1)
454
+ if not resp.get("status"):
455
+ raise RuntimeError("Preparing progress status is not available.")
456
+
457
+ if resp.get("status") == "download_info":
458
+ logger.info("Downloading infos...")
459
459
  progress_widget = preparing_cb(
460
460
  message="Downloading infos", total=resp["total"], unit="it"
461
461
  )
462
- while resp["status"] == "download_info":
463
- current = resp["current"]
464
- # pylint: disable=possibly-used-before-assignment
465
- progress_widget.update(current - prev_current)
466
- prev_current = current
467
- resp = self._get_preparing_progress()
468
-
469
- if resp["status"] == "download_project":
462
+ while resp["status"] == "download_info":
463
+ current = resp["current"]
464
+ # pylint: disable=possibly-used-before-assignment
465
+ progress_widget.update(current - prev_current)
466
+ prev_current = current
467
+ resp = self._get_preparing_progress()
468
+
469
+ if resp.get("status") == "download_project":
470
+ logger.info("Downloading project...")
470
471
  progress_widget = preparing_cb(message="Download project", total=resp["total"])
471
- while resp["status"] == "download_project":
472
- current = resp["current"]
473
- progress_widget.update(current - prev_current)
474
- prev_current = current
475
- resp = self._get_preparing_progress()
476
-
477
- if resp["status"] == "warmup":
472
+ while resp.get("status") == "download_project":
473
+ current = resp["current"]
474
+ progress_widget.update(current - prev_current)
475
+ prev_current = current
476
+ resp = self._get_preparing_progress()
477
+
478
+ if resp.get("status") == "warmup":
479
+ logger.info("Running warmup...")
478
480
  progress_widget = preparing_cb(message="Running warmup", total=resp["total"])
479
- while resp["status"] == "warmup":
480
- current = resp["current"]
481
- progress_widget.update(current - prev_current)
482
- prev_current = current
483
- resp = self._get_preparing_progress()
481
+ while resp.get("status") == "warmup":
482
+ current = resp["current"]
483
+ progress_widget.update(current - prev_current)
484
+ prev_current = current
485
+ resp = self._get_preparing_progress()
486
+ except Exception as ex:
487
+ logger.warning(
488
+ f"An error occurred while getting preparing progress: {ex}. Continue without preparing progress status."
489
+ )
484
490
 
485
491
  logger.info("Inference has started:", extra={"response": resp})
486
492
  resp, has_started = self._wait_for_async_inference_start()
@@ -537,7 +543,9 @@ class SessionJSON:
537
543
  t0 = time.time()
538
544
  while not has_started and not timeout_exceeded:
539
545
  resp = self._get_inference_progress()
540
- has_started = bool(resp.get("result")) or resp["progress"]["total"] != 1
546
+ pending_results = resp.get("pending_results", None)
547
+ has_results = bool(pending_results)
548
+ has_started = bool(resp.get("result")) or resp["progress"]["total"] != 1 or has_results
541
549
  if not has_started:
542
550
  time.sleep(delay)
543
551
  timeout_exceeded = timeout and time.time() - t0 > timeout