supervisely 6.73.419__py3-none-any.whl → 6.73.421__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/api.py +10 -5
- supervisely/api/app_api.py +71 -4
- supervisely/api/module_api.py +4 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/project_api.py +35 -6
- supervisely/api/task_api.py +5 -1
- supervisely/app/widgets/__init__.py +8 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/deploy_model/__init__.py +0 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
- supervisely/app/widgets/fast_table/fast_table.py +402 -74
- supervisely/app/widgets/fast_table/script.js +364 -96
- supervisely/app/widgets/fast_table/style.css +24 -0
- supervisely/app/widgets/fast_table/template.html +43 -3
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +160 -94
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
- supervisely/nn/inference/predict_app/gui/gui.py +710 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +282 -0
- supervisely/nn/inference/predict_app/predict_app.py +184 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/prediction.py +2 -0
- supervisely/nn/model/prediction_session.py +20 -3
- supervisely/nn/training/gui/gui.py +131 -44
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
- supervisely/nn/training/gui/training_artifacts.py +0 -5
- supervisely/nn/training/train_app.py +161 -44
- supervisely/project/project.py +211 -73
- supervisely/template/experiment/experiment.html.jinja +74 -17
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/METADATA +3 -1
- {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/RECORD +75 -57
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/LICENSE +0 -0
- {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/WHEEL +0 -0
- {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.419.dist-info → supervisely-6.73.421.dist-info}/top_level.txt +0 -0
|
@@ -19,6 +19,7 @@ from pathlib import Path
|
|
|
19
19
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
20
20
|
from urllib.request import urlopen
|
|
21
21
|
|
|
22
|
+
import _pickle
|
|
22
23
|
import numpy as np
|
|
23
24
|
import requests
|
|
24
25
|
import uvicorn
|
|
@@ -904,9 +905,15 @@ class Inference:
|
|
|
904
905
|
if extracted_files:
|
|
905
906
|
local_model_files[file] = file_path
|
|
906
907
|
return local_model_files
|
|
908
|
+
except _pickle.UnpicklingError as e:
|
|
909
|
+
# TODO: raise error - checkpoint is corrupted
|
|
910
|
+
logger.warning(f"Couldn't load '{file_name}'. Checkpoint might be corrupted. Error: {repr(e)}")
|
|
911
|
+
logger.warning("Model files will be downloaded from Team Files")
|
|
912
|
+
local_model_files[file] = file_path
|
|
913
|
+
continue
|
|
907
914
|
except Exception as e:
|
|
908
|
-
logger.
|
|
909
|
-
logger.
|
|
915
|
+
logger.warning(f"Failed to process checkpoint '{file_name}' to extract auxiliary files: {repr(e)}")
|
|
916
|
+
logger.warning("Model files will be downloaded from Team Files")
|
|
910
917
|
local_model_files[file] = file_path
|
|
911
918
|
continue
|
|
912
919
|
|
|
@@ -975,8 +982,7 @@ class Inference:
|
|
|
975
982
|
# --- LOCAL ---
|
|
976
983
|
try:
|
|
977
984
|
logger.debug("Reading state dict...")
|
|
978
|
-
|
|
979
|
-
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
|
985
|
+
ckpt = torch_load_safe(checkpoint_path)
|
|
980
986
|
model_info = ckpt.get("model_info", {})
|
|
981
987
|
model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
|
|
982
988
|
model_files["checkpoint"] = checkpoint_path
|
|
@@ -1016,10 +1022,8 @@ class Inference:
|
|
|
1016
1022
|
if file_ext not in (".pth", ".pt"):
|
|
1017
1023
|
return extracted_files
|
|
1018
1024
|
|
|
1019
|
-
import torch # pylint: disable=import-error
|
|
1020
1025
|
logger.debug(f"Reading checkpoint: {checkpoint_path}")
|
|
1021
|
-
checkpoint =
|
|
1022
|
-
|
|
1026
|
+
checkpoint = torch_load_safe(checkpoint_path)
|
|
1023
1027
|
# 1. Extract additional model files embedded into checkpoint (if any)
|
|
1024
1028
|
ckpt_files = checkpoint.get("model_files", None)
|
|
1025
1029
|
if ckpt_files and isinstance(ckpt_files, dict):
|
|
@@ -1235,7 +1239,7 @@ class Inference:
|
|
|
1235
1239
|
|
|
1236
1240
|
def _set_checkpoint_info_pretrained(self, deploy_params: dict):
|
|
1237
1241
|
checkpoint_name = os.path.basename(deploy_params["model_files"]["checkpoint"])
|
|
1238
|
-
model_name = deploy_params["model_info"]
|
|
1242
|
+
model_name = _get_model_name(deploy_params["model_info"])
|
|
1239
1243
|
checkpoint_url = deploy_params["model_info"]["meta"]["model_files"]["checkpoint"]
|
|
1240
1244
|
model_source = ModelSource.PRETRAINED
|
|
1241
1245
|
self.checkpoint_info = CheckpointInfo(
|
|
@@ -1941,6 +1945,7 @@ class Inference:
|
|
|
1941
1945
|
raise ValueError("Image ids are not provided")
|
|
1942
1946
|
if not isinstance(image_ids, list):
|
|
1943
1947
|
image_ids = [image_ids]
|
|
1948
|
+
model_prediction_suffix = state.get("model_prediction_suffix", None)
|
|
1944
1949
|
upload_mode = state.get("upload_mode", None)
|
|
1945
1950
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
1946
1951
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
@@ -2005,6 +2010,7 @@ class Inference:
|
|
|
2005
2010
|
progress_cb=inference_request.done,
|
|
2006
2011
|
iou_merge_threshold=iou_merge_threshold,
|
|
2007
2012
|
inference_request=inference_request,
|
|
2013
|
+
model_prediction_suffix=model_prediction_suffix,
|
|
2008
2014
|
)
|
|
2009
2015
|
|
|
2010
2016
|
_add_results_to_request = partial(
|
|
@@ -2020,8 +2026,8 @@ class Inference:
|
|
|
2020
2026
|
with Uploader(upload_f, logger=logger) as uploader:
|
|
2021
2027
|
for image_ids_batch in batched(image_ids, batch_size=batch_size):
|
|
2022
2028
|
if uploader.has_exception():
|
|
2023
|
-
exception = uploader.exception
|
|
2024
|
-
raise
|
|
2029
|
+
exception = uploader.exception
|
|
2030
|
+
raise exception
|
|
2025
2031
|
if inference_request.is_stopped():
|
|
2026
2032
|
logger.debug(
|
|
2027
2033
|
f"Cancelling inference project...",
|
|
@@ -2177,6 +2183,8 @@ class Inference:
|
|
|
2177
2183
|
project_info = api.project.get_info_by_id(project_id)
|
|
2178
2184
|
if project_info.type != str(ProjectType.IMAGES):
|
|
2179
2185
|
raise ValueError("Only images projects are supported.")
|
|
2186
|
+
|
|
2187
|
+
model_prediction_suffix = state.get("model_prediction_suffix", None)
|
|
2180
2188
|
upload_mode = state.get("upload_mode", None)
|
|
2181
2189
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
2182
2190
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
@@ -2251,6 +2259,7 @@ class Inference:
|
|
|
2251
2259
|
progress_cb=inference_request.done,
|
|
2252
2260
|
iou_merge_threshold=iou_merge_threshold,
|
|
2253
2261
|
inference_request=inference_request,
|
|
2262
|
+
model_prediction_suffix=model_prediction_suffix,
|
|
2254
2263
|
)
|
|
2255
2264
|
|
|
2256
2265
|
_add_results_to_request = partial(
|
|
@@ -2276,7 +2285,7 @@ class Inference:
|
|
|
2276
2285
|
return
|
|
2277
2286
|
if uploader.has_exception():
|
|
2278
2287
|
exception = uploader.exception
|
|
2279
|
-
raise
|
|
2288
|
+
raise exception
|
|
2280
2289
|
if cache_project_on_model:
|
|
2281
2290
|
images_paths, _ = zip(
|
|
2282
2291
|
*read_from_cached_project(
|
|
@@ -2405,7 +2414,7 @@ class Inference:
|
|
|
2405
2414
|
return
|
|
2406
2415
|
if uploader.has_exception():
|
|
2407
2416
|
exception = uploader.exception
|
|
2408
|
-
raise
|
|
2417
|
+
raise exception
|
|
2409
2418
|
if i == num_warmup:
|
|
2410
2419
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, num_iterations)
|
|
2411
2420
|
|
|
@@ -2470,6 +2479,7 @@ class Inference:
|
|
|
2470
2479
|
# raise DialogWindowError(title="Call undeployed model.", description=msg)
|
|
2471
2480
|
raise RuntimeError(msg)
|
|
2472
2481
|
return func(*args, **kwargs)
|
|
2482
|
+
|
|
2473
2483
|
return wrapper
|
|
2474
2484
|
|
|
2475
2485
|
def _freeze_model(self):
|
|
@@ -2621,6 +2631,7 @@ class Inference:
|
|
|
2621
2631
|
progress_cb=None,
|
|
2622
2632
|
iou_merge_threshold: float = None,
|
|
2623
2633
|
inference_request: InferenceRequest = None,
|
|
2634
|
+
model_prediction_suffix: str = None,
|
|
2624
2635
|
):
|
|
2625
2636
|
ds_predictions: Dict[int, List[Prediction]] = defaultdict(list)
|
|
2626
2637
|
for prediction in predictions:
|
|
@@ -2682,7 +2693,9 @@ class Inference:
|
|
|
2682
2693
|
meta_changed = False
|
|
2683
2694
|
for pred in preds:
|
|
2684
2695
|
ann = pred.annotation
|
|
2685
|
-
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
2696
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
2697
|
+
project_meta, ann, model_prediction_suffix
|
|
2698
|
+
)
|
|
2686
2699
|
meta_changed = meta_changed or meta_changed_
|
|
2687
2700
|
pred.annotation = ann
|
|
2688
2701
|
prediction.model_meta = project_meta
|
|
@@ -2746,7 +2759,9 @@ class Inference:
|
|
|
2746
2759
|
meta_changed = False
|
|
2747
2760
|
for pred in preds:
|
|
2748
2761
|
ann = pred.annotation
|
|
2749
|
-
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
2762
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
2763
|
+
project_meta, ann, model_prediction_suffix
|
|
2764
|
+
)
|
|
2750
2765
|
meta_changed = meta_changed or meta_changed_
|
|
2751
2766
|
pred.annotation = ann
|
|
2752
2767
|
prediction.model_meta = project_meta
|
|
@@ -2828,12 +2843,12 @@ class Inference:
|
|
|
2828
2843
|
# Predict and shutdown
|
|
2829
2844
|
if self._args.mode == "predict":
|
|
2830
2845
|
if any(
|
|
2831
|
-
|
|
2832
|
-
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2846
|
+
[
|
|
2847
|
+
self._args.input,
|
|
2848
|
+
self._args.project_id,
|
|
2849
|
+
self._args.dataset_id,
|
|
2850
|
+
self._args.image_id,
|
|
2851
|
+
]
|
|
2837
2852
|
):
|
|
2838
2853
|
self._parse_inference_settings_from_args()
|
|
2839
2854
|
self._inference_by_cli_deploy_args()
|
|
@@ -3687,6 +3702,7 @@ class Inference:
|
|
|
3687
3702
|
|
|
3688
3703
|
def _parse_inference_settings_from_args(self):
|
|
3689
3704
|
logger.debug("Parsing inference settings from args")
|
|
3705
|
+
|
|
3690
3706
|
def parse_value(value: str):
|
|
3691
3707
|
if value.lower() in ("true", "false"):
|
|
3692
3708
|
return value.lower() == "true"
|
|
@@ -3813,8 +3829,7 @@ class Inference:
|
|
|
3813
3829
|
try:
|
|
3814
3830
|
# Read data from checkpoint
|
|
3815
3831
|
logger.debug(f"Reading data from checkpoint: {checkpoint_path}")
|
|
3816
|
-
|
|
3817
|
-
checkpoint = torch.load(checkpoint_path)
|
|
3832
|
+
checkpoint = torch_load_safe(checkpoint_path)
|
|
3818
3833
|
model_info = checkpoint["model_info"]
|
|
3819
3834
|
model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
|
|
3820
3835
|
model_meta = os.path.join(self.model_dir, "model_meta.json")
|
|
@@ -4044,6 +4059,7 @@ class Inference:
|
|
|
4044
4059
|
draw: bool = False,
|
|
4045
4060
|
):
|
|
4046
4061
|
logger.info(f"Predicting Local Data: {input_path}")
|
|
4062
|
+
|
|
4047
4063
|
def postprocess_image(image_path: str, ann: Annotation, pred_dir: str = None):
|
|
4048
4064
|
image_name = sly_fs.get_file_name_with_ext(image_path)
|
|
4049
4065
|
if pred_dir is not None:
|
|
@@ -4166,13 +4182,14 @@ class Inference:
|
|
|
4166
4182
|
|
|
4167
4183
|
task_id = experiment_info.task_id
|
|
4168
4184
|
self.gui.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
|
|
4169
|
-
self.gui.experiment_selector.
|
|
4185
|
+
self.gui.experiment_selector.set_selected_row_by_task_id(task_id)
|
|
4170
4186
|
|
|
4171
4187
|
best_ckpt = experiment_info.best_checkpoint
|
|
4172
4188
|
if best_ckpt:
|
|
4173
|
-
row = self.gui.experiment_selector.
|
|
4189
|
+
row = self.gui.experiment_selector.get_selected_row_by_task_id(task_id)
|
|
4174
4190
|
if row is not None:
|
|
4175
4191
|
row.set_selected_checkpoint_by_name(best_ckpt)
|
|
4192
|
+
|
|
4176
4193
|
except Exception as e:
|
|
4177
4194
|
logger.warning(f"Failed to set checkpoint from experiment info: {repr(e)}")
|
|
4178
4195
|
|
|
@@ -4181,6 +4198,7 @@ class Inference:
|
|
|
4181
4198
|
return
|
|
4182
4199
|
self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
4183
4200
|
|
|
4201
|
+
|
|
4184
4202
|
def _exclude_duplicated_predictions(
|
|
4185
4203
|
api: Api,
|
|
4186
4204
|
pred_anns: List[Annotation],
|
|
@@ -4456,7 +4474,7 @@ def _fix_classes_names(meta: ProjectMeta, ann: Annotation):
|
|
|
4456
4474
|
return meta, ann, replaced_classes_in_meta, list(replaced_classes_in_ann)
|
|
4457
4475
|
|
|
4458
4476
|
|
|
4459
|
-
def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
|
|
4477
|
+
def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suffix: str = None):
|
|
4460
4478
|
"""Update project meta and annotation to match each other
|
|
4461
4479
|
If obj class or tag meta from annotation conflicts with project meta
|
|
4462
4480
|
add suffix to obj class or tag meta.
|
|
@@ -4464,8 +4482,13 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
|
|
|
4464
4482
|
"""
|
|
4465
4483
|
obj_classes_suffixes = ["_nn"]
|
|
4466
4484
|
tag_meta_suffixes = ["_nn"]
|
|
4467
|
-
|
|
4468
|
-
|
|
4485
|
+
if model_prediction_suffix is not None:
|
|
4486
|
+
obj_classes_suffixes = [model_prediction_suffix]
|
|
4487
|
+
tag_meta_suffixes = [model_prediction_suffix]
|
|
4488
|
+
logger.debug(
|
|
4489
|
+
f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
|
|
4490
|
+
)
|
|
4491
|
+
logger.debug("source meta", extra={"meta": meta.to_json()})
|
|
4469
4492
|
meta_changed = False
|
|
4470
4493
|
|
|
4471
4494
|
meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
|
|
@@ -4476,91 +4499,116 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
|
|
|
4476
4499
|
extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
|
|
4477
4500
|
)
|
|
4478
4501
|
|
|
4479
|
-
|
|
4502
|
+
updated_labels = []
|
|
4503
|
+
any_label_updated = False
|
|
4504
|
+
for label in ann.labels:
|
|
4505
|
+
original_obj_class_name = label.obj_class.name
|
|
4506
|
+
suffix_found = False
|
|
4507
|
+
for suffix in ["", *obj_classes_suffixes]:
|
|
4508
|
+
label_obj_class = label.obj_class
|
|
4509
|
+
label_obj_class_name = label_obj_class.name + suffix
|
|
4510
|
+
if suffix:
|
|
4511
|
+
label_obj_class = label_obj_class.clone(name=label_obj_class_name)
|
|
4512
|
+
label = label.clone(obj_class=label_obj_class)
|
|
4513
|
+
any_label_updated = True
|
|
4514
|
+
meta_obj_class = meta.get_obj_class(label_obj_class_name)
|
|
4515
|
+
if meta_obj_class is None:
|
|
4516
|
+
# if obj class is not in meta, add it with suffix
|
|
4517
|
+
meta = meta.add_obj_class(label_obj_class)
|
|
4518
|
+
updated_labels.append(label)
|
|
4519
|
+
meta_changed = True
|
|
4520
|
+
suffix_found = True
|
|
4521
|
+
break
|
|
4522
|
+
elif meta_obj_class.geometry_type.geometry_name() == label.geometry.geometry_name():
|
|
4523
|
+
# if label geometry is the same as in meta, use meta obj class
|
|
4524
|
+
label = label.clone(obj_class=meta_obj_class)
|
|
4525
|
+
updated_labels.append(label)
|
|
4526
|
+
suffix_found = True
|
|
4527
|
+
any_label_updated = True
|
|
4528
|
+
break
|
|
4529
|
+
elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
|
|
4530
|
+
# if meta obj class is AnyGeometry, use it in label
|
|
4531
|
+
label = label.clone(obj_class=meta_obj_class)
|
|
4532
|
+
updated_labels.append(label)
|
|
4533
|
+
suffix_found = True
|
|
4534
|
+
any_label_updated = True
|
|
4535
|
+
break
|
|
4536
|
+
if not suffix_found:
|
|
4537
|
+
# if no suffix found, raise error
|
|
4538
|
+
raise ValueError(
|
|
4539
|
+
f"Can't add obj class {original_obj_class_name} to project meta. "
|
|
4540
|
+
"Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
|
|
4541
|
+
"Please check if model geometry type is compatible with existing obj classes."
|
|
4542
|
+
)
|
|
4543
|
+
if any_label_updated:
|
|
4544
|
+
ann = ann.clone(labels=updated_labels)
|
|
4545
|
+
|
|
4546
|
+
# check if tag metas are in project meta
|
|
4547
|
+
# if not, add them with suffix
|
|
4548
|
+
ann_tag_metas = {}
|
|
4480
4549
|
for label in ann.labels:
|
|
4481
|
-
ann_obj_classes[label.obj_class.name] = label.obj_class
|
|
4482
4550
|
for tag in label.tags:
|
|
4483
4551
|
ann_tag_metas[tag.meta.name] = tag.meta
|
|
4484
4552
|
for tag in ann.img_tags:
|
|
4485
4553
|
ann_tag_metas[tag.meta.name] = tag.meta
|
|
4486
4554
|
|
|
4487
|
-
# check if obj classes are in project meta
|
|
4488
|
-
# if not, add them.
|
|
4489
|
-
# if shape is different, add them with suffix
|
|
4490
|
-
changed_obj_classes = {}
|
|
4491
|
-
for ann_obj_class in ann_obj_classes.values():
|
|
4492
|
-
if meta.get_obj_class(ann_obj_class.name) is None:
|
|
4493
|
-
meta = meta.add_obj_class(ann_obj_class)
|
|
4494
|
-
meta_changed = True
|
|
4495
|
-
elif (
|
|
4496
|
-
meta.get_obj_class(ann_obj_class.name).geometry_type != ann_obj_class.geometry_type
|
|
4497
|
-
and meta.get_obj_class(ann_obj_class.name).geometry_type != AnyGeometry
|
|
4498
|
-
):
|
|
4499
|
-
found = False
|
|
4500
|
-
for suffix in obj_classes_suffixes:
|
|
4501
|
-
new_obj_class_name = ann_obj_class.name + suffix
|
|
4502
|
-
meta_obj_class = meta.get_obj_class(new_obj_class_name)
|
|
4503
|
-
if meta_obj_class is None:
|
|
4504
|
-
new_obj_class = ann_obj_class.clone(name=new_obj_class_name)
|
|
4505
|
-
meta = meta.add_obj_class(new_obj_class)
|
|
4506
|
-
meta_changed = True
|
|
4507
|
-
changed_obj_classes[ann_obj_class.name] = new_obj_class
|
|
4508
|
-
found = True
|
|
4509
|
-
break
|
|
4510
|
-
if meta_obj_class.geometry_type == ann_obj_class.geometry_type:
|
|
4511
|
-
changed_obj_classes[ann_obj_class.name] = meta_obj_class
|
|
4512
|
-
found = True
|
|
4513
|
-
break
|
|
4514
|
-
if not found:
|
|
4515
|
-
raise ValueError(f"Can't add obj class {ann_obj_class.name} to project meta")
|
|
4516
|
-
|
|
4517
|
-
# check if tag metas are in project meta
|
|
4518
|
-
# if not, add them with suffix
|
|
4519
4555
|
changed_tag_metas = {}
|
|
4520
|
-
for
|
|
4521
|
-
|
|
4522
|
-
|
|
4556
|
+
for ann_tag_meta in ann_tag_metas.values():
|
|
4557
|
+
meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
|
|
4558
|
+
if meta_tag_meta is None:
|
|
4559
|
+
meta = meta.add_tag_meta(ann_tag_meta)
|
|
4523
4560
|
meta_changed = True
|
|
4524
|
-
elif not
|
|
4525
|
-
|
|
4561
|
+
elif not meta_tag_meta.is_compatible(ann_tag_meta):
|
|
4562
|
+
suffix_found = False
|
|
4526
4563
|
for suffix in tag_meta_suffixes:
|
|
4527
|
-
new_tag_meta_name =
|
|
4564
|
+
new_tag_meta_name = ann_tag_meta.name + suffix
|
|
4528
4565
|
meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
|
|
4529
4566
|
if meta_tag_meta is None:
|
|
4530
|
-
new_tag_meta =
|
|
4567
|
+
new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
|
|
4531
4568
|
meta = meta.add_tag_meta(new_tag_meta)
|
|
4532
|
-
changed_tag_metas[
|
|
4569
|
+
changed_tag_metas[ann_tag_meta.name] = new_tag_meta
|
|
4533
4570
|
meta_changed = True
|
|
4534
|
-
|
|
4571
|
+
suffix_found = True
|
|
4535
4572
|
break
|
|
4536
|
-
if meta_tag_meta.is_compatible(
|
|
4537
|
-
changed_tag_metas[
|
|
4538
|
-
|
|
4573
|
+
if meta_tag_meta.is_compatible(ann_tag_meta):
|
|
4574
|
+
changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
|
|
4575
|
+
suffix_found = True
|
|
4539
4576
|
break
|
|
4540
|
-
if not
|
|
4541
|
-
raise ValueError(f"Can't add tag meta {
|
|
4577
|
+
if not suffix_found:
|
|
4578
|
+
raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
|
|
4542
4579
|
|
|
4543
|
-
|
|
4544
|
-
|
|
4545
|
-
|
|
4546
|
-
|
|
4547
|
-
|
|
4548
|
-
|
|
4549
|
-
|
|
4580
|
+
if changed_tag_metas:
|
|
4581
|
+
labels = []
|
|
4582
|
+
any_label_updated = False
|
|
4583
|
+
for label in ann.labels:
|
|
4584
|
+
any_tag_updated = False
|
|
4585
|
+
label_tags = []
|
|
4586
|
+
for tag in label.tags:
|
|
4587
|
+
if tag.meta.name in changed_tag_metas:
|
|
4588
|
+
label_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
4589
|
+
any_tag_updated = True
|
|
4590
|
+
else:
|
|
4591
|
+
label_tags.append(tag)
|
|
4592
|
+
if any_tag_updated:
|
|
4593
|
+
label = label.clone(tags=TagCollection(label_tags))
|
|
4594
|
+
any_label_updated = True
|
|
4595
|
+
labels.append(label)
|
|
4596
|
+
img_tags = []
|
|
4597
|
+
any_tag_updated = False
|
|
4598
|
+
for tag in ann.img_tags:
|
|
4550
4599
|
if tag.meta.name in changed_tag_metas:
|
|
4551
|
-
|
|
4600
|
+
img_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
4601
|
+
any_tag_updated = True
|
|
4552
4602
|
else:
|
|
4553
|
-
|
|
4554
|
-
|
|
4555
|
-
|
|
4556
|
-
|
|
4557
|
-
|
|
4558
|
-
|
|
4559
|
-
|
|
4560
|
-
|
|
4561
|
-
|
|
4562
|
-
|
|
4563
|
-
ann = ann.clone(labels=labels, img_tags=TagCollection(img_tags))
|
|
4603
|
+
img_tags.append(tag)
|
|
4604
|
+
if any_tag_updated or any_label_updated:
|
|
4605
|
+
if any_tag_updated:
|
|
4606
|
+
img_tags = TagCollection(img_tags)
|
|
4607
|
+
else:
|
|
4608
|
+
img_tags = None
|
|
4609
|
+
if not any_label_updated:
|
|
4610
|
+
labels = None
|
|
4611
|
+
ann = ann.clone(img_tags=TagCollection(img_tags))
|
|
4564
4612
|
return meta, ann, meta_changed
|
|
4565
4613
|
|
|
4566
4614
|
|
|
@@ -4673,3 +4721,21 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
|
|
|
4673
4721
|
continue
|
|
4674
4722
|
return data[key]
|
|
4675
4723
|
return None
|
|
4724
|
+
|
|
4725
|
+
def torch_load_safe(checkpoint_path: str, device:str = "cpu"):
|
|
4726
|
+
import torch # pylint: disable=import-error
|
|
4727
|
+
|
|
4728
|
+
# TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
|
|
4729
|
+
try:
|
|
4730
|
+
logger.debug(f"Loading checkpoint from {checkpoint_path} on {device}")
|
|
4731
|
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
4732
|
+
logger.debug(f"Checkpoint loaded from {checkpoint_path} on {device}")
|
|
4733
|
+
except:
|
|
4734
|
+
logger.debug(
|
|
4735
|
+
f"Failed to load checkpoint from {checkpoint_path} on {device}. Trying again with weights_only=False"
|
|
4736
|
+
)
|
|
4737
|
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
4738
|
+
logger.debug(
|
|
4739
|
+
f"Checkpoint loaded from {checkpoint_path} on {device} with weights_only=False"
|
|
4740
|
+
)
|
|
4741
|
+
return checkpoint
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from typing import Dict, Any
|
|
2
|
+
from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ClassesSelector:
|
|
6
|
+
title = "Classes Selector"
|
|
7
|
+
description = "Select classes that will be used for inference"
|
|
8
|
+
lock_message = "Select previous step to unlock"
|
|
9
|
+
|
|
10
|
+
def __init__(self):
|
|
11
|
+
# Init Step
|
|
12
|
+
self.display_widgets = []
|
|
13
|
+
# -------------------------------- #
|
|
14
|
+
|
|
15
|
+
# Init Base Widgets
|
|
16
|
+
self.validator_text = None
|
|
17
|
+
self.button = None
|
|
18
|
+
self.container = None
|
|
19
|
+
self.card = None
|
|
20
|
+
# -------------------------------- #
|
|
21
|
+
|
|
22
|
+
# Init Step Widgets
|
|
23
|
+
self.classes_table = None
|
|
24
|
+
# -------------------------------- #
|
|
25
|
+
|
|
26
|
+
# Classes
|
|
27
|
+
self.classes_table = ClassesTable()
|
|
28
|
+
self.classes_table.hide()
|
|
29
|
+
# Add widgets to display ------------ #
|
|
30
|
+
self.display_widgets.extend([self.classes_table])
|
|
31
|
+
# ----------------------------------- #
|
|
32
|
+
|
|
33
|
+
# Base Widgets
|
|
34
|
+
self.validator_text = Text("")
|
|
35
|
+
self.validator_text.hide()
|
|
36
|
+
self.button = Button("Select")
|
|
37
|
+
self.display_widgets.extend([self.validator_text, self.button])
|
|
38
|
+
# -------------------------------- #
|
|
39
|
+
|
|
40
|
+
# Card Layout
|
|
41
|
+
self.container = Container(self.display_widgets)
|
|
42
|
+
self.card = Card(
|
|
43
|
+
title=self.title,
|
|
44
|
+
description=self.description,
|
|
45
|
+
content=self.container,
|
|
46
|
+
lock_message=self.lock_message,
|
|
47
|
+
)
|
|
48
|
+
self.card.lock()
|
|
49
|
+
# -------------------------------- #
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def widgets_to_disable(self) -> list:
|
|
53
|
+
return [self.classes_table]
|
|
54
|
+
|
|
55
|
+
def load_from_json(self, data: Dict[str, Any]) -> None:
|
|
56
|
+
if "classes" in data:
|
|
57
|
+
self.set_classes(data["classes"])
|
|
58
|
+
|
|
59
|
+
def get_selected_classes(self) -> list:
|
|
60
|
+
return self.classes_table.get_selected_classes()
|
|
61
|
+
|
|
62
|
+
def set_classes(self, classes) -> None:
|
|
63
|
+
self.classes_table.select_classes(classes)
|
|
64
|
+
|
|
65
|
+
def select_all_classes(self) -> None:
|
|
66
|
+
self.classes_table.select_all()
|
|
67
|
+
|
|
68
|
+
def get_settings(self) -> Dict[str, Any]:
|
|
69
|
+
return {"classes": self.get_selected_classes()}
|
|
70
|
+
|
|
71
|
+
def validate_step(self) -> bool:
|
|
72
|
+
if self.classes_table.is_hidden():
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
self.validator_text.hide()
|
|
76
|
+
selected_classes = self.classes_table.get_selected_classes()
|
|
77
|
+
n_classes = len(selected_classes)
|
|
78
|
+
|
|
79
|
+
if n_classes == 0:
|
|
80
|
+
self.validator_text.set(text="Please select at least one class", status="error")
|
|
81
|
+
self.validator_text.show()
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
class_word = "class" if n_classes == 1 else "classes"
|
|
85
|
+
message_parts = [f"Selected {n_classes} {class_word}"]
|
|
86
|
+
status = "success"
|
|
87
|
+
is_valid = True
|
|
88
|
+
|
|
89
|
+
self.validator_text.set(text=". ".join(message_parts), status=status)
|
|
90
|
+
self.validator_text.show()
|
|
91
|
+
return is_valid
|