argus-cv 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of argus-cv might be problematic. Click here for more details.
- argus/__init__.py +1 -1
- argus/cli.py +345 -1
- argus/core/__init__.py +20 -0
- argus/core/coco.py +46 -8
- argus/core/convert.py +277 -0
- argus/core/filter.py +670 -0
- argus/core/yolo.py +29 -0
- {argus_cv-1.4.0.dist-info → argus_cv-1.5.1.dist-info}/METADATA +1 -1
- argus_cv-1.5.1.dist-info/RECORD +16 -0
- argus_cv-1.4.0.dist-info/RECORD +0 -14
- {argus_cv-1.4.0.dist-info → argus_cv-1.5.1.dist-info}/WHEEL +0 -0
- {argus_cv-1.4.0.dist-info → argus_cv-1.5.1.dist-info}/entry_points.txt +0 -0
argus/__init__.py
CHANGED
argus/cli.py
CHANGED
|
@@ -8,11 +8,23 @@ import cv2
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import typer
|
|
10
10
|
from rich.console import Console
|
|
11
|
-
from rich.progress import
|
|
11
|
+
from rich.progress import (
|
|
12
|
+
BarColumn,
|
|
13
|
+
Progress,
|
|
14
|
+
SpinnerColumn,
|
|
15
|
+
TaskProgressColumn,
|
|
16
|
+
TextColumn,
|
|
17
|
+
)
|
|
12
18
|
from rich.table import Table
|
|
13
19
|
|
|
14
20
|
from argus.core import COCODataset, Dataset, MaskDataset, YOLODataset
|
|
15
21
|
from argus.core.base import DatasetFormat, TaskType
|
|
22
|
+
from argus.core.convert import convert_mask_to_yolo_seg
|
|
23
|
+
from argus.core.filter import (
|
|
24
|
+
filter_coco_dataset,
|
|
25
|
+
filter_mask_dataset,
|
|
26
|
+
filter_yolo_dataset,
|
|
27
|
+
)
|
|
16
28
|
from argus.core.split import (
|
|
17
29
|
is_coco_unsplit,
|
|
18
30
|
parse_ratio,
|
|
@@ -632,6 +644,338 @@ def split_dataset(
|
|
|
632
644
|
)
|
|
633
645
|
|
|
634
646
|
|
|
647
|
+
@app.command(name="convert")
|
|
648
|
+
def convert_dataset(
|
|
649
|
+
input_path: Annotated[
|
|
650
|
+
Path,
|
|
651
|
+
typer.Option(
|
|
652
|
+
"--input-path",
|
|
653
|
+
"-i",
|
|
654
|
+
help="Path to the source dataset.",
|
|
655
|
+
),
|
|
656
|
+
] = Path("."),
|
|
657
|
+
output_path: Annotated[
|
|
658
|
+
Path,
|
|
659
|
+
typer.Option(
|
|
660
|
+
"--output-path",
|
|
661
|
+
"-o",
|
|
662
|
+
help="Output directory for converted dataset.",
|
|
663
|
+
),
|
|
664
|
+
] = Path("converted"),
|
|
665
|
+
to_format: Annotated[
|
|
666
|
+
str,
|
|
667
|
+
typer.Option(
|
|
668
|
+
"--to",
|
|
669
|
+
help="Target format (currently only 'yolo-seg' is supported).",
|
|
670
|
+
),
|
|
671
|
+
] = "yolo-seg",
|
|
672
|
+
epsilon_factor: Annotated[
|
|
673
|
+
float,
|
|
674
|
+
typer.Option(
|
|
675
|
+
"--epsilon-factor",
|
|
676
|
+
"-e",
|
|
677
|
+
help="Polygon simplification factor (Douglas-Peucker algorithm).",
|
|
678
|
+
min=0.0,
|
|
679
|
+
max=1.0,
|
|
680
|
+
),
|
|
681
|
+
] = 0.005,
|
|
682
|
+
min_area: Annotated[
|
|
683
|
+
float,
|
|
684
|
+
typer.Option(
|
|
685
|
+
"--min-area",
|
|
686
|
+
"-a",
|
|
687
|
+
help="Minimum contour area in pixels to include.",
|
|
688
|
+
min=0.0,
|
|
689
|
+
),
|
|
690
|
+
] = 100.0,
|
|
691
|
+
) -> None:
|
|
692
|
+
"""Convert a dataset from one format to another.
|
|
693
|
+
|
|
694
|
+
Currently supports converting MaskDataset to YOLO segmentation format.
|
|
695
|
+
|
|
696
|
+
Example:
|
|
697
|
+
uvx argus-cv convert -i /path/to/masks -o /path/to/output --to yolo-seg
|
|
698
|
+
"""
|
|
699
|
+
# Validate format
|
|
700
|
+
if to_format != "yolo-seg":
|
|
701
|
+
console.print(
|
|
702
|
+
f"[red]Error: Unsupported target format '{to_format}'.[/red]\n"
|
|
703
|
+
"[yellow]Currently only 'yolo-seg' is supported.[/yellow]"
|
|
704
|
+
)
|
|
705
|
+
raise typer.Exit(1)
|
|
706
|
+
|
|
707
|
+
# Resolve and validate input path
|
|
708
|
+
input_path = input_path.resolve()
|
|
709
|
+
if not input_path.exists():
|
|
710
|
+
console.print(f"[red]Error: Path does not exist: {input_path}[/red]")
|
|
711
|
+
raise typer.Exit(1)
|
|
712
|
+
if not input_path.is_dir():
|
|
713
|
+
console.print(f"[red]Error: Path is not a directory: {input_path}[/red]")
|
|
714
|
+
raise typer.Exit(1)
|
|
715
|
+
|
|
716
|
+
# Detect source dataset - must be MaskDataset for yolo-seg conversion
|
|
717
|
+
dataset = MaskDataset.detect(input_path)
|
|
718
|
+
if not dataset:
|
|
719
|
+
console.print(
|
|
720
|
+
f"[red]Error: No MaskDataset found at {input_path}[/red]\n"
|
|
721
|
+
"[yellow]Ensure the path contains images/ + masks/ directories "
|
|
722
|
+
"(or equivalent patterns like img/+gt/ or leftImg8bit/+gtFine/).[/yellow]"
|
|
723
|
+
)
|
|
724
|
+
raise typer.Exit(1)
|
|
725
|
+
|
|
726
|
+
# Resolve output path
|
|
727
|
+
if not output_path.is_absolute():
|
|
728
|
+
output_path = input_path.parent / output_path
|
|
729
|
+
output_path = output_path.resolve()
|
|
730
|
+
|
|
731
|
+
# Check if output already exists
|
|
732
|
+
if output_path.exists() and any(output_path.iterdir()):
|
|
733
|
+
console.print(
|
|
734
|
+
f"[red]Error: Output directory already exists and is not empty: "
|
|
735
|
+
f"{output_path}[/red]"
|
|
736
|
+
)
|
|
737
|
+
raise typer.Exit(1)
|
|
738
|
+
|
|
739
|
+
# Show conversion info
|
|
740
|
+
console.print("[cyan]Converting MaskDataset to YOLO segmentation format[/cyan]")
|
|
741
|
+
console.print(f" Source: {input_path}")
|
|
742
|
+
console.print(f" Output: {output_path}")
|
|
743
|
+
console.print(f" Classes: {dataset.num_classes}")
|
|
744
|
+
splits_str = ", ".join(dataset.splits) if dataset.splits else "unsplit"
|
|
745
|
+
console.print(f" Splits: {splits_str}")
|
|
746
|
+
console.print()
|
|
747
|
+
|
|
748
|
+
# Run conversion with progress bar
|
|
749
|
+
with Progress(
|
|
750
|
+
SpinnerColumn(),
|
|
751
|
+
TextColumn("[progress.description]{task.description}"),
|
|
752
|
+
BarColumn(),
|
|
753
|
+
TaskProgressColumn(),
|
|
754
|
+
console=console,
|
|
755
|
+
) as progress:
|
|
756
|
+
task = progress.add_task("Processing images...", total=None)
|
|
757
|
+
|
|
758
|
+
def update_progress(current: int, total: int) -> None:
|
|
759
|
+
progress.update(task, completed=current, total=total)
|
|
760
|
+
|
|
761
|
+
try:
|
|
762
|
+
stats = convert_mask_to_yolo_seg(
|
|
763
|
+
dataset=dataset,
|
|
764
|
+
output_path=output_path,
|
|
765
|
+
epsilon_factor=epsilon_factor,
|
|
766
|
+
min_area=min_area,
|
|
767
|
+
progress_callback=update_progress,
|
|
768
|
+
)
|
|
769
|
+
except Exception as exc:
|
|
770
|
+
console.print(f"[red]Error during conversion: {exc}[/red]")
|
|
771
|
+
raise typer.Exit(1) from exc
|
|
772
|
+
|
|
773
|
+
# Show results
|
|
774
|
+
console.print()
|
|
775
|
+
console.print("[green]Conversion complete![/green]")
|
|
776
|
+
console.print(f" Images processed: {stats['images']}")
|
|
777
|
+
console.print(f" Labels created: {stats['labels']}")
|
|
778
|
+
console.print(f" Polygons extracted: {stats['polygons']}")
|
|
779
|
+
|
|
780
|
+
if stats["skipped"] > 0:
|
|
781
|
+
skipped = stats["skipped"]
|
|
782
|
+
console.print(f" [yellow]Skipped: {skipped} (no mask or empty)[/yellow]")
|
|
783
|
+
if stats["warnings"] > 0:
|
|
784
|
+
console.print(f" [yellow]Warnings: {stats['warnings']}[/yellow]")
|
|
785
|
+
|
|
786
|
+
console.print(f"\n[cyan]Output dataset: {output_path}[/cyan]")
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
@app.command(name="filter")
|
|
790
|
+
def filter_dataset(
|
|
791
|
+
dataset_path: Annotated[
|
|
792
|
+
Path,
|
|
793
|
+
typer.Option(
|
|
794
|
+
"--dataset-path",
|
|
795
|
+
"-d",
|
|
796
|
+
help="Path to the dataset root directory.",
|
|
797
|
+
),
|
|
798
|
+
] = Path("."),
|
|
799
|
+
output_path: Annotated[
|
|
800
|
+
Path,
|
|
801
|
+
typer.Option(
|
|
802
|
+
"--output",
|
|
803
|
+
"-o",
|
|
804
|
+
help="Output directory for filtered dataset.",
|
|
805
|
+
),
|
|
806
|
+
] = Path("filtered"),
|
|
807
|
+
classes: Annotated[
|
|
808
|
+
str,
|
|
809
|
+
typer.Option(
|
|
810
|
+
"--classes",
|
|
811
|
+
"-c",
|
|
812
|
+
help="Comma-separated list of class names to keep.",
|
|
813
|
+
),
|
|
814
|
+
] = "",
|
|
815
|
+
no_background: Annotated[
|
|
816
|
+
bool,
|
|
817
|
+
typer.Option(
|
|
818
|
+
"--no-background",
|
|
819
|
+
help="Exclude images with no annotations after filtering.",
|
|
820
|
+
),
|
|
821
|
+
] = False,
|
|
822
|
+
use_symlinks: Annotated[
|
|
823
|
+
bool,
|
|
824
|
+
typer.Option(
|
|
825
|
+
"--symlinks",
|
|
826
|
+
help="Use symlinks instead of copying images.",
|
|
827
|
+
),
|
|
828
|
+
] = False,
|
|
829
|
+
) -> None:
|
|
830
|
+
"""Filter a dataset by class names.
|
|
831
|
+
|
|
832
|
+
Creates a filtered copy of the dataset containing only the specified classes.
|
|
833
|
+
Class IDs are remapped to sequential values (0, 1, 2, ...).
|
|
834
|
+
|
|
835
|
+
Examples:
|
|
836
|
+
argus-cv filter -d dataset -o output --classes ball --no-background
|
|
837
|
+
argus-cv filter -d dataset -o output --classes ball,player
|
|
838
|
+
argus-cv filter -d dataset -o output --classes ball --symlinks
|
|
839
|
+
"""
|
|
840
|
+
# Resolve path and validate
|
|
841
|
+
dataset_path = dataset_path.resolve()
|
|
842
|
+
if not dataset_path.exists():
|
|
843
|
+
console.print(f"[red]Error: Path does not exist: {dataset_path}[/red]")
|
|
844
|
+
raise typer.Exit(1)
|
|
845
|
+
if not dataset_path.is_dir():
|
|
846
|
+
console.print(f"[red]Error: Path is not a directory: {dataset_path}[/red]")
|
|
847
|
+
raise typer.Exit(1)
|
|
848
|
+
|
|
849
|
+
# Parse classes
|
|
850
|
+
if not classes:
|
|
851
|
+
console.print(
|
|
852
|
+
"[red]Error: No classes specified. "
|
|
853
|
+
"Use --classes to specify classes to keep.[/red]"
|
|
854
|
+
)
|
|
855
|
+
raise typer.Exit(1)
|
|
856
|
+
|
|
857
|
+
class_list = [c.strip() for c in classes.split(",") if c.strip()]
|
|
858
|
+
if not class_list:
|
|
859
|
+
console.print("[red]Error: No valid class names provided.[/red]")
|
|
860
|
+
raise typer.Exit(1)
|
|
861
|
+
|
|
862
|
+
# Detect dataset
|
|
863
|
+
dataset = _detect_dataset(dataset_path)
|
|
864
|
+
if not dataset:
|
|
865
|
+
console.print(
|
|
866
|
+
f"[red]Error: No dataset found at {dataset_path}[/red]\n"
|
|
867
|
+
"[yellow]Ensure the path points to a dataset root containing "
|
|
868
|
+
"data.yaml (YOLO), annotations/ folder (COCO), or "
|
|
869
|
+
"images/ + masks/ directories (Mask).[/yellow]"
|
|
870
|
+
)
|
|
871
|
+
raise typer.Exit(1)
|
|
872
|
+
|
|
873
|
+
# Validate classes exist in dataset
|
|
874
|
+
missing_classes = [c for c in class_list if c not in dataset.class_names]
|
|
875
|
+
if missing_classes:
|
|
876
|
+
available = ", ".join(dataset.class_names)
|
|
877
|
+
missing = ", ".join(missing_classes)
|
|
878
|
+
console.print(
|
|
879
|
+
f"[red]Error: Classes not found in dataset: {missing}[/red]\n"
|
|
880
|
+
f"[yellow]Available classes: {available}[/yellow]"
|
|
881
|
+
)
|
|
882
|
+
raise typer.Exit(1)
|
|
883
|
+
|
|
884
|
+
# Resolve output path
|
|
885
|
+
if not output_path.is_absolute():
|
|
886
|
+
output_path = dataset_path.parent / output_path
|
|
887
|
+
output_path = output_path.resolve()
|
|
888
|
+
|
|
889
|
+
# Check if output already exists
|
|
890
|
+
if output_path.exists() and any(output_path.iterdir()):
|
|
891
|
+
console.print(
|
|
892
|
+
f"[red]Error: Output directory already exists and is not empty: "
|
|
893
|
+
f"{output_path}[/red]"
|
|
894
|
+
)
|
|
895
|
+
raise typer.Exit(1)
|
|
896
|
+
|
|
897
|
+
# Show filter info
|
|
898
|
+
console.print(f"[cyan]Filtering {dataset.format.value.upper()} dataset[/cyan]")
|
|
899
|
+
console.print(f" Source: {dataset_path}")
|
|
900
|
+
console.print(f" Output: {output_path}")
|
|
901
|
+
console.print(f" Classes to keep: {', '.join(class_list)}")
|
|
902
|
+
console.print(f" Exclude background: {no_background}")
|
|
903
|
+
console.print(f" Use symlinks: {use_symlinks}")
|
|
904
|
+
console.print()
|
|
905
|
+
|
|
906
|
+
# Run filtering with progress bar
|
|
907
|
+
with Progress(
|
|
908
|
+
SpinnerColumn(),
|
|
909
|
+
TextColumn("[progress.description]{task.description}"),
|
|
910
|
+
BarColumn(),
|
|
911
|
+
TaskProgressColumn(),
|
|
912
|
+
console=console,
|
|
913
|
+
) as progress:
|
|
914
|
+
task = progress.add_task("Filtering dataset...", total=None)
|
|
915
|
+
|
|
916
|
+
def update_progress(current: int, total: int) -> None:
|
|
917
|
+
progress.update(task, completed=current, total=total)
|
|
918
|
+
|
|
919
|
+
try:
|
|
920
|
+
if dataset.format == DatasetFormat.YOLO:
|
|
921
|
+
assert isinstance(dataset, YOLODataset)
|
|
922
|
+
stats = filter_yolo_dataset(
|
|
923
|
+
dataset=dataset,
|
|
924
|
+
output_path=output_path,
|
|
925
|
+
classes=class_list,
|
|
926
|
+
no_background=no_background,
|
|
927
|
+
use_symlinks=use_symlinks,
|
|
928
|
+
progress_callback=update_progress,
|
|
929
|
+
)
|
|
930
|
+
elif dataset.format == DatasetFormat.COCO:
|
|
931
|
+
assert isinstance(dataset, COCODataset)
|
|
932
|
+
stats = filter_coco_dataset(
|
|
933
|
+
dataset=dataset,
|
|
934
|
+
output_path=output_path,
|
|
935
|
+
classes=class_list,
|
|
936
|
+
no_background=no_background,
|
|
937
|
+
use_symlinks=use_symlinks,
|
|
938
|
+
progress_callback=update_progress,
|
|
939
|
+
)
|
|
940
|
+
elif dataset.format == DatasetFormat.MASK:
|
|
941
|
+
assert isinstance(dataset, MaskDataset)
|
|
942
|
+
stats = filter_mask_dataset(
|
|
943
|
+
dataset=dataset,
|
|
944
|
+
output_path=output_path,
|
|
945
|
+
classes=class_list,
|
|
946
|
+
no_background=no_background,
|
|
947
|
+
use_symlinks=use_symlinks,
|
|
948
|
+
progress_callback=update_progress,
|
|
949
|
+
)
|
|
950
|
+
else:
|
|
951
|
+
console.print(
|
|
952
|
+
f"[red]Error: Unsupported dataset format: {dataset.format}[/red]"
|
|
953
|
+
)
|
|
954
|
+
raise typer.Exit(1)
|
|
955
|
+
except ValueError as exc:
|
|
956
|
+
console.print(f"[red]Error: {exc}[/red]")
|
|
957
|
+
raise typer.Exit(1) from exc
|
|
958
|
+
except Exception as exc:
|
|
959
|
+
console.print(f"[red]Error during filtering: {exc}[/red]")
|
|
960
|
+
raise typer.Exit(1) from exc
|
|
961
|
+
|
|
962
|
+
# Show results
|
|
963
|
+
console.print()
|
|
964
|
+
console.print("[green]Filtering complete![/green]")
|
|
965
|
+
console.print(f" Images: {stats.get('images', 0)}")
|
|
966
|
+
if "labels" in stats:
|
|
967
|
+
console.print(f" Labels: {stats['labels']}")
|
|
968
|
+
if "annotations" in stats:
|
|
969
|
+
console.print(f" Annotations: {stats['annotations']}")
|
|
970
|
+
if "masks" in stats:
|
|
971
|
+
console.print(f" Masks: {stats['masks']}")
|
|
972
|
+
if stats.get("skipped", 0) > 0:
|
|
973
|
+
skipped = stats["skipped"]
|
|
974
|
+
console.print(f" [yellow]Skipped: {skipped} (background images)[/yellow]")
|
|
975
|
+
|
|
976
|
+
console.print(f"\n[cyan]Output dataset: {output_path}[/cyan]")
|
|
977
|
+
|
|
978
|
+
|
|
635
979
|
class _ImageViewer:
|
|
636
980
|
"""Interactive image viewer with zoom and pan support."""
|
|
637
981
|
|
argus/core/__init__.py
CHANGED
|
@@ -2,6 +2,18 @@
|
|
|
2
2
|
|
|
3
3
|
from argus.core.base import Dataset
|
|
4
4
|
from argus.core.coco import COCODataset
|
|
5
|
+
from argus.core.convert import (
|
|
6
|
+
ConversionParams,
|
|
7
|
+
Polygon,
|
|
8
|
+
convert_mask_to_yolo_labels,
|
|
9
|
+
convert_mask_to_yolo_seg,
|
|
10
|
+
mask_to_polygons,
|
|
11
|
+
)
|
|
12
|
+
from argus.core.filter import (
|
|
13
|
+
filter_coco_dataset,
|
|
14
|
+
filter_mask_dataset,
|
|
15
|
+
filter_yolo_dataset,
|
|
16
|
+
)
|
|
5
17
|
from argus.core.mask import ConfigurationError, MaskDataset
|
|
6
18
|
from argus.core.split import split_coco_dataset, split_yolo_dataset
|
|
7
19
|
from argus.core.yolo import YOLODataset
|
|
@@ -14,4 +26,12 @@ __all__ = [
|
|
|
14
26
|
"ConfigurationError",
|
|
15
27
|
"split_coco_dataset",
|
|
16
28
|
"split_yolo_dataset",
|
|
29
|
+
"filter_yolo_dataset",
|
|
30
|
+
"filter_coco_dataset",
|
|
31
|
+
"filter_mask_dataset",
|
|
32
|
+
"ConversionParams",
|
|
33
|
+
"Polygon",
|
|
34
|
+
"mask_to_polygons",
|
|
35
|
+
"convert_mask_to_yolo_labels",
|
|
36
|
+
"convert_mask_to_yolo_seg",
|
|
17
37
|
]
|
argus/core/coco.py
CHANGED
|
@@ -75,6 +75,13 @@ class COCODataset(Dataset):
|
|
|
75
75
|
# Also check root directory for single annotation file
|
|
76
76
|
annotation_files.extend(path.glob("*.json"))
|
|
77
77
|
|
|
78
|
+
# Check split directories for Roboflow COCO format
|
|
79
|
+
for split_name in ["train", "valid", "val", "test"]:
|
|
80
|
+
split_dir = path / split_name
|
|
81
|
+
if split_dir.is_dir():
|
|
82
|
+
annotation_files.extend(split_dir.glob("*annotations*.json"))
|
|
83
|
+
annotation_files.extend(split_dir.glob("*coco*.json"))
|
|
84
|
+
|
|
78
85
|
# Filter to only include files that might be COCO annotations
|
|
79
86
|
# (exclude package.json, tsconfig.json, etc.)
|
|
80
87
|
filtered_files = []
|
|
@@ -185,8 +192,10 @@ class COCODataset(Dataset):
|
|
|
185
192
|
if isinstance(cat, dict) and "id" in cat and "name" in cat:
|
|
186
193
|
id_to_name[cat["id"]] = cat["name"]
|
|
187
194
|
|
|
188
|
-
# Determine split from filename
|
|
189
|
-
split = self._get_split_from_filename(
|
|
195
|
+
# Determine split from filename or parent directory
|
|
196
|
+
split = self._get_split_from_filename(
|
|
197
|
+
ann_file.stem, ann_file.parent.name
|
|
198
|
+
)
|
|
190
199
|
|
|
191
200
|
# Count annotations per category
|
|
192
201
|
split_counts: dict[str, int] = counts.get(split, {})
|
|
@@ -224,7 +233,9 @@ class COCODataset(Dataset):
|
|
|
224
233
|
if not isinstance(data, dict):
|
|
225
234
|
continue
|
|
226
235
|
|
|
227
|
-
split = self._get_split_from_filename(
|
|
236
|
+
split = self._get_split_from_filename(
|
|
237
|
+
ann_file.stem, ann_file.parent.name
|
|
238
|
+
)
|
|
228
239
|
|
|
229
240
|
images = data.get("images", [])
|
|
230
241
|
annotations = data.get("annotations", [])
|
|
@@ -256,11 +267,12 @@ class COCODataset(Dataset):
|
|
|
256
267
|
return counts
|
|
257
268
|
|
|
258
269
|
@staticmethod
|
|
259
|
-
def _get_split_from_filename(filename: str) -> str:
|
|
260
|
-
"""Extract split name from annotation filename.
|
|
270
|
+
def _get_split_from_filename(filename: str, parent_dir: str | None = None) -> str:
|
|
271
|
+
"""Extract split name from annotation filename or parent directory.
|
|
261
272
|
|
|
262
273
|
Args:
|
|
263
274
|
filename: Annotation file stem (without extension).
|
|
275
|
+
parent_dir: Optional parent directory name (for Roboflow COCO format).
|
|
264
276
|
|
|
265
277
|
Returns:
|
|
266
278
|
Split name (train, val, test) or 'train' as default.
|
|
@@ -272,6 +284,17 @@ class COCODataset(Dataset):
|
|
|
272
284
|
return "val"
|
|
273
285
|
elif "test" in name_lower:
|
|
274
286
|
return "test"
|
|
287
|
+
|
|
288
|
+
# Check parent directory name (Roboflow COCO format)
|
|
289
|
+
if parent_dir:
|
|
290
|
+
parent_lower = parent_dir.lower()
|
|
291
|
+
if parent_lower == "train":
|
|
292
|
+
return "train"
|
|
293
|
+
elif parent_lower in ("val", "valid"):
|
|
294
|
+
return "val"
|
|
295
|
+
elif parent_lower == "test":
|
|
296
|
+
return "test"
|
|
297
|
+
|
|
275
298
|
return "train"
|
|
276
299
|
|
|
277
300
|
@classmethod
|
|
@@ -301,7 +324,7 @@ class COCODataset(Dataset):
|
|
|
301
324
|
|
|
302
325
|
@classmethod
|
|
303
326
|
def _detect_splits(cls, annotation_files: list[Path]) -> list[str]:
|
|
304
|
-
"""Detect available splits from annotation filenames.
|
|
327
|
+
"""Detect available splits from annotation filenames or parent directories.
|
|
305
328
|
|
|
306
329
|
Args:
|
|
307
330
|
annotation_files: List of annotation file paths.
|
|
@@ -313,13 +336,22 @@ class COCODataset(Dataset):
|
|
|
313
336
|
|
|
314
337
|
for ann_file in annotation_files:
|
|
315
338
|
name_lower = ann_file.stem.lower()
|
|
339
|
+
parent_lower = ann_file.parent.name.lower()
|
|
316
340
|
|
|
341
|
+
# Check filename first
|
|
317
342
|
if "train" in name_lower and "train" not in splits:
|
|
318
343
|
splits.append("train")
|
|
319
344
|
elif "val" in name_lower and "val" not in splits:
|
|
320
345
|
splits.append("val")
|
|
321
346
|
elif "test" in name_lower and "test" not in splits:
|
|
322
347
|
splits.append("test")
|
|
348
|
+
# Check parent directory (Roboflow COCO format)
|
|
349
|
+
elif parent_lower == "train" and "train" not in splits:
|
|
350
|
+
splits.append("train")
|
|
351
|
+
elif parent_lower in ("val", "valid") and "val" not in splits:
|
|
352
|
+
splits.append("val")
|
|
353
|
+
elif parent_lower == "test" and "test" not in splits:
|
|
354
|
+
splits.append("test")
|
|
323
355
|
|
|
324
356
|
# If no splits detected from filenames, default to train
|
|
325
357
|
if not splits:
|
|
@@ -342,7 +374,9 @@ class COCODataset(Dataset):
|
|
|
342
374
|
for ann_file in self.annotation_files:
|
|
343
375
|
# Filter by split if specified
|
|
344
376
|
if split:
|
|
345
|
-
file_split = self._get_split_from_filename(
|
|
377
|
+
file_split = self._get_split_from_filename(
|
|
378
|
+
ann_file.stem, ann_file.parent.name
|
|
379
|
+
)
|
|
346
380
|
if file_split != split:
|
|
347
381
|
continue
|
|
348
382
|
|
|
@@ -354,7 +388,9 @@ class COCODataset(Dataset):
|
|
|
354
388
|
continue
|
|
355
389
|
|
|
356
390
|
images = data.get("images", [])
|
|
357
|
-
file_split = self._get_split_from_filename(
|
|
391
|
+
file_split = self._get_split_from_filename(
|
|
392
|
+
ann_file.stem, ann_file.parent.name
|
|
393
|
+
)
|
|
358
394
|
|
|
359
395
|
for img in images:
|
|
360
396
|
if not isinstance(img, dict) or "file_name" not in img:
|
|
@@ -371,6 +407,8 @@ class COCODataset(Dataset):
|
|
|
371
407
|
self.path / "images" / file_name,
|
|
372
408
|
self.path / file_split / file_name,
|
|
373
409
|
self.path / file_name,
|
|
410
|
+
# Roboflow format: images alongside annotations
|
|
411
|
+
ann_file.parent / file_name,
|
|
374
412
|
]
|
|
375
413
|
|
|
376
414
|
for img_path in possible_paths:
|