argus-cv 1.1.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of argus-cv might be problematic. Click here for more details.

Files changed (35) hide show
  1. {argus_cv-1.1.0 → argus_cv-1.2.0}/CHANGELOG.md +16 -0
  2. {argus_cv-1.1.0 → argus_cv-1.2.0}/PKG-INFO +1 -1
  3. {argus_cv-1.1.0 → argus_cv-1.2.0}/pyproject.toml +1 -1
  4. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/__init__.py +1 -1
  5. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/cli.py +262 -35
  6. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/yolo.py +355 -9
  7. {argus_cv-1.1.0 → argus_cv-1.2.0}/tests/conftest.py +71 -0
  8. argus_cv-1.2.0/tests/test_classification.py +205 -0
  9. {argus_cv-1.1.0 → argus_cv-1.2.0}/.github/workflows/ci.yml +0 -0
  10. {argus_cv-1.1.0 → argus_cv-1.2.0}/.github/workflows/docs.yml +0 -0
  11. {argus_cv-1.1.0 → argus_cv-1.2.0}/.github/workflows/release.yml +0 -0
  12. {argus_cv-1.1.0 → argus_cv-1.2.0}/.gitignore +0 -0
  13. {argus_cv-1.1.0 → argus_cv-1.2.0}/README.md +0 -0
  14. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/assets/javascripts/extra.js +0 -0
  15. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/assets/stylesheets/extra.css +0 -0
  16. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/getting-started/installation.md +0 -0
  17. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/getting-started/quickstart.md +0 -0
  18. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/datasets.md +0 -0
  19. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/listing.md +0 -0
  20. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/splitting.md +0 -0
  21. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/stats.md +0 -0
  22. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/guides/viewer.md +0 -0
  23. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/index.md +0 -0
  24. {argus_cv-1.1.0 → argus_cv-1.2.0}/docs/reference/cli.md +0 -0
  25. {argus_cv-1.1.0 → argus_cv-1.2.0}/mkdocs.yml +0 -0
  26. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/__main__.py +0 -0
  27. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/commands/__init__.py +0 -0
  28. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/__init__.py +0 -0
  29. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/base.py +0 -0
  30. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/coco.py +0 -0
  31. {argus_cv-1.1.0 → argus_cv-1.2.0}/src/argus/core/split.py +0 -0
  32. {argus_cv-1.1.0 → argus_cv-1.2.0}/tests/test_list_command.py +0 -0
  33. {argus_cv-1.1.0 → argus_cv-1.2.0}/tests/test_split_command.py +0 -0
  34. {argus_cv-1.1.0 → argus_cv-1.2.0}/tests/test_stats_command.py +0 -0
  35. {argus_cv-1.1.0 → argus_cv-1.2.0}/uv.lock +0 -0
@@ -2,6 +2,22 @@
2
2
 
3
3
  <!-- version list -->
4
4
 
