supervisely 6.73.420__py3-none-any.whl → 6.73.422__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. supervisely/api/api.py +10 -5
  2. supervisely/api/app_api.py +71 -4
  3. supervisely/api/module_api.py +4 -0
  4. supervisely/api/nn/deploy_api.py +15 -9
  5. supervisely/api/nn/ecosystem_models_api.py +201 -0
  6. supervisely/api/nn/neural_network_api.py +12 -3
  7. supervisely/api/project_api.py +35 -6
  8. supervisely/api/task_api.py +5 -1
  9. supervisely/app/widgets/__init__.py +8 -1
  10. supervisely/app/widgets/agent_selector/template.html +1 -0
  11. supervisely/app/widgets/deploy_model/__init__.py +0 -0
  12. supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
  13. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  14. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  15. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  16. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  17. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
  18. supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
  19. supervisely/app/widgets/fast_table/fast_table.py +402 -74
  20. supervisely/app/widgets/fast_table/script.js +364 -96
  21. supervisely/app/widgets/fast_table/style.css +24 -0
  22. supervisely/app/widgets/fast_table/template.html +43 -3
  23. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  24. supervisely/app/widgets/select/select.py +6 -4
  25. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
  26. supervisely/app/widgets/tabs/tabs.py +22 -6
  27. supervisely/app/widgets/tabs/template.html +5 -1
  28. supervisely/nn/artifacts/__init__.py +1 -1
  29. supervisely/nn/artifacts/artifacts.py +10 -2
  30. supervisely/nn/artifacts/detectron2.py +1 -0
  31. supervisely/nn/artifacts/hrda.py +1 -0
  32. supervisely/nn/artifacts/mmclassification.py +20 -0
  33. supervisely/nn/artifacts/mmdetection.py +5 -3
  34. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  35. supervisely/nn/artifacts/ritm.py +1 -0
  36. supervisely/nn/artifacts/rtdetr.py +1 -0
  37. supervisely/nn/artifacts/unet.py +1 -0
  38. supervisely/nn/artifacts/utils.py +3 -0
  39. supervisely/nn/artifacts/yolov5.py +2 -0
  40. supervisely/nn/artifacts/yolov8.py +1 -0
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  42. supervisely/nn/experiments.py +9 -0
  43. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  44. supervisely/nn/inference/inference.py +160 -94
  45. supervisely/nn/inference/predict_app/__init__.py +0 -0
  46. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  47. supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
  48. supervisely/nn/inference/predict_app/gui/gui.py +710 -0
  49. supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
  50. supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
  51. supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
  52. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  53. supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
  54. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  55. supervisely/nn/inference/predict_app/gui/utils.py +282 -0
  56. supervisely/nn/inference/predict_app/predict_app.py +184 -0
  57. supervisely/nn/inference/uploader.py +9 -5
  58. supervisely/nn/model/prediction.py +2 -0
  59. supervisely/nn/model/prediction_session.py +20 -3
  60. supervisely/nn/training/gui/gui.py +131 -44
  61. supervisely/nn/training/gui/model_selector.py +8 -6
  62. supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
  63. supervisely/nn/training/gui/training_artifacts.py +0 -5
  64. supervisely/nn/training/train_app.py +161 -44
  65. supervisely/template/experiment/experiment.html.jinja +74 -17
  66. supervisely/template/experiment/experiment_generator.py +258 -112
  67. supervisely/template/experiment/header.html.jinja +31 -13
  68. supervisely/template/experiment/sly-style.css +7 -2
  69. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/METADATA +3 -1
  70. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/RECORD +74 -56
  71. supervisely/app/widgets/experiment_selector/style.css +0 -27
  72. supervisely/app/widgets/experiment_selector/template.html +0 -61
  73. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/LICENSE +0 -0
  74. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/WHEEL +0 -0
  75. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/entry_points.txt +0 -0
  76. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from pathlib import Path
19
19
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
20
  from urllib.request import urlopen
21
21
 
22
+ import _pickle
22
23
  import numpy as np
23
24
  import requests
24
25
  import uvicorn
