argus-cv 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of argus-cv might be problematic. Click here for more details.

argus/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  """Argus - Vision AI dataset toolkit."""
2
2
 
3
- __version__ = "1.0.1"
3
+ __version__ = "1.2.0"
argus/cli.py CHANGED
@@ -12,6 +12,7 @@ from rich.progress import Progress, SpinnerColumn, TextColumn
12
12
  from rich.table import Table
13
13
 
14
14
  from argus.core import COCODataset, Dataset, YOLODataset
15
+ from argus.core.base import TaskType
15
16
  from argus.core.split import (
16
17
  is_coco_unsplit,
17
18
  parse_ratio,
@@ -238,18 +239,29 @@ def view(
238
239
  help="Specific split to view (train, val, test).",
239
240
  ),
240
241
  ] = None,
242
+ max_classes: Annotated[
243
+ int | None,
244
+ typer.Option(
245
+ "--max-classes",
246
+ "-m",
247
+ help="Maximum classes to show in grid (classification only).",
248
+ ),
249
+ ] = None,
241
250
  ) -> None:
242
251
  """View annotated images in a dataset.
243
252
 
244
253
  Opens an interactive viewer to browse images with their annotations
245
254
  (bounding boxes and segmentation masks) overlaid.
246
255
 
256
+ For classification datasets, shows a grid view with one image per class.
257
+
247
258
  Controls:
248
- - Right Arrow / N: Next image
249
- - Left Arrow / P: Previous image
250
- - Mouse Wheel: Zoom in/out
251
- - Mouse Drag: Pan when zoomed
252
- - R: Reset zoom
259
+ - Right Arrow / N: Next image(s)
260
+ - Left Arrow / P: Previous image(s)
261
+ - Mouse Wheel: Zoom in/out (detection/segmentation only)
262
+ - Mouse Drag: Pan when zoomed (detection/segmentation only)
263
+ - R: Reset zoom / Reset to first images
264
+ - T: Toggle annotations (detection/segmentation only)
253
265
  - Q / ESC: Quit viewer
254
266
  """
255
267
  # Resolve path and validate
@@ -280,38 +292,76 @@ def view(
280
292
  )
281
293
  raise typer.Exit(1)
282
294
 
283
- # Get image paths
284
- with Progress(
285
- SpinnerColumn(),
286
- TextColumn("[progress.description]{task.description}"),
287
- console=console,
288
- transient=True,
289
- ) as progress:
290
- progress.add_task("Loading images...", total=None)
291
- image_paths = dataset.get_image_paths(split)
295
+ # Generate consistent colors for each class
296
+ class_colors = _generate_class_colors(dataset.class_names)
292
297
 
293
- if not image_paths:
294
- console.print("[yellow]No images found in the dataset.[/yellow]")
295
- return
298
+ # Handle classification datasets with grid viewer
299
+ if dataset.task == TaskType.CLASSIFICATION:
300
+ # Use first split if specified, otherwise let get_images_by_class handle it
301
+ view_split = split if split else (dataset.splits[0] if dataset.splits else None)
296
302
 
297
- console.print(
298
- f"[green]Found {len(image_paths)} images. "
299
- f"Opening viewer...[/green]\n"
300
- "[dim]Controls: ← / → or P / N to navigate, "
301
- "Mouse wheel to zoom, Drag to pan, R to reset, Q / ESC to quit[/dim]"
302
- )
303
+ with Progress(
304
+ SpinnerColumn(),
305
+ TextColumn("[progress.description]{task.description}"),
306
+ console=console,
307
+ transient=True,
308
+ ) as progress:
309
+ progress.add_task("Loading images by class...", total=None)
310
+ images_by_class = dataset.get_images_by_class(view_split)
303
311
 
304
- # Generate consistent colors for each class
305
- class_colors = _generate_class_colors(dataset.class_names)
312
+ total_images = sum(len(imgs) for imgs in images_by_class.values())
313
+ if total_images == 0:
314
+ console.print("[yellow]No images found in the dataset.[/yellow]")
315
+ return
306
316
 
307
- # Create and run the interactive viewer
308
- viewer = _ImageViewer(
309
- image_paths=image_paths,
310
- dataset=dataset,
311
- class_colors=class_colors,
312
- window_name=f"Argus Viewer - {dataset_path.name}",
313
- )
314
- viewer.run()
317
+ num_classes = len(dataset.class_names)
318
+ display_classes = min(num_classes, max_classes) if max_classes else num_classes
319
+
320
+ console.print(
321
+ f"[green]Found {total_images} images across {num_classes} classes "
322
+ f"(showing {display_classes}). Opening grid viewer...[/green]\n"
323
+ "[dim]Controls: ← / → or P / N to navigate all classes, "
324
+ "R to reset, Q / ESC to quit[/dim]"
325
+ )
326
+
327
+ viewer = _ClassificationGridViewer(
328
+ images_by_class=images_by_class,
329
+ class_names=dataset.class_names,
330
+ class_colors=class_colors,
331
+ window_name=f"Argus Classification Viewer - {dataset_path.name}",
332
+ max_classes=max_classes,
333
+ )
334
+ viewer.run()
335
+ else:
336
+ # Detection/Segmentation viewer
337
+ with Progress(
338
+ SpinnerColumn(),
339
+ TextColumn("[progress.description]{task.description}"),
340
+ console=console,
341
+ transient=True,
342
+ ) as progress:
343
+ progress.add_task("Loading images...", total=None)
344
+ image_paths = dataset.get_image_paths(split)
345
+
346
+ if not image_paths:
347
+ console.print("[yellow]No images found in the dataset.[/yellow]")
348
+ return
349
+
350
+ console.print(
351
+ f"[green]Found {len(image_paths)} images. "
352
+ f"Opening viewer...[/green]\n"
353
+ "[dim]Controls: ← / → or P / N to navigate, "
354
+ "Mouse wheel to zoom, Drag to pan, R to reset, T to toggle annotations, "
355
+ "Q / ESC to quit[/dim]"
356
+ )
357
+
358
+ viewer = _ImageViewer(
359
+ image_paths=image_paths,
360
+ dataset=dataset,
361
+ class_colors=class_colors,
362
+ window_name=f"Argus Viewer - {dataset_path.name}",
363
+ )
364
+ viewer.run()
315
365
 
316
366
  console.print("[green]Viewer closed.[/green]")
317
367
 
@@ -480,6 +530,9 @@ class _ImageViewer:
480
530
  self.current_img: np.ndarray | None = None
481
531
  self.annotated_img: np.ndarray | None = None
482
532
 
533
+ # Annotation visibility toggle
534
+ self.show_annotations = True
535
+
483
536
  def _load_current_image(self) -> bool:
484
537
  """Load and annotate the current image."""
485
538
  image_path = self.image_paths[self.current_idx]
@@ -500,7 +553,12 @@ class _ImageViewer:
500
553
  if self.annotated_img is None:
501
554
  return np.zeros((480, 640, 3), dtype=np.uint8)
502
555
 
503
- img = self.annotated_img
556
+ if self.show_annotations:
557
+ img = self.annotated_img
558
+ elif self.current_img is not None:
559
+ img = self.current_img
560
+ else:
561
+ img = self.annotated_img
504
562
  h, w = img.shape[:2]
505
563
 
506
564
  if self.zoom == 1.0 and self.pan_x == 0.0 and self.pan_y == 0.0:
@@ -537,6 +595,8 @@ class _ImageViewer:
537
595
  info_text = f"[{idx}/{total}] {image_path.name}"
538
596
  if self.zoom > 1.0:
539
597
  info_text += f" (Zoom: {self.zoom:.1f}x)"
598
+ if not self.show_annotations:
599
+ info_text += " [Annotations: OFF]"
540
600
 
