supervisely 6.73.459__py3-none-any.whl → 6.73.468__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/dataset_api.py +74 -12
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/widgets/fast_table/fast_table.py +101 -45
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +10 -2
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/nn/inference/cache.py +8 -2
- supervisely/nn/inference/inference.py +272 -15
- supervisely/nn/inference/inference_request.py +3 -9
- supervisely/nn/inference/predict_app/gui/input_selector.py +53 -27
- supervisely/nn/inference/session.py +43 -35
- supervisely/video/sampling.py +39 -20
- supervisely/video/video.py +25 -10
- supervisely/volume/stl_converter.py +2 -0
- {supervisely-6.73.459.dist-info → supervisely-6.73.468.dist-info}/METADATA +11 -9
- {supervisely-6.73.459.dist-info → supervisely-6.73.468.dist-info}/RECORD +19 -19
- {supervisely-6.73.459.dist-info → supervisely-6.73.468.dist-info}/LICENSE +0 -0
- {supervisely-6.73.459.dist-info → supervisely-6.73.468.dist-info}/WHEEL +0 -0
- {supervisely-6.73.459.dist-info → supervisely-6.73.468.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.459.dist-info → supervisely-6.73.468.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@ import subprocess
|
|
|
11
11
|
import tempfile
|
|
12
12
|
import threading
|
|
13
13
|
import time
|
|
14
|
+
import uuid
|
|
14
15
|
from collections import OrderedDict, defaultdict
|
|
15
16
|
from concurrent.futures import ThreadPoolExecutor
|
|
16
17
|
from dataclasses import asdict, dataclass
|
|
@@ -52,6 +53,7 @@ from supervisely.annotation.tag_meta import TagMeta, TagValueType
|
|
|
52
53
|
from supervisely.api.api import Api, ApiField
|
|
53
54
|
from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
|
|
54
55
|
from supervisely.api.image_api import ImageInfo
|
|
56
|
+
from supervisely.api.video.video_api import VideoInfo
|
|
55
57
|
from supervisely.app.content import get_data_dir
|
|
56
58
|
from supervisely.app.fastapi.subapp import (
|
|
57
59
|
Application,
|
|
@@ -102,6 +104,11 @@ from supervisely.video_annotation.video_figure import VideoFigure
|
|
|
102
104
|
from supervisely.video_annotation.video_object import VideoObject
|
|
103
105
|
from supervisely.video_annotation.video_object_collection import VideoObjectCollection
|
|
104
106
|
from supervisely.video_annotation.video_tag_collection import VideoTagCollection
|
|
107
|
+
from supervisely.video_annotation.key_id_map import KeyIdMap
|
|
108
|
+
from supervisely.video_annotation.video_object_collection import (
|
|
109
|
+
VideoObject,
|
|
110
|
+
VideoObjectCollection,
|
|
111
|
+
)
|
|
105
112
|
|
|
106
113
|
try:
|
|
107
114
|
from typing import Literal
|
|
@@ -435,7 +442,7 @@ class Inference:
|
|
|
435
442
|
|
|
436
443
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
437
444
|
except Exception as e:
|
|
438
|
-
logger.
|
|
445
|
+
logger.warning(
|
|
439
446
|
f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
|
|
440
447
|
)
|
|
441
448
|
device = "cpu"
|
|
@@ -1367,6 +1374,7 @@ class Inference:
|
|
|
1367
1374
|
|
|
1368
1375
|
if tracker == "botsort":
|
|
1369
1376
|
from supervisely.nn.tracker import BotSortTracker
|
|
1377
|
+
|
|
1370
1378
|
device = tracker_settings.get("device", self.device)
|
|
1371
1379
|
logger.debug(f"Initializing BotSort tracker with device: {device}")
|
|
1372
1380
|
return BotSortTracker(settings=tracker_settings, device=device)
|
|
@@ -1383,15 +1391,15 @@ class Inference:
|
|
|
1383
1391
|
if classes is not None:
|
|
1384
1392
|
num_classes = len(classes)
|
|
1385
1393
|
except NotImplementedError:
|
|
1386
|
-
logger.
|
|
1394
|
+
logger.warning(f"get_classes() function not implemented for {type(self)} object.")
|
|
1387
1395
|
except AttributeError:
|
|
1388
|
-
logger.
|
|
1396
|
+
logger.warning("Probably, get_classes() function not working without model deploy.")
|
|
1389
1397
|
except Exception as exc:
|
|
1390
|
-
logger.
|
|
1398
|
+
logger.warning("Unknown exception. Please, contact support")
|
|
1391
1399
|
logger.exception(exc)
|
|
1392
1400
|
|
|
1393
1401
|
if num_classes is None:
|
|
1394
|
-
logger.
|
|
1402
|
+
logger.warning(f"get_classes() function return {classes}; skip classes processing.")
|
|
1395
1403
|
|
|
1396
1404
|
return {
|
|
1397
1405
|
"app_name": get_name_from_env(default="Neural Network Serving"),
|
|
@@ -1412,7 +1420,7 @@ class Inference:
|
|
|
1412
1420
|
def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
|
|
1413
1421
|
"""
|
|
1414
1422
|
Get default parameters for all available tracking algorithms.
|
|
1415
|
-
|
|
1423
|
+
|
|
1416
1424
|
Returns:
|
|
1417
1425
|
{"botsort": {"track_high_thresh": 0.6, ...}}
|
|
1418
1426
|
Empty dict if tracking not supported.
|
|
@@ -1430,6 +1438,7 @@ class Inference:
|
|
|
1430
1438
|
try:
|
|
1431
1439
|
if tracker_name == "botsort":
|
|
1432
1440
|
from supervisely.nn.tracker import BotSortTracker
|
|
1441
|
+
|
|
1433
1442
|
trackers_params[tracker_name] = BotSortTracker.get_default_params()
|
|
1434
1443
|
# Add other trackers here as elif blocks
|
|
1435
1444
|
else:
|
|
@@ -1441,7 +1450,7 @@ class Inference:
|
|
|
1441
1450
|
for tracker_name, params in trackers_params.items():
|
|
1442
1451
|
trackers_params[tracker_name] = {
|
|
1443
1452
|
k: v for k, v in params.items() if k not in INTERNAL_FIELDS
|
|
1444
|
-
|
|
1453
|
+
}
|
|
1445
1454
|
return trackers_params
|
|
1446
1455
|
|
|
1447
1456
|
def get_human_readable_info(self, replace_none_with: Optional[str] = None):
|
|
@@ -2270,8 +2279,8 @@ class Inference:
|
|
|
2270
2279
|
frame_index=frame_index,
|
|
2271
2280
|
video_id=video_info.id,
|
|
2272
2281
|
dataset_id=video_info.dataset_id,
|
|
2273
|
-
|
|
2274
|
-
|
|
2282
|
+
project_id=video_info.project_id,
|
|
2283
|
+
)
|
|
2275
2284
|
for ann, frame_index in zip(anns, batch)
|
|
2276
2285
|
]
|
|
2277
2286
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
@@ -2289,6 +2298,162 @@ class Inference:
|
|
|
2289
2298
|
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2290
2299
|
return video_ann_json
|
|
2291
2300
|
|
|
2301
|
+
def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2302
|
+
logger.debug("Inferring video_id...", extra={"state": state})
|
|
2303
|
+
inference_settings = self._get_inference_settings(state)
|
|
2304
|
+
logger.debug(f"Inference settings:", extra=inference_settings)
|
|
2305
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
2306
|
+
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
2307
|
+
if video_id is None:
|
|
2308
|
+
raise ValueError("Video id is not provided")
|
|
2309
|
+
video_info = api.video.get_info_by_id(video_id)
|
|
2310
|
+
start_frame_index = get_value_for_keys(
|
|
2311
|
+
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
2312
|
+
)
|
|
2313
|
+
if start_frame_index is None:
|
|
2314
|
+
start_frame_index = 0
|
|
2315
|
+
step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
|
|
2316
|
+
if step is None:
|
|
2317
|
+
step = 1
|
|
2318
|
+
end_frame_index = get_value_for_keys(
|
|
2319
|
+
state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
|
|
2320
|
+
)
|
|
2321
|
+
duration = state.get("duration", None)
|
|
2322
|
+
frames_count = get_value_for_keys(
|
|
2323
|
+
state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
|
|
2324
|
+
)
|
|
2325
|
+
tracking = state.get("tracker", None)
|
|
2326
|
+
direction = state.get("direction", "forward")
|
|
2327
|
+
direction = 1 if direction == "forward" else -1
|
|
2328
|
+
track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
|
|
2329
|
+
|
|
2330
|
+
if frames_count is not None:
|
|
2331
|
+
n_frames = frames_count
|
|
2332
|
+
elif end_frame_index is not None:
|
|
2333
|
+
n_frames = end_frame_index - start_frame_index
|
|
2334
|
+
elif duration is not None:
|
|
2335
|
+
fps = video_info.frames_count / video_info.duration
|
|
2336
|
+
n_frames = int(duration * fps)
|
|
2337
|
+
else:
|
|
2338
|
+
n_frames = video_info.frames_count
|
|
2339
|
+
|
|
2340
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2341
|
+
|
|
2342
|
+
logger.debug(
|
|
2343
|
+
f"Video info:",
|
|
2344
|
+
extra=dict(
|
|
2345
|
+
w=video_info.frame_width,
|
|
2346
|
+
h=video_info.frame_height,
|
|
2347
|
+
start_frame_index=start_frame_index,
|
|
2348
|
+
n_frames=n_frames,
|
|
2349
|
+
),
|
|
2350
|
+
)
|
|
2351
|
+
|
|
2352
|
+
# start downloading video in background
|
|
2353
|
+
self.cache.run_cache_task_manually(api, None, video_id=video_id)
|
|
2354
|
+
|
|
2355
|
+
progress_total = (n_frames + step - 1) // step
|
|
2356
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
2357
|
+
|
|
2358
|
+
_upload_f = partial(
|
|
2359
|
+
self.upload_predictions_to_video,
|
|
2360
|
+
api=api,
|
|
2361
|
+
video_info=video_info,
|
|
2362
|
+
track_id=track_id,
|
|
2363
|
+
context=inference_request.context,
|
|
2364
|
+
progress_cb=inference_request.done,
|
|
2365
|
+
inference_request=inference_request,
|
|
2366
|
+
)
|
|
2367
|
+
|
|
2368
|
+
_range = (start_frame_index, start_frame_index + direction * n_frames)
|
|
2369
|
+
if _range[0] > _range[1]:
|
|
2370
|
+
_range = (_range[1], _range[0])
|
|
2371
|
+
|
|
2372
|
+
def _notify_f(predictions: List[Prediction]):
|
|
2373
|
+
logger.debug(
|
|
2374
|
+
"Notifying tracking progress...",
|
|
2375
|
+
extra={
|
|
2376
|
+
"track_id": track_id,
|
|
2377
|
+
"range": _range,
|
|
2378
|
+
"current": inference_request.progress.current,
|
|
2379
|
+
"total": inference_request.progress.total,
|
|
2380
|
+
},
|
|
2381
|
+
)
|
|
2382
|
+
stopped = self.api.video.notify_progress(
|
|
2383
|
+
track_id=track_id,
|
|
2384
|
+
video_id=video_info.id,
|
|
2385
|
+
frame_start=_range[0],
|
|
2386
|
+
frame_end=_range[1],
|
|
2387
|
+
current=inference_request.progress.current,
|
|
2388
|
+
total=inference_request.progress.total,
|
|
2389
|
+
)
|
|
2390
|
+
if stopped:
|
|
2391
|
+
inference_request.stop()
|
|
2392
|
+
logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
|
|
2393
|
+
|
|
2394
|
+
def _exception_handler(e: Exception):
|
|
2395
|
+
self.api.video.notify_tracking_error(
|
|
2396
|
+
track_id=track_id,
|
|
2397
|
+
error=str(type(e)),
|
|
2398
|
+
message=str(e),
|
|
2399
|
+
)
|
|
2400
|
+
raise e
|
|
2401
|
+
|
|
2402
|
+
with Uploader(
|
|
2403
|
+
upload_f=_upload_f,
|
|
2404
|
+
notify_f=_notify_f,
|
|
2405
|
+
exception_handler=_exception_handler,
|
|
2406
|
+
logger=logger,
|
|
2407
|
+
) as uploader:
|
|
2408
|
+
for batch in batched(
|
|
2409
|
+
range(
|
|
2410
|
+
start_frame_index, start_frame_index + direction * n_frames, direction * step
|
|
2411
|
+
),
|
|
2412
|
+
batch_size,
|
|
2413
|
+
):
|
|
2414
|
+
if inference_request.is_stopped():
|
|
2415
|
+
logger.debug(
|
|
2416
|
+
f"Cancelling inference video...",
|
|
2417
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2418
|
+
)
|
|
2419
|
+
break
|
|
2420
|
+
logger.debug(
|
|
2421
|
+
f"Inferring frames {batch[0]}-{batch[-1]}:",
|
|
2422
|
+
)
|
|
2423
|
+
frames = self.cache.download_frames(
|
|
2424
|
+
api, video_info.id, batch, redownload_video=True
|
|
2425
|
+
)
|
|
2426
|
+
anns, slides_data = self._inference_auto(
|
|
2427
|
+
source=frames,
|
|
2428
|
+
settings=inference_settings,
|
|
2429
|
+
)
|
|
2430
|
+
|
|
2431
|
+
if inference_request.tracker is not None:
|
|
2432
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2433
|
+
|
|
2434
|
+
predictions = [
|
|
2435
|
+
Prediction(
|
|
2436
|
+
ann,
|
|
2437
|
+
model_meta=self.model_meta,
|
|
2438
|
+
frame_index=frame_index,
|
|
2439
|
+
video_id=video_info.id,
|
|
2440
|
+
dataset_id=video_info.dataset_id,
|
|
2441
|
+
project_id=video_info.project_id,
|
|
2442
|
+
)
|
|
2443
|
+
for ann, frame_index in zip(anns, batch)
|
|
2444
|
+
]
|
|
2445
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2446
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
2447
|
+
uploader.put(predictions)
|
|
2448
|
+
video_ann_json = None
|
|
2449
|
+
if inference_request.tracker is not None:
|
|
2450
|
+
inference_request.set_stage("Postprocess...", 0, 1)
|
|
2451
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
2452
|
+
inference_request.done()
|
|
2453
|
+
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2454
|
+
return video_ann_json
|
|
2455
|
+
|
|
2456
|
+
|
|
2292
2457
|
def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2293
2458
|
"""Inference project images.
|
|
2294
2459
|
If "output_project_id" in state, upload images and annotations to the output project.
|
|
@@ -2955,6 +3120,83 @@ class Inference:
|
|
|
2955
3120
|
inference_request.add_results(results)
|
|
2956
3121
|
inference_request.done(len(results))
|
|
2957
3122
|
|
|
3123
|
+
def upload_predictions_to_video(
|
|
3124
|
+
self,
|
|
3125
|
+
predictions: List[Prediction],
|
|
3126
|
+
api: Api,
|
|
3127
|
+
video_info: VideoInfo,
|
|
3128
|
+
track_id: str,
|
|
3129
|
+
context: Dict,
|
|
3130
|
+
progress_cb=None,
|
|
3131
|
+
inference_request: InferenceRequest = None,
|
|
3132
|
+
):
|
|
3133
|
+
key_id_map = KeyIdMap()
|
|
3134
|
+
project_meta = context.get("project_meta", None)
|
|
3135
|
+
if project_meta is None:
|
|
3136
|
+
project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
|
|
3137
|
+
context["project_meta"] = project_meta
|
|
3138
|
+
meta_changed = False
|
|
3139
|
+
for prediction in predictions:
|
|
3140
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
3141
|
+
project_meta, prediction.annotation, None
|
|
3142
|
+
)
|
|
3143
|
+
prediction.annotation = ann
|
|
3144
|
+
meta_changed = meta_changed or meta_changed_
|
|
3145
|
+
if meta_changed:
|
|
3146
|
+
project_meta = api.project.update_meta(video_info.project_id, project_meta)
|
|
3147
|
+
context["project_meta"] = project_meta
|
|
3148
|
+
|
|
3149
|
+
figure_data_by_object_id = defaultdict(list)
|
|
3150
|
+
|
|
3151
|
+
tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
|
|
3152
|
+
new_tracks: Dict[int, VideoObject] = {}
|
|
3153
|
+
for prediction in predictions:
|
|
3154
|
+
annotation = prediction.annotation
|
|
3155
|
+
tracks = annotation.custom_data
|
|
3156
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3157
|
+
if track not in tracks_to_object_ids and track not in new_tracks:
|
|
3158
|
+
video_object = VideoObject(obj_class=label.obj_class)
|
|
3159
|
+
new_tracks[track] = video_object
|
|
3160
|
+
if new_tracks:
|
|
3161
|
+
tracks, video_objects = zip(*new_tracks.items())
|
|
3162
|
+
added_object_ids = api.video.object.append_bulk(
|
|
3163
|
+
video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
|
|
3164
|
+
)
|
|
3165
|
+
for track, object_id in zip(tracks, added_object_ids):
|
|
3166
|
+
tracks_to_object_ids[track] = object_id
|
|
3167
|
+
for prediction in predictions:
|
|
3168
|
+
annotation = prediction.annotation
|
|
3169
|
+
tracks = annotation.custom_data
|
|
3170
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3171
|
+
object_id = tracks_to_object_ids[track]
|
|
3172
|
+
figure_data_by_object_id[object_id].append(
|
|
3173
|
+
{
|
|
3174
|
+
ApiField.OBJECT_ID: object_id,
|
|
3175
|
+
ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
|
|
3176
|
+
ApiField.GEOMETRY: label.geometry.to_json(),
|
|
3177
|
+
ApiField.META: {ApiField.FRAME: prediction.frame_index},
|
|
3178
|
+
ApiField.TRACK_ID: track_id,
|
|
3179
|
+
}
|
|
3180
|
+
)
|
|
3181
|
+
|
|
3182
|
+
for object_id, figures_data in figure_data_by_object_id.items():
|
|
3183
|
+
figures_keys = [uuid.uuid4() for _ in figures_data]
|
|
3184
|
+
api.video.figure._append_bulk(
|
|
3185
|
+
entity_id=video_info.id,
|
|
3186
|
+
figures_json=figures_data,
|
|
3187
|
+
figures_keys=figures_keys,
|
|
3188
|
+
key_id_map=key_id_map,
|
|
3189
|
+
)
|
|
3190
|
+
logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
|
|
3191
|
+
if progress_cb:
|
|
3192
|
+
progress_cb(len(predictions))
|
|
3193
|
+
if inference_request is not None:
|
|
3194
|
+
results = self._format_output(predictions)
|
|
3195
|
+
for result in results:
|
|
3196
|
+
result["annotation"] = None
|
|
3197
|
+
result["data"] = None
|
|
3198
|
+
inference_request.add_results(results)
|
|
3199
|
+
|
|
2958
3200
|
def serve(self):
|
|
2959
3201
|
if not self._use_gui and not self._is_cli_deploy:
|
|
2960
3202
|
Progress("Deploying model ...", 1)
|
|
@@ -3352,6 +3594,22 @@ class Inference:
|
|
|
3352
3594
|
"inference_request_uuid": inference_request.uuid,
|
|
3353
3595
|
}
|
|
3354
3596
|
|
|
3597
|
+
@server.post("/tracking_by_detection")
|
|
3598
|
+
def tracking_by_detection(response: Response, request: Request):
|
|
3599
|
+
state = request.state.state
|
|
3600
|
+
context = request.state.context
|
|
3601
|
+
state.update(context)
|
|
3602
|
+
if state.get("tracker") is None:
|
|
3603
|
+
state["tracker"] = "botsort"
|
|
3604
|
+
|
|
3605
|
+
logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
|
|
3606
|
+
self.validate_inference_state(state)
|
|
3607
|
+
api = self.api_from_request(request)
|
|
3608
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
3609
|
+
self._tracking_by_detection, api, state
|
|
3610
|
+
)
|
|
3611
|
+
return {"message": "Track task started."}
|
|
3612
|
+
|
|
3355
3613
|
@server.post("/inference_project_id_async")
|
|
3356
3614
|
def inference_project_id_async(response: Response, request: Request):
|
|
3357
3615
|
state = request.state.state
|
|
@@ -3415,10 +3673,7 @@ class Inference:
|
|
|
3415
3673
|
data = {**inference_request.to_json(), **log_extra}
|
|
3416
3674
|
if inference_request.stage != InferenceRequest.Stage.INFERENCE:
|
|
3417
3675
|
data["progress"] = {"current": 0, "total": 1}
|
|
3418
|
-
logger.debug(
|
|
3419
|
-
f"Sending inference progress with uuid:",
|
|
3420
|
-
extra=data,
|
|
3421
|
-
)
|
|
3676
|
+
logger.debug(f"Sending inference progress with uuid:", extra=data)
|
|
3422
3677
|
return data
|
|
3423
3678
|
|
|
3424
3679
|
@server.post(f"/pop_inference_results")
|
|
@@ -4411,6 +4666,7 @@ def _filter_duplicated_predictions_from_ann_cpu(
|
|
|
4411
4666
|
|
|
4412
4667
|
return pred_ann.clone(labels=new_labels)
|
|
4413
4668
|
|
|
4669
|
+
|
|
4414
4670
|
def _filter_duplicated_predictions_from_ann(
|
|
4415
4671
|
gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
|
4416
4672
|
) -> Annotation:
|
|
@@ -4644,7 +4900,7 @@ def get_gpu_count():
|
|
|
4644
4900
|
gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
|
|
4645
4901
|
return gpu_count
|
|
4646
4902
|
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
|
4647
|
-
logger.
|
|
4903
|
+
logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
|
|
4648
4904
|
return 0
|
|
4649
4905
|
|
|
4650
4906
|
|
|
@@ -5111,7 +5367,8 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
|
|
|
5111
5367
|
return data[key]
|
|
5112
5368
|
return None
|
|
5113
5369
|
|
|
5114
|
-
|
|
5370
|
+
|
|
5371
|
+
def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
|
|
5115
5372
|
import torch # pylint: disable=import-error
|
|
5116
5373
|
|
|
5117
5374
|
# TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
|
|
@@ -14,13 +14,6 @@ from supervisely.sly_logger import logger
|
|
|
14
14
|
from supervisely.task.progress import Progress
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def generate_uuid(self) -> str:
|
|
18
|
-
"""
|
|
19
|
-
Generates a unique UUID for the inference request.
|
|
20
|
-
"""
|
|
21
|
-
return uuid.uuid5(namespace=uuid.NAMESPACE_URL, name=f"{time.time()}-{rand_str(10)}").hex
|
|
22
|
-
|
|
23
|
-
|
|
24
17
|
class InferenceRequest:
|
|
25
18
|
class Stage:
|
|
26
19
|
PREPARING = "Preparing model for inference..."
|
|
@@ -59,7 +52,7 @@ class InferenceRequest:
|
|
|
59
52
|
self._created_at = time.monotonic()
|
|
60
53
|
self._updated_at = self._created_at
|
|
61
54
|
self._finished = False
|
|
62
|
-
|
|
55
|
+
|
|
63
56
|
self.tracker = None
|
|
64
57
|
|
|
65
58
|
self.global_progress = None
|
|
@@ -252,7 +245,8 @@ class InferenceRequest:
|
|
|
252
245
|
status_data.pop(key, None)
|
|
253
246
|
status_data.update(self.get_usage())
|
|
254
247
|
return status_data
|
|
255
|
-
|
|
248
|
+
|
|
249
|
+
|
|
256
250
|
class GlobalProgress:
|
|
257
251
|
def __init__(self):
|
|
258
252
|
self.progress = Progress(message="Ready", total_cnt=1)
|
|
@@ -2,6 +2,9 @@ import threading
|
|
|
2
2
|
from typing import Any, Dict, List
|
|
3
3
|
|
|
4
4
|
from supervisely.api.api import Api
|
|
5
|
+
from supervisely.api.dataset_api import DatasetInfo
|
|
6
|
+
from supervisely.api.project_api import ProjectInfo
|
|
7
|
+
from supervisely.api.video.video_api import VideoInfo
|
|
5
8
|
from supervisely.app.widgets import (
|
|
6
9
|
Button,
|
|
7
10
|
Card,
|
|
@@ -202,6 +205,53 @@ class InputSelector:
|
|
|
202
205
|
|
|
203
206
|
self.select_video.add_rows(rows)
|
|
204
207
|
|
|
208
|
+
def select_project(self, project_id: int, project_info: ProjectInfo = None):
|
|
209
|
+
if project_info is None:
|
|
210
|
+
project_info = self.api.project.get_info_by_id(project_id)
|
|
211
|
+
if project_info.type == ProjectType.IMAGES.value:
|
|
212
|
+
self.select_dataset_for_images.set_project_id(project_id)
|
|
213
|
+
self.select_dataset_for_images.select_all()
|
|
214
|
+
self.radio.set_value(ProjectType.IMAGES.value)
|
|
215
|
+
elif project_info.type == ProjectType.VIDEOS.value:
|
|
216
|
+
self.select_dataset_for_video.set_project_id(project_id)
|
|
217
|
+
self.select_dataset_for_video.select_all()
|
|
218
|
+
self._refresh_video_table()
|
|
219
|
+
self.select_video.select_rows(list(range(len(self.select_video._rows_total))))
|
|
220
|
+
self.radio.set_value(ProjectType.VIDEOS.value)
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError(f"Project of type {project_info.type} is not supported.")
|
|
223
|
+
|
|
224
|
+
def select_datasets(self, dataset_ids: List[int], dataset_infos: List[DatasetInfo] = None):
|
|
225
|
+
if dataset_infos is None:
|
|
226
|
+
dataset_infos = [self.api.dataset.get_info_by_id(ds_id) for ds_id in dataset_ids]
|
|
227
|
+
project_ids = set(ds.project_id for ds in dataset_infos)
|
|
228
|
+
if len(project_ids) > 1:
|
|
229
|
+
raise ValueError("Cannot select datasets from different projects")
|
|
230
|
+
project_id = project_ids.pop()
|
|
231
|
+
project_info = self.api.project.get_info_by_id(project_id)
|
|
232
|
+
if project_info.type == ProjectType.IMAGES.value:
|
|
233
|
+
self.select_dataset_for_images.set_project_id(project_id)
|
|
234
|
+
self.select_dataset_for_images.set_dataset_ids(dataset_ids)
|
|
235
|
+
self.radio.set_value(ProjectType.IMAGES.value)
|
|
236
|
+
elif project_info.type == ProjectType.VIDEOS.value:
|
|
237
|
+
self.select_dataset_for_video.set_project_id(project_id)
|
|
238
|
+
self.select_dataset_for_video.set_dataset_ids(dataset_ids)
|
|
239
|
+
self._refresh_video_table()
|
|
240
|
+
self.select_video.select_rows(list(range(self.select_video._rows_total)))
|
|
241
|
+
self.radio.set_value(ProjectType.VIDEOS.value)
|
|
242
|
+
else:
|
|
243
|
+
raise ValueError(f"Project of type {project_info.type} is not supported.")
|
|
244
|
+
|
|
245
|
+
def select_videos(self, video_ids: List[int], video_infos: List[VideoInfo] = None):
|
|
246
|
+
if video_infos is None:
|
|
247
|
+
video_infos = self.api.video.get_info_by_id_batch(video_ids)
|
|
248
|
+
project_id = video_infos[0].project_id
|
|
249
|
+
self.select_dataset_for_video.set_project_id(project_id)
|
|
250
|
+
self.select_dataset_for_video.select_all()
|
|
251
|
+
self._refresh_video_table()
|
|
252
|
+
self.select_video.select_row_by_value("id", video_ids)
|
|
253
|
+
self.radio.set_value(ProjectType.VIDEOS.value)
|
|
254
|
+
|
|
205
255
|
def disable(self):
|
|
206
256
|
for widget in self.widgets_to_disable:
|
|
207
257
|
widget.disable()
|
|
@@ -249,37 +299,13 @@ class InputSelector:
|
|
|
249
299
|
video_infos = self.api.video.get_info_by_id_batch(video_ids)
|
|
250
300
|
if not video_infos:
|
|
251
301
|
raise ValueError(f"Videos with video ids {video_ids} are not found")
|
|
252
|
-
|
|
253
|
-
self.select_dataset_for_video.set_project_id(project_id)
|
|
254
|
-
self.select_dataset_for_video.select_all()
|
|
255
|
-
self.select_video.select_row_by_value("id", data["video_ids"])
|
|
256
|
-
self.radio.set_value(ProjectType.VIDEOS.value)
|
|
302
|
+
self.select_videos(video_ids, video_infos)
|
|
257
303
|
elif "dataset_ids" in data:
|
|
258
304
|
dataset_ids = data["dataset_ids"]
|
|
259
|
-
|
|
260
|
-
raise ValueError("Dataset ids cannot be empty")
|
|
261
|
-
dataset_id = dataset_ids[0]
|
|
262
|
-
dataset_info = self.api.dataset.get_info_by_id(dataset_id)
|
|
263
|
-
project_info = self.api.project.get_info_by_id(dataset_info.project_id)
|
|
264
|
-
if project_info.type == ProjectType.VIDEOS:
|
|
265
|
-
self.select_dataset_for_video.set_project_id(project_info.id)
|
|
266
|
-
self.select_dataset_for_video.set_dataset_ids(dataset_ids)
|
|
267
|
-
self.radio.set_value(ProjectType.VIDEOS.value)
|
|
268
|
-
else:
|
|
269
|
-
self.select_dataset_for_images.set_project_id(project_info.id)
|
|
270
|
-
self.select_dataset_for_images.set_dataset_ids(dataset_ids)
|
|
271
|
-
self.radio.set_value(ProjectType.IMAGES.value)
|
|
305
|
+
self.select_datasets(dataset_ids)
|
|
272
306
|
elif "project_id" in data:
|
|
273
307
|
project_id = data["project_id"]
|
|
274
|
-
|
|
275
|
-
if project_info.type == ProjectType.VIDEOS:
|
|
276
|
-
self.select_dataset_for_video.set_project_id(project_id)
|
|
277
|
-
self.select_dataset_for_video.select_all()
|
|
278
|
-
self.radio.set_value(ProjectType.VIDEOS.value)
|
|
279
|
-
else:
|
|
280
|
-
self.select_dataset_for_images.set_project_id(project_id)
|
|
281
|
-
self.select_dataset_for_images.select_all()
|
|
282
|
-
self.radio.set_value(ProjectType.IMAGES.value)
|
|
308
|
+
self.select_project(project_id)
|
|
283
309
|
|
|
284
310
|
def get_project_id(self) -> int:
|
|
285
311
|
if self.radio.get_value() == ProjectType.IMAGES.value:
|
|
@@ -441,46 +441,52 @@ class SessionJSON:
|
|
|
441
441
|
prev_current = 0
|
|
442
442
|
if preparing_cb:
|
|
443
443
|
# wait for inference status
|
|
444
|
-
|
|
445
|
-
awaiting_preparing_progress = 0
|
|
446
|
-
break_flag = False
|
|
447
|
-
while resp.get("status") is None:
|
|
448
|
-
time.sleep(1)
|
|
449
|
-
awaiting_preparing_progress += 1
|
|
450
|
-
if awaiting_preparing_progress > 30:
|
|
451
|
-
break_flag = True
|
|
444
|
+
try:
|
|
452
445
|
resp = self._get_preparing_progress()
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
446
|
+
for i in range(30):
|
|
447
|
+
logger.info(
|
|
448
|
+
f"Waiting for preparing progress... {30 - i} seconds left until timeout"
|
|
449
|
+
)
|
|
450
|
+
resp = self._get_preparing_progress()
|
|
451
|
+
if resp.get("status") is not None:
|
|
452
|
+
break
|
|
453
|
+
time.sleep(1)
|
|
454
|
+
if not resp.get("status"):
|
|
455
|
+
raise RuntimeError("Preparing progress status is not available.")
|
|
456
|
+
|
|
457
|
+
if resp.get("status") == "download_info":
|
|
458
|
+
logger.info("Downloading infos...")
|
|
459
459
|
progress_widget = preparing_cb(
|
|
460
460
|
message="Downloading infos", total=resp["total"], unit="it"
|
|
461
461
|
)
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
if resp
|
|
462
|
+
while resp["status"] == "download_info":
|
|
463
|
+
current = resp["current"]
|
|
464
|
+
# pylint: disable=possibly-used-before-assignment
|
|
465
|
+
progress_widget.update(current - prev_current)
|
|
466
|
+
prev_current = current
|
|
467
|
+
resp = self._get_preparing_progress()
|
|
468
|
+
|
|
469
|
+
if resp.get("status") == "download_project":
|
|
470
|
+
logger.info("Downloading project...")
|
|
470
471
|
progress_widget = preparing_cb(message="Download project", total=resp["total"])
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
if resp
|
|
472
|
+
while resp.get("status") == "download_project":
|
|
473
|
+
current = resp["current"]
|
|
474
|
+
progress_widget.update(current - prev_current)
|
|
475
|
+
prev_current = current
|
|
476
|
+
resp = self._get_preparing_progress()
|
|
477
|
+
|
|
478
|
+
if resp.get("status") == "warmup":
|
|
479
|
+
logger.info("Running warmup...")
|
|
478
480
|
progress_widget = preparing_cb(message="Running warmup", total=resp["total"])
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
481
|
+
while resp.get("status") == "warmup":
|
|
482
|
+
current = resp["current"]
|
|
483
|
+
progress_widget.update(current - prev_current)
|
|
484
|
+
prev_current = current
|
|
485
|
+
resp = self._get_preparing_progress()
|
|
486
|
+
except Exception as ex:
|
|
487
|
+
logger.warning(
|
|
488
|
+
f"An error occurred while getting preparing progress: {ex}. Continue without preparing progress status."
|
|
489
|
+
)
|
|
484
490
|
|
|
485
491
|
logger.info("Inference has started:", extra={"response": resp})
|
|
486
492
|
resp, has_started = self._wait_for_async_inference_start()
|
|
@@ -537,7 +543,9 @@ class SessionJSON:
|
|
|
537
543
|
t0 = time.time()
|
|
538
544
|
while not has_started and not timeout_exceeded:
|
|
539
545
|
resp = self._get_inference_progress()
|
|
540
|
-
|
|
546
|
+
pending_results = resp.get("pending_results", None)
|
|
547
|
+
has_results = bool(pending_results)
|
|
548
|
+
has_started = bool(resp.get("result")) or resp["progress"]["total"] != 1 or has_results
|
|
541
549
|
if not has_started:
|
|
542
550
|
time.sleep(delay)
|
|
543
551
|
timeout_exceeded = timeout and time.time() - t0 > timeout
|