supervisely 6.73.326__py3-none-any.whl → 6.73.327__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. supervisely/annotation/annotation.py +1 -1
  2. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +17 -14
  3. supervisely/app/widgets/pretrained_models_selector/template.html +2 -1
  4. supervisely/convert/image/yolo/yolo_helper.py +95 -25
  5. supervisely/nn/inference/gui/serving_gui_template.py +2 -3
  6. supervisely/nn/inference/inference.py +33 -25
  7. supervisely/nn/training/gui/classes_selector.py +24 -19
  8. supervisely/nn/training/gui/gui.py +90 -37
  9. supervisely/nn/training/gui/hyperparameters_selector.py +32 -15
  10. supervisely/nn/training/gui/input_selector.py +13 -2
  11. supervisely/nn/training/gui/model_selector.py +16 -6
  12. supervisely/nn/training/gui/train_val_splits_selector.py +10 -1
  13. supervisely/nn/training/gui/training_artifacts.py +23 -4
  14. supervisely/nn/training/gui/training_logs.py +15 -3
  15. supervisely/nn/training/gui/training_process.py +14 -13
  16. supervisely/nn/training/train_app.py +59 -24
  17. supervisely/nn/utils.py +9 -0
  18. supervisely/project/project.py +16 -3
  19. {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/METADATA +1 -1
  20. {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/RECORD +24 -24
  21. {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/LICENSE +0 -0
  22. {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/WHEEL +0 -0
  23. {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/entry_points.txt +0 -0
  24. {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/top_level.txt +0 -0
@@ -3088,7 +3088,7 @@ class Annotation:
3088
3088
  def to_yolo(
3089
3089
  self,
3090
3090
  class_names: List[str],
3091
- task_type: Literal["detection", "segmentation", "pose"] = "detection",
3091
+ task_type: Literal["detect", "segment", "pose"] = "detect",
3092
3092
  ) -> List[str]:
3093
3093
  """
3094
3094
  Convert Supervisely annotation to YOLO annotation format.
@@ -4,7 +4,7 @@ from supervisely.api.api import Api
4
4
  from supervisely.app.content import DataJson, StateJson
5
5
  from supervisely.app.widgets import Widget
6
6
  from supervisely.io.fs import get_file_ext
7
- from supervisely.nn.utils import ModelSource
7
+ from supervisely.nn.utils import ModelSource, _get_model_name
8
8
 
9
9
 
10
10
  class PretrainedModelsSelector(Widget):
@@ -147,9 +147,11 @@ class PretrainedModelsSelector(Widget):
147
147
  if train_version == "v1":
148
148
  model_name = selected_model.get(model_name_column)
149
149
  if model_name is None:
150
- raise ValueError(
151
- "Could not find model name. Make sure you have column 'Model' in your models list."
152
- )
150
+ model_name = _get_model_name(selected_model)
151
+ if model_name is None:
152
+ raise ValueError(
153
+ "Could not find model name. Make sure you have column 'Model' in your models list."
154
+ )
153
155
 
154
156
  model_meta = selected_model.get("meta")
155
157
  if model_meta is None:
@@ -230,24 +232,25 @@ class PretrainedModelsSelector(Widget):
230
232
  for task_type in self._table_data:
231
233
  for arch_type in self._table_data[task_type]:
232
234
  for idx, model in enumerate(self._table_data[task_type][arch_type]):
233
- model_meta = model.get("meta", {})
234
- if model_meta.get("model_name") == model_name:
235
- self.set_active_task_type(task_type)
236
- self.set_active_arch_type(arch_type)
237
- self.set_active_row(idx)
238
- return
235
+ name_from_info = _get_model_name(model)
236
+ if name_from_info is not None:
237
+ if name_from_info.lower() == model_name.lower():
238
+ self.set_active_task_type(task_type)
239
+ self.set_active_arch_type(arch_type)
240
+ self.set_active_row(idx)
241
+ return
239
242
 
240
243
  def get_by_model_name(self, model_name: str) -> Union[Dict, None]:
241
244
  for task_type in self._table_data:
242
245
  for arch_type in self._table_data[task_type]:
243
246
  for idx, model in enumerate(self._table_data[task_type][arch_type]):
244
- model_meta = model.get("meta", {})
245
- if model_meta.get("model_name") == model_name:
246
- return model
247
+ name_from_info = _get_model_name(model)
248
+ if name_from_info is not None:
249
+ if name_from_info.lower() == model_name.lower():
250
+ return model
247
251
 
248
252
  def _filter_and_sort_models(self, models: List[Dict], sort_models: bool = True) -> Dict:
249
253
  filtered_models = {}
250
-
251
254
  for model in models:
252
255
  for key in model:
253
256
  if isinstance(model[key], (int, float)):
@@ -62,7 +62,8 @@
62
62
  <tr>
63
63
  <template v-for="(col, colKey) in data.{{{widget.widget_id}}}.tableData[state.{{{widget.widget_id}}}.selectedTaskType][state.{{{widget.widget_id}}}.selectedArchType][0]">
64
64
  <th v-if="colKey != 'meta'" :key="colKey">
65
- <div>{{ colKey }}</div>
65
+ <div v-if="colKey==='model_name'">Model</div>
66
+ <div v-else>{{ colKey }}</div>
66
67
  </th>
67
68
  </template>
68
69
  </tr>
@@ -17,7 +17,8 @@ from supervisely.geometry.polygon import Polygon
17
17
  from supervisely.geometry.polyline import Polyline
18
18
  from supervisely.geometry.rectangle import Rectangle
19
19
  from supervisely.imaging.color import generate_rgb
20
- from supervisely.io.fs import get_file_name_with_ext, touch
20
+ from supervisely.io.fs import get_file_name, get_file_name_with_ext, touch
21
+ from supervisely.nn.task_type import TaskType
21
22
  from supervisely.project.project import Dataset, OpenMode, Project
22
23
  from supervisely.project.project_meta import ProjectMeta
23
24
  from supervisely.sly_logger import logger
@@ -27,6 +28,20 @@ YOLO_DETECTION_COORDS_NUM = 4
27
28
  YOLO_SEGM_MIN_COORDS_NUM = 6
28
29
  YOLO_KEYPOINTS_MIN_COORDS_NUM = 6
29
30
 
31
+
32
+ class YOLOTaskType:
33
+ DETECT = "detect"
34
+ SEGMENT = "segment"
35
+ POSE = "pose"
36
+
37
+
38
+ SLY_YOLO_TASK_TYPE_MAP = {
39
+ TaskType.OBJECT_DETECTION: YOLOTaskType.DETECT,
40
+ TaskType.INSTANCE_SEGMENTATION: YOLOTaskType.SEGMENT,
41
+ TaskType.POSE_ESTIMATION: YOLOTaskType.POSE,
42
+ }
43
+
44
+
30
45
  coco_classes = [
31
46
  "person",
32
47
  "bicycle",
@@ -398,22 +413,22 @@ def keypoints_to_yolo_line(
398
413
 
399
414
  def convert_label_geometry_if_needed(
400
415
  label: Label,
401
- task_type: Literal["detection", "segmentation", "pose"],
416
+ task_type: Literal["detect", "segment", "pose"],
402
417
  verbose: bool = False,
403
418
  ) -> List[Label]:
404
- if task_type == "detection":
419
+ if task_type == YOLOTaskType.DETECT:
405
420
  available_geometry_type = Rectangle
406
421
  convertable_geometry_types = [Polygon, GraphNodes, Bitmap, Polyline, AlphaMask, AnyGeometry]
407
- elif task_type == "segmentation":
422
+ elif task_type == YOLOTaskType.SEGMENT:
408
423
  available_geometry_type = Polygon
409
424
  convertable_geometry_types = [Bitmap, AlphaMask, AnyGeometry]
410
- elif task_type == "pose":
425
+ elif task_type == YOLOTaskType.POSE:
411
426
  available_geometry_type = GraphNodes
412
427
  convertable_geometry_types = []
413
428
  else:
414
429
  raise ValueError(
415
430
  f"Unsupported task type: {task_type}. "
416
- "Supported types: 'detection', 'segmentation', 'pose'"
431
+ f"Supported types: '{YOLOTaskType.DETECT}', '{YOLOTaskType.SEGMENT}', '{YOLOTaskType.POSE}'"
417
432
  )
418
433
 
419
434
  if label.obj_class.geometry_type == available_geometry_type:
@@ -438,7 +453,7 @@ def label_to_yolo_lines(
438
453
  img_height: int,
439
454
  img_width: int,
440
455
  class_names: List[str],
441
- task_type: Literal["detection", "segmentation", "pose"],
456
+ task_type: Literal["detect", "segment", "pose"],
442
457
  ) -> List[str]:
443
458
  """
444
459
  Convert the Supervisely Label to a line in the YOLO format.
@@ -449,21 +464,21 @@ def label_to_yolo_lines(
449
464
 
450
465
  lines = []
451
466
  for label in labels:
452
- if task_type == "detection":
467
+ if task_type == YOLOTaskType.DETECT:
453
468
  yolo_line = rectangle_to_yolo_line(
454
469
  class_idx=class_idx,
455
470
  geometry=label.geometry,
456
471
  img_height=img_height,
457
472
  img_width=img_width,
458
473
  )
459
- elif task_type == "segmentation":
474
+ elif task_type == YOLOTaskType.SEGMENT:
460
475
  yolo_line = polygon_to_yolo_line(
461
476
  class_idx=class_idx,
462
477
  geometry=label.geometry,
463
478
  img_height=img_height,
464
479
  img_width=img_width,
465
480
  )
466
- elif task_type == "pose":
481
+ elif task_type == YOLOTaskType.POSE:
467
482
  nodes_field = label.obj_class.geometry_type.items_json_field
468
483
  max_kpts_count = len(label.obj_class.geometry_config[nodes_field])
469
484
  yolo_line = keypoints_to_yolo_line(
@@ -474,7 +489,10 @@ def label_to_yolo_lines(
474
489
  max_kpts_count=max_kpts_count,
475
490
  )
476
491
  else:
477
- raise ValueError(f"Unsupported task type: {task_type}")
492
+ raise ValueError(
493
+ f"Unsupported task type: {task_type}. "
494
+ f"Supported types: '{YOLOTaskType.DETECT}', '{YOLOTaskType.SEGMENT}', '{YOLOTaskType.POSE}'"
495
+ )
478
496
 
479
497
  if yolo_line is not None:
480
498
  lines.append(yolo_line)
@@ -485,12 +503,11 @@ def label_to_yolo_lines(
485
503
  def sly_ann_to_yolo(
486
504
  ann: Annotation,
487
505
  class_names: List[str],
488
- task_type: Literal["detection", "segmentation", "pose"] = "detection",
506
+ task_type: Literal["detect", "segment", "pose"] = "detect",
489
507
  ) -> List[str]:
490
508
  """
491
509
  Convert the Supervisely annotation to the YOLO format.
492
510
  """
493
-
494
511
  h, w = ann.img_size
495
512
  yolo_lines = []
496
513
  for label in ann.labels:
@@ -509,11 +526,12 @@ def sly_ds_to_yolo(
509
526
  dataset: Dataset,
510
527
  meta: ProjectMeta,
511
528
  dest_dir: Optional[str] = None,
512
- task_type: Literal["detection", "segmentation", "pose"] = "detection",
529
+ task_type: Literal["detect", "segment", "pose"] = "detect",
513
530
  log_progress: bool = False,
514
531
  progress_cb: Optional[Union[tqdm, Callable]] = None,
532
+ is_val: Optional[bool] = None,
515
533
  ) -> str:
516
-
534
+ task_type = validate_task_type(task_type)
517
535
  if progress_cb is not None:
518
536
  log_progress = False
519
537
 
@@ -543,15 +561,19 @@ def sly_ds_to_yolo(
543
561
  ann_path = dataset.get_ann_path(name)
544
562
  ann = Annotation.load_json_file(ann_path, meta)
545
563
 
546
- images_dir = val_images_dir if ann.img_tags.get("val") else train_images_dir
547
- labels_dir = val_labels_dir if ann.img_tags.get("val") else train_labels_dir
564
+ if is_val is not None:
565
+ images_dir = val_images_dir if is_val else train_images_dir
566
+ labels_dir = val_labels_dir if is_val else train_labels_dir
567
+ else:
568
+ images_dir = val_images_dir if ann.img_tags.get("val") else train_images_dir
569
+ labels_dir = val_labels_dir if ann.img_tags.get("val") else train_labels_dir
548
570
 
549
571
  img_path = Path(dataset.get_img_path(name))
550
572
  img_name = f"{dataset.short_name}_{get_file_name_with_ext(img_path)}"
551
573
  img_name = generate_free_name(used_names, img_name, with_ext=True, extend_used_names=True)
552
574
  shutil.copy2(img_path, images_dir / img_name)
553
575
 
554
- label_path = str(labels_dir / f"{img_name}.txt")
576
+ label_path = str(labels_dir / f"{get_file_name(img_name)}.txt")
555
577
  yolo_lines = ann.to_yolo(class_names, task_type)
556
578
  if len(yolo_lines) > 0:
557
579
  with open(label_path, "w") as f:
@@ -565,7 +587,8 @@ def sly_ds_to_yolo(
565
587
  # * save data config file if it does not exist
566
588
  config_path = dest_dir / "data_config.yaml"
567
589
  if not config_path.exists():
568
- save_yolo_config(meta, dest_dir, with_keypoint=task_type == "pose")
590
+ with_keypoint = task_type is YOLOTaskType.POSE
591
+ save_yolo_config(meta, dest_dir, with_keypoint=with_keypoint)
569
592
 
570
593
  return str(dest_dir)
571
594
 
@@ -578,6 +601,8 @@ def save_yolo_config(meta: ProjectMeta, dest_dir: str, with_keypoint: bool = Fal
578
601
  data_yaml = {
579
602
  "train": f"../{str(dest_dir.name)}/images/train",
580
603
  "val": f"../{str(dest_dir.name)}/images/val",
604
+ "train_labels": f"../{str(dest_dir.name)}/labels/train",
605
+ "val_labels": f"../{str(dest_dir.name)}/labels/val",
581
606
  "nc": len(class_names),
582
607
  "names": class_names,
583
608
  "colors": class_colors,
@@ -590,6 +615,7 @@ def save_yolo_config(meta: ProjectMeta, dest_dir: str, with_keypoint: bool = Fal
590
615
  field_name = obj_class.geometry_type.items_json_field
591
616
  max_kpts_count = max(max_kpts_count, len(obj_class.geometry_config[field_name]))
592
617
  data_yaml["kpt_shape"] = [max_kpts_count, 3]
618
+ data_yaml["flip_idx"] = [i for i in range(max_kpts_count)]
593
619
  with open(save_path, "w") as f:
594
620
  yaml.dump(data_yaml, f, default_flow_style=None)
595
621
 
@@ -599,21 +625,31 @@ def save_yolo_config(meta: ProjectMeta, dest_dir: str, with_keypoint: bool = Fal
599
625
  def sly_project_to_yolo(
600
626
  project: Union[Project, str],
601
627
  dest_dir: Optional[str] = None,
602
- task_type: Literal["detection", "segmentation", "pose"] = "detection",
628
+ task_type: Literal["detect", "segment", "pose"] = "detect",
603
629
  log_progress: bool = False,
604
630
  progress_cb: Optional[Callable] = None,
631
+ val_datasets: Optional[List[str]] = None,
605
632
  ):
606
633
  """
607
634
  Convert Supervisely project to YOLO format.
608
635
 
636
+ :param project: Supervisely project or path to the directory with the project.
637
+ :type project: :class:`supervisely.project.project.Project` or :class:`str`
609
638
  :param dest_dir: Destination directory.
610
639
  :type dest_dir: :class:`str`, optional
640
+ :param task_type: Task type.
641
+ :type task_type: :class:`str`, optional
611
642
  :param log_progress: Show uploading progress bar.
612
643
  :type log_progress: :class:`bool`
613
644
  :param progress_cb: Function for tracking conversion progress (for all items in the project).
614
645
  :type progress_cb: callable, optional
615
- :return: None
616
- :rtype: NoneType
646
+ :param val_datasets: List of dataset names for validation.
647
+ Full dataset names are required (e.g., 'ds0/nested_ds1/ds3').
648
+ If specified, datasets from the list will be marked as val, others as train.
649
+ If not specified, the function will determine the validation datasets automatically.
650
+ :type val_datasets: :class:`list`, optional
651
+ :return: Path to the destination directory.
652
+ :rtype: :class:`str`
617
653
 
618
654
  :Usage example:
619
655
 
@@ -627,6 +663,7 @@ def sly_project_to_yolo(
627
663
  # Convert Project to YOLO format
628
664
  sly.Project(project_directory).to_yolo(log_progress=True)
629
665
  """
666
+ task_type = validate_task_type(task_type)
630
667
  if isinstance(project, str):
631
668
  project = Project(project, mode=OpenMode.READ)
632
669
 
@@ -644,9 +681,15 @@ def sly_project_to_yolo(
644
681
  desc="Converting Supervisely project to YOLO format", total=project.total_items
645
682
  ).update
646
683
 
647
- save_yolo_config(project.meta, dest_dir, with_keypoint=task_type == "pose")
684
+ with_keypoint = task_type is YOLOTaskType.POSE
685
+ save_yolo_config(project.meta, dest_dir, with_keypoint=with_keypoint)
648
686
 
649
687
  for dataset in project.datasets:
688
+ if val_datasets is not None:
689
+ is_val = dataset.name in val_datasets
690
+ else:
691
+ is_val = None
692
+
650
693
  dataset: Dataset
651
694
  dataset.to_yolo(
652
695
  meta=project.meta,
@@ -654,18 +697,23 @@ def sly_project_to_yolo(
654
697
  task_type=task_type,
655
698
  log_progress=log_progress,
656
699
  progress_cb=progress_cb,
700
+ is_val=is_val,
657
701
  )
658
702
  logger.info(f"Dataset '{dataset.short_name}' has been converted to YOLO format.")
659
703
  logger.info(f"Project '{project.name}' has been converted to YOLO format.")
660
704
 
705
+ return str(dest_dir)
706
+
661
707
 
662
708
  def to_yolo(
663
709
  input_data: Union[Project, Dataset, str],
664
710
  dest_dir: Optional[str] = None,
665
- task_type: Literal["detection", "segmentation", "pose"] = "detection",
711
+ task_type: Literal["detect", "segment", "pose"] = "detect",
666
712
  meta: Optional[ProjectMeta] = None,
667
713
  log_progress: bool = True,
668
714
  progress_cb: Optional[Callable] = None,
715
+ val_datasets: Optional[List[str]] = None,
716
+ is_val: Optional[bool] = None,
669
717
  ) -> Union[None, str]:
670
718
  """
671
719
  Universal function to convert Supervisely project or dataset to YOLO format.
@@ -691,6 +739,13 @@ def to_yolo(
691
739
  :type log_progress: :class:`bool`
692
740
  :param progress_cb: Function for tracking conversion progress (for all items in the project).
693
741
  :type progress_cb: callable, optional
742
+ :param val_datasets: List of dataset names for validation.
743
+ Full dataset names are required (e.g., 'ds0/nested_ds1/ds3').
744
+ If specified, datasets from the list will be marked as val, others as train.
745
+ If not specified, the function will determine the validation datasets automatically.
746
+ :type val_datasets: :class:`list`, optional
747
+ :param is_val: Whether the dataset is for validation.
748
+ :type is_val: :class:`bool`, optional
694
749
  :return: None, list of YOLO lines, or path to the destination directory.
695
750
  :rtype: NoneType, list, str
696
751
 
@@ -711,7 +766,7 @@ def to_yolo(
711
766
 
712
767
  # Convert Dataset to YOLO format
713
768
  dataset: sly.Dataset = project_fs.datasets.get("dataset_name")
714
- sly.convert.to_yolo(dataset, dest_dir="./yolo", meta=project_fs.meta)
769
+ sly.convert.to_yolo(dataset, dest_dir="./yolo", meta=project_fs.meta, is_val=True)
715
770
  """
716
771
  if isinstance(input_data, str):
717
772
  try:
@@ -728,6 +783,7 @@ def to_yolo(
728
783
  task_type=task_type,
729
784
  log_progress=log_progress,
730
785
  progress_cb=progress_cb,
786
+ val_datasets=val_datasets,
731
787
  )
732
788
  elif isinstance(input_data, Dataset):
733
789
  return sly_ds_to_yolo(
@@ -737,6 +793,20 @@ def to_yolo(
737
793
  task_type=task_type,
738
794
  log_progress=log_progress,
739
795
  progress_cb=progress_cb,
796
+ is_val=is_val,
740
797
  )
741
798
  else:
742
799
  raise ValueError("Unsupported input type. Only Project or Dataset are supported.")
800
+
801
+
802
+ def validate_task_type(task_type: Literal["detect", "segment", "pose"]) -> str:
803
+ if task_type not in [YOLOTaskType.DETECT, YOLOTaskType.SEGMENT, YOLOTaskType.POSE]:
804
+ task_type = SLY_YOLO_TASK_TYPE_MAP.get(task_type)
805
+ if task_type is None:
806
+ raise ValueError(
807
+ f"Unsupported task type: {task_type}. "
808
+ f"Supported types: '{YOLOTaskType.DETECT}', '{SLY_YOLO_TASK_TYPE_MAP[TaskType.OBJECT_DETECTION]}', "
809
+ f"'{YOLOTaskType.SEGMENT}', '{SLY_YOLO_TASK_TYPE_MAP[TaskType.INSTANCE_SEGMENTATION]}', "
810
+ f"'{YOLOTaskType.POSE}', '{SLY_YOLO_TASK_TYPE_MAP[TaskType.POSE_ESTIMATION]}'"
811
+ )
812
+ return task_type
@@ -23,7 +23,7 @@ from supervisely.app.widgets.pretrained_models_selector.pretrained_models_select
23
23
  )
24
24
  from supervisely.nn.experiments import get_experiment_infos
25
25
  from supervisely.nn.inference.gui.serving_gui import ServingGUI
26
- from supervisely.nn.utils import ModelSource, RuntimeType
26
+ from supervisely.nn.utils import ModelSource, RuntimeType, _get_model_name
27
27
 
28
28
 
29
29
  class ServingGUITemplate(ServingGUI):
@@ -149,8 +149,7 @@ class ServingGUITemplate(ServingGUI):
149
149
  @property
150
150
  def model_name(self) -> Optional[str]:
151
151
  if self.model_source == ModelSource.PRETRAINED:
152
- model_meta = self.model_info.get("meta", {})
153
- return model_meta.get("model_name")
152
+ return _get_model_name(self.model_info)
154
153
  else:
155
154
  return self.model_info.get("model_name")
156
155
 
@@ -75,6 +75,7 @@ from supervisely.nn.utils import (
75
75
  ModelPrecision,
76
76
  ModelSource,
77
77
  RuntimeType,
78
+ _get_model_name,
78
79
  )
79
80
  from supervisely.project import ProjectType
80
81
  from supervisely.project.download import download_to_cache, read_from_cached_project
@@ -173,9 +174,7 @@ class Inference:
173
174
  self._use_gui = False
174
175
  deploy_params, need_download = self._get_deploy_params_from_args()
175
176
  if need_download:
176
- local_model_files = self._download_model_files(
177
- deploy_params["model_source"], deploy_params["model_files"], False
178
- )
177
+ local_model_files = self._download_model_files(deploy_params, False)
179
178
  deploy_params["model_files"] = local_model_files
180
179
  self._load_model_headless(**deploy_params)
181
180
 
@@ -210,14 +209,12 @@ class Inference:
210
209
  self.initialize_gui()
211
210
 
212
211
  def on_serve_callback(
213
- gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
212
+ gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate],
214
213
  ):
215
214
  Progress("Deploying model ...", 1)
216
215
  if isinstance(self.gui, GUI.ServingGUITemplate):
217
216
  deploy_params = self.get_params_from_gui()
218
- model_files = self._download_model_files(
219
- deploy_params["model_source"], deploy_params["model_files"]
220
- )
217
+ model_files = self._download_model_files(deploy_params)
221
218
  deploy_params["model_files"] = model_files
222
219
  self._load_model_headless(**deploy_params)
223
220
  elif isinstance(self.gui, GUI.ServingGUI):
@@ -230,7 +227,7 @@ class Inference:
230
227
  gui.show_deployed_model_info(self)
231
228
 
232
229
  def on_change_model_callback(
233
- gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
230
+ gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate],
234
231
  ):
235
232
  self.shutdown_model()
236
233
  if isinstance(self.gui, (GUI.ServingGUI, GUI.ServingGUITemplate)):
@@ -567,13 +564,23 @@ class Inference:
567
564
  def _checkpoints_cache_dir(self):
568
565
  return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
569
566
 
570
- def _download_model_files(
571
- self, model_source: str, model_files: List[str], log_progress: bool = True
572
- ) -> dict:
573
- if model_source == ModelSource.PRETRAINED:
574
- return self._download_pretrained_model(model_files, log_progress)
575
- elif model_source == ModelSource.CUSTOM:
576
- return self._download_custom_model(model_files, log_progress)
567
+ def _download_model_files(self, deploy_params: dict, log_progress: bool = True) -> dict:
568
+ if deploy_params["runtime"] != RuntimeType.PYTORCH:
569
+ export = deploy_params["model_info"].get("export", {})
570
+ export_model = export.get(deploy_params["runtime"], None)
571
+ if export_model is not None:
572
+ if sly_fs.get_file_name(export_model) == sly_fs.get_file_name(
573
+ deploy_params["model_files"]["checkpoint"]
574
+ ):
575
+ deploy_params["model_files"]["checkpoint"] = (
576
+ deploy_params["model_info"]["artifacts_dir"] + export_model
577
+ )
578
+ logger.info(f"Found model checkpoint for '{deploy_params['runtime']}'")
579
+
580
+ if deploy_params["model_source"] == ModelSource.PRETRAINED:
581
+ return self._download_pretrained_model(deploy_params["model_files"], log_progress)
582
+ elif deploy_params["model_source"] == ModelSource.CUSTOM:
583
+ return self._download_custom_model(deploy_params["model_files"], log_progress)
577
584
 
578
585
  def _download_pretrained_model(self, model_files: dict, log_progress: bool = True):
579
586
  """
@@ -2929,9 +2936,7 @@ class Inference:
2929
2936
  state = request.state.state
2930
2937
  deploy_params = state["deploy_params"]
2931
2938
  if isinstance(self.gui, GUI.ServingGUITemplate):
2932
- model_files = self._download_model_files(
2933
- deploy_params["model_source"], deploy_params["model_files"]
2934
- )
2939
+ model_files = self._download_model_files(deploy_params)
2935
2940
  deploy_params["model_files"] = model_files
2936
2941
  self._load_model_headless(**deploy_params)
2937
2942
  elif isinstance(self.gui, GUI.ServingGUI):
@@ -3061,7 +3066,7 @@ class Inference:
3061
3066
  raise ValueError("No pretrained models found.")
3062
3067
 
3063
3068
  model = self.pretrained_models[0]
3064
- model_name = model.get("meta", {}).get("model_name", None)
3069
+ model_name = _get_model_name(model)
3065
3070
  if model_name is None:
3066
3071
  raise ValueError("No model name found in the first pretrained model.")
3067
3072
 
@@ -3126,7 +3131,7 @@ class Inference:
3126
3131
  meta = m.get("meta", None)
3127
3132
  if meta is None:
3128
3133
  continue
3129
- model_name = meta.get("model_name", None)
3134
+ model_name = _get_model_name(m)
3130
3135
  if model_name is None:
3131
3136
  continue
3132
3137
  m_files = meta.get("model_files", None)
@@ -3135,7 +3140,7 @@ class Inference:
3135
3140
  checkpoint = m_files.get("checkpoint", None)
3136
3141
  if checkpoint is None:
3137
3142
  continue
3138
- if model == m["meta"]["model_name"]:
3143
+ if model.lower() == model_name.lower():
3139
3144
  model_info = m
3140
3145
  model_source = ModelSource.PRETRAINED
3141
3146
  model_files = {"checkpoint": checkpoint}
@@ -3153,8 +3158,6 @@ class Inference:
3153
3158
  model_meta_path = os.path.join(artifacts_dir, "model_meta.json")
3154
3159
  model_info["model_meta"] = self._load_json_file(model_meta_path)
3155
3160
  original_model_files = model_info.get("model_files")
3156
- if not original_model_files:
3157
- raise ValueError("Invalid 'experiment_info.json'. Missing 'model_files' key.")
3158
3161
  return model_info, original_model_files
3159
3162
 
3160
3163
  def _prepare_local_model_files(artifacts_dir, checkpoint_path, original_model_files):
@@ -3201,6 +3204,7 @@ class Inference:
3201
3204
  model_files = _prepare_local_model_files(
3202
3205
  artifacts_dir, checkpoint_path, original_model_files
3203
3206
  )
3207
+
3204
3208
  else:
3205
3209
  local_artifacts_dir = os.path.join(
3206
3210
  self.model_dir, "local_deploy", os.path.basename(artifacts_dir)
@@ -3298,7 +3302,11 @@ class Inference:
3298
3302
  if draw:
3299
3303
  raise ValueError("Draw visualization is not supported for project inference")
3300
3304
 
3301
- state = {"projectId": project_id, "dataset_ids": dataset_ids, "settings": settings}
3305
+ state = {
3306
+ "projectId": project_id,
3307
+ "dataset_ids": dataset_ids,
3308
+ "settings": settings,
3309
+ }
3302
3310
  if upload:
3303
3311
  source_project = api.project.get_info_by_id(project_id)
3304
3312
  workspace_id = source_project.workspace_id
@@ -3472,7 +3480,7 @@ class Inference:
3472
3480
  def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
3473
3481
  if model_source == ModelSource.PRETRAINED:
3474
3482
  checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
3475
- checkpoint_name = model_info["meta"]["model_name"]
3483
+ checkpoint_name = _get_model_name(model_info)
3476
3484
  else:
3477
3485
  checkpoint_name = sly_fs.get_file_name_with_ext(model_files["checkpoint"])
3478
3486
  checkpoint_url = os.path.join(
@@ -4,21 +4,28 @@ from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
4
4
 
5
5
  class ClassesSelector:
6
6
  title = "Classes Selector"
7
- description = (
8
- "Select classes that will be used for training. "
9
- "Supported shapes are Bitmap, Polygon, Rectangle."
10
- )
7
+ description = "Select classes that will be used for training"
11
8
  lock_message = "Select training and validation splits to unlock"
12
9
 
13
10
  def __init__(self, project_id: int, classes: list, app_options: dict = {}):
11
+ # Init widgets
12
+ self.qa_stats_text = None
13
+ self.classes_table = None
14
+ self.validator_text = None
15
+ self.button = None
16
+ self.container = None
17
+ self.card = None
18
+ # -------------------------------- #
19
+
14
20
  self.display_widgets = []
21
+ self.app_options = app_options
15
22
 
16
23
  # GUI Components
17
24
  if is_development() or is_debug_with_sly_net():
18
25
  qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
19
26
  else:
20
27
  qa_stats_link = f"/projects/{project_id}/stats/datasets"
21
- qa_stats_text = Text(
28
+ self.qa_stats_text = Text(
22
29
  text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
23
30
  )
24
31
 
@@ -32,7 +39,7 @@ class ClassesSelector:
32
39
  self.validator_text.hide()
33
40
  self.button = Button("Select")
34
41
  self.display_widgets.extend(
35
- [qa_stats_text, self.classes_table, self.validator_text, self.button]
42
+ [self.qa_stats_text, self.classes_table, self.validator_text, self.button]
36
43
  )
37
44
  # -------------------------------- #
38
45
 
@@ -42,7 +49,7 @@ class ClassesSelector:
42
49
  description=self.description,
43
50
  content=self.container,
44
51
  lock_message=self.lock_message,
45
- collapsable=app_options.get("collapsable", False),
52
+ collapsable=self.app_options.get("collapsable", False),
46
53
  )
47
54
  self.card.lock()
48
55
 
@@ -62,14 +69,14 @@ class ClassesSelector:
62
69
  def validate_step(self) -> bool:
63
70
  self.validator_text.hide()
64
71
 
65
- if len(self.classes_table.project_meta.obj_classes) == 0:
72
+ project_classes = self.classes_table.project_meta.obj_classes
73
+ if len(project_classes) == 0:
66
74
  self.validator_text.set(text="Project has no classes", status="error")
67
75
  self.validator_text.show()
68
76
  return False
69
77
 
70
78
  selected_classes = self.classes_table.get_selected_classes()
71
79
  table_data = self.classes_table._table_data
72
-
73
80
  empty_classes = [
74
81
  row[0]["data"]
75
82
  for row in table_data
@@ -78,23 +85,21 @@ class ClassesSelector:
78
85
 
79
86
  n_classes = len(selected_classes)
80
87
  if n_classes == 0:
81
- self.validator_text.set(text="Please select at least one class", status="error")
88
+ message = "Please select at least one class"
89
+ status = "error"
82
90
  else:
83
- warning_text = ""
91
+ class_text = "class" if n_classes == 1 else "classes"
92
+ message = f"Selected {n_classes} {class_text}"
84
93
  status = "success"
85
94
  if empty_classes:
86
95
  intersections = set(selected_classes).intersection(empty_classes)
87
96
  if intersections:
88
- warning_text = (
89
- f". Selected class has no annotations: {', '.join(intersections)}"
90
- if len(intersections) == 1
91
- else f". Selected classes have no annotations: {', '.join(intersections)}"
97
+ class_text = "class" if len(intersections) == 1 else "classes"
98
+ message += (
99
+ f". Selected {class_text} have no annotations: {', '.join(intersections)}"
92
100
  )
93
101
  status = "warning"
94
102
 
95
- class_text = "class" if n_classes == 1 else "classes"
96
- self.validator_text.set(
97
- text=f"Selected {n_classes} {class_text}{warning_text}", status=status
98
- )
103
+ self.validator_text.set(text=message, status=status)
99
104
  self.validator_text.show()
100
105
  return n_classes > 0