5
+ ## v1.2.0 (2026-01-15)
6
+
7
+ ### Code Style
8
+
9
+ - Fix ruff linting errors
10
+ ([`b2d5ea2`](https://github.com/pirnerjonas/argus/commit/b2d5ea2c4d0715a474d4ffaa5be60d0499d200a2))
11
+
12
+ - Remove unused pytest import in test_classification.py
13
+ ([`e22175a`](https://github.com/pirnerjonas/argus/commit/e22175a7378276dd2840754e672c7d4d1ed0e067))
14
+
15
+ ### Features
16
+
17
+ - Add classification dataset support with grid viewer
18
+ ([`8089bd3`](https://github.com/pirnerjonas/argus/commit/8089bd3367ede4c288b0276ac3da4d3ef9960c4d))
19
+
20
+
5
21
  ## v1.1.0 (2026-01-14)
6
22
 
7
23
  ### Code Style
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: argus-cv
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: CLI tool for working with vision AI datasets
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: numpy>=1.24.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "argus-cv"
3
- version = "1.1.0"
3
+ version = "1.2.0"
4
4
  description = "CLI tool for working with vision AI datasets"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -1,3 +1,3 @@
1
1
  """Argus - Vision AI dataset toolkit."""
2
2
 
3
- __version__ = "1.1.0"
3
+ __version__ = "1.2.0"
@@ -12,6 +12,7 @@ from rich.progress import Progress, SpinnerColumn, TextColumn
12
12
  from rich.table import Table
13
13
 
14
14
  from argus.core import COCODataset, Dataset, YOLODataset
15
+ from argus.core.base import TaskType
15
16
  from argus.core.split import (
16
17
  is_coco_unsplit,
17
18
  parse_ratio,
@@ -238,19 +239,29 @@ def view(
238
239
  help="Specific split to view (train, val, test).",
239
240
  ),
240
241
  ] = None,
242
+ max_classes: Annotated[
243
+ int | None,
244
+ typer.Option(
245
+ "--max-classes",
246
+ "-m",
247
+ help="Maximum classes to show in grid (classification only).",
248
+ ),
249
+ ] = None,
241
250
  ) -> None:
242
251
  """View annotated images in a dataset.
243
252
 
244
253
  Opens an interactive viewer to browse images with their annotations
245
254
  (bounding boxes and segmentation masks) overlaid.
246
255
 
256
+ For classification datasets, shows a grid view with one image per class.
257
+
247
258
  Controls:
248
- - Right Arrow / N: Next image
249
- - Left Arrow / P: Previous image
250
- - Mouse Wheel: Zoom in/out
251
- - Mouse Drag: Pan when zoomed
252
- - R: Reset zoom
253
- - T: Toggle annotations
259
+ - Right Arrow / N: Next image(s)
260
+ - Left Arrow / P: Previous image(s)
261
+ - Mouse Wheel: Zoom in/out (detection/segmentation only)
262
+ - Mouse Drag: Pan when zoomed (detection/segmentation only)
263
+ - R: Reset zoom / Reset to first images
264
+ - T: Toggle annotations (detection/segmentation only)
254
265
  - Q / ESC: Quit viewer
255
266
  """
256
267
  # Resolve path and validate
@@ -281,39 +292,76 @@ def view(
281
292
  )
282
293
  raise typer.Exit(1)
283
294
 
284
- # Get image paths
285
- with Progress(
286
- SpinnerColumn(),
287
- TextColumn("[progress.description]{task.description}"),
288
- console=console,
289
- transient=True,
290
- ) as progress:
291
- progress.add_task("Loading images...", total=None)
292
- image_paths = dataset.get_image_paths(split)
295
+ # Generate consistent colors for each class
296
+ class_colors = _generate_class_colors(dataset.class_names)
293
297
 
294
- if not image_paths:
295
- console.print("[yellow]No images found in the dataset.[/yellow]")
296
- return
298
+ # Handle classification datasets with grid viewer
299
+ if dataset.task == TaskType.CLASSIFICATION:
300
+ # Use first split if specified, otherwise let get_images_by_class handle it
301
+ view_split = split if split else (dataset.splits[0] if dataset.splits else None)
297
302
 
298
- console.print(
299
- f"[green]Found {len(image_paths)} images. "
300
- f"Opening viewer...[/green]\n"
301
- "[dim]Controls: ← / → or P / N to navigate, "
302
- "Mouse wheel to zoom, Drag to pan, R to reset, T to toggle annotations, "
303
- "Q / ESC to quit[/dim]"
304
- )
303
+ with Progress(
304
+ SpinnerColumn(),
305
+ TextColumn("[progress.description]{task.description}"),
306
+ console=console,
307
+ transient=True,
308
+ ) as progress:
309
+ progress.add_task("Loading images by class...", total=None)
310
+ images_by_class = dataset.get_images_by_class(view_split)
305
311
 
306
- # Generate consistent colors for each class
307
- class_colors = _generate_class_colors(dataset.class_names)
312
+ total_images = sum(len(imgs) for imgs in images_by_class.values())
313
+ if total_images == 0:
314
+ console.print("[yellow]No images found in the dataset.[/yellow]")
315
+ return
308
316
 
309
- # Create and run the interactive viewer
310
- viewer = _ImageViewer(
311
- image_paths=image_paths,
312
- dataset=dataset,
313
- class_colors=class_colors,
314
- window_name=f"Argus Viewer - {dataset_path.name}",
315
- )
316
- viewer.run()
317
+ num_classes = len(dataset.class_names)
318
+ display_classes = min(num_classes, max_classes) if max_classes else num_classes
319
+
320
+ console.print(
321
+ f"[green]Found {total_images} images across {num_classes} classes "
322
+ f"(showing {display_classes}). Opening grid viewer...[/green]\n"
323
+ "[dim]Controls: ← / → or P / N to navigate all classes, "
324
+ "R to reset, Q / ESC to quit[/dim]"
325
+ )
326
+
327
+ viewer = _ClassificationGridViewer(
328
+ images_by_class=images_by_class,
329
+ class_names=dataset.class_names,
330
+ class_colors=class_colors,
331
+ window_name=f"Argus Classification Viewer - {dataset_path.name}",
332
+ max_classes=max_classes,
333
+ )
334
+ viewer.run()
335
+ else:
336
+ # Detection/Segmentation viewer
337
+ with Progress(
338
+ SpinnerColumn(),
339
+ TextColumn("[progress.description]{task.description}"),
340
+ console=console,
341
+ transient=True,
342
+ ) as progress:
343
+ progress.add_task("Loading images...", total=None)
344
+ image_paths = dataset.get_image_paths(split)
345
+
346
+ if not image_paths:
347
+ console.print("[yellow]No images found in the dataset.[/yellow]")
348
+ return
349
+
350
+ console.print(
351
+ f"[green]Found {len(image_paths)} images. "
352
+ f"Opening viewer...[/green]\n"
353
+ "[dim]Controls: ← / → or P / N to navigate, "
354
+ "Mouse wheel to zoom, Drag to pan, R to reset, T to toggle annotations, "
355
+ "Q / ESC to quit[/dim]"
356
+ )
357
+
358
+ viewer = _ImageViewer(
359
+ image_paths=image_paths,
360
+ dataset=dataset,
361
+ class_colors=class_colors,
362
+ window_name=f"Argus Viewer - {dataset_path.name}",
363
+ )
364
+ viewer.run()
317
365
 
318
366
  console.print("[green]Viewer closed.[/green]")
319
367
 
@@ -656,6 +704,185 @@ class _ImageViewer:
656
704
  cv2.destroyAllWindows()
657
705
 
658
706
 
707
+ class _ClassificationGridViewer:
708
+ """Grid viewer for classification datasets showing one image per class."""
709
+
710
+ def __init__(
711
+ self,
712
+ images_by_class: dict[str, list[Path]],
713
+ class_names: list[str],
714
+ class_colors: dict[str, tuple[int, int, int]],
715
+ window_name: str,
716
+ max_classes: int | None = None,
717
+ tile_size: int = 300,
718
+ ):
719
+ # Limit classes if max_classes specified
720
+ if max_classes and len(class_names) > max_classes:
721
+ self.class_names = class_names[:max_classes]
722
+ else:
723
+ self.class_names = class_names
724
+
725
+ self.images_by_class = {
726
+ cls: images_by_class.get(cls, []) for cls in self.class_names
727
+ }
728
+ self.class_colors = class_colors
729
+ self.window_name = window_name
730
+ self.tile_size = tile_size
731
+
732
+ # Global image index (same for all classes)
733
+ self.current_index = 0
734
+
735
+ # Calculate max images across all classes
736
+ self.max_images = max(
737
+ len(imgs) for imgs in self.images_by_class.values()
738
+ ) if self.images_by_class else 0
739
+
740
+ # Calculate grid layout
741
+ self.cols, self.rows = self._calculate_grid_layout()
742
+
743
+ def _calculate_grid_layout(self) -> tuple[int, int]:
744
+ """Calculate optimal grid layout based on number of classes."""
745
+ n = len(self.class_names)
746
+ if n <= 0:
747
+ return 1, 1
748
+
749
+ # Try to make a roughly square grid
750
+ import math
751
+
752
+ cols = int(math.ceil(math.sqrt(n)))
753
+ rows = int(math.ceil(n / cols))
754
+ return cols, rows
755
+
756
+ def _create_tile(
757
+ self, class_name: str, image_path: Path | None, index: int, total: int
758
+ ) -> np.ndarray:
759
+ """Create a single tile for a class."""
760
+ tile = np.zeros((self.tile_size, self.tile_size, 3), dtype=np.uint8)
761
+
762
+ if image_path is not None and image_path.exists():
763
+ # Load and resize image
764
+ img = cv2.imread(str(image_path))
765
+ if img is not None:
766
+ # Resize maintaining aspect ratio
767
+ h, w = img.shape[:2]
768
+ scale = min(self.tile_size / w, self.tile_size / h)
769
+ new_w = int(w * scale)
770
+ new_h = int(h * scale)
771
+ resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
772
+
773
+ # Center in tile
774
+ x_offset = (self.tile_size - new_w) // 2
775
+ y_offset = (self.tile_size - new_h) // 2
776
+ tile[y_offset : y_offset + new_h, x_offset : x_offset + new_w] = resized
777
+
778
+ # Draw label at top: "class_name (N/M)"
779
+ if image_path is not None:
780
+ label = f"{class_name} ({index + 1}/{total})"
781
+ else:
782
+ label = f"{class_name} (-/{total})"
783
+
784
+ font = cv2.FONT_HERSHEY_SIMPLEX
785
+ font_scale = 0.5
786
+ thickness = 1
787
+ (label_w, label_h), baseline = cv2.getTextSize(
788
+ label, font, font_scale, thickness
789
+ )
790
+
791
+ # Semi-transparent background for label
792
+ overlay = tile.copy()
793
+ label_bg_height = label_h + baseline + 10
794
+ cv2.rectangle(overlay, (0, 0), (self.tile_size, label_bg_height), (0, 0, 0), -1)
795
+ cv2.addWeighted(overlay, 0.6, tile, 0.4, 0, tile)
796
+
797
+ cv2.putText(
798
+ tile,
799
+ label,
800
+ (5, label_h + 5),
801
+ font,
802
+ font_scale,
803
+ (255, 255, 255),
804
+ thickness,
805
+ )
806
+
807
+ # Draw thin border
808
+ border_end = self.tile_size - 1
809
+ cv2.rectangle(tile, (0, 0), (border_end, border_end), (80, 80, 80), 1)
810
+
811
+ return tile
812
+
813
+ def _compose_grid(self) -> np.ndarray:
814
+ """Compose all tiles into a single grid image."""
815
+ grid_h = self.rows * self.tile_size
816
+ grid_w = self.cols * self.tile_size
817
+ grid = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
818
+
819
+ for i, class_name in enumerate(self.class_names):
820
+ row = i // self.cols
821
+ col = i % self.cols
822
+
823
+ images = self.images_by_class[class_name]
824
+ total = len(images)
825
+
826
+ # Use global index - show black tile if class doesn't have this image
827
+ if self.current_index < total:
828
+ image_path = images[self.current_index]
829
+ display_index = self.current_index
830
+ else:
831
+ image_path = None
832
+ display_index = self.current_index
833
+
834
+ tile = self._create_tile(class_name, image_path, display_index, total)
835
+
836
+ y_start = row * self.tile_size
837
+ x_start = col * self.tile_size
838
+ y_end = y_start + self.tile_size
839
+ x_end = x_start + self.tile_size
840
+ grid[y_start:y_end, x_start:x_end] = tile
841
+
842
+ return grid
843
+
844
+ def _next_images(self) -> None:
845
+ """Advance to next image index."""
846
+ if self.max_images > 0:
847
+ self.current_index = min(self.current_index + 1, self.max_images - 1)
848
+
849
+ def _prev_images(self) -> None:
850
+ """Go back to previous image index."""
851
+ self.current_index = max(self.current_index - 1, 0)
852
+
853
+ def _reset_indices(self) -> None:
854
+ """Reset to first image."""
855
+ self.current_index = 0
856
+
857
+ def run(self) -> None:
858
+ """Run the interactive grid viewer."""
859
+ if not self.class_names:
860
+ console.print("[yellow]No classes to display.[/yellow]")
861
+ return
862
+
863
+ cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
864
+
865
+ while True:
866
+ # Compose and display grid
867
+ grid = self._compose_grid()
868
+ cv2.imshow(self.window_name, grid)
869
+
870
+ # Wait for input
871
+ key = cv2.waitKey(30) & 0xFF
872
+
873
+ # Handle keyboard input
874
+ if key == ord("q") or key == 27: # Q or ESC
875
+ break
876
+ elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
877
+ self._next_images()
878
+ elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
879
+ self._prev_images()
880
+ elif key == ord("r"): # R to reset
881
+ self._reset_indices()
882
+
883
+ cv2.destroyAllWindows()
884
+
885
+
659
886
  def _generate_class_colors(class_names: list[str]) -> dict[str, tuple[int, int, int]]:
660
887
  """Generate consistent colors for each class name.
661
888
 
@@ -12,9 +12,9 @@ from argus.core.base import Dataset, DatasetFormat, TaskType
12
12
  class YOLODataset(Dataset):
13
13
  """YOLO format dataset.
14
14
 
15
- Supports detection and segmentation tasks.
15
+ Supports detection, segmentation, and classification tasks.
16
16
 
17
- Structure:
17
+ Structure (detection/segmentation):
18
18
  dataset/
19
19
  ├── data.yaml (or *.yaml/*.yml with 'names' key)
20
20
  ├── images/
@@ -23,6 +23,19 @@ class YOLODataset(Dataset):
23
23
  └── labels/
24
24
  ├── train/
25
25
  └── val/
26
+
27
+ Structure (classification):
28
+ dataset/
29
+ ├── images/
30
+ │ ├── train/
31
+ │ │ ├── class1/
32
+ │ │ │ ├── img1.jpg
33
+ │ │ │ └── img2.jpg
34
+ │ │ └── class2/
35
+ │ │ └── img1.jpg
36
+ │ └── val/
37
+ │ ├── class1/
38
+ │ └── class2/
26
39
  """
27
40
 
28
41
  config_file: Path | None = None
@@ -43,8 +56,13 @@ class YOLODataset(Dataset):
43
56
  if not path.is_dir():
44
57
  return None
45
58
 
46
- # Try detection/segmentation (YAML-based)
47
- return cls._detect_yaml_based(path)
59
+ # Try detection/segmentation (YAML-based) first
60
+ result = cls._detect_yaml_based(path)
61
+ if result:
62
+ return result
63
+
64
+ # Try classification (directory-based structure)
65
+ return cls._detect_classification(path)
48
66
 
49
67
  @classmethod
50
68
  def _detect_yaml_based(cls, path: Path) -> "YOLODataset | None":
@@ -103,16 +121,121 @@ class YOLODataset(Dataset):
103
121
 
104
122
  return None
105
123
 
124
+ @classmethod
125
+ def _detect_classification(cls, path: Path) -> "YOLODataset | None":
126
+ """Detect classification dataset from directory structure.
127
+
128
+ Classification datasets can have two structures:
129
+
130
+ 1. Split structure:
131
+ images/{split}/class_name/image.jpg
132
+
133
+ 2. Flat structure (unsplit):
134
+ class_name/image.jpg
135
+
136
+ No YAML config required - class names inferred from directory names.
137
+
138
+ Args:
139
+ path: Directory path to check.
140
+
141
+ Returns:
142
+ YOLODataset if classification structure found, None otherwise.
143
+ """
144
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
145
+
146
+ # Try split structure first: images/{split}/class/
147
+ images_root = path / "images"
148
+ if images_root.is_dir():
149
+ splits: list[str] = []
150
+ class_names_set: set[str] = set()
151
+
152
+ for split_name in ["train", "val", "test"]:
153
+ split_dir = images_root / split_name
154
+ if not split_dir.is_dir():
155
+ continue
156
+
157
+ # Get subdirectories (potential class folders)
158
+ class_dirs = [d for d in split_dir.iterdir() if d.is_dir()]
159
+ if not class_dirs:
160
+ continue
161
+
162
+ # Check if at least one class dir contains images
163
+ has_images = False
164
+ for class_dir in class_dirs:
165
+ for f in class_dir.iterdir():
166
+ if f.suffix.lower() in image_extensions:
167
+ has_images = True
168
+ break
169
+ if has_images:
170
+ break
171
+
172
+ if has_images:
173
+ splits.append(split_name)
174
+ class_names_set.update(d.name for d in class_dirs)
175
+
176
+ if splits and class_names_set:
177
+ class_names = sorted(class_names_set)
178
+ return cls(
179
+ path=path,
180
+ task=TaskType.CLASSIFICATION,
181
+ num_classes=len(class_names),
182
+ class_names=class_names,
183
+ splits=splits,
184
+ config_file=None,
185
+ )
186
+
187
+ # Try flat structure: class_name/image.jpg (no images/ or split dirs)
188
+ # Check if root contains subdirectories with images
189
+ class_dirs = [d for d in path.iterdir() if d.is_dir()]
190
+
191
+ # Filter out common non-class directories
192
+ excluded_dirs = {"images", "labels", "annotations", ".git", "__pycache__"}
193
+ class_dirs = [d for d in class_dirs if d.name not in excluded_dirs]
194
+
195
+ if not class_dirs:
196
+ return None
197
+
198
+ # Check if these are class directories (contain images directly)
199
+ class_names_set = set()
200
+ for class_dir in class_dirs:
201
+ has_images = any(
202
+ f.suffix.lower() in image_extensions
203
+ for f in class_dir.iterdir()
204
+ if f.is_file()
205
+ )
206
+ if has_images:
207
+ class_names_set.add(class_dir.name)
208
+
209
+ # Need at least 2 classes to be a valid classification dataset
210
+ if len(class_names_set) < 2:
211
+ return None
212
+
213
+ class_names = sorted(class_names_set)
214
+ return cls(
215
+ path=path,
216
+ task=TaskType.CLASSIFICATION,
217
+ num_classes=len(class_names),
218
+ class_names=class_names,
219
+ splits=[], # No splits for flat structure
220
+ config_file=None,
221
+ )
222
+
106
223
  def get_instance_counts(self) -> dict[str, dict[str, int]]:
107
224
  """Get the number of annotation instances per class, per split.
108
225
 
109
- Parses all label files in labels/{split}/*.txt and counts
110
- occurrences of each class ID. For unsplit datasets, uses "unsplit"
111
- as the split name.
226
+ For detection/segmentation: Parses all label files in labels/{split}/*.txt
227
+ and counts occurrences of each class ID.
228
+
229
+ For classification: Counts images in each class directory
230
+ (1 image = 1 instance).
112
231
 
113
232
  Returns:
114
233
  Dictionary mapping split name to dict of class name to instance count.
115
234
  """
235
+ # Handle classification datasets differently
236
+ if self.task == TaskType.CLASSIFICATION:
237
+ return self._get_classification_instance_counts()
238
+
116
239
  counts: dict[str, dict[str, int]] = {}
117
240
 
118
241
  # Build class_id -> class_name mapping
@@ -162,15 +285,77 @@ class YOLODataset(Dataset):
162
285
 
163
286
  return counts
164
287
 
288
+ def _get_classification_instance_counts(self) -> dict[str, dict[str, int]]:
289
+ """Get instance counts for classification datasets.
290
+
291
+ Each image is one instance of its class.
292
+
293
+ Returns:
294
+ Dictionary mapping split name to dict of class name to image count.
295
+ """
296
+ counts: dict[str, dict[str, int]] = {}
297
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
298
+
299
+ # Handle flat structure (no splits)
300
+ if not self.splits:
301
+ split_counts: dict[str, int] = {}
302
+ for class_name in self.class_names:
303
+ class_dir = self.path / class_name
304
+ if not class_dir.is_dir():
305
+ split_counts[class_name] = 0
306
+ continue
307
+
308
+ image_count = sum(
309
+ 1
310
+ for f in class_dir.iterdir()
311
+ if f.suffix.lower() in image_extensions
312
+ )
313
+ split_counts[class_name] = image_count
314
+
315
+ counts["unsplit"] = split_counts
316
+ return counts
317
+
318
+ # Handle split structure
319
+ images_root = self.path / "images"
320
+ for split in self.splits:
321
+ split_dir = images_root / split
322
+ if not split_dir.is_dir():
323
+ continue
324
+
325
+ split_counts = {}
326
+ for class_name in self.class_names:
327
+ class_dir = split_dir / class_name
328
+ if not class_dir.is_dir():
329
+ split_counts[class_name] = 0
330
+ continue
331
+
332
+ image_count = sum(
333
+ 1
334
+ for f in class_dir.iterdir()
335
+ if f.suffix.lower() in image_extensions
336
+ )
337
+ split_counts[class_name] = image_count
338
+
339
+ counts[split] = split_counts
340
+
341
+ return counts
342
+
165
343
  def get_image_counts(self) -> dict[str, dict[str, int]]:
166
344
  """Get image counts per split, including background images.
167
345
 
168
- Counts label files in labels/{split}/*.txt. Empty files are
169
- counted as background images.
346
+ For detection/segmentation: Counts label files in labels/{split}/*.txt.
347
+ Empty files are counted as background images.
348
+
349
+ For classification: Counts total images across all class directories.
350
+ Background count is always 0 (no background concept in classification).
170
351
 
171
352
  Returns:
172
353
  Dictionary mapping split name to dict with "total" and "background" counts.
173
354
  """
355
+ # Handle classification datasets differently
356
+ if self.task == TaskType.CLASSIFICATION:
357
+ return self._get_classification_image_counts()
358
+
174
359
  counts: dict[str, dict[str, int]] = {}
175
360
 
176
361
  labels_root = self.path / "labels"
@@ -203,6 +388,56 @@ class YOLODataset(Dataset):
203
388
 
204
389
  return counts
205
390
 
391
+ def _get_classification_image_counts(self) -> dict[str, dict[str, int]]:
392
+ """Get image counts for classification datasets.
393
+
394
+ Returns:
395
+ Dictionary mapping split name to dict with "total" and "background" counts.
396
+ Background is always 0 for classification.
397
+ """
398
+ counts: dict[str, dict[str, int]] = {}
399
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
400
+
401
+ # Handle flat structure (no splits)
402
+ if not self.splits:
403
+ total = 0
404
+ for class_name in self.class_names:
405
+ class_dir = self.path / class_name
406
+ if not class_dir.is_dir():
407
+ continue
408
+
409
+ total += sum(
410
+ 1
411
+ for f in class_dir.iterdir()
412
+ if f.suffix.lower() in image_extensions
413
+ )
414
+
415
+ counts["unsplit"] = {"total": total, "background": 0}
416
+ return counts
417
+
418
+ # Handle split structure
419
+ images_root = self.path / "images"
420
+ for split in self.splits:
421
+ split_dir = images_root / split
422
+ if not split_dir.is_dir():
423
+ continue
424
+
425
+ total = 0
426
+ for class_name in self.class_names:
427
+ class_dir = split_dir / class_name
428
+ if not class_dir.is_dir():
429
+ continue
430
+
431
+ total += sum(
432
+ 1
433
+ for f in class_dir.iterdir()
434
+ if f.suffix.lower() in image_extensions
435
+ )
436
+
437
+ counts[split] = {"total": total, "background": 0}
438
+
439
+ return counts
440
+
206
441
  @classmethod
207
442
  def _detect_splits(cls, path: Path, config: dict) -> list[str]:
208
443
  """Detect available splits from config and filesystem.
@@ -301,6 +536,10 @@ class YOLODataset(Dataset):
301
536
  Returns:
302
537
  List of image file paths sorted alphabetically.
303
538
  """
539
+ # Handle classification datasets differently
540
+ if self.task == TaskType.CLASSIFICATION:
541
+ return self._get_classification_image_paths(split)
542
+
304
543
  image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
305
544
  images_root = self.path / "images"
306
545
  image_paths: list[Path] = []
@@ -345,6 +584,51 @@ class YOLODataset(Dataset):
345
584
 
346
585
  return sorted(image_paths, key=lambda p: p.name)
347
586
 
587
+ def _get_classification_image_paths(self, split: str | None = None) -> list[Path]:
588
+ """Get image paths for classification datasets.
589
+
590
+ Args:
591
+ split: Specific split to get images from. If None, returns all images.
592
+
593
+ Returns:
594
+ List of image file paths sorted alphabetically.
595
+ """
596
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
597
+ image_paths: list[Path] = []
598
+
599
+ # Handle flat structure (no splits)
600
+ if not self.splits:
601
+ for class_name in self.class_names:
602
+ class_dir = self.path / class_name
603
+ if not class_dir.is_dir():
604
+ continue
605
+
606
+ for img_file in class_dir.iterdir():
607
+ if img_file.suffix.lower() in image_extensions:
608
+ image_paths.append(img_file)
609
+
610
+ return sorted(image_paths, key=lambda p: p.name)
611
+
612
+ # Handle split structure
613
+ images_root = self.path / "images"
614
+ splits_to_search = [split] if split else self.splits
615
+
616
+ for s in splits_to_search:
617
+ split_dir = images_root / s
618
+ if not split_dir.is_dir():
619
+ continue
620
+
621
+ for class_name in self.class_names:
622
+ class_dir = split_dir / class_name
623
+ if not class_dir.is_dir():
624
+ continue
625
+
626
+ for img_file in class_dir.iterdir():
627
+ if img_file.suffix.lower() in image_extensions:
628
+ image_paths.append(img_file)
629
+
630
+ return sorted(image_paths, key=lambda p: p.name)
631
+
348
632
  def get_annotations_for_image(self, image_path: Path) -> list[dict]:
349
633
  """Get annotations for a specific image.
350
634
 
@@ -445,3 +729,65 @@ class YOLODataset(Dataset):
445
729
  pass
446
730
 
447
731
  return annotations
732
+
733
+ def get_images_by_class(self, split: str | None = None) -> dict[str, list[Path]]:
734
+ """Get images grouped by class for classification datasets.
735
+
736
+ Args:
737
+ split: Specific split to get images from. If None, uses first
738
+ available split or all images for flat structure.
739
+
740
+ Returns:
741
+ Dictionary mapping class name to list of image paths.
742
+ """
743
+ if self.task != TaskType.CLASSIFICATION:
744
+ return {}
745
+
746
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
747
+ images_by_class: dict[str, list[Path]] = {cls: [] for cls in self.class_names}
748
+
749
+ # Handle flat structure (no splits)
750
+ if not self.splits:
751
+ for class_name in self.class_names:
752
+ class_dir = self.path / class_name
753
+ if not class_dir.is_dir():
754
+ continue
755
+
756
+ for img_file in class_dir.iterdir():
757
+ if img_file.suffix.lower() in image_extensions:
758
+ images_by_class[class_name].append(img_file)
759
+
760
+ # Sort images within each class
761
+ for class_name in images_by_class:
762
+ images_by_class[class_name] = sorted(
763
+ images_by_class[class_name], key=lambda p: p.name
764
+ )
765
+
766
+ return images_by_class
767
+
768
+ # Handle split structure
769
+ images_root = self.path / "images"
770
+ default_splits = self.splits[:1] if self.splits else []
771
+ splits_to_search = [split] if split else default_splits
772
+
773
+ for s in splits_to_search:
774
+ split_dir = images_root / s
775
+ if not split_dir.is_dir():
776
+ continue
777
+
778
+ for class_name in self.class_names:
779
+ class_dir = split_dir / class_name
780
+ if not class_dir.is_dir():
781
+ continue
782
+
783
+ for img_file in class_dir.iterdir():
784
+ if img_file.suffix.lower() in image_extensions:
785
+ images_by_class[class_name].append(img_file)
786
+
787
+ # Sort images within each class for consistent ordering
788
+ for class_name in images_by_class:
789
+ images_by_class[class_name] = sorted(
790
+ images_by_class[class_name], key=lambda p: p.name
791
+ )
792
+
793
+ return images_by_class
@@ -389,3 +389,74 @@ names:
389
389
  (annotations_dir / "annotations.json").write_text(json.dumps(coco_data))
390
390
 
391
391
  return root_path
392
+
393
+
394
+ @pytest.fixture
395
+ def yolo_classification_dataset(tmp_path: Path) -> Path:
396
+ """Create a valid YOLO classification dataset.
397
+
398
+ Structure:
399
+ dataset/
400
+ └── images/
401
+ ├── train/
402
+ │ ├── cat/
403
+ │ │ ├── img001.jpg
404
+ │ │ └── img002.jpg
405
+ │ └── dog/
406
+ │ └── img001.jpg
407
+ └── val/
408
+ ├── cat/
409
+ │ └── img001.jpg
410
+ └── dog/
411
+ └── img001.jpg
412
+ """
413
+ dataset_path = tmp_path / "yolo_classification"
414
+ dataset_path.mkdir()
415
+
416
+ # Create directory structure with class subdirectories
417
+ (dataset_path / "images" / "train" / "cat").mkdir(parents=True)
418
+ (dataset_path / "images" / "train" / "dog").mkdir(parents=True)
419
+ (dataset_path / "images" / "val" / "cat").mkdir(parents=True)
420
+ (dataset_path / "images" / "val" / "dog").mkdir(parents=True)
421
+
422
+ # Create dummy images in each class directory
423
+ # Train split
424
+ (dataset_path / "images" / "train" / "cat" / "img001.jpg").write_bytes(b"fake cat")
425
+ (dataset_path / "images" / "train" / "cat" / "img002.jpg").write_bytes(b"fake cat")
426
+ (dataset_path / "images" / "train" / "dog" / "img001.jpg").write_bytes(b"fake dog")
427
+
428
+ # Val split
429
+ (dataset_path / "images" / "val" / "cat" / "img001.jpg").write_bytes(b"fake cat")
430
+ (dataset_path / "images" / "val" / "dog" / "img001.jpg").write_bytes(b"fake dog")
431
+
432
+ return dataset_path
433
+
434
+
435
+ @pytest.fixture
436
+ def yolo_classification_multiclass_dataset(tmp_path: Path) -> Path:
437
+ """Create a YOLO classification dataset with more classes.
438
+
439
+ Structure:
440
+ dataset/
441
+ └── images/
442
+ └── train/
443
+ ├── class1/
444
+ ├── class2/
445
+ ├── class3/
446
+ └── class4/
447
+ """
448
+ dataset_path = tmp_path / "yolo_cls_multiclass"
449
+ dataset_path.mkdir()
450
+
451
+ classes = ["apple", "banana", "cherry", "date"]
452
+
453
+ for cls in classes:
454
+ (dataset_path / "images" / "train" / cls).mkdir(parents=True)
455
+ # Add varying number of images per class
456
+ num_images = classes.index(cls) + 1
457
+ for i in range(num_images):
458
+ (dataset_path / "images" / "train" / cls / f"img{i:03d}.jpg").write_bytes(
459
+ b"fake image"
460
+ )
461
+
462
+ return dataset_path
@@ -0,0 +1,205 @@
1
+ """Tests for YOLO classification dataset support."""
2
+
3
+ from pathlib import Path
4
+
5
+ from argus.core.base import DatasetFormat, TaskType
6
+ from argus.core.yolo import YOLODataset
7
+
8
+
9
+ class TestClassificationDetection:
10
+ """Tests for detecting YOLO classification datasets."""
11
+
12
+ def test_detect_classification_dataset(
13
+ self, yolo_classification_dataset: Path
14
+ ) -> None:
15
+ """Test that classification dataset is correctly detected."""
16
+ dataset = YOLODataset.detect(yolo_classification_dataset)
17
+
18
+ assert dataset is not None
19
+ assert dataset.task == TaskType.CLASSIFICATION
20
+ assert dataset.format == DatasetFormat.YOLO
21
+ assert set(dataset.class_names) == {"cat", "dog"}
22
+ assert dataset.num_classes == 2
23
+ assert set(dataset.splits) == {"train", "val"}
24
+
25
+ def test_detect_classification_multiclass(
26
+ self, yolo_classification_multiclass_dataset: Path
27
+ ) -> None:
28
+ """Test detection with multiple classes."""
29
+ dataset = YOLODataset.detect(yolo_classification_multiclass_dataset)
30
+
31
+ assert dataset is not None
32
+ assert dataset.task == TaskType.CLASSIFICATION
33
+ assert set(dataset.class_names) == {"apple", "banana", "cherry", "date"}
34
+ assert dataset.num_classes == 4
35
+ assert "train" in dataset.splits
36
+
37
+ def test_classification_no_yaml_required(
38
+ self, yolo_classification_dataset: Path
39
+ ) -> None:
40
+ """Test that classification datasets don't need a YAML config."""
41
+ dataset = YOLODataset.detect(yolo_classification_dataset)
42
+
43
+ assert dataset is not None
44
+ assert dataset.config_file is None
45
+
46
+ def test_classification_not_detected_for_detection(
47
+ self, yolo_detection_dataset: Path
48
+ ) -> None:
49
+ """Test that detection datasets are not classified as classification."""
50
+ dataset = YOLODataset.detect(yolo_detection_dataset)
51
+
52
+ assert dataset is not None
53
+ assert dataset.task != TaskType.CLASSIFICATION
54
+ assert dataset.task == TaskType.DETECTION
55
+
56
+
57
+ class TestClassificationGetImagesByClass:
58
+ """Tests for get_images_by_class method."""
59
+
60
+ def test_get_images_by_class(self, yolo_classification_dataset: Path) -> None:
61
+ """Test getting images grouped by class."""
62
+ dataset = YOLODataset.detect(yolo_classification_dataset)
63
+ assert dataset is not None
64
+
65
+ images_by_class = dataset.get_images_by_class("train")
66
+
67
+ assert "cat" in images_by_class
68
+ assert "dog" in images_by_class
69
+ assert len(images_by_class["cat"]) == 2
70
+ assert len(images_by_class["dog"]) == 1
71
+
72
+ def test_get_images_by_class_val_split(
73
+ self, yolo_classification_dataset: Path
74
+ ) -> None:
75
+ """Test getting images from val split."""
76
+ dataset = YOLODataset.detect(yolo_classification_dataset)
77
+ assert dataset is not None
78
+
79
+ images_by_class = dataset.get_images_by_class("val")
80
+
81
+ assert len(images_by_class["cat"]) == 1
82
+ assert len(images_by_class["dog"]) == 1
83
+
84
+ def test_get_images_by_class_default_split(
85
+ self, yolo_classification_dataset: Path
86
+ ) -> None:
87
+ """Test that None split uses first available split."""
88
+ dataset = YOLODataset.detect(yolo_classification_dataset)
89
+ assert dataset is not None
90
+
91
+ images_by_class = dataset.get_images_by_class(None)
92
+
93
+ # Should return images from first split (train)
94
+ total_images = sum(len(imgs) for imgs in images_by_class.values())
95
+ assert total_images > 0
96
+
97
+ def test_get_images_by_class_non_classification(
98
+ self, yolo_detection_dataset: Path
99
+ ) -> None:
100
+ """Test that non-classification datasets return empty dict."""
101
+ dataset = YOLODataset.detect(yolo_detection_dataset)
102
+ assert dataset is not None
103
+
104
+ images_by_class = dataset.get_images_by_class("train")
105
+
106
+ assert images_by_class == {}
107
+
108
+
109
+ class TestClassificationInstanceCounts:
110
+ """Tests for instance counts in classification datasets."""
111
+
112
+ def test_get_instance_counts(self, yolo_classification_dataset: Path) -> None:
113
+ """Test counting images per class per split."""
114
+ dataset = YOLODataset.detect(yolo_classification_dataset)
115
+ assert dataset is not None
116
+
117
+ counts = dataset.get_instance_counts()
118
+
119
+ assert "train" in counts
120
+ assert "val" in counts
121
+ assert counts["train"]["cat"] == 2
122
+ assert counts["train"]["dog"] == 1
123
+ assert counts["val"]["cat"] == 1
124
+ assert counts["val"]["dog"] == 1
125
+
126
+ def test_get_image_counts(self, yolo_classification_dataset: Path) -> None:
127
+ """Test total image counts per split."""
128
+ dataset = YOLODataset.detect(yolo_classification_dataset)
129
+ assert dataset is not None
130
+
131
+ counts = dataset.get_image_counts()
132
+
133
+ assert "train" in counts
134
+ assert "val" in counts
135
+ assert counts["train"]["total"] == 3 # 2 cat + 1 dog
136
+ assert counts["train"]["background"] == 0
137
+ assert counts["val"]["total"] == 2 # 1 cat + 1 dog
138
+ assert counts["val"]["background"] == 0
139
+
140
+
141
+ class TestClassificationImagePaths:
142
+ """Tests for get_image_paths with classification datasets."""
143
+
144
+ def test_get_image_paths_all(self, yolo_classification_dataset: Path) -> None:
145
+ """Test getting all image paths."""
146
+ dataset = YOLODataset.detect(yolo_classification_dataset)
147
+ assert dataset is not None
148
+
149
+ paths = dataset.get_image_paths()
150
+
151
+ assert len(paths) == 5 # 3 train + 2 val
152
+
153
+ def test_get_image_paths_train(self, yolo_classification_dataset: Path) -> None:
154
+ """Test getting image paths for train split."""
155
+ dataset = YOLODataset.detect(yolo_classification_dataset)
156
+ assert dataset is not None
157
+
158
+ paths = dataset.get_image_paths("train")
159
+
160
+ assert len(paths) == 3
161
+
162
+ def test_get_image_paths_val(self, yolo_classification_dataset: Path) -> None:
163
+ """Test getting image paths for val split."""
164
+ dataset = YOLODataset.detect(yolo_classification_dataset)
165
+ assert dataset is not None
166
+
167
+ paths = dataset.get_image_paths("val")
168
+
169
+ assert len(paths) == 2
170
+
171
+
172
+ class TestClassificationCLI:
173
+ """Tests for CLI integration with classification datasets."""
174
+
175
+ def test_list_shows_classification(
176
+ self, yolo_classification_dataset: Path
177
+ ) -> None:
178
+ """Test that list command shows classification datasets."""
179
+ from typer.testing import CliRunner
180
+
181
+ from argus.cli import app
182
+
183
+ runner = CliRunner()
184
+ result = runner.invoke(app, ["list", "-p", str(yolo_classification_dataset)])
185
+
186
+ assert result.exit_code == 0
187
+ assert "classification" in result.output.lower()
188
+ assert "yolo" in result.output.lower()
189
+
190
+ def test_stats_shows_classification_counts(
191
+ self, yolo_classification_dataset: Path
192
+ ) -> None:
193
+ """Test that stats command shows image counts per class."""
194
+ from typer.testing import CliRunner
195
+
196
+ from argus.cli import app
197
+
198
+ runner = CliRunner()
199
+ result = runner.invoke(
200
+ app, ["stats", "-d", str(yolo_classification_dataset)]
201
+ )
202
+
203
+ assert result.exit_code == 0
204
+ assert "cat" in result.output
205
+ assert "dog" in result.output
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes