supervisely 6.73.326__py3-none-any.whl → 6.73.327__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/annotation/annotation.py +1 -1
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +17 -14
- supervisely/app/widgets/pretrained_models_selector/template.html +2 -1
- supervisely/convert/image/yolo/yolo_helper.py +95 -25
- supervisely/nn/inference/gui/serving_gui_template.py +2 -3
- supervisely/nn/inference/inference.py +33 -25
- supervisely/nn/training/gui/classes_selector.py +24 -19
- supervisely/nn/training/gui/gui.py +90 -37
- supervisely/nn/training/gui/hyperparameters_selector.py +32 -15
- supervisely/nn/training/gui/input_selector.py +13 -2
- supervisely/nn/training/gui/model_selector.py +16 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +10 -1
- supervisely/nn/training/gui/training_artifacts.py +23 -4
- supervisely/nn/training/gui/training_logs.py +15 -3
- supervisely/nn/training/gui/training_process.py +14 -13
- supervisely/nn/training/train_app.py +59 -24
- supervisely/nn/utils.py +9 -0
- supervisely/project/project.py +16 -3
- {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/METADATA +1 -1
- {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/RECORD +24 -24
- {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/LICENSE +0 -0
- {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/WHEEL +0 -0
- {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.326.dist-info → supervisely-6.73.327.dist-info}/top_level.txt +0 -0
|
@@ -3088,7 +3088,7 @@ class Annotation:
|
|
|
3088
3088
|
def to_yolo(
|
|
3089
3089
|
self,
|
|
3090
3090
|
class_names: List[str],
|
|
3091
|
-
task_type: Literal["
|
|
3091
|
+
task_type: Literal["detect", "segment", "pose"] = "detect",
|
|
3092
3092
|
) -> List[str]:
|
|
3093
3093
|
"""
|
|
3094
3094
|
Convert Supervisely annotation to YOLO annotation format.
|
|
@@ -4,7 +4,7 @@ from supervisely.api.api import Api
|
|
|
4
4
|
from supervisely.app.content import DataJson, StateJson
|
|
5
5
|
from supervisely.app.widgets import Widget
|
|
6
6
|
from supervisely.io.fs import get_file_ext
|
|
7
|
-
from supervisely.nn.utils import ModelSource
|
|
7
|
+
from supervisely.nn.utils import ModelSource, _get_model_name
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class PretrainedModelsSelector(Widget):
|
|
@@ -147,9 +147,11 @@ class PretrainedModelsSelector(Widget):
|
|
|
147
147
|
if train_version == "v1":
|
|
148
148
|
model_name = selected_model.get(model_name_column)
|
|
149
149
|
if model_name is None:
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
150
|
+
model_name = _get_model_name(selected_model)
|
|
151
|
+
if model_name is None:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
"Could not find model name. Make sure you have column 'Model' in your models list."
|
|
154
|
+
)
|
|
153
155
|
|
|
154
156
|
model_meta = selected_model.get("meta")
|
|
155
157
|
if model_meta is None:
|
|
@@ -230,24 +232,25 @@ class PretrainedModelsSelector(Widget):
|
|
|
230
232
|
for task_type in self._table_data:
|
|
231
233
|
for arch_type in self._table_data[task_type]:
|
|
232
234
|
for idx, model in enumerate(self._table_data[task_type][arch_type]):
|
|
233
|
-
|
|
234
|
-
if
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
235
|
+
name_from_info = _get_model_name(model)
|
|
236
|
+
if name_from_info is not None:
|
|
237
|
+
if name_from_info.lower() == model_name.lower():
|
|
238
|
+
self.set_active_task_type(task_type)
|
|
239
|
+
self.set_active_arch_type(arch_type)
|
|
240
|
+
self.set_active_row(idx)
|
|
241
|
+
return
|
|
239
242
|
|
|
240
243
|
def get_by_model_name(self, model_name: str) -> Union[Dict, None]:
|
|
241
244
|
for task_type in self._table_data:
|
|
242
245
|
for arch_type in self._table_data[task_type]:
|
|
243
246
|
for idx, model in enumerate(self._table_data[task_type][arch_type]):
|
|
244
|
-
|
|
245
|
-
if
|
|
246
|
-
|
|
247
|
+
name_from_info = _get_model_name(model)
|
|
248
|
+
if name_from_info is not None:
|
|
249
|
+
if name_from_info.lower() == model_name.lower():
|
|
250
|
+
return model
|
|
247
251
|
|
|
248
252
|
def _filter_and_sort_models(self, models: List[Dict], sort_models: bool = True) -> Dict:
|
|
249
253
|
filtered_models = {}
|
|
250
|
-
|
|
251
254
|
for model in models:
|
|
252
255
|
for key in model:
|
|
253
256
|
if isinstance(model[key], (int, float)):
|
|
@@ -62,7 +62,8 @@
|
|
|
62
62
|
<tr>
|
|
63
63
|
<template v-for="(col, colKey) in data.{{{widget.widget_id}}}.tableData[state.{{{widget.widget_id}}}.selectedTaskType][state.{{{widget.widget_id}}}.selectedArchType][0]">
|
|
64
64
|
<th v-if="colKey != 'meta'" :key="colKey">
|
|
65
|
-
<div
|
|
65
|
+
<div v-if="colKey==='model_name'">Model</div>
|
|
66
|
+
<div v-else>{{ colKey }}</div>
|
|
66
67
|
</th>
|
|
67
68
|
</template>
|
|
68
69
|
</tr>
|
|
@@ -17,7 +17,8 @@ from supervisely.geometry.polygon import Polygon
|
|
|
17
17
|
from supervisely.geometry.polyline import Polyline
|
|
18
18
|
from supervisely.geometry.rectangle import Rectangle
|
|
19
19
|
from supervisely.imaging.color import generate_rgb
|
|
20
|
-
from supervisely.io.fs import get_file_name_with_ext, touch
|
|
20
|
+
from supervisely.io.fs import get_file_name, get_file_name_with_ext, touch
|
|
21
|
+
from supervisely.nn.task_type import TaskType
|
|
21
22
|
from supervisely.project.project import Dataset, OpenMode, Project
|
|
22
23
|
from supervisely.project.project_meta import ProjectMeta
|
|
23
24
|
from supervisely.sly_logger import logger
|
|
@@ -27,6 +28,20 @@ YOLO_DETECTION_COORDS_NUM = 4
|
|
|
27
28
|
YOLO_SEGM_MIN_COORDS_NUM = 6
|
|
28
29
|
YOLO_KEYPOINTS_MIN_COORDS_NUM = 6
|
|
29
30
|
|
|
31
|
+
|
|
32
|
+
class YOLOTaskType:
|
|
33
|
+
DETECT = "detect"
|
|
34
|
+
SEGMENT = "segment"
|
|
35
|
+
POSE = "pose"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
SLY_YOLO_TASK_TYPE_MAP = {
|
|
39
|
+
TaskType.OBJECT_DETECTION: YOLOTaskType.DETECT,
|
|
40
|
+
TaskType.INSTANCE_SEGMENTATION: YOLOTaskType.SEGMENT,
|
|
41
|
+
TaskType.POSE_ESTIMATION: YOLOTaskType.POSE,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
30
45
|
coco_classes = [
|
|
31
46
|
"person",
|
|
32
47
|
"bicycle",
|
|
@@ -398,22 +413,22 @@ def keypoints_to_yolo_line(
|
|
|
398
413
|
|
|
399
414
|
def convert_label_geometry_if_needed(
|
|
400
415
|
label: Label,
|
|
401
|
-
task_type: Literal["
|
|
416
|
+
task_type: Literal["detect", "segment", "pose"],
|
|
402
417
|
verbose: bool = False,
|
|
403
418
|
) -> List[Label]:
|
|
404
|
-
if task_type ==
|
|
419
|
+
if task_type == YOLOTaskType.DETECT:
|
|
405
420
|
available_geometry_type = Rectangle
|
|
406
421
|
convertable_geometry_types = [Polygon, GraphNodes, Bitmap, Polyline, AlphaMask, AnyGeometry]
|
|
407
|
-
elif task_type ==
|
|
422
|
+
elif task_type == YOLOTaskType.SEGMENT:
|
|
408
423
|
available_geometry_type = Polygon
|
|
409
424
|
convertable_geometry_types = [Bitmap, AlphaMask, AnyGeometry]
|
|
410
|
-
elif task_type ==
|
|
425
|
+
elif task_type == YOLOTaskType.POSE:
|
|
411
426
|
available_geometry_type = GraphNodes
|
|
412
427
|
convertable_geometry_types = []
|
|
413
428
|
else:
|
|
414
429
|
raise ValueError(
|
|
415
430
|
f"Unsupported task type: {task_type}. "
|
|
416
|
-
"Supported types: '
|
|
431
|
+
f"Supported types: '{YOLOTaskType.DETECT}', '{YOLOTaskType.SEGMENT}', '{YOLOTaskType.POSE}'"
|
|
417
432
|
)
|
|
418
433
|
|
|
419
434
|
if label.obj_class.geometry_type == available_geometry_type:
|
|
@@ -438,7 +453,7 @@ def label_to_yolo_lines(
|
|
|
438
453
|
img_height: int,
|
|
439
454
|
img_width: int,
|
|
440
455
|
class_names: List[str],
|
|
441
|
-
task_type: Literal["
|
|
456
|
+
task_type: Literal["detect", "segment", "pose"],
|
|
442
457
|
) -> List[str]:
|
|
443
458
|
"""
|
|
444
459
|
Convert the Supervisely Label to a line in the YOLO format.
|
|
@@ -449,21 +464,21 @@ def label_to_yolo_lines(
|
|
|
449
464
|
|
|
450
465
|
lines = []
|
|
451
466
|
for label in labels:
|
|
452
|
-
if task_type ==
|
|
467
|
+
if task_type == YOLOTaskType.DETECT:
|
|
453
468
|
yolo_line = rectangle_to_yolo_line(
|
|
454
469
|
class_idx=class_idx,
|
|
455
470
|
geometry=label.geometry,
|
|
456
471
|
img_height=img_height,
|
|
457
472
|
img_width=img_width,
|
|
458
473
|
)
|
|
459
|
-
elif task_type ==
|
|
474
|
+
elif task_type == YOLOTaskType.SEGMENT:
|
|
460
475
|
yolo_line = polygon_to_yolo_line(
|
|
461
476
|
class_idx=class_idx,
|
|
462
477
|
geometry=label.geometry,
|
|
463
478
|
img_height=img_height,
|
|
464
479
|
img_width=img_width,
|
|
465
480
|
)
|
|
466
|
-
elif task_type ==
|
|
481
|
+
elif task_type == YOLOTaskType.POSE:
|
|
467
482
|
nodes_field = label.obj_class.geometry_type.items_json_field
|
|
468
483
|
max_kpts_count = len(label.obj_class.geometry_config[nodes_field])
|
|
469
484
|
yolo_line = keypoints_to_yolo_line(
|
|
@@ -474,7 +489,10 @@ def label_to_yolo_lines(
|
|
|
474
489
|
max_kpts_count=max_kpts_count,
|
|
475
490
|
)
|
|
476
491
|
else:
|
|
477
|
-
raise ValueError(
|
|
492
|
+
raise ValueError(
|
|
493
|
+
f"Unsupported task type: {task_type}. "
|
|
494
|
+
f"Supported types: '{YOLOTaskType.DETECT}', '{YOLOTaskType.SEGMENT}', '{YOLOTaskType.POSE}'"
|
|
495
|
+
)
|
|
478
496
|
|
|
479
497
|
if yolo_line is not None:
|
|
480
498
|
lines.append(yolo_line)
|
|
@@ -485,12 +503,11 @@ def label_to_yolo_lines(
|
|
|
485
503
|
def sly_ann_to_yolo(
|
|
486
504
|
ann: Annotation,
|
|
487
505
|
class_names: List[str],
|
|
488
|
-
task_type: Literal["
|
|
506
|
+
task_type: Literal["detect", "segment", "pose"] = "detect",
|
|
489
507
|
) -> List[str]:
|
|
490
508
|
"""
|
|
491
509
|
Convert the Supervisely annotation to the YOLO format.
|
|
492
510
|
"""
|
|
493
|
-
|
|
494
511
|
h, w = ann.img_size
|
|
495
512
|
yolo_lines = []
|
|
496
513
|
for label in ann.labels:
|
|
@@ -509,11 +526,12 @@ def sly_ds_to_yolo(
|
|
|
509
526
|
dataset: Dataset,
|
|
510
527
|
meta: ProjectMeta,
|
|
511
528
|
dest_dir: Optional[str] = None,
|
|
512
|
-
task_type: Literal["
|
|
529
|
+
task_type: Literal["detect", "segment", "pose"] = "detect",
|
|
513
530
|
log_progress: bool = False,
|
|
514
531
|
progress_cb: Optional[Union[tqdm, Callable]] = None,
|
|
532
|
+
is_val: Optional[bool] = None,
|
|
515
533
|
) -> str:
|
|
516
|
-
|
|
534
|
+
task_type = validate_task_type(task_type)
|
|
517
535
|
if progress_cb is not None:
|
|
518
536
|
log_progress = False
|
|
519
537
|
|
|
@@ -543,15 +561,19 @@ def sly_ds_to_yolo(
|
|
|
543
561
|
ann_path = dataset.get_ann_path(name)
|
|
544
562
|
ann = Annotation.load_json_file(ann_path, meta)
|
|
545
563
|
|
|
546
|
-
|
|
547
|
-
|
|
564
|
+
if is_val is not None:
|
|
565
|
+
images_dir = val_images_dir if is_val else train_images_dir
|
|
566
|
+
labels_dir = val_labels_dir if is_val else train_labels_dir
|
|
567
|
+
else:
|
|
568
|
+
images_dir = val_images_dir if ann.img_tags.get("val") else train_images_dir
|
|
569
|
+
labels_dir = val_labels_dir if ann.img_tags.get("val") else train_labels_dir
|
|
548
570
|
|
|
549
571
|
img_path = Path(dataset.get_img_path(name))
|
|
550
572
|
img_name = f"{dataset.short_name}_{get_file_name_with_ext(img_path)}"
|
|
551
573
|
img_name = generate_free_name(used_names, img_name, with_ext=True, extend_used_names=True)
|
|
552
574
|
shutil.copy2(img_path, images_dir / img_name)
|
|
553
575
|
|
|
554
|
-
label_path = str(labels_dir / f"{img_name}.txt")
|
|
576
|
+
label_path = str(labels_dir / f"{get_file_name(img_name)}.txt")
|
|
555
577
|
yolo_lines = ann.to_yolo(class_names, task_type)
|
|
556
578
|
if len(yolo_lines) > 0:
|
|
557
579
|
with open(label_path, "w") as f:
|
|
@@ -565,7 +587,8 @@ def sly_ds_to_yolo(
|
|
|
565
587
|
# * save data config file if it does not exist
|
|
566
588
|
config_path = dest_dir / "data_config.yaml"
|
|
567
589
|
if not config_path.exists():
|
|
568
|
-
|
|
590
|
+
with_keypoint = task_type is YOLOTaskType.POSE
|
|
591
|
+
save_yolo_config(meta, dest_dir, with_keypoint=with_keypoint)
|
|
569
592
|
|
|
570
593
|
return str(dest_dir)
|
|
571
594
|
|
|
@@ -578,6 +601,8 @@ def save_yolo_config(meta: ProjectMeta, dest_dir: str, with_keypoint: bool = Fal
|
|
|
578
601
|
data_yaml = {
|
|
579
602
|
"train": f"../{str(dest_dir.name)}/images/train",
|
|
580
603
|
"val": f"../{str(dest_dir.name)}/images/val",
|
|
604
|
+
"train_labels": f"../{str(dest_dir.name)}/labels/train",
|
|
605
|
+
"val_labels": f"../{str(dest_dir.name)}/labels/val",
|
|
581
606
|
"nc": len(class_names),
|
|
582
607
|
"names": class_names,
|
|
583
608
|
"colors": class_colors,
|
|
@@ -590,6 +615,7 @@ def save_yolo_config(meta: ProjectMeta, dest_dir: str, with_keypoint: bool = Fal
|
|
|
590
615
|
field_name = obj_class.geometry_type.items_json_field
|
|
591
616
|
max_kpts_count = max(max_kpts_count, len(obj_class.geometry_config[field_name]))
|
|
592
617
|
data_yaml["kpt_shape"] = [max_kpts_count, 3]
|
|
618
|
+
data_yaml["flip_idx"] = [i for i in range(max_kpts_count)]
|
|
593
619
|
with open(save_path, "w") as f:
|
|
594
620
|
yaml.dump(data_yaml, f, default_flow_style=None)
|
|
595
621
|
|
|
@@ -599,21 +625,31 @@ def save_yolo_config(meta: ProjectMeta, dest_dir: str, with_keypoint: bool = Fal
|
|
|
599
625
|
def sly_project_to_yolo(
|
|
600
626
|
project: Union[Project, str],
|
|
601
627
|
dest_dir: Optional[str] = None,
|
|
602
|
-
task_type: Literal["
|
|
628
|
+
task_type: Literal["detect", "segment", "pose"] = "detect",
|
|
603
629
|
log_progress: bool = False,
|
|
604
630
|
progress_cb: Optional[Callable] = None,
|
|
631
|
+
val_datasets: Optional[List[str]] = None,
|
|
605
632
|
):
|
|
606
633
|
"""
|
|
607
634
|
Convert Supervisely project to YOLO format.
|
|
608
635
|
|
|
636
|
+
:param project: Supervisely project or path to the directory with the project.
|
|
637
|
+
:type project: :class:`supervisely.project.project.Project` or :class:`str`
|
|
609
638
|
:param dest_dir: Destination directory.
|
|
610
639
|
:type dest_dir: :class:`str`, optional
|
|
640
|
+
:param task_type: Task type.
|
|
641
|
+
:type task_type: :class:`str`, optional
|
|
611
642
|
:param log_progress: Show uploading progress bar.
|
|
612
643
|
:type log_progress: :class:`bool`
|
|
613
644
|
:param progress_cb: Function for tracking conversion progress (for all items in the project).
|
|
614
645
|
:type progress_cb: callable, optional
|
|
615
|
-
:
|
|
616
|
-
|
|
646
|
+
:param val_datasets: List of dataset names for validation.
|
|
647
|
+
Full dataset names are required (e.g., 'ds0/nested_ds1/ds3').
|
|
648
|
+
If specified, datasets from the list will be marked as val, others as train.
|
|
649
|
+
If not specified, the function will determine the validation datasets automatically.
|
|
650
|
+
:type val_datasets: :class:`list`, optional
|
|
651
|
+
:return: Path to the destination directory.
|
|
652
|
+
:rtype: :class:`str`
|
|
617
653
|
|
|
618
654
|
:Usage example:
|
|
619
655
|
|
|
@@ -627,6 +663,7 @@ def sly_project_to_yolo(
|
|
|
627
663
|
# Convert Project to YOLO format
|
|
628
664
|
sly.Project(project_directory).to_yolo(log_progress=True)
|
|
629
665
|
"""
|
|
666
|
+
task_type = validate_task_type(task_type)
|
|
630
667
|
if isinstance(project, str):
|
|
631
668
|
project = Project(project, mode=OpenMode.READ)
|
|
632
669
|
|
|
@@ -644,9 +681,15 @@ def sly_project_to_yolo(
|
|
|
644
681
|
desc="Converting Supervisely project to YOLO format", total=project.total_items
|
|
645
682
|
).update
|
|
646
683
|
|
|
647
|
-
|
|
684
|
+
with_keypoint = task_type is YOLOTaskType.POSE
|
|
685
|
+
save_yolo_config(project.meta, dest_dir, with_keypoint=with_keypoint)
|
|
648
686
|
|
|
649
687
|
for dataset in project.datasets:
|
|
688
|
+
if val_datasets is not None:
|
|
689
|
+
is_val = dataset.name in val_datasets
|
|
690
|
+
else:
|
|
691
|
+
is_val = None
|
|
692
|
+
|
|
650
693
|
dataset: Dataset
|
|
651
694
|
dataset.to_yolo(
|
|
652
695
|
meta=project.meta,
|
|
@@ -654,18 +697,23 @@ def sly_project_to_yolo(
|
|
|
654
697
|
task_type=task_type,
|
|
655
698
|
log_progress=log_progress,
|
|
656
699
|
progress_cb=progress_cb,
|
|
700
|
+
is_val=is_val,
|
|
657
701
|
)
|
|
658
702
|
logger.info(f"Dataset '{dataset.short_name}' has been converted to YOLO format.")
|
|
659
703
|
logger.info(f"Project '{project.name}' has been converted to YOLO format.")
|
|
660
704
|
|
|
705
|
+
return str(dest_dir)
|
|
706
|
+
|
|
661
707
|
|
|
662
708
|
def to_yolo(
|
|
663
709
|
input_data: Union[Project, Dataset, str],
|
|
664
710
|
dest_dir: Optional[str] = None,
|
|
665
|
-
task_type: Literal["
|
|
711
|
+
task_type: Literal["detect", "segment", "pose"] = "detect",
|
|
666
712
|
meta: Optional[ProjectMeta] = None,
|
|
667
713
|
log_progress: bool = True,
|
|
668
714
|
progress_cb: Optional[Callable] = None,
|
|
715
|
+
val_datasets: Optional[List[str]] = None,
|
|
716
|
+
is_val: Optional[bool] = None,
|
|
669
717
|
) -> Union[None, str]:
|
|
670
718
|
"""
|
|
671
719
|
Universal function to convert Supervisely project or dataset to YOLO format.
|
|
@@ -691,6 +739,13 @@ def to_yolo(
|
|
|
691
739
|
:type log_progress: :class:`bool`
|
|
692
740
|
:param progress_cb: Function for tracking conversion progress (for all items in the project).
|
|
693
741
|
:type progress_cb: callable, optional
|
|
742
|
+
:param val_datasets: List of dataset names for validation.
|
|
743
|
+
Full dataset names are required (e.g., 'ds0/nested_ds1/ds3').
|
|
744
|
+
If specified, datasets from the list will be marked as val, others as train.
|
|
745
|
+
If not specified, the function will determine the validation datasets automatically.
|
|
746
|
+
:type val_datasets: :class:`list`, optional
|
|
747
|
+
:param is_val: Whether the dataset is for validation.
|
|
748
|
+
:type is_val: :class:`bool`, optional
|
|
694
749
|
:return: None, list of YOLO lines, or path to the destination directory.
|
|
695
750
|
:rtype: NoneType, list, str
|
|
696
751
|
|
|
@@ -711,7 +766,7 @@ def to_yolo(
|
|
|
711
766
|
|
|
712
767
|
# Convert Dataset to YOLO format
|
|
713
768
|
dataset: sly.Dataset = project_fs.datasets.get("dataset_name")
|
|
714
|
-
sly.convert.to_yolo(dataset, dest_dir="./yolo", meta=project_fs.meta)
|
|
769
|
+
sly.convert.to_yolo(dataset, dest_dir="./yolo", meta=project_fs.meta, is_val=True)
|
|
715
770
|
"""
|
|
716
771
|
if isinstance(input_data, str):
|
|
717
772
|
try:
|
|
@@ -728,6 +783,7 @@ def to_yolo(
|
|
|
728
783
|
task_type=task_type,
|
|
729
784
|
log_progress=log_progress,
|
|
730
785
|
progress_cb=progress_cb,
|
|
786
|
+
val_datasets=val_datasets,
|
|
731
787
|
)
|
|
732
788
|
elif isinstance(input_data, Dataset):
|
|
733
789
|
return sly_ds_to_yolo(
|
|
@@ -737,6 +793,20 @@ def to_yolo(
|
|
|
737
793
|
task_type=task_type,
|
|
738
794
|
log_progress=log_progress,
|
|
739
795
|
progress_cb=progress_cb,
|
|
796
|
+
is_val=is_val,
|
|
740
797
|
)
|
|
741
798
|
else:
|
|
742
799
|
raise ValueError("Unsupported input type. Only Project or Dataset are supported.")
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
def validate_task_type(task_type: Literal["detect", "segment", "pose"]) -> str:
|
|
803
|
+
if task_type not in [YOLOTaskType.DETECT, YOLOTaskType.SEGMENT, YOLOTaskType.POSE]:
|
|
804
|
+
task_type = SLY_YOLO_TASK_TYPE_MAP.get(task_type)
|
|
805
|
+
if task_type is None:
|
|
806
|
+
raise ValueError(
|
|
807
|
+
f"Unsupported task type: {task_type}. "
|
|
808
|
+
f"Supported types: '{YOLOTaskType.DETECT}', '{SLY_YOLO_TASK_TYPE_MAP[TaskType.OBJECT_DETECTION]}', "
|
|
809
|
+
f"'{YOLOTaskType.SEGMENT}', '{SLY_YOLO_TASK_TYPE_MAP[TaskType.INSTANCE_SEGMENTATION]}', "
|
|
810
|
+
f"'{YOLOTaskType.POSE}', '{SLY_YOLO_TASK_TYPE_MAP[TaskType.POSE_ESTIMATION]}'"
|
|
811
|
+
)
|
|
812
|
+
return task_type
|
|
@@ -23,7 +23,7 @@ from supervisely.app.widgets.pretrained_models_selector.pretrained_models_select
|
|
|
23
23
|
)
|
|
24
24
|
from supervisely.nn.experiments import get_experiment_infos
|
|
25
25
|
from supervisely.nn.inference.gui.serving_gui import ServingGUI
|
|
26
|
-
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
26
|
+
from supervisely.nn.utils import ModelSource, RuntimeType, _get_model_name
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class ServingGUITemplate(ServingGUI):
|
|
@@ -149,8 +149,7 @@ class ServingGUITemplate(ServingGUI):
|
|
|
149
149
|
@property
|
|
150
150
|
def model_name(self) -> Optional[str]:
|
|
151
151
|
if self.model_source == ModelSource.PRETRAINED:
|
|
152
|
-
|
|
153
|
-
return model_meta.get("model_name")
|
|
152
|
+
return _get_model_name(self.model_info)
|
|
154
153
|
else:
|
|
155
154
|
return self.model_info.get("model_name")
|
|
156
155
|
|
|
@@ -75,6 +75,7 @@ from supervisely.nn.utils import (
|
|
|
75
75
|
ModelPrecision,
|
|
76
76
|
ModelSource,
|
|
77
77
|
RuntimeType,
|
|
78
|
+
_get_model_name,
|
|
78
79
|
)
|
|
79
80
|
from supervisely.project import ProjectType
|
|
80
81
|
from supervisely.project.download import download_to_cache, read_from_cached_project
|
|
@@ -173,9 +174,7 @@ class Inference:
|
|
|
173
174
|
self._use_gui = False
|
|
174
175
|
deploy_params, need_download = self._get_deploy_params_from_args()
|
|
175
176
|
if need_download:
|
|
176
|
-
local_model_files = self._download_model_files(
|
|
177
|
-
deploy_params["model_source"], deploy_params["model_files"], False
|
|
178
|
-
)
|
|
177
|
+
local_model_files = self._download_model_files(deploy_params, False)
|
|
179
178
|
deploy_params["model_files"] = local_model_files
|
|
180
179
|
self._load_model_headless(**deploy_params)
|
|
181
180
|
|
|
@@ -210,14 +209,12 @@ class Inference:
|
|
|
210
209
|
self.initialize_gui()
|
|
211
210
|
|
|
212
211
|
def on_serve_callback(
|
|
213
|
-
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
|
|
212
|
+
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate],
|
|
214
213
|
):
|
|
215
214
|
Progress("Deploying model ...", 1)
|
|
216
215
|
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
217
216
|
deploy_params = self.get_params_from_gui()
|
|
218
|
-
model_files = self._download_model_files(
|
|
219
|
-
deploy_params["model_source"], deploy_params["model_files"]
|
|
220
|
-
)
|
|
217
|
+
model_files = self._download_model_files(deploy_params)
|
|
221
218
|
deploy_params["model_files"] = model_files
|
|
222
219
|
self._load_model_headless(**deploy_params)
|
|
223
220
|
elif isinstance(self.gui, GUI.ServingGUI):
|
|
@@ -230,7 +227,7 @@ class Inference:
|
|
|
230
227
|
gui.show_deployed_model_info(self)
|
|
231
228
|
|
|
232
229
|
def on_change_model_callback(
|
|
233
|
-
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
|
|
230
|
+
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate],
|
|
234
231
|
):
|
|
235
232
|
self.shutdown_model()
|
|
236
233
|
if isinstance(self.gui, (GUI.ServingGUI, GUI.ServingGUITemplate)):
|
|
@@ -567,13 +564,23 @@ class Inference:
|
|
|
567
564
|
def _checkpoints_cache_dir(self):
|
|
568
565
|
return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
|
|
569
566
|
|
|
570
|
-
def _download_model_files(
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
567
|
+
def _download_model_files(self, deploy_params: dict, log_progress: bool = True) -> dict:
|
|
568
|
+
if deploy_params["runtime"] != RuntimeType.PYTORCH:
|
|
569
|
+
export = deploy_params["model_info"].get("export", {})
|
|
570
|
+
export_model = export.get(deploy_params["runtime"], None)
|
|
571
|
+
if export_model is not None:
|
|
572
|
+
if sly_fs.get_file_name(export_model) == sly_fs.get_file_name(
|
|
573
|
+
deploy_params["model_files"]["checkpoint"]
|
|
574
|
+
):
|
|
575
|
+
deploy_params["model_files"]["checkpoint"] = (
|
|
576
|
+
deploy_params["model_info"]["artifacts_dir"] + export_model
|
|
577
|
+
)
|
|
578
|
+
logger.info(f"Found model checkpoint for '{deploy_params['runtime']}'")
|
|
579
|
+
|
|
580
|
+
if deploy_params["model_source"] == ModelSource.PRETRAINED:
|
|
581
|
+
return self._download_pretrained_model(deploy_params["model_files"], log_progress)
|
|
582
|
+
elif deploy_params["model_source"] == ModelSource.CUSTOM:
|
|
583
|
+
return self._download_custom_model(deploy_params["model_files"], log_progress)
|
|
577
584
|
|
|
578
585
|
def _download_pretrained_model(self, model_files: dict, log_progress: bool = True):
|
|
579
586
|
"""
|
|
@@ -2929,9 +2936,7 @@ class Inference:
|
|
|
2929
2936
|
state = request.state.state
|
|
2930
2937
|
deploy_params = state["deploy_params"]
|
|
2931
2938
|
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
2932
|
-
model_files = self._download_model_files(
|
|
2933
|
-
deploy_params["model_source"], deploy_params["model_files"]
|
|
2934
|
-
)
|
|
2939
|
+
model_files = self._download_model_files(deploy_params)
|
|
2935
2940
|
deploy_params["model_files"] = model_files
|
|
2936
2941
|
self._load_model_headless(**deploy_params)
|
|
2937
2942
|
elif isinstance(self.gui, GUI.ServingGUI):
|
|
@@ -3061,7 +3066,7 @@ class Inference:
|
|
|
3061
3066
|
raise ValueError("No pretrained models found.")
|
|
3062
3067
|
|
|
3063
3068
|
model = self.pretrained_models[0]
|
|
3064
|
-
model_name = model
|
|
3069
|
+
model_name = _get_model_name(model)
|
|
3065
3070
|
if model_name is None:
|
|
3066
3071
|
raise ValueError("No model name found in the first pretrained model.")
|
|
3067
3072
|
|
|
@@ -3126,7 +3131,7 @@ class Inference:
|
|
|
3126
3131
|
meta = m.get("meta", None)
|
|
3127
3132
|
if meta is None:
|
|
3128
3133
|
continue
|
|
3129
|
-
model_name =
|
|
3134
|
+
model_name = _get_model_name(m)
|
|
3130
3135
|
if model_name is None:
|
|
3131
3136
|
continue
|
|
3132
3137
|
m_files = meta.get("model_files", None)
|
|
@@ -3135,7 +3140,7 @@ class Inference:
|
|
|
3135
3140
|
checkpoint = m_files.get("checkpoint", None)
|
|
3136
3141
|
if checkpoint is None:
|
|
3137
3142
|
continue
|
|
3138
|
-
if model ==
|
|
3143
|
+
if model.lower() == model_name.lower():
|
|
3139
3144
|
model_info = m
|
|
3140
3145
|
model_source = ModelSource.PRETRAINED
|
|
3141
3146
|
model_files = {"checkpoint": checkpoint}
|
|
@@ -3153,8 +3158,6 @@ class Inference:
|
|
|
3153
3158
|
model_meta_path = os.path.join(artifacts_dir, "model_meta.json")
|
|
3154
3159
|
model_info["model_meta"] = self._load_json_file(model_meta_path)
|
|
3155
3160
|
original_model_files = model_info.get("model_files")
|
|
3156
|
-
if not original_model_files:
|
|
3157
|
-
raise ValueError("Invalid 'experiment_info.json'. Missing 'model_files' key.")
|
|
3158
3161
|
return model_info, original_model_files
|
|
3159
3162
|
|
|
3160
3163
|
def _prepare_local_model_files(artifacts_dir, checkpoint_path, original_model_files):
|
|
@@ -3201,6 +3204,7 @@ class Inference:
|
|
|
3201
3204
|
model_files = _prepare_local_model_files(
|
|
3202
3205
|
artifacts_dir, checkpoint_path, original_model_files
|
|
3203
3206
|
)
|
|
3207
|
+
|
|
3204
3208
|
else:
|
|
3205
3209
|
local_artifacts_dir = os.path.join(
|
|
3206
3210
|
self.model_dir, "local_deploy", os.path.basename(artifacts_dir)
|
|
@@ -3298,7 +3302,11 @@ class Inference:
|
|
|
3298
3302
|
if draw:
|
|
3299
3303
|
raise ValueError("Draw visualization is not supported for project inference")
|
|
3300
3304
|
|
|
3301
|
-
state = {
|
|
3305
|
+
state = {
|
|
3306
|
+
"projectId": project_id,
|
|
3307
|
+
"dataset_ids": dataset_ids,
|
|
3308
|
+
"settings": settings,
|
|
3309
|
+
}
|
|
3302
3310
|
if upload:
|
|
3303
3311
|
source_project = api.project.get_info_by_id(project_id)
|
|
3304
3312
|
workspace_id = source_project.workspace_id
|
|
@@ -3472,7 +3480,7 @@ class Inference:
|
|
|
3472
3480
|
def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
|
|
3473
3481
|
if model_source == ModelSource.PRETRAINED:
|
|
3474
3482
|
checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
|
|
3475
|
-
checkpoint_name = model_info
|
|
3483
|
+
checkpoint_name = _get_model_name(model_info)
|
|
3476
3484
|
else:
|
|
3477
3485
|
checkpoint_name = sly_fs.get_file_name_with_ext(model_files["checkpoint"])
|
|
3478
3486
|
checkpoint_url = os.path.join(
|
|
@@ -4,21 +4,28 @@ from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
|
|
|
4
4
|
|
|
5
5
|
class ClassesSelector:
|
|
6
6
|
title = "Classes Selector"
|
|
7
|
-
description =
|
|
8
|
-
"Select classes that will be used for training. "
|
|
9
|
-
"Supported shapes are Bitmap, Polygon, Rectangle."
|
|
10
|
-
)
|
|
7
|
+
description = "Select classes that will be used for training"
|
|
11
8
|
lock_message = "Select training and validation splits to unlock"
|
|
12
9
|
|
|
13
10
|
def __init__(self, project_id: int, classes: list, app_options: dict = {}):
|
|
11
|
+
# Init widgets
|
|
12
|
+
self.qa_stats_text = None
|
|
13
|
+
self.classes_table = None
|
|
14
|
+
self.validator_text = None
|
|
15
|
+
self.button = None
|
|
16
|
+
self.container = None
|
|
17
|
+
self.card = None
|
|
18
|
+
# -------------------------------- #
|
|
19
|
+
|
|
14
20
|
self.display_widgets = []
|
|
21
|
+
self.app_options = app_options
|
|
15
22
|
|
|
16
23
|
# GUI Components
|
|
17
24
|
if is_development() or is_debug_with_sly_net():
|
|
18
25
|
qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
|
|
19
26
|
else:
|
|
20
27
|
qa_stats_link = f"/projects/{project_id}/stats/datasets"
|
|
21
|
-
qa_stats_text = Text(
|
|
28
|
+
self.qa_stats_text = Text(
|
|
22
29
|
text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
|
|
23
30
|
)
|
|
24
31
|
|
|
@@ -32,7 +39,7 @@ class ClassesSelector:
|
|
|
32
39
|
self.validator_text.hide()
|
|
33
40
|
self.button = Button("Select")
|
|
34
41
|
self.display_widgets.extend(
|
|
35
|
-
[qa_stats_text, self.classes_table, self.validator_text, self.button]
|
|
42
|
+
[self.qa_stats_text, self.classes_table, self.validator_text, self.button]
|
|
36
43
|
)
|
|
37
44
|
# -------------------------------- #
|
|
38
45
|
|
|
@@ -42,7 +49,7 @@ class ClassesSelector:
|
|
|
42
49
|
description=self.description,
|
|
43
50
|
content=self.container,
|
|
44
51
|
lock_message=self.lock_message,
|
|
45
|
-
collapsable=app_options.get("collapsable", False),
|
|
52
|
+
collapsable=self.app_options.get("collapsable", False),
|
|
46
53
|
)
|
|
47
54
|
self.card.lock()
|
|
48
55
|
|
|
@@ -62,14 +69,14 @@ class ClassesSelector:
|
|
|
62
69
|
def validate_step(self) -> bool:
|
|
63
70
|
self.validator_text.hide()
|
|
64
71
|
|
|
65
|
-
|
|
72
|
+
project_classes = self.classes_table.project_meta.obj_classes
|
|
73
|
+
if len(project_classes) == 0:
|
|
66
74
|
self.validator_text.set(text="Project has no classes", status="error")
|
|
67
75
|
self.validator_text.show()
|
|
68
76
|
return False
|
|
69
77
|
|
|
70
78
|
selected_classes = self.classes_table.get_selected_classes()
|
|
71
79
|
table_data = self.classes_table._table_data
|
|
72
|
-
|
|
73
80
|
empty_classes = [
|
|
74
81
|
row[0]["data"]
|
|
75
82
|
for row in table_data
|
|
@@ -78,23 +85,21 @@ class ClassesSelector:
|
|
|
78
85
|
|
|
79
86
|
n_classes = len(selected_classes)
|
|
80
87
|
if n_classes == 0:
|
|
81
|
-
|
|
88
|
+
message = "Please select at least one class"
|
|
89
|
+
status = "error"
|
|
82
90
|
else:
|
|
83
|
-
|
|
91
|
+
class_text = "class" if n_classes == 1 else "classes"
|
|
92
|
+
message = f"Selected {n_classes} {class_text}"
|
|
84
93
|
status = "success"
|
|
85
94
|
if empty_classes:
|
|
86
95
|
intersections = set(selected_classes).intersection(empty_classes)
|
|
87
96
|
if intersections:
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
else f". Selected classes have no annotations: {', '.join(intersections)}"
|
|
97
|
+
class_text = "class" if len(intersections) == 1 else "classes"
|
|
98
|
+
message += (
|
|
99
|
+
f". Selected {class_text} have no annotations: {', '.join(intersections)}"
|
|
92
100
|
)
|
|
93
101
|
status = "warning"
|
|
94
102
|
|
|
95
|
-
|
|
96
|
-
self.validator_text.set(
|
|
97
|
-
text=f"Selected {n_classes} {class_text}{warning_text}", status=status
|
|
98
|
-
)
|
|
103
|
+
self.validator_text.set(text=message, status=status)
|
|
99
104
|
self.validator_text.show()
|
|
100
105
|
return n_classes > 0
|