argus-cv 1.1.0__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of argus-cv might be problematic. Click here for more details.
- {argus_cv-1.1.0 → argus_cv-1.2.0}/CHANGELOG.md +16 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/PKG-INFO +1 -1
- {argus_cv-1.1.0 → argus_cv-1.2.0}/pyproject.toml +1 -1
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/__init__.py +1 -1
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/cli.py +262 -35
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/yolo.py +355 -9
- {argus_cv-1.1.0 → argus_cv-1.2.0}/tests/conftest.py +71 -0
- argus_cv-1.2.0/tests/test_classification.py +205 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/.github/workflows/ci.yml +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/.github/workflows/docs.yml +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/.github/workflows/release.yml +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/.gitignore +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/README.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/assets/javascripts/extra.js +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/assets/stylesheets/extra.css +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/getting-started/installation.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/getting-started/quickstart.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/datasets.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/listing.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/splitting.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/stats.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/viewer.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/index.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/reference/cli.md +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/mkdocs.yml +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/__main__.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/commands/__init__.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/__init__.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/base.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/coco.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/split.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/tests/test_list_command.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/tests/test_split_command.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/tests/test_stats_command.py +0 -0
- {argus_cv-1.1.0 → argus_cv-1.2.0}/uv.lock +0 -0
|
@@ -2,6 +2,22 @@
|
|
|
2
2
|
|
|
3
3
|
<!-- version list -->
|
|
4
4
|
|
|
5
|
+
## v1.2.0 (2026-01-15)
|
|
6
|
+
|
|
7
|
+
### Code Style
|
|
8
|
+
|
|
9
|
+
- Fix ruff linting errors
|
|
10
|
+
([`b2d5ea2`](https://github.com/pirnerjonas/argus/commit/b2d5ea2c4d0715a474d4ffaa5be60d0499d200a2))
|
|
11
|
+
|
|
12
|
+
- Remove unused pytest import in test_classification.py
|
|
13
|
+
([`e22175a`](https://github.com/pirnerjonas/argus/commit/e22175a7378276dd2840754e672c7d4d1ed0e067))
|
|
14
|
+
|
|
15
|
+
### Features
|
|
16
|
+
|
|
17
|
+
- Add classification dataset support with grid viewer
|
|
18
|
+
([`8089bd3`](https://github.com/pirnerjonas/argus/commit/8089bd3367ede4c288b0276ac3da4d3ef9960c4d))
|
|
19
|
+
|
|
20
|
+
|
|
5
21
|
## v1.1.0 (2026-01-14)
|
|
6
22
|
|
|
7
23
|
### Code Style
|
|
@@ -12,6 +12,7 @@ from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
|
12
12
|
from rich.table import Table
|
|
13
13
|
|
|
14
14
|
from argus.core import COCODataset, Dataset, YOLODataset
|
|
15
|
+
from argus.core.base import TaskType
|
|
15
16
|
from argus.core.split import (
|
|
16
17
|
is_coco_unsplit,
|
|
17
18
|
parse_ratio,
|
|
@@ -238,19 +239,29 @@ def view(
|
|
|
238
239
|
help="Specific split to view (train, val, test).",
|
|
239
240
|
),
|
|
240
241
|
] = None,
|
|
242
|
+
max_classes: Annotated[
|
|
243
|
+
int | None,
|
|
244
|
+
typer.Option(
|
|
245
|
+
"--max-classes",
|
|
246
|
+
"-m",
|
|
247
|
+
help="Maximum classes to show in grid (classification only).",
|
|
248
|
+
),
|
|
249
|
+
] = None,
|
|
241
250
|
) -> None:
|
|
242
251
|
"""View annotated images in a dataset.
|
|
243
252
|
|
|
244
253
|
Opens an interactive viewer to browse images with their annotations
|
|
245
254
|
(bounding boxes and segmentation masks) overlaid.
|
|
246
255
|
|
|
256
|
+
For classification datasets, shows a grid view with one image per class.
|
|
257
|
+
|
|
247
258
|
Controls:
|
|
248
|
-
- Right Arrow / N: Next image
|
|
249
|
-
- Left Arrow / P: Previous image
|
|
250
|
-
- Mouse Wheel: Zoom in/out
|
|
251
|
-
- Mouse Drag: Pan when zoomed
|
|
252
|
-
- R: Reset zoom
|
|
253
|
-
- T: Toggle annotations
|
|
259
|
+
- Right Arrow / N: Next image(s)
|
|
260
|
+
- Left Arrow / P: Previous image(s)
|
|
261
|
+
- Mouse Wheel: Zoom in/out (detection/segmentation only)
|
|
262
|
+
- Mouse Drag: Pan when zoomed (detection/segmentation only)
|
|
263
|
+
- R: Reset zoom / Reset to first images
|
|
264
|
+
- T: Toggle annotations (detection/segmentation only)
|
|
254
265
|
- Q / ESC: Quit viewer
|
|
255
266
|
"""
|
|
256
267
|
# Resolve path and validate
|
|
@@ -281,39 +292,76 @@ def view(
|
|
|
281
292
|
)
|
|
282
293
|
raise typer.Exit(1)
|
|
283
294
|
|
|
284
|
-
#
|
|
285
|
-
|
|
286
|
-
SpinnerColumn(),
|
|
287
|
-
TextColumn("[progress.description]{task.description}"),
|
|
288
|
-
console=console,
|
|
289
|
-
transient=True,
|
|
290
|
-
) as progress:
|
|
291
|
-
progress.add_task("Loading images...", total=None)
|
|
292
|
-
image_paths = dataset.get_image_paths(split)
|
|
295
|
+
# Generate consistent colors for each class
|
|
296
|
+
class_colors = _generate_class_colors(dataset.class_names)
|
|
293
297
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
298
|
+
# Handle classification datasets with grid viewer
|
|
299
|
+
if dataset.task == TaskType.CLASSIFICATION:
|
|
300
|
+
# Use first split if specified, otherwise let get_images_by_class handle it
|
|
301
|
+
view_split = split if split else (dataset.splits[0] if dataset.splits else None)
|
|
297
302
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
303
|
+
with Progress(
|
|
304
|
+
SpinnerColumn(),
|
|
305
|
+
TextColumn("[progress.description]{task.description}"),
|
|
306
|
+
console=console,
|
|
307
|
+
transient=True,
|
|
308
|
+
) as progress:
|
|
309
|
+
progress.add_task("Loading images by class...", total=None)
|
|
310
|
+
images_by_class = dataset.get_images_by_class(view_split)
|
|
305
311
|
|
|
306
|
-
|
|
307
|
-
|
|
312
|
+
total_images = sum(len(imgs) for imgs in images_by_class.values())
|
|
313
|
+
if total_images == 0:
|
|
314
|
+
console.print("[yellow]No images found in the dataset.[/yellow]")
|
|
315
|
+
return
|
|
308
316
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
+
num_classes = len(dataset.class_names)
|
|
318
|
+
display_classes = min(num_classes, max_classes) if max_classes else num_classes
|
|
319
|
+
|
|
320
|
+
console.print(
|
|
321
|
+
f"[green]Found {total_images} images across {num_classes} classes "
|
|
322
|
+
f"(showing {display_classes}). Opening grid viewer...[/green]\n"
|
|
323
|
+
"[dim]Controls: ← / → or P / N to navigate all classes, "
|
|
324
|
+
"R to reset, Q / ESC to quit[/dim]"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
viewer = _ClassificationGridViewer(
|
|
328
|
+
images_by_class=images_by_class,
|
|
329
|
+
class_names=dataset.class_names,
|
|
330
|
+
class_colors=class_colors,
|
|
331
|
+
window_name=f"Argus Classification Viewer - {dataset_path.name}",
|
|
332
|
+
max_classes=max_classes,
|
|
333
|
+
)
|
|
334
|
+
viewer.run()
|
|
335
|
+
else:
|
|
336
|
+
# Detection/Segmentation viewer
|
|
337
|
+
with Progress(
|
|
338
|
+
SpinnerColumn(),
|
|
339
|
+
TextColumn("[progress.description]{task.description}"),
|
|
340
|
+
console=console,
|
|
341
|
+
transient=True,
|
|
342
|
+
) as progress:
|
|
343
|
+
progress.add_task("Loading images...", total=None)
|
|
344
|
+
image_paths = dataset.get_image_paths(split)
|
|
345
|
+
|
|
346
|
+
if not image_paths:
|
|
347
|
+
console.print("[yellow]No images found in the dataset.[/yellow]")
|
|
348
|
+
return
|
|
349
|
+
|
|
350
|
+
console.print(
|
|
351
|
+
f"[green]Found {len(image_paths)} images. "
|
|
352
|
+
f"Opening viewer...[/green]\n"
|
|
353
|
+
"[dim]Controls: ← / → or P / N to navigate, "
|
|
354
|
+
"Mouse wheel to zoom, Drag to pan, R to reset, T to toggle annotations, "
|
|
355
|
+
"Q / ESC to quit[/dim]"
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
viewer = _ImageViewer(
|
|
359
|
+
image_paths=image_paths,
|
|
360
|
+
dataset=dataset,
|
|
361
|
+
class_colors=class_colors,
|
|
362
|
+
window_name=f"Argus Viewer - {dataset_path.name}",
|
|
363
|
+
)
|
|
364
|
+
viewer.run()
|
|
317
365
|
|
|
318
366
|
console.print("[green]Viewer closed.[/green]")
|
|
319
367
|
|
|
@@ -656,6 +704,185 @@ class _ImageViewer:
|
|
|
656
704
|
cv2.destroyAllWindows()
|
|
657
705
|
|
|
658
706
|
|
|
707
|
+
class _ClassificationGridViewer:
|
|
708
|
+
"""Grid viewer for classification datasets showing one image per class."""
|
|
709
|
+
|
|
710
|
+
def __init__(
|
|
711
|
+
self,
|
|
712
|
+
images_by_class: dict[str, list[Path]],
|
|
713
|
+
class_names: list[str],
|
|
714
|
+
class_colors: dict[str, tuple[int, int, int]],
|
|
715
|
+
window_name: str,
|
|
716
|
+
max_classes: int | None = None,
|
|
717
|
+
tile_size: int = 300,
|
|
718
|
+
):
|
|
719
|
+
# Limit classes if max_classes specified
|
|
720
|
+
if max_classes and len(class_names) > max_classes:
|
|
721
|
+
self.class_names = class_names[:max_classes]
|
|
722
|
+
else:
|
|
723
|
+
self.class_names = class_names
|
|
724
|
+
|
|
725
|
+
self.images_by_class = {
|
|
726
|
+
cls: images_by_class.get(cls, []) for cls in self.class_names
|
|
727
|
+
}
|
|
728
|
+
self.class_colors = class_colors
|
|
729
|
+
self.window_name = window_name
|
|
730
|
+
self.tile_size = tile_size
|
|
731
|
+
|
|
732
|
+
# Global image index (same for all classes)
|
|
733
|
+
self.current_index = 0
|
|
734
|
+
|
|
735
|
+
# Calculate max images across all classes
|
|
736
|
+
self.max_images = max(
|
|
737
|
+
len(imgs) for imgs in self.images_by_class.values()
|
|
738
|
+
) if self.images_by_class else 0
|
|
739
|
+
|
|
740
|
+
# Calculate grid layout
|
|
741
|
+
self.cols, self.rows = self._calculate_grid_layout()
|
|
742
|
+
|
|
743
|
+
def _calculate_grid_layout(self) -> tuple[int, int]:
|
|
744
|
+
"""Calculate optimal grid layout based on number of classes."""
|
|
745
|
+
n = len(self.class_names)
|
|
746
|
+
if n <= 0:
|
|
747
|
+
return 1, 1
|
|
748
|
+
|
|
749
|
+
# Try to make a roughly square grid
|
|
750
|
+
import math
|
|
751
|
+
|
|
752
|
+
cols = int(math.ceil(math.sqrt(n)))
|
|
753
|
+
rows = int(math.ceil(n / cols))
|
|
754
|
+
return cols, rows
|
|
755
|
+
|
|
756
|
+
def _create_tile(
|
|
757
|
+
self, class_name: str, image_path: Path | None, index: int, total: int
|
|
758
|
+
) -> np.ndarray:
|
|
759
|
+
"""Create a single tile for a class."""
|
|
760
|
+
tile = np.zeros((self.tile_size, self.tile_size, 3), dtype=np.uint8)
|
|
761
|
+
|
|
762
|
+
if image_path is not None and image_path.exists():
|
|
763
|
+
# Load and resize image
|
|
764
|
+
img = cv2.imread(str(image_path))
|
|
765
|
+
if img is not None:
|
|
766
|
+
# Resize maintaining aspect ratio
|
|
767
|
+
h, w = img.shape[:2]
|
|
768
|
+
scale = min(self.tile_size / w, self.tile_size / h)
|
|
769
|
+
new_w = int(w * scale)
|
|
770
|
+
new_h = int(h * scale)
|
|
771
|
+
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
|
772
|
+
|
|
773
|
+
# Center in tile
|
|
774
|
+
x_offset = (self.tile_size - new_w) // 2
|
|
775
|
+
y_offset = (self.tile_size - new_h) // 2
|
|
776
|
+
tile[y_offset : y_offset + new_h, x_offset : x_offset + new_w] = resized
|
|
777
|
+
|
|
778
|
+
# Draw label at top: "class_name (N/M)"
|
|
779
|
+
if image_path is not None:
|
|
780
|
+
label = f"{class_name} ({index + 1}/{total})"
|
|
781
|
+
else:
|
|
782
|
+
label = f"{class_name} (-/{total})"
|
|
783
|
+
|
|
784
|
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
785
|
+
font_scale = 0.5
|
|
786
|
+
thickness = 1
|
|
787
|
+
(label_w, label_h), baseline = cv2.getTextSize(
|
|
788
|
+
label, font, font_scale, thickness
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
# Semi-transparent background for label
|
|
792
|
+
overlay = tile.copy()
|
|
793
|
+
label_bg_height = label_h + baseline + 10
|
|
794
|
+
cv2.rectangle(overlay, (0, 0), (self.tile_size, label_bg_height), (0, 0, 0), -1)
|
|
795
|
+
cv2.addWeighted(overlay, 0.6, tile, 0.4, 0, tile)
|
|
796
|
+
|
|
797
|
+
cv2.putText(
|
|
798
|
+
tile,
|
|
799
|
+
label,
|
|
800
|
+
(5, label_h + 5),
|
|
801
|
+
font,
|
|
802
|
+
font_scale,
|
|
803
|
+
(255, 255, 255),
|
|
804
|
+
thickness,
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
# Draw thin border
|
|
808
|
+
border_end = self.tile_size - 1
|
|
809
|
+
cv2.rectangle(tile, (0, 0), (border_end, border_end), (80, 80, 80), 1)
|
|
810
|
+
|
|
811
|
+
return tile
|
|
812
|
+
|
|
813
|
+
def _compose_grid(self) -> np.ndarray:
|
|
814
|
+
"""Compose all tiles into a single grid image."""
|
|
815
|
+
grid_h = self.rows * self.tile_size
|
|
816
|
+
grid_w = self.cols * self.tile_size
|
|
817
|
+
grid = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
|
|
818
|
+
|
|
819
|
+
for i, class_name in enumerate(self.class_names):
|
|
820
|
+
row = i // self.cols
|
|
821
|
+
col = i % self.cols
|
|
822
|
+
|
|
823
|
+
images = self.images_by_class[class_name]
|
|
824
|
+
total = len(images)
|
|
825
|
+
|
|
826
|
+
# Use global index - show black tile if class doesn't have this image
|
|
827
|
+
if self.current_index < total:
|
|
828
|
+
image_path = images[self.current_index]
|
|
829
|
+
display_index = self.current_index
|
|
830
|
+
else:
|
|
831
|
+
image_path = None
|
|
832
|
+
display_index = self.current_index
|
|
833
|
+
|
|
834
|
+
tile = self._create_tile(class_name, image_path, display_index, total)
|
|
835
|
+
|
|
836
|
+
y_start = row * self.tile_size
|
|
837
|
+
x_start = col * self.tile_size
|
|
838
|
+
y_end = y_start + self.tile_size
|
|
839
|
+
x_end = x_start + self.tile_size
|
|
840
|
+
grid[y_start:y_end, x_start:x_end] = tile
|
|
841
|
+
|
|
842
|
+
return grid
|
|
843
|
+
|
|
844
|
+
def _next_images(self) -> None:
|
|
845
|
+
"""Advance to next image index."""
|
|
846
|
+
if self.max_images > 0:
|
|
847
|
+
self.current_index = min(self.current_index + 1, self.max_images - 1)
|
|
848
|
+
|
|
849
|
+
def _prev_images(self) -> None:
|
|
850
|
+
"""Go back to previous image index."""
|
|
851
|
+
self.current_index = max(self.current_index - 1, 0)
|
|
852
|
+
|
|
853
|
+
def _reset_indices(self) -> None:
|
|
854
|
+
"""Reset to first image."""
|
|
855
|
+
self.current_index = 0
|
|
856
|
+
|
|
857
|
+
def run(self) -> None:
|
|
858
|
+
"""Run the interactive grid viewer."""
|
|
859
|
+
if not self.class_names:
|
|
860
|
+
console.print("[yellow]No classes to display.[/yellow]")
|
|
861
|
+
return
|
|
862
|
+
|
|
863
|
+
cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
|
|
864
|
+
|
|
865
|
+
while True:
|
|
866
|
+
# Compose and display grid
|
|
867
|
+
grid = self._compose_grid()
|
|
868
|
+
cv2.imshow(self.window_name, grid)
|
|
869
|
+
|
|
870
|
+
# Wait for input
|
|
871
|
+
key = cv2.waitKey(30) & 0xFF
|
|
872
|
+
|
|
873
|
+
# Handle keyboard input
|
|
874
|
+
if key == ord("q") or key == 27: # Q or ESC
|
|
875
|
+
break
|
|
876
|
+
elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
|
|
877
|
+
self._next_images()
|
|
878
|
+
elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
|
|
879
|
+
self._prev_images()
|
|
880
|
+
elif key == ord("r"): # R to reset
|
|
881
|
+
self._reset_indices()
|
|
882
|
+
|
|
883
|
+
cv2.destroyAllWindows()
|
|
884
|
+
|
|
885
|
+
|
|
659
886
|
def _generate_class_colors(class_names: list[str]) -> dict[str, tuple[int, int, int]]:
|
|
660
887
|
"""Generate consistent colors for each class name.
|
|
661
888
|
|
|
@@ -12,9 +12,9 @@ from argus.core.base import Dataset, DatasetFormat, TaskType
|
|
|
12
12
|
class YOLODataset(Dataset):
|
|
13
13
|
"""YOLO format dataset.
|
|
14
14
|
|
|
15
|
-
Supports detection and
|
|
15
|
+
Supports detection, segmentation, and classification tasks.
|
|
16
16
|
|
|
17
|
-
Structure:
|
|
17
|
+
Structure (detection/segmentation):
|
|
18
18
|
dataset/
|
|
19
19
|
├── data.yaml (or *.yaml/*.yml with 'names' key)
|
|
20
20
|
├── images/
|
|
@@ -23,6 +23,19 @@ class YOLODataset(Dataset):
|
|
|
23
23
|
└── labels/
|
|
24
24
|
├── train/
|
|
25
25
|
└── val/
|
|
26
|
+
|
|
27
|
+
Structure (classification):
|
|
28
|
+
dataset/
|
|
29
|
+
├── images/
|
|
30
|
+
│ ├── train/
|
|
31
|
+
│ │ ├── class1/
|
|
32
|
+
│ │ │ ├── img1.jpg
|
|
33
|
+
│ │ │ └── img2.jpg
|
|
34
|
+
│ │ └── class2/
|
|
35
|
+
│ │ └── img1.jpg
|
|
36
|
+
│ └── val/
|
|
37
|
+
│ ├── class1/
|
|
38
|
+
│ └── class2/
|
|
26
39
|
"""
|
|
27
40
|
|
|
28
41
|
config_file: Path | None = None
|
|
@@ -43,8 +56,13 @@ class YOLODataset(Dataset):
|
|
|
43
56
|
if not path.is_dir():
|
|
44
57
|
return None
|
|
45
58
|
|
|
46
|
-
# Try detection/segmentation (YAML-based)
|
|
47
|
-
|
|
59
|
+
# Try detection/segmentation (YAML-based) first
|
|
60
|
+
result = cls._detect_yaml_based(path)
|
|
61
|
+
if result:
|
|
62
|
+
return result
|
|
63
|
+
|
|
64
|
+
# Try classification (directory-based structure)
|
|
65
|
+
return cls._detect_classification(path)
|
|
48
66
|
|
|
49
67
|
@classmethod
|
|
50
68
|
def _detect_yaml_based(cls, path: Path) -> "YOLODataset | None":
|
|
@@ -103,16 +121,121 @@ class YOLODataset(Dataset):
|
|
|
103
121
|
|
|
104
122
|
return None
|
|
105
123
|
|
|
124
|
+
@classmethod
|
|
125
|
+
def _detect_classification(cls, path: Path) -> "YOLODataset | None":
|
|
126
|
+
"""Detect classification dataset from directory structure.
|
|
127
|
+
|
|
128
|
+
Classification datasets can have two structures:
|
|
129
|
+
|
|
130
|
+
1. Split structure:
|
|
131
|
+
images/{split}/class_name/image.jpg
|
|
132
|
+
|
|
133
|
+
2. Flat structure (unsplit):
|
|
134
|
+
class_name/image.jpg
|
|
135
|
+
|
|
136
|
+
No YAML config required - class names inferred from directory names.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
path: Directory path to check.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
YOLODataset if classification structure found, None otherwise.
|
|
143
|
+
"""
|
|
144
|
+
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
145
|
+
|
|
146
|
+
# Try split structure first: images/{split}/class/
|
|
147
|
+
images_root = path / "images"
|
|
148
|
+
if images_root.is_dir():
|
|
149
|
+
splits: list[str] = []
|
|
150
|
+
class_names_set: set[str] = set()
|
|
151
|
+
|
|
152
|
+
for split_name in ["train", "val", "test"]:
|
|
153
|
+
split_dir = images_root / split_name
|
|
154
|
+
if not split_dir.is_dir():
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
# Get subdirectories (potential class folders)
|
|
158
|
+
class_dirs = [d for d in split_dir.iterdir() if d.is_dir()]
|
|
159
|
+
if not class_dirs:
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
# Check if at least one class dir contains images
|
|
163
|
+
has_images = False
|
|
164
|
+
for class_dir in class_dirs:
|
|
165
|
+
for f in class_dir.iterdir():
|
|
166
|
+
if f.suffix.lower() in image_extensions:
|
|
167
|
+
has_images = True
|
|
168
|
+
break
|
|
169
|
+
if has_images:
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
if has_images:
|
|
173
|
+
splits.append(split_name)
|
|
174
|
+
class_names_set.update(d.name for d in class_dirs)
|
|
175
|
+
|
|
176
|
+
if splits and class_names_set:
|
|
177
|
+
class_names = sorted(class_names_set)
|
|
178
|
+
return cls(
|
|
179
|
+
path=path,
|
|
180
|
+
task=TaskType.CLASSIFICATION,
|
|
181
|
+
num_classes=len(class_names),
|
|
182
|
+
class_names=class_names,
|
|
183
|
+
splits=splits,
|
|
184
|
+
config_file=None,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Try flat structure: class_name/image.jpg (no images/ or split dirs)
|
|
188
|
+
# Check if root contains subdirectories with images
|
|
189
|
+
class_dirs = [d for d in path.iterdir() if d.is_dir()]
|
|
190
|
+
|
|
191
|
+
# Filter out common non-class directories
|
|
192
|
+
excluded_dirs = {"images", "labels", "annotations", ".git", "__pycache__"}
|
|
193
|
+
class_dirs = [d for d in class_dirs if d.name not in excluded_dirs]
|
|
194
|
+
|
|
195
|
+
if not class_dirs:
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
# Check if these are class directories (contain images directly)
|
|
199
|
+
class_names_set = set()
|
|
200
|
+
for class_dir in class_dirs:
|
|
201
|
+
has_images = any(
|
|
202
|
+
f.suffix.lower() in image_extensions
|
|
203
|
+
for f in class_dir.iterdir()
|
|
204
|
+
if f.is_file()
|
|
205
|
+
)
|
|
206
|
+
if has_images:
|
|
207
|
+
class_names_set.add(class_dir.name)
|
|
208
|
+
|
|
209
|
+
# Need at least 2 classes to be a valid classification dataset
|
|
210
|
+
if len(class_names_set) < 2:
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
class_names = sorted(class_names_set)
|
|
214
|
+
return cls(
|
|
215
|
+
path=path,
|
|
216
|
+
task=TaskType.CLASSIFICATION,
|
|
217
|
+
num_classes=len(class_names),
|
|
218
|
+
class_names=class_names,
|
|
219
|
+
splits=[], # No splits for flat structure
|
|
220
|
+
config_file=None,
|
|
221
|
+
)
|
|
222
|
+
|
|
106
223
|
def get_instance_counts(self) -> dict[str, dict[str, int]]:
|
|
107
224
|
"""Get the number of annotation instances per class, per split.
|
|
108
225
|
|
|
109
|
-
Parses all label files in labels/{split}/*.txt
|
|
110
|
-
occurrences of each class ID.
|
|
111
|
-
|
|
226
|
+
For detection/segmentation: Parses all label files in labels/{split}/*.txt
|
|
227
|
+
and counts occurrences of each class ID.
|
|
228
|
+
|
|
229
|
+
For classification: Counts images in each class directory
|
|
230
|
+
(1 image = 1 instance).
|
|
112
231
|
|
|
113
232
|
Returns:
|
|
114
233
|
Dictionary mapping split name to dict of class name to instance count.
|
|
115
234
|
"""
|
|
235
|
+
# Handle classification datasets differently
|
|
236
|
+
if self.task == TaskType.CLASSIFICATION:
|
|
237
|
+
return self._get_classification_instance_counts()
|
|
238
|
+
|
|
116
239
|
counts: dict[str, dict[str, int]] = {}
|
|
117
240
|
|
|
118
241
|
# Build class_id -> class_name mapping
|
|
@@ -162,15 +285,77 @@ class YOLODataset(Dataset):
|
|
|
162
285
|
|
|
163
286
|
return counts
|
|
164
287
|
|
|
288
|
+
def _get_classification_instance_counts(self) -> dict[str, dict[str, int]]:
|
|
289
|
+
"""Get instance counts for classification datasets.
|
|
290
|
+
|
|
291
|
+
Each image is one instance of its class.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
Dictionary mapping split name to dict of class name to image count.
|
|
295
|
+
"""
|
|
296
|
+
counts: dict[str, dict[str, int]] = {}
|
|
297
|
+
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
298
|
+
|
|
299
|
+
# Handle flat structure (no splits)
|
|
300
|
+
if not self.splits:
|
|
301
|
+
split_counts: dict[str, int] = {}
|
|
302
|
+
for class_name in self.class_names:
|
|
303
|
+
class_dir = self.path / class_name
|
|
304
|
+
if not class_dir.is_dir():
|
|
305
|
+
split_counts[class_name] = 0
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
image_count = sum(
|
|
309
|
+
1
|
|
310
|
+
for f in class_dir.iterdir()
|
|
311
|
+
if f.suffix.lower() in image_extensions
|
|
312
|
+
)
|
|
313
|
+
split_counts[class_name] = image_count
|
|
314
|
+
|
|
315
|
+
counts["unsplit"] = split_counts
|
|
316
|
+
return counts
|
|
317
|
+
|
|
318
|
+
# Handle split structure
|
|
319
|
+
images_root = self.path / "images"
|
|
320
|
+
for split in self.splits:
|
|
321
|
+
split_dir = images_root / split
|
|
322
|
+
if not split_dir.is_dir():
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
split_counts = {}
|
|
326
|
+
for class_name in self.class_names:
|
|
327
|
+
class_dir = split_dir / class_name
|
|
328
|
+
if not class_dir.is_dir():
|
|
329
|
+
split_counts[class_name] = 0
|
|
330
|
+
continue
|
|
331
|
+
|
|
332
|
+
image_count = sum(
|
|
333
|
+
1
|
|
334
|
+
for f in class_dir.iterdir()
|
|
335
|
+
if f.suffix.lower() in image_extensions
|
|
336
|
+
)
|
|
337
|
+
split_counts[class_name] = image_count
|
|
338
|
+
|
|
339
|
+
counts[split] = split_counts
|
|
340
|
+
|
|
341
|
+
return counts
|
|
342
|
+
|
|
165
343
|
def get_image_counts(self) -> dict[str, dict[str, int]]:
|
|
166
344
|
"""Get image counts per split, including background images.
|
|
167
345
|
|
|
168
|
-
Counts label files in labels/{split}/*.txt.
|
|
169
|
-
counted as background images.
|
|
346
|
+
For detection/segmentation: Counts label files in labels/{split}/*.txt.
|
|
347
|
+
Empty files are counted as background images.
|
|
348
|
+
|
|
349
|
+
For classification: Counts total images across all class directories.
|
|
350
|
+
Background count is always 0 (no background concept in classification).
|
|
170
351
|
|
|
171
352
|
Returns:
|
|
172
353
|
Dictionary mapping split name to dict with "total" and "background" counts.
|
|
173
354
|
"""
|
|
355
|
+
# Handle classification datasets differently
|
|
356
|
+
if self.task == TaskType.CLASSIFICATION:
|
|
357
|
+
return self._get_classification_image_counts()
|
|
358
|
+
|
|
174
359
|
counts: dict[str, dict[str, int]] = {}
|
|
175
360
|
|
|
176
361
|
labels_root = self.path / "labels"
|
|
@@ -203,6 +388,56 @@ class YOLODataset(Dataset):
|
|
|
203
388
|
|
|
204
389
|
return counts
|
|
205
390
|
|
|
391
|
+
def _get_classification_image_counts(self) -> dict[str, dict[str, int]]:
|
|
392
|
+
"""Get image counts for classification datasets.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
Dictionary mapping split name to dict with "total" and "background" counts.
|
|
396
|
+
Background is always 0 for classification.
|
|
397
|
+
"""
|
|
398
|
+
counts: dict[str, dict[str, int]] = {}
|
|
399
|
+
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
400
|
+
|
|
401
|
+
# Handle flat structure (no splits)
|
|
402
|
+
if not self.splits:
|
|
403
|
+
total = 0
|
|
404
|
+
for class_name in self.class_names:
|
|
405
|
+
class_dir = self.path / class_name
|
|
406
|
+
if not class_dir.is_dir():
|
|
407
|
+
continue
|
|
408
|
+
|
|
409
|
+
total += sum(
|
|
410
|
+
1
|
|
411
|
+
for f in class_dir.iterdir()
|
|
412
|
+
if f.suffix.lower() in image_extensions
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
counts["unsplit"] = {"total": total, "background": 0}
|
|
416
|
+
return counts
|
|
417
|
+
|
|
418
|
+
# Handle split structure
|
|
419
|
+
images_root = self.path / "images"
|
|
420
|
+
for split in self.splits:
|
|
421
|
+
split_dir = images_root / split
|
|
422
|
+
if not split_dir.is_dir():
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
total = 0
|
|
426
|
+
for class_name in self.class_names:
|
|
427
|
+
class_dir = split_dir / class_name
|
|
428
|
+
if not class_dir.is_dir():
|
|
429
|
+
continue
|
|
430
|
+
|
|
431
|
+
total += sum(
|
|
432
|
+
1
|
|
433
|
+
for f in class_dir.iterdir()
|
|
434
|
+
if f.suffix.lower() in image_extensions
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
counts[split] = {"total": total, "background": 0}
|
|
438
|
+
|
|
439
|
+
return counts
|
|
440
|
+
|
|
206
441
|
@classmethod
|
|
207
442
|
def _detect_splits(cls, path: Path, config: dict) -> list[str]:
|
|
208
443
|
"""Detect available splits from config and filesystem.
|
|
@@ -301,6 +536,10 @@ class YOLODataset(Dataset):
|
|
|
301
536
|
Returns:
|
|
302
537
|
List of image file paths sorted alphabetically.
|
|
303
538
|
"""
|
|
539
|
+
# Handle classification datasets differently
|
|
540
|
+
if self.task == TaskType.CLASSIFICATION:
|
|
541
|
+
return self._get_classification_image_paths(split)
|
|
542
|
+
|
|
304
543
|
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
305
544
|
images_root = self.path / "images"
|
|
306
545
|
image_paths: list[Path] = []
|
|
@@ -345,6 +584,51 @@ class YOLODataset(Dataset):
|
|
|
345
584
|
|
|
346
585
|
return sorted(image_paths, key=lambda p: p.name)
|
|
347
586
|
|
|
587
|
+
def _get_classification_image_paths(self, split: str | None = None) -> list[Path]:
|
|
588
|
+
"""Get image paths for classification datasets.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
split: Specific split to get images from. If None, returns all images.
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
List of image file paths sorted alphabetically.
|
|
595
|
+
"""
|
|
596
|
+
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
597
|
+
image_paths: list[Path] = []
|
|
598
|
+
|
|
599
|
+
# Handle flat structure (no splits)
|
|
600
|
+
if not self.splits:
|
|
601
|
+
for class_name in self.class_names:
|
|
602
|
+
class_dir = self.path / class_name
|
|
603
|
+
if not class_dir.is_dir():
|
|
604
|
+
continue
|
|
605
|
+
|
|
606
|
+
for img_file in class_dir.iterdir():
|
|
607
|
+
if img_file.suffix.lower() in image_extensions:
|
|
608
|
+
image_paths.append(img_file)
|
|
609
|
+
|
|
610
|
+
return sorted(image_paths, key=lambda p: p.name)
|
|
611
|
+
|
|
612
|
+
# Handle split structure
|
|
613
|
+
images_root = self.path / "images"
|
|
614
|
+
splits_to_search = [split] if split else self.splits
|
|
615
|
+
|
|
616
|
+
for s in splits_to_search:
|
|
617
|
+
split_dir = images_root / s
|
|
618
|
+
if not split_dir.is_dir():
|
|
619
|
+
continue
|
|
620
|
+
|
|
621
|
+
for class_name in self.class_names:
|
|
622
|
+
class_dir = split_dir / class_name
|
|
623
|
+
if not class_dir.is_dir():
|
|
624
|
+
continue
|
|
625
|
+
|
|
626
|
+
for img_file in class_dir.iterdir():
|
|
627
|
+
if img_file.suffix.lower() in image_extensions:
|
|
628
|
+
image_paths.append(img_file)
|
|
629
|
+
|
|
630
|
+
return sorted(image_paths, key=lambda p: p.name)
|
|
631
|
+
|
|
348
632
|
def get_annotations_for_image(self, image_path: Path) -> list[dict]:
|
|
349
633
|
"""Get annotations for a specific image.
|
|
350
634
|
|
|
@@ -445,3 +729,65 @@ class YOLODataset(Dataset):
|
|
|
445
729
|
pass
|
|
446
730
|
|
|
447
731
|
return annotations
|
|
732
|
+
|
|
733
|
+
def get_images_by_class(self, split: str | None = None) -> dict[str, list[Path]]:
|
|
734
|
+
"""Get images grouped by class for classification datasets.
|
|
735
|
+
|
|
736
|
+
Args:
|
|
737
|
+
split: Specific split to get images from. If None, uses first
|
|
738
|
+
available split or all images for flat structure.
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
Dictionary mapping class name to list of image paths.
|
|
742
|
+
"""
|
|
743
|
+
if self.task != TaskType.CLASSIFICATION:
|
|
744
|
+
return {}
|
|
745
|
+
|
|
746
|
+
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
747
|
+
images_by_class: dict[str, list[Path]] = {cls: [] for cls in self.class_names}
|
|
748
|
+
|
|
749
|
+
# Handle flat structure (no splits)
|
|
750
|
+
if not self.splits:
|
|
751
|
+
for class_name in self.class_names:
|
|
752
|
+
class_dir = self.path / class_name
|
|
753
|
+
if not class_dir.is_dir():
|
|
754
|
+
continue
|
|
755
|
+
|
|
756
|
+
for img_file in class_dir.iterdir():
|
|
757
|
+
if img_file.suffix.lower() in image_extensions:
|
|
758
|
+
images_by_class[class_name].append(img_file)
|
|
759
|
+
|
|
760
|
+
# Sort images within each class
|
|
761
|
+
for class_name in images_by_class:
|
|
762
|
+
images_by_class[class_name] = sorted(
|
|
763
|
+
images_by_class[class_name], key=lambda p: p.name
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
return images_by_class
|
|
767
|
+
|
|
768
|
+
# Handle split structure
|
|
769
|
+
images_root = self.path / "images"
|
|
770
|
+
default_splits = self.splits[:1] if self.splits else []
|
|
771
|
+
splits_to_search = [split] if split else default_splits
|
|
772
|
+
|
|
773
|
+
for s in splits_to_search:
|
|
774
|
+
split_dir = images_root / s
|
|
775
|
+
if not split_dir.is_dir():
|
|
776
|
+
continue
|
|
777
|
+
|
|
778
|
+
for class_name in self.class_names:
|
|
779
|
+
class_dir = split_dir / class_name
|
|
780
|
+
if not class_dir.is_dir():
|
|
781
|
+
continue
|
|
782
|
+
|
|
783
|
+
for img_file in class_dir.iterdir():
|
|
784
|
+
if img_file.suffix.lower() in image_extensions:
|
|
785
|
+
images_by_class[class_name].append(img_file)
|
|
786
|
+
|
|
787
|
+
# Sort images within each class for consistent ordering
|
|
788
|
+
for class_name in images_by_class:
|
|
789
|
+
images_by_class[class_name] = sorted(
|
|
790
|
+
images_by_class[class_name], key=lambda p: p.name
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
return images_by_class
|
|
@@ -389,3 +389,74 @@ names:
|
|
|
389
389
|
(annotations_dir / "annotations.json").write_text(json.dumps(coco_data))
|
|
390
390
|
|
|
391
391
|
return root_path
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@pytest.fixture
|
|
395
|
+
def yolo_classification_dataset(tmp_path: Path) -> Path:
|
|
396
|
+
"""Create a valid YOLO classification dataset.
|
|
397
|
+
|
|
398
|
+
Structure:
|
|
399
|
+
dataset/
|
|
400
|
+
└── images/
|
|
401
|
+
├── train/
|
|
402
|
+
│ ├── cat/
|
|
403
|
+
│ │ ├── img001.jpg
|
|
404
|
+
│ │ └── img002.jpg
|
|
405
|
+
│ └── dog/
|
|
406
|
+
│ └── img001.jpg
|
|
407
|
+
└── val/
|
|
408
|
+
├── cat/
|
|
409
|
+
│ └── img001.jpg
|
|
410
|
+
└── dog/
|
|
411
|
+
└── img001.jpg
|
|
412
|
+
"""
|
|
413
|
+
dataset_path = tmp_path / "yolo_classification"
|
|
414
|
+
dataset_path.mkdir()
|
|
415
|
+
|
|
416
|
+
# Create directory structure with class subdirectories
|
|
417
|
+
(dataset_path / "images" / "train" / "cat").mkdir(parents=True)
|
|
418
|
+
(dataset_path / "images" / "train" / "dog").mkdir(parents=True)
|
|
419
|
+
(dataset_path / "images" / "val" / "cat").mkdir(parents=True)
|
|
420
|
+
(dataset_path / "images" / "val" / "dog").mkdir(parents=True)
|
|
421
|
+
|
|
422
|
+
# Create dummy images in each class directory
|
|
423
|
+
# Train split
|
|
424
|
+
(dataset_path / "images" / "train" / "cat" / "img001.jpg").write_bytes(b"fake cat")
|
|
425
|
+
(dataset_path / "images" / "train" / "cat" / "img002.jpg").write_bytes(b"fake cat")
|
|
426
|
+
(dataset_path / "images" / "train" / "dog" / "img001.jpg").write_bytes(b"fake dog")
|
|
427
|
+
|
|
428
|
+
# Val split
|
|
429
|
+
(dataset_path / "images" / "val" / "cat" / "img001.jpg").write_bytes(b"fake cat")
|
|
430
|
+
(dataset_path / "images" / "val" / "dog" / "img001.jpg").write_bytes(b"fake dog")
|
|
431
|
+
|
|
432
|
+
return dataset_path
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@pytest.fixture
|
|
436
|
+
def yolo_classification_multiclass_dataset(tmp_path: Path) -> Path:
|
|
437
|
+
"""Create a YOLO classification dataset with more classes.
|
|
438
|
+
|
|
439
|
+
Structure:
|
|
440
|
+
dataset/
|
|
441
|
+
└── images/
|
|
442
|
+
└── train/
|
|
443
|
+
├── class1/
|
|
444
|
+
├── class2/
|
|
445
|
+
├── class3/
|
|
446
|
+
└── class4/
|
|
447
|
+
"""
|
|
448
|
+
dataset_path = tmp_path / "yolo_cls_multiclass"
|
|
449
|
+
dataset_path.mkdir()
|
|
450
|
+
|
|
451
|
+
classes = ["apple", "banana", "cherry", "date"]
|
|
452
|
+
|
|
453
|
+
for cls in classes:
|
|
454
|
+
(dataset_path / "images" / "train" / cls).mkdir(parents=True)
|
|
455
|
+
# Add varying number of images per class
|
|
456
|
+
num_images = classes.index(cls) + 1
|
|
457
|
+
for i in range(num_images):
|
|
458
|
+
(dataset_path / "images" / "train" / cls / f"img{i:03d}.jpg").write_bytes(
|
|
459
|
+
b"fake image"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
return dataset_path
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""Tests for YOLO classification dataset support."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from argus.core.base import DatasetFormat, TaskType
|
|
6
|
+
from argus.core.yolo import YOLODataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestClassificationDetection:
|
|
10
|
+
"""Tests for detecting YOLO classification datasets."""
|
|
11
|
+
|
|
12
|
+
def test_detect_classification_dataset(
|
|
13
|
+
self, yolo_classification_dataset: Path
|
|
14
|
+
) -> None:
|
|
15
|
+
"""Test that classification dataset is correctly detected."""
|
|
16
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
17
|
+
|
|
18
|
+
assert dataset is not None
|
|
19
|
+
assert dataset.task == TaskType.CLASSIFICATION
|
|
20
|
+
assert dataset.format == DatasetFormat.YOLO
|
|
21
|
+
assert set(dataset.class_names) == {"cat", "dog"}
|
|
22
|
+
assert dataset.num_classes == 2
|
|
23
|
+
assert set(dataset.splits) == {"train", "val"}
|
|
24
|
+
|
|
25
|
+
def test_detect_classification_multiclass(
|
|
26
|
+
self, yolo_classification_multiclass_dataset: Path
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Test detection with multiple classes."""
|
|
29
|
+
dataset = YOLODataset.detect(yolo_classification_multiclass_dataset)
|
|
30
|
+
|
|
31
|
+
assert dataset is not None
|
|
32
|
+
assert dataset.task == TaskType.CLASSIFICATION
|
|
33
|
+
assert set(dataset.class_names) == {"apple", "banana", "cherry", "date"}
|
|
34
|
+
assert dataset.num_classes == 4
|
|
35
|
+
assert "train" in dataset.splits
|
|
36
|
+
|
|
37
|
+
def test_classification_no_yaml_required(
|
|
38
|
+
self, yolo_classification_dataset: Path
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Test that classification datasets don't need a YAML config."""
|
|
41
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
42
|
+
|
|
43
|
+
assert dataset is not None
|
|
44
|
+
assert dataset.config_file is None
|
|
45
|
+
|
|
46
|
+
def test_classification_not_detected_for_detection(
|
|
47
|
+
self, yolo_detection_dataset: Path
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Test that detection datasets are not classified as classification."""
|
|
50
|
+
dataset = YOLODataset.detect(yolo_detection_dataset)
|
|
51
|
+
|
|
52
|
+
assert dataset is not None
|
|
53
|
+
assert dataset.task != TaskType.CLASSIFICATION
|
|
54
|
+
assert dataset.task == TaskType.DETECTION
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TestClassificationGetImagesByClass:
|
|
58
|
+
"""Tests for get_images_by_class method."""
|
|
59
|
+
|
|
60
|
+
def test_get_images_by_class(self, yolo_classification_dataset: Path) -> None:
|
|
61
|
+
"""Test getting images grouped by class."""
|
|
62
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
63
|
+
assert dataset is not None
|
|
64
|
+
|
|
65
|
+
images_by_class = dataset.get_images_by_class("train")
|
|
66
|
+
|
|
67
|
+
assert "cat" in images_by_class
|
|
68
|
+
assert "dog" in images_by_class
|
|
69
|
+
assert len(images_by_class["cat"]) == 2
|
|
70
|
+
assert len(images_by_class["dog"]) == 1
|
|
71
|
+
|
|
72
|
+
def test_get_images_by_class_val_split(
|
|
73
|
+
self, yolo_classification_dataset: Path
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Test getting images from val split."""
|
|
76
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
77
|
+
assert dataset is not None
|
|
78
|
+
|
|
79
|
+
images_by_class = dataset.get_images_by_class("val")
|
|
80
|
+
|
|
81
|
+
assert len(images_by_class["cat"]) == 1
|
|
82
|
+
assert len(images_by_class["dog"]) == 1
|
|
83
|
+
|
|
84
|
+
def test_get_images_by_class_default_split(
|
|
85
|
+
self, yolo_classification_dataset: Path
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Test that None split uses first available split."""
|
|
88
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
89
|
+
assert dataset is not None
|
|
90
|
+
|
|
91
|
+
images_by_class = dataset.get_images_by_class(None)
|
|
92
|
+
|
|
93
|
+
# Should return images from first split (train)
|
|
94
|
+
total_images = sum(len(imgs) for imgs in images_by_class.values())
|
|
95
|
+
assert total_images > 0
|
|
96
|
+
|
|
97
|
+
def test_get_images_by_class_non_classification(
|
|
98
|
+
self, yolo_detection_dataset: Path
|
|
99
|
+
) -> None:
|
|
100
|
+
"""Test that non-classification datasets return empty dict."""
|
|
101
|
+
dataset = YOLODataset.detect(yolo_detection_dataset)
|
|
102
|
+
assert dataset is not None
|
|
103
|
+
|
|
104
|
+
images_by_class = dataset.get_images_by_class("train")
|
|
105
|
+
|
|
106
|
+
assert images_by_class == {}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class TestClassificationInstanceCounts:
|
|
110
|
+
"""Tests for instance counts in classification datasets."""
|
|
111
|
+
|
|
112
|
+
def test_get_instance_counts(self, yolo_classification_dataset: Path) -> None:
|
|
113
|
+
"""Test counting images per class per split."""
|
|
114
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
115
|
+
assert dataset is not None
|
|
116
|
+
|
|
117
|
+
counts = dataset.get_instance_counts()
|
|
118
|
+
|
|
119
|
+
assert "train" in counts
|
|
120
|
+
assert "val" in counts
|
|
121
|
+
assert counts["train"]["cat"] == 2
|
|
122
|
+
assert counts["train"]["dog"] == 1
|
|
123
|
+
assert counts["val"]["cat"] == 1
|
|
124
|
+
assert counts["val"]["dog"] == 1
|
|
125
|
+
|
|
126
|
+
def test_get_image_counts(self, yolo_classification_dataset: Path) -> None:
|
|
127
|
+
"""Test total image counts per split."""
|
|
128
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
129
|
+
assert dataset is not None
|
|
130
|
+
|
|
131
|
+
counts = dataset.get_image_counts()
|
|
132
|
+
|
|
133
|
+
assert "train" in counts
|
|
134
|
+
assert "val" in counts
|
|
135
|
+
assert counts["train"]["total"] == 3 # 2 cat + 1 dog
|
|
136
|
+
assert counts["train"]["background"] == 0
|
|
137
|
+
assert counts["val"]["total"] == 2 # 1 cat + 1 dog
|
|
138
|
+
assert counts["val"]["background"] == 0
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class TestClassificationImagePaths:
|
|
142
|
+
"""Tests for get_image_paths with classification datasets."""
|
|
143
|
+
|
|
144
|
+
def test_get_image_paths_all(self, yolo_classification_dataset: Path) -> None:
|
|
145
|
+
"""Test getting all image paths."""
|
|
146
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
147
|
+
assert dataset is not None
|
|
148
|
+
|
|
149
|
+
paths = dataset.get_image_paths()
|
|
150
|
+
|
|
151
|
+
assert len(paths) == 5 # 3 train + 2 val
|
|
152
|
+
|
|
153
|
+
def test_get_image_paths_train(self, yolo_classification_dataset: Path) -> None:
|
|
154
|
+
"""Test getting image paths for train split."""
|
|
155
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
156
|
+
assert dataset is not None
|
|
157
|
+
|
|
158
|
+
paths = dataset.get_image_paths("train")
|
|
159
|
+
|
|
160
|
+
assert len(paths) == 3
|
|
161
|
+
|
|
162
|
+
def test_get_image_paths_val(self, yolo_classification_dataset: Path) -> None:
|
|
163
|
+
"""Test getting image paths for val split."""
|
|
164
|
+
dataset = YOLODataset.detect(yolo_classification_dataset)
|
|
165
|
+
assert dataset is not None
|
|
166
|
+
|
|
167
|
+
paths = dataset.get_image_paths("val")
|
|
168
|
+
|
|
169
|
+
assert len(paths) == 2
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class TestClassificationCLI:
|
|
173
|
+
"""Tests for CLI integration with classification datasets."""
|
|
174
|
+
|
|
175
|
+
def test_list_shows_classification(
|
|
176
|
+
self, yolo_classification_dataset: Path
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Test that list command shows classification datasets."""
|
|
179
|
+
from typer.testing import CliRunner
|
|
180
|
+
|
|
181
|
+
from argus.cli import app
|
|
182
|
+
|
|
183
|
+
runner = CliRunner()
|
|
184
|
+
result = runner.invoke(app, ["list", "-p", str(yolo_classification_dataset)])
|
|
185
|
+
|
|
186
|
+
assert result.exit_code == 0
|
|
187
|
+
assert "classification" in result.output.lower()
|
|
188
|
+
assert "yolo" in result.output.lower()
|
|
189
|
+
|
|
190
|
+
def test_stats_shows_classification_counts(
|
|
191
|
+
self, yolo_classification_dataset: Path
|
|
192
|
+
) -> None:
|
|
193
|
+
"""Test that stats command shows image counts per class."""
|
|
194
|
+
from typer.testing import CliRunner
|
|
195
|
+
|
|
196
|
+
from argus.cli import app
|
|
197
|
+
|
|
198
|
+
runner = CliRunner()
|
|
199
|
+
result = runner.invoke(
|
|
200
|
+
app, ["stats", "-d", str(yolo_classification_dataset)]
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
assert result.exit_code == 0
|
|
204
|
+
assert "cat" in result.output
|
|
205
|
+
assert "dog" in result.output
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|