@@ -904,9 +905,15 @@ class Inference:
904
905
  if extracted_files:
905
906
  local_model_files[file] = file_path
906
907
  return local_model_files
908
+ except _pickle.UnpicklingError as e:
909
+ # TODO: raise error - checkpoint is corrupted
910
+ logger.warning(f"Couldn't load '{file_name}'. Checkpoint might be corrupted. Error: {repr(e)}")
911
+ logger.warning("Model files will be downloaded from Team Files")
912
+ local_model_files[file] = file_path
913
+ continue
907
914
  except Exception as e:
908
- logger.debug(f"Failed to process checkpoint '{file_name}' to extract auxiliary files: {repr(e)}")
909
- logger.debug("Model files will be downloaded from Team Files")
915
+ logger.warning(f"Failed to process checkpoint '{file_name}' to extract auxiliary files: {repr(e)}")
916
+ logger.warning("Model files will be downloaded from Team Files")
910
917
  local_model_files[file] = file_path
911
918
  continue
912
919
 
@@ -975,8 +982,7 @@ class Inference:
975
982
  # --- LOCAL ---
976
983
  try:
977
984
  logger.debug("Reading state dict...")
978
- import torch # pylint: disable=import-error
979
- ckpt = torch.load(checkpoint_path, map_location="cpu")
985
+ ckpt = torch_load_safe(checkpoint_path)
980
986
  model_info = ckpt.get("model_info", {})
981
987
  model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
982
988
  model_files["checkpoint"] = checkpoint_path
@@ -1016,10 +1022,8 @@ class Inference:
1016
1022
  if file_ext not in (".pth", ".pt"):
1017
1023
  return extracted_files
1018
1024
 
1019
- import torch # pylint: disable=import-error
1020
1025
  logger.debug(f"Reading checkpoint: {checkpoint_path}")
1021
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
1022
-
1026
+ checkpoint = torch_load_safe(checkpoint_path)
1023
1027
  # 1. Extract additional model files embedded into checkpoint (if any)
1024
1028
  ckpt_files = checkpoint.get("model_files", None)
1025
1029
  if ckpt_files and isinstance(ckpt_files, dict):
@@ -1235,7 +1239,7 @@ class Inference:
1235
1239
 
1236
1240
  def _set_checkpoint_info_pretrained(self, deploy_params: dict):
1237
1241
  checkpoint_name = os.path.basename(deploy_params["model_files"]["checkpoint"])
1238
- model_name = deploy_params["model_info"]["model_name"]
1242
+ model_name = _get_model_name(deploy_params["model_info"])
1239
1243
  checkpoint_url = deploy_params["model_info"]["meta"]["model_files"]["checkpoint"]
1240
1244
  model_source = ModelSource.PRETRAINED
1241
1245
  self.checkpoint_info = CheckpointInfo(
@@ -1941,6 +1945,7 @@ class Inference:
1941
1945
  raise ValueError("Image ids are not provided")
1942
1946
  if not isinstance(image_ids, list):
1943
1947
  image_ids = [image_ids]
1948
+ model_prediction_suffix = state.get("model_prediction_suffix", None)
1944
1949
  upload_mode = state.get("upload_mode", None)
1945
1950
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
1946
1951
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
@@ -2005,6 +2010,7 @@ class Inference:
2005
2010
  progress_cb=inference_request.done,
2006
2011
  iou_merge_threshold=iou_merge_threshold,
2007
2012
  inference_request=inference_request,
2013
+ model_prediction_suffix=model_prediction_suffix,
2008
2014
  )
2009
2015
 
2010
2016
  _add_results_to_request = partial(
@@ -2020,8 +2026,8 @@ class Inference:
2020
2026
  with Uploader(upload_f, logger=logger) as uploader:
2021
2027
  for image_ids_batch in batched(image_ids, batch_size=batch_size):
2022
2028
  if uploader.has_exception():
2023
- exception = uploader.exception()
2024
- raise RuntimeError(f"Error in upload loop: {exception}") from exception
2029
+ exception = uploader.exception
2030
+ raise exception
2025
2031
  if inference_request.is_stopped():
2026
2032
  logger.debug(
2027
2033
  f"Cancelling inference project...",
@@ -2177,6 +2183,8 @@ class Inference:
2177
2183
  project_info = api.project.get_info_by_id(project_id)
2178
2184
  if project_info.type != str(ProjectType.IMAGES):
2179
2185
  raise ValueError("Only images projects are supported.")
2186
+
2187
+ model_prediction_suffix = state.get("model_prediction_suffix", None)
2180
2188
  upload_mode = state.get("upload_mode", None)
2181
2189
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
2182
2190
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
@@ -2251,6 +2259,7 @@ class Inference:
2251
2259
  progress_cb=inference_request.done,
2252
2260
  iou_merge_threshold=iou_merge_threshold,
2253
2261
  inference_request=inference_request,
2262
+ model_prediction_suffix=model_prediction_suffix,
2254
2263
  )
2255
2264
 
2256
2265
  _add_results_to_request = partial(
@@ -2276,7 +2285,7 @@ class Inference:
2276
2285
  return
2277
2286
  if uploader.has_exception():
2278
2287
  exception = uploader.exception
2279
- raise RuntimeError(f"Error in upload loop: {exception}") from exception
2288
+ raise exception
2280
2289
  if cache_project_on_model:
2281
2290
  images_paths, _ = zip(
2282
2291
  *read_from_cached_project(
@@ -2405,7 +2414,7 @@ class Inference:
2405
2414
  return
2406
2415
  if uploader.has_exception():
2407
2416
  exception = uploader.exception
2408
- raise RuntimeError(f"Error in upload loop: {exception}") from exception
2417
+ raise exception
2409
2418
  if i == num_warmup:
2410
2419
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, num_iterations)
2411
2420
 
@@ -2470,6 +2479,7 @@ class Inference:
2470
2479
  # raise DialogWindowError(title="Call undeployed model.", description=msg)
2471
2480
  raise RuntimeError(msg)
2472
2481
  return func(*args, **kwargs)
2482
+
2473
2483
  return wrapper
2474
2484
 
2475
2485
  def _freeze_model(self):
@@ -2621,6 +2631,7 @@ class Inference:
2621
2631
  progress_cb=None,
2622
2632
  iou_merge_threshold: float = None,
2623
2633
  inference_request: InferenceRequest = None,
2634
+ model_prediction_suffix: str = None,
2624
2635
  ):
2625
2636
  ds_predictions: Dict[int, List[Prediction]] = defaultdict(list)
2626
2637
  for prediction in predictions:
@@ -2682,7 +2693,9 @@ class Inference:
2682
2693
  meta_changed = False
2683
2694
  for pred in preds:
2684
2695
  ann = pred.annotation
2685
- project_meta, ann, meta_changed_ = update_meta_and_ann(project_meta, ann)
2696
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
2697
+ project_meta, ann, model_prediction_suffix
2698
+ )
2686
2699
  meta_changed = meta_changed or meta_changed_
2687
2700
  pred.annotation = ann
2688
2701
  prediction.model_meta = project_meta
@@ -2746,7 +2759,9 @@ class Inference:
2746
2759
  meta_changed = False
2747
2760
  for pred in preds:
2748
2761
  ann = pred.annotation
2749
- project_meta, ann, meta_changed_ = update_meta_and_ann(project_meta, ann)
2762
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
2763
+ project_meta, ann, model_prediction_suffix
2764
+ )
2750
2765
  meta_changed = meta_changed or meta_changed_
2751
2766
  pred.annotation = ann
2752
2767
  prediction.model_meta = project_meta
@@ -2828,12 +2843,12 @@ class Inference:
2828
2843
  # Predict and shutdown
2829
2844
  if self._args.mode == "predict":
2830
2845
  if any(
2831
- [
2832
- self._args.input,
2833
- self._args.project_id,
2834
- self._args.dataset_id,
2835
- self._args.image_id,
2836
- ]
2846
+ [
2847
+ self._args.input,
2848
+ self._args.project_id,
2849
+ self._args.dataset_id,
2850
+ self._args.image_id,
2851
+ ]
2837
2852
  ):