541
601
  cv2.putText(
542
602
  display, info_text, (10, 30),
@@ -638,6 +698,187 @@ class _ImageViewer:
638
698
  self._prev_image()
639
699
  elif key == ord("r"): # R to reset zoom
640
700
  self._reset_view()
701
+ elif key == ord("t"): # T to toggle annotations
702
+ self.show_annotations = not self.show_annotations
703
+
704
+ cv2.destroyAllWindows()
705
+
706
+
707
+ class _ClassificationGridViewer:
708
+ """Grid viewer for classification datasets showing one image per class."""
709
+
710
+ def __init__(
711
+ self,
712
+ images_by_class: dict[str, list[Path]],
713
+ class_names: list[str],
714
+ class_colors: dict[str, tuple[int, int, int]],
715
+ window_name: str,
716
+ max_classes: int | None = None,
717
+ tile_size: int = 300,
718
+ ):
719
+ # Limit classes if max_classes specified
720
+ if max_classes and len(class_names) > max_classes:
721
+ self.class_names = class_names[:max_classes]
722
+ else:
723
+ self.class_names = class_names
724
+
725
+ self.images_by_class = {
726
+ cls: images_by_class.get(cls, []) for cls in self.class_names
727
+ }
728
+ self.class_colors = class_colors
729
+ self.window_name = window_name
730
+ self.tile_size = tile_size
731
+
732
+ # Global image index (same for all classes)
733
+ self.current_index = 0
734
+
735
+ # Calculate max images across all classes
736
+ self.max_images = max(
737
+ len(imgs) for imgs in self.images_by_class.values()
738
+ ) if self.images_by_class else 0
739
+
740
+ # Calculate grid layout
741
+ self.cols, self.rows = self._calculate_grid_layout()
742
+
743
+ def _calculate_grid_layout(self) -> tuple[int, int]:
744
+ """Calculate optimal grid layout based on number of classes."""
745
+ n = len(self.class_names)
746
+ if n <= 0:
747
+ return 1, 1
748
+
749
+ # Try to make a roughly square grid
750
+ import math
751
+
752
+ cols = int(math.ceil(math.sqrt(n)))
753
+ rows = int(math.ceil(n / cols))
754
+ return cols, rows
755
+
756
+ def _create_tile(
757
+ self, class_name: str, image_path: Path | None, index: int, total: int
758
+ ) -> np.ndarray:
759
+ """Create a single tile for a class."""
760
+ tile = np.zeros((self.tile_size, self.tile_size, 3), dtype=np.uint8)
761
+
762
+ if image_path is not None and image_path.exists():
763
+ # Load and resize image
764
+ img = cv2.imread(str(image_path))
765
+ if img is not None:
766
+ # Resize maintaining aspect ratio
767
+ h, w = img.shape[:2]
768
+ scale = min(self.tile_size / w, self.tile_size / h)
769
+ new_w = int(w * scale)
770
+ new_h = int(h * scale)
771
+ resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
772
+
773
+ # Center in tile
774
+ x_offset = (self.tile_size - new_w) // 2
775
+ y_offset = (self.tile_size - new_h) // 2
776
+ tile[y_offset : y_offset + new_h, x_offset : x_offset + new_w] = resized
777
+
778
+ # Draw label at top: "class_name (N/M)"
779
+ if image_path is not None:
780
+ label = f"{class_name} ({index + 1}/{total})"
781
+ else:
782
+ label = f"{class_name} (-/{total})"
783
+
784
+ font = cv2.FONT_HERSHEY_SIMPLEX
785
+ font_scale = 0.5
786
+ thickness = 1
787
+ (label_w, label_h), baseline = cv2.getTextSize(
788
+ label, font, font_scale, thickness
789
+ )
790
+
791
+ # Semi-transparent background for label
792
+ overlay = tile.copy()
793
+ label_bg_height = label_h + baseline + 10
794
+ cv2.rectangle(overlay, (0, 0), (self.tile_size, label_bg_height), (0, 0, 0), -1)
795
+ cv2.addWeighted(overlay, 0.6, tile, 0.4, 0, tile)
796
+
797
+ cv2.putText(
798
+ tile,
799
+ label,
800
+ (5, label_h + 5),
801
+ font,
802
+ font_scale,
803
+ (255, 255, 255),
804
+ thickness,
805
+ )
806
+
807
+ # Draw thin border
808
+ border_end = self.tile_size - 1
809
+ cv2.rectangle(tile, (0, 0), (border_end, border_end), (80, 80, 80), 1)
810
+
811
+ return tile
812
+
813
+ def _compose_grid(self) -> np.ndarray:
814
+ """Compose all tiles into a single grid image."""
815
+ grid_h = self.rows * self.tile_size
816
+ grid_w = self.cols * self.tile_size
817
+ grid = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
818
+
819
+ for i, class_name in enumerate(self.class_names):
820
+ row = i // self.cols
821
+ col = i % self.cols
822
+
823
+ images = self.images_by_class[class_name]
824
+ total = len(images)
825
+
826
+ # Use global index - show black tile if class doesn't have this image
827
+ if self.current_index < total:
828
+ image_path = images[self.current_index]
829
+ display_index = self.current_index
830
+ else:
831
+ image_path = None
832
+ display_index = self.current_index
833
+
834
+ tile = self._create_tile(class_name, image_path, display_index, total)
835
+
836
+ y_start = row * self.tile_size
837
+ x_start = col * self.tile_size
838
+ y_end = y_start + self.tile_size
839
+ x_end = x_start + self.tile_size
840
+ grid[y_start:y_end, x_start:x_end] = tile
841
+
842
+ return grid
843
+
844
+ def _next_images(self) -> None:
845
+ """Advance to next image index."""
846
+ if self.max_images > 0:
847
+ self.current_index = min(self.current_index + 1, self.max_images - 1)
848
+
849
+ def _prev_images(self) -> None:
850
+ """Go back to previous image index."""
851
+ self.current_index = max(self.current_index - 1, 0)
852
+
853
+ def _reset_indices(self) -> None:
854
+ """Reset to first image."""
855
+ self.current_index = 0
856
+
857
+ def run(self) -> None:
858
+ """Run the interactive grid viewer."""
859
+ if not self.class_names:
860
+ console.print("[yellow]No classes to display.[/yellow]")
861
+ return
862
+
863
+ cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
864
+
865
+ while True:
866
+ # Compose and display grid
867
+ grid = self._compose_grid()
868
+ cv2.imshow(self.window_name, grid)
869
+
870
+ # Wait for input
871
+ key = cv2.waitKey(30) & 0xFF
872
+
873
+ # Handle keyboard input
874
+ if key == ord("q") or key == 27: # Q or ESC
875
+ break
876
+ elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
877
+ self._next_images()
878
+ elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
879
+ self._prev_images()
880
+ elif key == ord("r"): # R to reset
881
+ self._reset_indices()
641
882
 
642
883
  cv2.destroyAllWindows()
643
884
 
argus/core/yolo.py CHANGED
@@ -12,9 +12,9 @@ from argus.core.base import Dataset, DatasetFormat, TaskType
12
12
  class YOLODataset(Dataset):
13
13
  """YOLO format dataset.
14
14
 
15
- Supports detection and segmentation tasks.
15
+ Supports detection, segmentation, and classification tasks.
16
16
 
17
- Structure:
17
+ Structure (detection/segmentation):
18
18
  dataset/
19
19
  ├── data.yaml (or *.yaml/*.yml with 'names' key)
20
20
  ├── images/
@@ -23,6 +23,19 @@ class YOLODataset(Dataset):
23
23
  └── labels/
24
24
  ├── train/
25
25
  └── val/
26
+
27
+ Structure (classification):
28
+ dataset/
29
+ ├── images/
30
+ │ ├── train/
31
+ │ │ ├── class1/
32
+ │ │ │ ├── img1.jpg
33
+ │ │ │ └── img2.jpg
34
+ │ │ └── class2/
35
+ │ │ └── img1.jpg
36
+ │ └── val/
37
+ │ ├── class1/
38
+ │ └── class2/
26
39
  """
27
40
 
28
41
  config_file: Path | None = None
@@ -43,8 +56,13 @@ class YOLODataset(Dataset):
43
56
  if not path.is_dir():
44
57
  return None
45
58
 
46
- # Try detection/segmentation (YAML-based)
47
- return cls._detect_yaml_based(path)
59
+ # Try detection/segmentation (YAML-based) first
60
+ result = cls._detect_yaml_based(path)
61
+ if result:
62
+ return result
63
+
64
+ # Try classification (directory-based structure)
65
+ return cls._detect_classification(path)
48
66
 
49
67
  @classmethod
50
68
  def _detect_yaml_based(cls, path: Path) -> "YOLODataset | None":
@@ -103,16 +121,121 @@ class YOLODataset(Dataset):
103
121
 
104
122
  return None
105
123
 
124
+ @classmethod
125
+ def _detect_classification(cls, path: Path) -> "YOLODataset | None":
126
+ """Detect classification dataset from directory structure.
127
+
128
+ Classification datasets can have two structures:
129
+
130
+ 1. Split structure:
131
+ images/{split}/class_name/image.jpg
132
+
133
+ 2. Flat structure (unsplit):
134
+ class_name/image.jpg
135
+
136
+ No YAML config required - class names inferred from directory names.
137
+
138
+ Args:
139
+ path: Directory path to check.
140
+
141
+ Returns:
142
+ YOLODataset if classification structure found, None otherwise.
143
+ """
144
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
145
+
146
+ # Try split structure first: images/{split}/class/
147
+ images_root = path / "images"
148
+ if images_root.is_dir():
149
+ splits: list[str] = []
150
+ class_names_set: set[str] = set()
151
+
152
+ for split_name in ["train", "val", "test"]:
153
+ split_dir = images_root / split_name
154
+ if not split_dir.is_dir():
155
+ continue
156
+
157
+ # Get subdirectories (potential class folders)
158
+ class_dirs = [d for d in split_dir.iterdir() if d.is_dir()]
159
+ if not class_dirs:
160
+ continue
161
+
162
+ # Check if at least one class dir contains images
163
+ has_images = False
164
+ for class_dir in class_dirs:
165
+ for f in class_dir.iterdir():
166
+ if f.suffix.lower() in image_extensions:
167
+ has_images = True
168
+ break
169
+ if has_images:
170
+ break
171
+
172
+ if has_images:
173
+ splits.append(split_name)
174
+ class_names_set.update(d.name for d in class_dirs)
175
+
176
+ if splits and class_names_set:
177
+ class_names = sorted(class_names_set)
178
+ return cls(
179
+ path=path,
180
+ task=TaskType.CLASSIFICATION,
181
+ num_classes=len(class_names),
182
+ class_names=class_names,
183
+ splits=splits,
184
+ config_file=None,
185
+ )
186
+
187
+ # Try flat structure: class_name/image.jpg (no images/ or split dirs)
188
+ # Check if root contains subdirectories with images
189
+ class_dirs = [d for d in path.iterdir() if d.is_dir()]
190
+
191
+ # Filter out common non-class directories
192
+ excluded_dirs = {"images", "labels", "annotations", ".git", "__pycache__"}
193
+ class_dirs = [d for d in class_dirs if d.name not in excluded_dirs]
194
+
195
+ if not class_dirs:
196
+ return None
197
+
198
+ # Check if these are class directories (contain images directly)
199
+ class_names_set = set()
200
+ for class_dir in class_dirs:
201
+ has_images = any(
202
+ f.suffix.lower() in image_extensions
203
+ for f in class_dir.iterdir()
204
+ if f.is_file()
205
+ )
206
+ if has_images:
207
+ class_names_set.add(class_dir.name)
208
+
209
+ # Need at least 2 classes to be a valid classification dataset
210
+ if len(class_names_set) < 2:
211
+ return None
212
+
213
+ class_names = sorted(class_names_set)
214
+ return cls(
215
+ path=path,
216
+ task=TaskType.CLASSIFICATION,
217
+ num_classes=len(class_names),
218
+ class_names=class_names,
219
+ splits=[], # No splits for flat structure
220
+ config_file=None,
221
+ )
222
+
106
223
  def get_instance_counts(self) -> dict[str, dict[str, int]]:
107
224
  """Get the number of annotation instances per class, per split.
108
225
 
109
- Parses all label files in labels/{split}/*.txt and counts
110
- occurrences of each class ID. For unsplit datasets, uses "unsplit"
111
- as the split name.
226
+ For detection/segmentation: Parses all label files in labels/{split}/*.txt
227
+ and counts occurrences of each class ID.
228
+
229
+ For classification: Counts images in each class directory
230
+ (1 image = 1 instance).
112
231
 
113
232
  Returns:
114
233
  Dictionary mapping split name to dict of class name to instance count.
115
234
  """
235
+ # Handle classification datasets differently
236
+ if self.task == TaskType.CLASSIFICATION:
237
+ return self._get_classification_instance_counts()
238
+
116
239
  counts: dict[str, dict[str, int]] = {}
117
240
 
118
241
  # Build class_id -> class_name mapping
@@ -162,15 +285,77 @@ class YOLODataset(Dataset):
162
285
 
163
286
  return counts
164
287
 
288
+ def _get_classification_instance_counts(self) -> dict[str, dict[str, int]]:
289
+ """Get instance counts for classification datasets.
290
+
291
+ Each image is one instance of its class.
292
+
293
+ Returns:
294
+ Dictionary mapping split name to dict of class name to image count.
295
+ """
296
+ counts: dict[str, dict[str, int]] = {}
297
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
298
+
299
+ # Handle flat structure (no splits)
300
+ if not self.splits:
301
+ split_counts: dict[str, int] = {}
302
+ for class_name in self.class_names:
303
+ class_dir = self.path / class_name
304
+ if not class_dir.is_dir():
305
+ split_counts[class_name] = 0
306
+ continue
307
+
308
+ image_count = sum(
309
+ 1
310
+ for f in class_dir.iterdir()
311
+ if f.suffix.lower() in image_extensions
312
+ )
313
+ split_counts[class_name] = image_count
314
+
315
+ counts["unsplit"] = split_counts
316
+ return counts
317
+
318
+ # Handle split structure
319
+ images_root = self.path / "images"
320
+ for split in self.splits:
321
+ split_dir = images_root / split
322
+ if not split_dir.is_dir():
323
+ continue
324
+
325
+ split_counts = {}
326
+ for class_name in self.class_names:
327
+ class_dir = split_dir / class_name
328
+ if not class_dir.is_dir():
329
+ split_counts[class_name] = 0
330
+ continue
331
+
332
+ image_count = sum(
333
+ 1
334
+ for f in class_dir.iterdir()
335
+ if f.suffix.lower() in image_extensions
336
+ )
337
+ split_counts[class_name] = image_count
338
+
339
+ counts[split] = split_counts
340
+
341
+ return counts
342
+
165
343
  def get_image_counts(self) -> dict[str, dict[str, int]]:
166
344
  """Get image counts per split, including background images.
167
345
 
168
- Counts label files in labels/{split}/*.txt. Empty files are
169
- counted as background images.
346
+ For detection/segmentation: Counts label files in labels/{split}/*.txt.
347
+ Empty files are counted as background images.
348
+
349
+ For classification: Counts total images across all class directories.
350
+ Background count is always 0 (no background concept in classification).
170
351
 
171
352
  Returns:
172
353
  Dictionary mapping split name to dict with "total" and "background" counts.
173
354
  """
355
+ # Handle classification datasets differently
356
+ if self.task == TaskType.CLASSIFICATION:
357
+ return self._get_classification_image_counts()
358
+
174
359
  counts: dict[str, dict[str, int]] = {}
175
360
 
176
361
  labels_root = self.path / "labels"
@@ -203,6 +388,56 @@ class YOLODataset(Dataset):
203
388
 
204
389
  return counts
205
390
 
391
+ def _get_classification_image_counts(self) -> dict[str, dict[str, int]]:
392
+ """Get image counts for classification datasets.
393
+
394
+ Returns:
395
+ Dictionary mapping split name to dict with "total" and "background" counts.
396
+ Background is always 0 for classification.
397
+ """
398
+ counts: dict[str, dict[str, int]] = {}
399
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
400
+
401
+ # Handle flat structure (no splits)
402
+ if not self.splits:
403
+ total = 0
404
+ for class_name in self.class_names:
405
+ class_dir = self.path / class_name
406
+ if not class_dir.is_dir():
407
+ continue
408
+
409
+ total += sum(
410
+ 1
411
+ for f in class_dir.iterdir()
412
+ if f.suffix.lower() in image_extensions
413
+ )
414
+
415
+ counts["unsplit"] = {"total": total, "background": 0}
416
+ return counts
417
+
418
+ # Handle split structure
419
+ images_root = self.path / "images"
420
+ for split in self.splits:
421
+ split_dir = images_root / split
422
+ if not split_dir.is_dir():
423
+ continue
424
+
425
+ total = 0
426
+ for class_name in self.class_names:
427
+ class_dir = split_dir / class_name
428
+ if not class_dir.is_dir():
429
+ continue
430
+
431
+ total += sum(
432
+ 1
433
+ for f in class_dir.iterdir()
434
+ if f.suffix.lower() in image_extensions
435
+ )
436
+
437
+ counts[split] = {"total": total, "background": 0}
438
+
439
+ return counts
440
+
206
441
  @classmethod
207
442
  def _detect_splits(cls, path: Path, config: dict) -> list[str]:
208
443
  """Detect available splits from config and filesystem.
@@ -301,6 +536,10 @@ class YOLODataset(Dataset):
301
536
  Returns:
302
537
  List of image file paths sorted alphabetically.
303
538
  """
539
+ # Handle classification datasets differently
540
+ if self.task == TaskType.CLASSIFICATION:
541
+ return self._get_classification_image_paths(split)
542
+
304
543
  image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
305
544
  images_root = self.path / "images"
306
545
  image_paths: list[Path] = []
@@ -345,6 +584,51 @@ class YOLODataset(Dataset):
345
584
 
346
585
  return sorted(image_paths, key=lambda p: p.name)
347
586
 
587
+ def _get_classification_image_paths(self, split: str | None = None) -> list[Path]:
588
+ """Get image paths for classification datasets.
589
+
590
+ Args:
591
+ split: Specific split to get images from. If None, returns all images.
592
+
593
+ Returns:
594
+ List of image file paths sorted alphabetically.
595
+ """
596
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
597
+ image_paths: list[Path] = []
598
+
599
+ # Handle flat structure (no splits)
600
+ if not self.splits:
601
+ for class_name in self.class_names:
602
+ class_dir = self.path / class_name
603
+ if not class_dir.is_dir():
604
+ continue
605
+
606
+ for img_file in class_dir.iterdir():
607
+ if img_file.suffix.lower() in image_extensions:
608
+ image_paths.append(img_file)
609
+
610
+ return sorted(image_paths, key=lambda p: p.name)
611
+
612
+ # Handle split structure
613
+ images_root = self.path / "images"
614
+ splits_to_search = [split] if split else self.splits
615
+
616
+ for s in splits_to_search:
617
+ split_dir = images_root / s
618
+ if not split_dir.is_dir():
619
+ continue
620
+
621
+ for class_name in self.class_names:
622
+ class_dir = split_dir / class_name
623
+ if not class_dir.is_dir():
624
+ continue
625
+
626
+ for img_file in class_dir.iterdir():
627
+ if img_file.suffix.lower() in image_extensions:
628
+ image_paths.append(img_file)
629
+
630
+ return sorted(image_paths, key=lambda p: p.name)
631
+
348
632
  def get_annotations_for_image(self, image_path: Path) -> list[dict]:
349
633
  """Get annotations for a specific image.
350
634
 
@@ -445,3 +729,65 @@ class YOLODataset(Dataset):
445
729
  pass
446
730
 
447
731
  return annotations
732
+
733
+ def get_images_by_class(self, split: str | None = None) -> dict[str, list[Path]]:
734
+ """Get images grouped by class for classification datasets.
735
+
736
+ Args:
737
+ split: Specific split to get images from. If None, uses first
738
+ available split or all images for flat structure.
739
+
740
+ Returns:
741
+ Dictionary mapping class name to list of image paths.
742
+ """
743
+ if self.task != TaskType.CLASSIFICATION:
744
+ return {}
745
+
746
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
747
+ images_by_class: dict[str, list[Path]] = {cls: [] for cls in self.class_names}
748
+
749
+ # Handle flat structure (no splits)
750
+ if not self.splits:
751
+ for class_name in self.class_names:
752
+ class_dir = self.path / class_name
753
+ if not class_dir.is_dir():
754
+ continue
755
+
756
+ for img_file in class_dir.iterdir():
757
+ if img_file.suffix.lower() in image_extensions:
758
+ images_by_class[class_name].append(img_file)
759
+
760
+ # Sort images within each class
761
+ for class_name in images_by_class:
762
+ images_by_class[class_name] = sorted(
763
+ images_by_class[class_name], key=lambda p: p.name
764
+ )
765
+
766
+ return images_by_class
767
+
768
+ # Handle split structure
769
+ images_root = self.path / "images"
770
+ default_splits = self.splits[:1] if self.splits else []
771
+ splits_to_search = [split] if split else default_splits
772
+
773
+ for s in splits_to_search:
774
+ split_dir = images_root / s
775
+ if not split_dir.is_dir():
776
+ continue
777
+
778
+ for class_name in self.class_names:
779
+ class_dir = split_dir / class_name
780
+ if not class_dir.is_dir():
781
+ continue
782
+
783
+ for img_file in class_dir.iterdir():
784
+ if img_file.suffix.lower() in image_extensions:
785
+ images_by_class[class_name].append(img_file)
786
+
787
+ # Sort images within each class for consistent ordering
788
+ for class_name in images_by_class:
789
+ images_by_class[class_name] = sorted(
790
+ images_by_class[class_name], key=lambda p: p.name
791
+ )
792
+
793
+ return images_by_class
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: argus-cv
3
- Version: 1.0.1
3
+ Version: 1.2.0
4
4
  Summary: CLI tool for working with vision AI datasets
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: numpy>=1.24.0
@@ -1,13 +1,13 @@
1
- argus/__init__.py,sha256=wDcE-H1pju8c4LVZEuSCsT1-7UQukOEsbSH6suOhWR8,64
1
+ argus/__init__.py,sha256=e8XoTWOOBcvGAX8HIUZFycf6Yi3IMFZfXSj4oiK2Cwg,64
2
2
  argus/__main__.py,sha256=63ezHx8eL_lCMoZrCbKhmpao0fmdvYVw1chbknGg-oI,104
3
- argus/cli.py,sha256=piAznlc1Z7EF8Ja5p2TYWeDW6UrpLvh_vnIoI-P-SwE,25325
3
+ argus/cli.py,sha256=SHAN07n1ffqWil2qhLkc0xIdU7iXhjTPW8nFMDUhnNQ,34103
4
4
  argus/commands/__init__.py,sha256=i2oor9hpVpF-_1qZWCGDLwwi1pZGJfZnUKJZ_NMBG18,30
5
5
  argus/core/__init__.py,sha256=Plv_tk0Wq9OlGLDPOSQWxrd5cTwNK9kEZANTim3s23A,348
6
6
  argus/core/base.py,sha256=Vd_2xR6L3lhu9vHoyLeFTc0Dg59py_D9kaye1tta5Co,3678
7
7
  argus/core/coco.py,sha256=bJvOhBzwjsOU8DBijGDysnSPlprwetkPf4Z02UOmqw0,15757
8
8
  argus/core/split.py,sha256=kEWtbdg6bH-WiNFf83HkqZD90EL4gsavw6JiefuAETs,10776
9
- argus/core/yolo.py,sha256=KTWgmEguxKZ_C0WsMxUB-B-zbx_Oi1ieGDk3Osuh0xY,15876
10
- argus_cv-1.0.1.dist-info/METADATA,sha256=3rISiuP5iBxwE2JasTNwV0ckEAEehB9eLzEr-4l3YGw,1070
11
- argus_cv-1.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
12
- argus_cv-1.0.1.dist-info/entry_points.txt,sha256=dvJFH7BkrOxJnifSjPhwq1YCafPaqdngWyBuFYE73yY,43
13
- argus_cv-1.0.1.dist-info/RECORD,,
9
+ argus/core/yolo.py,sha256=W7WH7RwNYvJu8nYdNcXVNRkK5IBov-1dG23esOkCB1M,28213
10
+ argus_cv-1.2.0.dist-info/METADATA,sha256=SP_8kVPy0Mv_9F8CM9zWTARqCOmzYFetyF7QaEgNT-Y,1070
11
+ argus_cv-1.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
12
+ argus_cv-1.2.0.dist-info/entry_points.txt,sha256=dvJFH7BkrOxJnifSjPhwq1YCafPaqdngWyBuFYE73yY,43
13
+ argus_cv-1.2.0.dist-info/RECORD,,