2838
2853
  self._parse_inference_settings_from_args()
2839
2854
  self._inference_by_cli_deploy_args()
@@ -3687,6 +3702,7 @@ class Inference:
3687
3702
 
3688
3703
  def _parse_inference_settings_from_args(self):
3689
3704
  logger.debug("Parsing inference settings from args")
3705
+
3690
3706
  def parse_value(value: str):
3691
3707
  if value.lower() in ("true", "false"):
3692
3708
  return value.lower() == "true"
@@ -3813,8 +3829,7 @@ class Inference:
3813
3829
  try:
3814
3830
  # Read data from checkpoint
3815
3831
  logger.debug(f"Reading data from checkpoint: {checkpoint_path}")
3816
- import torch # pylint: disable=import-error
3817
- checkpoint = torch.load(checkpoint_path)
3832
+ checkpoint = torch_load_safe(checkpoint_path)
3818
3833
  model_info = checkpoint["model_info"]
3819
3834
  model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
3820
3835
  model_meta = os.path.join(self.model_dir, "model_meta.json")
@@ -4044,6 +4059,7 @@ class Inference:
4044
4059
  draw: bool = False,
4045
4060
  ):
4046
4061
  logger.info(f"Predicting Local Data: {input_path}")
4062
+
4047
4063
  def postprocess_image(image_path: str, ann: Annotation, pred_dir: str = None):
4048
4064
  image_name = sly_fs.get_file_name_with_ext(image_path)
4049
4065
  if pred_dir is not None:
@@ -4166,13 +4182,14 @@ class Inference:
4166
4182
 
4167
4183
  task_id = experiment_info.task_id
4168
4184
  self.gui.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
4169
- self.gui.experiment_selector.set_by_task_id(task_id)
4185
+ self.gui.experiment_selector.set_selected_row_by_task_id(task_id)
4170
4186
 
4171
4187
  best_ckpt = experiment_info.best_checkpoint
4172
4188
  if best_ckpt:
4173
- row = self.gui.experiment_selector.get_by_task_id(task_id)
4189
+ row = self.gui.experiment_selector.get_selected_row_by_task_id(task_id)
4174
4190
  if row is not None:
4175
4191
  row.set_selected_checkpoint_by_name(best_ckpt)
4192
+
4176
4193
  except Exception as e:
4177
4194
  logger.warning(f"Failed to set checkpoint from experiment info: {repr(e)}")
4178
4195
 
@@ -4181,6 +4198,7 @@ class Inference:
4181
4198
  return
4182
4199
  self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
4183
4200
 
4201
+
4184
4202
  def _exclude_duplicated_predictions(
4185
4203
  api: Api,
4186
4204
  pred_anns: List[Annotation],
@@ -4456,7 +4474,7 @@ def _fix_classes_names(meta: ProjectMeta, ann: Annotation):
4456
4474
  return meta, ann, replaced_classes_in_meta, list(replaced_classes_in_ann)
4457
4475
 
4458
4476
 
4459
- def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
4477
+ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suffix: str = None):
4460
4478
  """Update project meta and annotation to match each other
4461
4479
  If obj class or tag meta from annotation conflicts with project meta
4462
4480
  add suffix to obj class or tag meta.
@@ -4464,8 +4482,13 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
4464
4482
  """
4465
4483
  obj_classes_suffixes = ["_nn"]
4466
4484
  tag_meta_suffixes = ["_nn"]
4467
- ann_obj_classes = {}
4468
- ann_tag_metas = {}
4485
+ if model_prediction_suffix is not None:
4486
+ obj_classes_suffixes = [model_prediction_suffix]
4487
+ tag_meta_suffixes = [model_prediction_suffix]
4488
+ logger.debug(
4489
+ f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
4490
+ )
4491
+ logger.debug("source meta", extra={"meta": meta.to_json()})
4469
4492
  meta_changed = False
4470
4493
 
4471
4494
  meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
@@ -4476,91 +4499,116 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
4476
4499
  extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
4477
4500
  )
4478
4501
 
4479
- # get all obj classes and tag metas from annotation
4502
+ updated_labels = []
4503
+ any_label_updated = False
4504
+ for label in ann.labels:
4505
+ original_obj_class_name = label.obj_class.name
4506
+ suffix_found = False
4507
+ for suffix in ["", *obj_classes_suffixes]:
4508
+ label_obj_class = label.obj_class
4509
+ label_obj_class_name = label_obj_class.name + suffix
4510
+ if suffix:
4511
+ label_obj_class = label_obj_class.clone(name=label_obj_class_name)
4512
+ label = label.clone(obj_class=label_obj_class)
4513
+ any_label_updated = True
4514
+ meta_obj_class = meta.get_obj_class(label_obj_class_name)
4515
+ if meta_obj_class is None:
4516
+ # if obj class is not in meta, add it with suffix
4517
+ meta = meta.add_obj_class(label_obj_class)
4518
+ updated_labels.append(label)
4519
+ meta_changed = True
4520
+ suffix_found = True
4521
+ break
4522
+ elif meta_obj_class.geometry_type.geometry_name() == label.geometry.geometry_name():
4523
+ # if label geometry is the same as in meta, use meta obj class
4524
+ label = label.clone(obj_class=meta_obj_class)
4525
+ updated_labels.append(label)
4526
+ suffix_found = True
4527
+ any_label_updated = True
4528
+ break
4529
+ elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
4530
+ # if meta obj class is AnyGeometry, use it in label
4531
+ label = label.clone(obj_class=meta_obj_class)
4532
+ updated_labels.append(label)
4533
+ suffix_found = True
4534
+ any_label_updated = True
4535
+ break
4536
+ if not suffix_found:
4537
+ # if no suffix found, raise error
4538
+ raise ValueError(
4539
+ f"Can't add obj class {original_obj_class_name} to project meta. "
4540
+ "Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
4541
+ "Please check if model geometry type is compatible with existing obj classes."
4542
+ )
4543
+ if any_label_updated:
4544
+ ann = ann.clone(labels=updated_labels)
4545
+
4546
+ # check if tag metas are in project meta
4547
+ # if not, add them with suffix
4548
+ ann_tag_metas = {}
4480
4549
  for label in ann.labels:
4481
- ann_obj_classes[label.obj_class.name] = label.obj_class
4482
4550
  for tag in label.tags:
4483
4551
  ann_tag_metas[tag.meta.name] = tag.meta
4484
4552
  for tag in ann.img_tags:
4485
4553
  ann_tag_metas[tag.meta.name] = tag.meta
4486
4554
 
4487
- # check if obj classes are in project meta
4488
- # if not, add them.
4489
- # if shape is different, add them with suffix
4490
- changed_obj_classes = {}
4491
- for ann_obj_class in ann_obj_classes.values():
4492
- if meta.get_obj_class(ann_obj_class.name) is None:
4493
- meta = meta.add_obj_class(ann_obj_class)
4494
- meta_changed = True
4495
- elif (
4496
- meta.get_obj_class(ann_obj_class.name).geometry_type != ann_obj_class.geometry_type
4497
- and meta.get_obj_class(ann_obj_class.name).geometry_type != AnyGeometry
4498
- ):
4499
- found = False
4500
- for suffix in obj_classes_suffixes:
4501
- new_obj_class_name = ann_obj_class.name + suffix
4502
- meta_obj_class = meta.get_obj_class(new_obj_class_name)
4503
- if meta_obj_class is None:
4504
- new_obj_class = ann_obj_class.clone(name=new_obj_class_name)
4505
- meta = meta.add_obj_class(new_obj_class)
4506
- meta_changed = True
4507
- changed_obj_classes[ann_obj_class.name] = new_obj_class
4508
- found = True
4509
- break
4510
- if meta_obj_class.geometry_type == ann_obj_class.geometry_type:
4511
- changed_obj_classes[ann_obj_class.name] = meta_obj_class
4512
- found = True
4513
- break
4514
- if not found:
4515
- raise ValueError(f"Can't add obj class {ann_obj_class.name} to project meta")
4516
-
4517
- # check if tag metas are in project meta
4518
- # if not, add them with suffix
4519
4555
  changed_tag_metas = {}
4520
- for tag_meta in ann_tag_metas.values():
4521
- if meta.get_tag_meta(tag_meta.name) is None:
4522
- meta = meta.add_tag_meta(tag_meta)
4556
+ for ann_tag_meta in ann_tag_metas.values():
4557
+ meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
4558
+ if meta_tag_meta is None:
4559
+ meta = meta.add_tag_meta(ann_tag_meta)
4523
4560
  meta_changed = True
4524
- elif not meta.get_tag_meta(tag_meta.name).is_compatible(tag_meta):
4525
- found = False
4561
+ elif not meta_tag_meta.is_compatible(ann_tag_meta):
4562
+ suffix_found = False
4526
4563
  for suffix in tag_meta_suffixes:
4527
- new_tag_meta_name = tag_meta.name + suffix
4564
+ new_tag_meta_name = ann_tag_meta.name + suffix
4528
4565
  meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
4529
4566
  if meta_tag_meta is None:
4530
- new_tag_meta = tag_meta.clone(name=new_tag_meta_name)
4567
+ new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
4531
4568
  meta = meta.add_tag_meta(new_tag_meta)
4532
- changed_tag_metas[tag_meta.name] = new_tag_meta
4569
+ changed_tag_metas[ann_tag_meta.name] = new_tag_meta
4533
4570
  meta_changed = True
4534
- found = True
4571
+ suffix_found = True
4535
4572
  break
4536
- if meta_tag_meta.is_compatible(tag_meta):
4537
- changed_tag_metas[tag_meta.name] = meta_tag_meta
4538
- found = True
4573
+ if meta_tag_meta.is_compatible(ann_tag_meta):
4574
+ changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
4575
+ suffix_found = True
4539
4576
  break
4540
- if not found:
4541
- raise ValueError(f"Can't add tag meta {tag_meta.name} to project meta")
4577
+ if not suffix_found:
4578
+ raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
4542
4579
 
4543
- labels = []
4544
- for label in ann.labels:
4545
- if label.obj_class.name in changed_obj_classes:
4546
- label = label.clone(obj_class=changed_obj_classes[label.obj_class.name])
4547
-
4548
- label_tags = []
4549
- for tag in label.tags:
4580
+ if changed_tag_metas:
4581
+ labels = []
4582
+ any_label_updated = False
4583
+ for label in ann.labels:
4584
+ any_tag_updated = False
4585
+ label_tags = []
4586
+ for tag in label.tags:
4587
+ if tag.meta.name in changed_tag_metas:
4588
+ label_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
4589
+ any_tag_updated = True
4590
+ else:
4591
+ label_tags.append(tag)
4592
+ if any_tag_updated:
4593
+ label = label.clone(tags=TagCollection(label_tags))
4594
+ any_label_updated = True
4595
+ labels.append(label)
4596
+ img_tags = []
4597
+ any_tag_updated = False
4598
+ for tag in ann.img_tags:
4550
4599
  if tag.meta.name in changed_tag_metas:
4551
- label_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
4600
+ img_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
4601
+ any_tag_updated = True
4552
4602
  else:
4553
- label_tags.append(tag)
4554
-
4555
- labels.append(label.clone(tags=TagCollection(label_tags)))
4556
- img_tags = []
4557
- for tag in ann.img_tags:
4558
- if tag.meta.name in changed_tag_metas:
4559
- img_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
4560
- else:
4561
- img_tags.append(tag)
4562
-
4563
- ann = ann.clone(labels=labels, img_tags=TagCollection(img_tags))
4603
+ img_tags.append(tag)
4604
+ if any_tag_updated or any_label_updated:
4605
+ if any_tag_updated:
4606
+ img_tags = TagCollection(img_tags)
4607
+ else:
4608
+ img_tags = None
4609
+ if not any_label_updated:
4610
+ labels = None
4611
+ ann = ann.clone(img_tags=TagCollection(img_tags))
4564
4612
  return meta, ann, meta_changed
4565
4613
 
4566
4614
 
@@ -4673,3 +4721,21 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
4673
4721
  continue
4674
4722
  return data[key]
4675
4723
  return None
4724
+
4725
+ def torch_load_safe(checkpoint_path: str, device:str = "cpu"):
4726
+ import torch # pylint: disable=import-error
4727
+
4728
+ # TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
4729
+ try:
4730
+ logger.debug(f"Loading checkpoint from {checkpoint_path} on {device}")
4731
+ checkpoint = torch.load(checkpoint_path, map_location=device)
4732
+ logger.debug(f"Checkpoint loaded from {checkpoint_path} on {device}")
4733
+ except:
4734
+ logger.debug(
4735
+ f"Failed to load checkpoint from {checkpoint_path} on {device}. Trying again with weights_only=False"
4736
+ )
4737
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
4738
+ logger.debug(
4739
+ f"Checkpoint loaded from {checkpoint_path} on {device} with weights_only=False"
4740
+ )
4741
+ return checkpoint
File without changes
File without changes
@@ -0,0 +1,91 @@
1
+ from typing import Dict, Any
2
+ from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
3
+
4
+
5
+ class ClassesSelector:
6
+ title = "Classes Selector"
7
+ description = "Select classes that will be used for inference"
8
+ lock_message = "Select previous step to unlock"
9
+
10
+ def __init__(self):
11
+ # Init Step
12
+ self.display_widgets = []
13
+ # -------------------------------- #
14
+
15
+ # Init Base Widgets
16
+ self.validator_text = None
17
+ self.button = None
18
+ self.container = None
19
+ self.card = None
20
+ # -------------------------------- #
21
+
22
+ # Init Step Widgets
23
+ self.classes_table = None
24
+ # -------------------------------- #
25
+
26
+ # Classes
27
+ self.classes_table = ClassesTable()
28
+ self.classes_table.hide()
29
+ # Add widgets to display ------------ #
30
+ self.display_widgets.extend([self.classes_table])
31
+ # ----------------------------------- #
32
+
33
+ # Base Widgets
34
+ self.validator_text = Text("")
35
+ self.validator_text.hide()
36
+ self.button = Button("Select")
37
+ self.display_widgets.extend([self.validator_text, self.button])
38
+ # -------------------------------- #
39
+
40
+ # Card Layout
41
+ self.container = Container(self.display_widgets)
42
+ self.card = Card(
43
+ title=self.title,
44
+ description=self.description,
45
+ content=self.container,
46
+ lock_message=self.lock_message,
47
+ )
48
+ self.card.lock()
49
+ # -------------------------------- #
50
+
51
+ @property
52
+ def widgets_to_disable(self) -> list:
53
+ return [self.classes_table]
54
+
55
+ def load_from_json(self, data: Dict[str, Any]) -> None:
56
+ if "classes" in data:
57
+ self.set_classes(data["classes"])
58
+
59
+ def get_selected_classes(self) -> list:
60
+ return self.classes_table.get_selected_classes()
61
+
62
+ def set_classes(self, classes) -> None:
63
+ self.classes_table.select_classes(classes)
64
+
65
+ def select_all_classes(self) -> None:
66
+ self.classes_table.select_all()
67
+
68
+ def get_settings(self) -> Dict[str, Any]:
69
+ return {"classes": self.get_selected_classes()}
70
+
71
+ def validate_step(self) -> bool:
72
+ if self.classes_table.is_hidden():
73
+ return True
74
+
75
+ self.validator_text.hide()
76
+ selected_classes = self.classes_table.get_selected_classes()
77
+ n_classes = len(selected_classes)
78
+
79
+ if n_classes == 0:
80
+ self.validator_text.set(text="Please select at least one class", status="error")
81
+ self.validator_text.show()
82
+ return False
83
+
84
+ class_word = "class" if n_classes == 1 else "classes"
85
+ message_parts = [f"Selected {n_classes} {class_word}"]
86
+ status = "success"
87
+ is_valid = True
88
+
89
+ self.validator_text.set(text=". ".join(message_parts), status=status)
90
+ self.validator_text.show()
91
+ return is_valid