argus-cv 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of argus-cv might be problematic. Click here for more details.

argus/cli.py CHANGED
@@ -11,7 +11,8 @@ from rich.console import Console
11
11
  from rich.progress import Progress, SpinnerColumn, TextColumn
12
12
  from rich.table import Table
13
13
 
14
- from argus.core import COCODataset, Dataset, YOLODataset
14
+ from argus.core import COCODataset, Dataset, MaskDataset, YOLODataset
15
+ from argus.core.base import DatasetFormat, TaskType
15
16
  from argus.core.split import (
16
17
  is_coco_unsplit,
17
18
  parse_ratio,
@@ -127,12 +128,19 @@ def stats(
127
128
  dataset = _detect_dataset(dataset_path)
128
129
  if not dataset:
129
130
  console.print(
130
- f"[red]Error: No YOLO or COCO dataset found at {dataset_path}[/red]\n"
131
+ f"[red]Error: No dataset found at {dataset_path}[/red]\n"
131
132
  "[yellow]Ensure the path points to a dataset root containing "
132
- "data.yaml (YOLO) or annotations/ folder (COCO).[/yellow]"
133
+ "data.yaml (YOLO), annotations/ folder (COCO), or "
134
+ "images/ + masks/ directories (Mask).[/yellow]"
133
135
  )
134
136
  raise typer.Exit(1)
135
137
 
138
+ # Handle mask datasets with pixel statistics
139
+ if dataset.format == DatasetFormat.MASK:
140
+ assert isinstance(dataset, MaskDataset)
141
+ _show_mask_stats(dataset, dataset_path)
142
+ return
143
+
136
144
  # Get instance counts with progress indicator
137
145
  with Progress(
138
146
  SpinnerColumn(),
@@ -211,10 +219,100 @@ def stats(
211
219
  else:
212
220
  image_parts.append(f"{split}: {img_total}")
213
221
 
214
- console.print(f"\n[green]Dataset: {dataset.format.value.upper()} | "
215
- f"Task: {dataset.task.value} | "
216
- f"Classes: {len(sorted_classes)} | "
217
- f"Total instances: {grand_total}[/green]")
222
+ console.print(
223
+ f"\n[green]Dataset: {dataset.format.value.upper()} | "
224
+ f"Task: {dataset.task.value} | "
225
+ f"Classes: {len(sorted_classes)} | "
226
+ f"Total instances: {grand_total}[/green]"
227
+ )
228
+
229
+ if image_parts:
230
+ console.print(f"[blue]Images: {' | '.join(image_parts)}[/blue]")
231
+
232
+
233
+ def _show_mask_stats(dataset: MaskDataset, dataset_path: Path) -> None:
234
+ """Show statistics for mask datasets with pixel-level information.
235
+
236
+ Args:
237
+ dataset: The MaskDataset instance.
238
+ dataset_path: Path to the dataset root.
239
+ """
240
+ with Progress(
241
+ SpinnerColumn(),
242
+ TextColumn("[progress.description]{task.description}"),
243
+ console=console,
244
+ transient=True,
245
+ ) as progress:
246
+ progress.add_task("Analyzing mask dataset...", total=None)
247
+ pixel_counts = dataset.get_pixel_counts()
248
+ image_presence = dataset.get_image_class_presence()
249
+ image_counts = dataset.get_image_counts()
250
+
251
+ # Get class mapping
252
+ class_mapping = dataset.get_class_mapping()
253
+
254
+ # Calculate total non-ignored pixels
255
+ total_pixels = sum(
256
+ count
257
+ for class_id, count in pixel_counts.items()
258
+ if class_id != dataset.ignore_index
259
+ )
260
+ ignored_pixels = pixel_counts.get(dataset.ignore_index, 0)
261
+
262
+ # Calculate total images
263
+ total_images = sum(ic["total"] for ic in image_counts.values())
264
+
265
+ # Create table
266
+ splits_str = ", ".join(dataset.splits) if dataset.splits else "unsplit"
267
+ title = f"Class Statistics: {dataset_path.name} ({splits_str})"
268
+ table = Table(title=title)
269
+ table.add_column("Class", style="cyan")
270
+ table.add_column("Total Pixels", justify="right", style="green")
271
+ table.add_column("% Coverage", justify="right", style="magenta")
272
+ table.add_column("Images With", justify="right", style="yellow")
273
+
274
+ # Sort classes by class_id
275
+ sorted_class_ids = sorted(class_mapping.keys())
276
+
277
+ for class_id in sorted_class_ids:
278
+ class_name = class_mapping[class_id]
279
+ pixels = pixel_counts.get(class_id, 0)
280
+ presence = image_presence.get(class_id, 0)
281
+
282
+ # Calculate coverage percentage
283
+ coverage = (pixels / total_pixels * 100) if total_pixels > 0 else 0.0
284
+
285
+ table.add_row(
286
+ class_name,
287
+ f"{pixels:,}",
288
+ f"{coverage:.1f}%",
289
+ str(presence),
290
+ )
291
+
292
+ # Add ignored row if there are ignored pixels
293
+ if ignored_pixels > 0:
294
+ table.add_section()
295
+ table.add_row(
296
+ "[dim](ignored)[/dim]",
297
+ f"[dim]{ignored_pixels:,}[/dim]",
298
+ "[dim]-[/dim]",
299
+ f"[dim]{total_images}[/dim]",
300
+ )
301
+
302
+ console.print(table)
303
+
304
+ # Summary line
305
+ console.print(f"\n[green]Dataset: {dataset_path}[/green]")
306
+ console.print(
307
+ f"[green]Format: {dataset.format.value.upper()} | "
308
+ f"Task: {dataset.task.value}[/green]"
309
+ )
310
+
311
+ # Image counts per split
312
+ image_parts = []
313
+ for split in dataset.splits if dataset.splits else ["unsplit"]:
314
+ if split in image_counts:
315
+ image_parts.append(f"{split}: {image_counts[split]['total']}")
218
316
 
219
317
  if image_parts:
220
318
  console.print(f"[blue]Images: {' | '.join(image_parts)}[/blue]")
@@ -238,19 +336,39 @@ def view(
238
336
  help="Specific split to view (train, val, test).",
239
337
  ),
240
338
  ] = None,
339
+ max_classes: Annotated[
340
+ int | None,
341
+ typer.Option(
342
+ "--max-classes",
343
+ "-m",
344
+ help="Maximum classes to show in grid (classification only).",
345
+ ),
346
+ ] = None,
347
+ opacity: Annotated[
348
+ float,
349
+ typer.Option(
350
+ "--opacity",
351
+ "-o",
352
+ help="Mask overlay opacity (0.0-1.0, mask datasets only).",
353
+ min=0.0,
354
+ max=1.0,
355
+ ),
356
+ ] = 0.5,
241
357
  ) -> None:
242
358
  """View annotated images in a dataset.
243
359
 
244
360
  Opens an interactive viewer to browse images with their annotations
245
361
  (bounding boxes and segmentation masks) overlaid.
246
362
 
363
+ For classification datasets, shows a grid view with one image per class.
364
+
247
365
  Controls:
248
- - Right Arrow / N: Next image
249
- - Left Arrow / P: Previous image
250
- - Mouse Wheel: Zoom in/out
251
- - Mouse Drag: Pan when zoomed
252
- - R: Reset zoom
253
- - T: Toggle annotations
366
+ - Right Arrow / N: Next image(s)
367
+ - Left Arrow / P: Previous image(s)
368
+ - Mouse Wheel: Zoom in/out (detection/segmentation only)
369
+ - Mouse Drag: Pan when zoomed (detection/segmentation only)
370
+ - R: Reset zoom / Reset to first images
371
+ - T: Toggle annotations (detection/segmentation only)
254
372
  - Q / ESC: Quit viewer
255
373
  """
256
374
  # Resolve path and validate
@@ -281,39 +399,111 @@ def view(
281
399
  )
282
400
  raise typer.Exit(1)
283
401
 
284
- # Get image paths
285
- with Progress(
286
- SpinnerColumn(),
287
- TextColumn("[progress.description]{task.description}"),
288
- console=console,
289
- transient=True,
290
- ) as progress:
291
- progress.add_task("Loading images...", total=None)
292
- image_paths = dataset.get_image_paths(split)
402
+ # Generate consistent colors for each class
403
+ class_colors = _generate_class_colors(dataset.class_names)
404
+
405
+ # Handle mask datasets with overlay viewer
406
+ if dataset.format == DatasetFormat.MASK:
407
+ assert isinstance(dataset, MaskDataset)
408
+ with Progress(
409
+ SpinnerColumn(),
410
+ TextColumn("[progress.description]{task.description}"),
411
+ console=console,
412
+ transient=True,
413
+ ) as progress:
414
+ progress.add_task("Loading images...", total=None)
415
+ image_paths = dataset.get_image_paths(split)
416
+
417
+ if not image_paths:
418
+ console.print("[yellow]No images found in the dataset.[/yellow]")
419
+ return
293
420
 
294
- if not image_paths:
295
- console.print("[yellow]No images found in the dataset.[/yellow]")
421
+ console.print(
422
+ f"[green]Found {len(image_paths)} images. "
423
+ f"Opening mask viewer...[/green]\n"
424
+ "[dim]Controls: \u2190 / \u2192 or P / N to navigate, "
425
+ "Mouse wheel to zoom, Drag to pan, R to reset, T to toggle overlay, "
426
+ "Q / ESC to quit[/dim]"
427
+ )
428
+
429
+ viewer = _MaskViewer(
430
+ image_paths=image_paths,
431
+ dataset=dataset,
432
+ class_colors=class_colors,
433
+ window_name=f"Argus Mask Viewer - {dataset_path.name}",
434
+ opacity=opacity,
435
+ )
436
+ viewer.run()
437
+ console.print("[green]Viewer closed.[/green]")
296
438
  return
297
439
 
298
- console.print(
299
- f"[green]Found {len(image_paths)} images. "
300
- f"Opening viewer...[/green]\n"
301
- "[dim]Controls: / or P / N to navigate, "
302
- "Mouse wheel to zoom, Drag to pan, R to reset, T to toggle annotations, "
303
- "Q / ESC to quit[/dim]"
304
- )
440
+ # Handle classification datasets with grid viewer
441
+ if dataset.task == TaskType.CLASSIFICATION:
442
+ # Use first split if specified, otherwise let get_images_by_class handle it
443
+ view_split = split if split else (dataset.splits[0] if dataset.splits else None)
305
444
 
306
- # Generate consistent colors for each class
307
- class_colors = _generate_class_colors(dataset.class_names)
445
+ with Progress(
446
+ SpinnerColumn(),
447
+ TextColumn("[progress.description]{task.description}"),
448
+ console=console,
449
+ transient=True,
450
+ ) as progress:
451
+ progress.add_task("Loading images by class...", total=None)
452
+ images_by_class = dataset.get_images_by_class(view_split)
308
453
 
309
- # Create and run the interactive viewer
310
- viewer = _ImageViewer(
311
- image_paths=image_paths,
312
- dataset=dataset,
313
- class_colors=class_colors,
314
- window_name=f"Argus Viewer - {dataset_path.name}",
315
- )
316
- viewer.run()
454
+ total_images = sum(len(imgs) for imgs in images_by_class.values())
455
+ if total_images == 0:
456
+ console.print("[yellow]No images found in the dataset.[/yellow]")
457
+ return
458
+
459
+ num_classes = len(dataset.class_names)
460
+ display_classes = min(num_classes, max_classes) if max_classes else num_classes
461
+
462
+ console.print(
463
+ f"[green]Found {total_images} images across {num_classes} classes "
464
+ f"(showing {display_classes}). Opening grid viewer...[/green]\n"
465
+ "[dim]Controls: ← / → or P / N to navigate all classes, "
466
+ "R to reset, Q / ESC to quit[/dim]"
467
+ )
468
+
469
+ viewer = _ClassificationGridViewer(
470
+ images_by_class=images_by_class,
471
+ class_names=dataset.class_names,
472
+ class_colors=class_colors,
473
+ window_name=f"Argus Classification Viewer - {dataset_path.name}",
474
+ max_classes=max_classes,
475
+ )
476
+ viewer.run()
477
+ else:
478
+ # Detection/Segmentation viewer
479
+ with Progress(
480
+ SpinnerColumn(),
481
+ TextColumn("[progress.description]{task.description}"),
482
+ console=console,
483
+ transient=True,
484
+ ) as progress:
485
+ progress.add_task("Loading images...", total=None)
486
+ image_paths = dataset.get_image_paths(split)
487
+
488
+ if not image_paths:
489
+ console.print("[yellow]No images found in the dataset.[/yellow]")
490
+ return
491
+
492
+ console.print(
493
+ f"[green]Found {len(image_paths)} images. "
494
+ f"Opening viewer...[/green]\n"
495
+ "[dim]Controls: ← / → or P / N to navigate, "
496
+ "Mouse wheel to zoom, Drag to pan, R to reset, T to toggle annotations, "
497
+ "Q / ESC to quit[/dim]"
498
+ )
499
+
500
+ viewer = _ImageViewer(
501
+ image_paths=image_paths,
502
+ dataset=dataset,
503
+ class_colors=class_colors,
504
+ window_name=f"Argus Viewer - {dataset_path.name}",
505
+ )
506
+ viewer.run()
317
507
 
318
508
  console.print("[green]Viewer closed.[/green]")
319
509
 
@@ -551,12 +741,16 @@ class _ImageViewer:
551
741
  info_text += " [Annotations: OFF]"
552
742
 
553
743
  cv2.putText(
554
- display, info_text, (10, 30),
555
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2
744
+ display,
745
+ info_text,
746
+ (10, 30),
747
+ cv2.FONT_HERSHEY_SIMPLEX,
748
+ 0.7,
749
+ (255, 255, 255),
750
+ 2,
556
751
  )
557
752
  cv2.putText(
558
- display, info_text, (10, 30),
559
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
753
+ display, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
560
754
  )
561
755
 
562
756
  return display
@@ -656,6 +850,448 @@ class _ImageViewer:
656
850
  cv2.destroyAllWindows()
657
851
 
658
852
 
853
+ class _ClassificationGridViewer:
854
+ """Grid viewer for classification datasets showing one image per class."""
855
+
856
+ def __init__(
857
+ self,
858
+ images_by_class: dict[str, list[Path]],
859
+ class_names: list[str],
860
+ class_colors: dict[str, tuple[int, int, int]],
861
+ window_name: str,
862
+ max_classes: int | None = None,
863
+ tile_size: int = 300,
864
+ ):
865
+ # Limit classes if max_classes specified
866
+ if max_classes and len(class_names) > max_classes:
867
+ self.class_names = class_names[:max_classes]
868
+ else:
869
+ self.class_names = class_names
870
+
871
+ self.images_by_class = {
872
+ cls: images_by_class.get(cls, []) for cls in self.class_names
873
+ }
874
+ self.class_colors = class_colors
875
+ self.window_name = window_name
876
+ self.tile_size = tile_size
877
+
878
+ # Global image index (same for all classes)
879
+ self.current_index = 0
880
+
881
+ # Calculate max images across all classes
882
+ self.max_images = (
883
+ max(len(imgs) for imgs in self.images_by_class.values())
884
+ if self.images_by_class
885
+ else 0
886
+ )
887
+
888
+ # Calculate grid layout
889
+ self.cols, self.rows = self._calculate_grid_layout()
890
+
891
+ def _calculate_grid_layout(self) -> tuple[int, int]:
892
+ """Calculate optimal grid layout based on number of classes."""
893
+ n = len(self.class_names)
894
+ if n <= 0:
895
+ return 1, 1
896
+
897
+ # Try to make a roughly square grid
898
+ import math
899
+
900
+ cols = int(math.ceil(math.sqrt(n)))
901
+ rows = int(math.ceil(n / cols))
902
+ return cols, rows
903
+
904
+ def _create_tile(
905
+ self, class_name: str, image_path: Path | None, index: int, total: int
906
+ ) -> np.ndarray:
907
+ """Create a single tile for a class."""
908
+ tile = np.zeros((self.tile_size, self.tile_size, 3), dtype=np.uint8)
909
+
910
+ if image_path is not None and image_path.exists():
911
+ # Load and resize image
912
+ img = cv2.imread(str(image_path))
913
+ if img is not None:
914
+ # Resize maintaining aspect ratio
915
+ h, w = img.shape[:2]
916
+ scale = min(self.tile_size / w, self.tile_size / h)
917
+ new_w = int(w * scale)
918
+ new_h = int(h * scale)
919
+ resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
920
+
921
+ # Center in tile
922
+ x_offset = (self.tile_size - new_w) // 2
923
+ y_offset = (self.tile_size - new_h) // 2
924
+ tile[y_offset : y_offset + new_h, x_offset : x_offset + new_w] = resized
925
+
926
+ # Draw label at top: "class_name (N/M)"
927
+ if image_path is not None:
928
+ label = f"{class_name} ({index + 1}/{total})"
929
+ else:
930
+ label = f"{class_name} (-/{total})"
931
+
932
+ font = cv2.FONT_HERSHEY_SIMPLEX
933
+ font_scale = 0.5
934
+ thickness = 1
935
+ (label_w, label_h), baseline = cv2.getTextSize(
936
+ label, font, font_scale, thickness
937
+ )
938
+
939
+ # Semi-transparent background for label
940
+ overlay = tile.copy()
941
+ label_bg_height = label_h + baseline + 10
942
+ cv2.rectangle(overlay, (0, 0), (self.tile_size, label_bg_height), (0, 0, 0), -1)
943
+ cv2.addWeighted(overlay, 0.6, tile, 0.4, 0, tile)
944
+
945
+ cv2.putText(
946
+ tile,
947
+ label,
948
+ (5, label_h + 5),
949
+ font,
950
+ font_scale,
951
+ (255, 255, 255),
952
+ thickness,
953
+ )
954
+
955
+ # Draw thin border
956
+ border_end = self.tile_size - 1
957
+ cv2.rectangle(tile, (0, 0), (border_end, border_end), (80, 80, 80), 1)
958
+
959
+ return tile
960
+
961
+ def _compose_grid(self) -> np.ndarray:
962
+ """Compose all tiles into a single grid image."""
963
+ grid_h = self.rows * self.tile_size
964
+ grid_w = self.cols * self.tile_size
965
+ grid = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
966
+
967
+ for i, class_name in enumerate(self.class_names):
968
+ row = i // self.cols
969
+ col = i % self.cols
970
+
971
+ images = self.images_by_class[class_name]
972
+ total = len(images)
973
+
974
+ # Use global index - show black tile if class doesn't have this image
975
+ if self.current_index < total:
976
+ image_path = images[self.current_index]
977
+ display_index = self.current_index
978
+ else:
979
+ image_path = None
980
+ display_index = self.current_index
981
+
982
+ tile = self._create_tile(class_name, image_path, display_index, total)
983
+
984
+ y_start = row * self.tile_size
985
+ x_start = col * self.tile_size
986
+ y_end = y_start + self.tile_size
987
+ x_end = x_start + self.tile_size
988
+ grid[y_start:y_end, x_start:x_end] = tile
989
+
990
+ return grid
991
+
992
+ def _next_images(self) -> None:
993
+ """Advance to next image index."""
994
+ if self.max_images > 0:
995
+ self.current_index = min(self.current_index + 1, self.max_images - 1)
996
+
997
+ def _prev_images(self) -> None:
998
+ """Go back to previous image index."""
999
+ self.current_index = max(self.current_index - 1, 0)
1000
+
1001
+ def _reset_indices(self) -> None:
1002
+ """Reset to first image."""
1003
+ self.current_index = 0
1004
+
1005
+ def run(self) -> None:
1006
+ """Run the interactive grid viewer."""
1007
+ if not self.class_names:
1008
+ console.print("[yellow]No classes to display.[/yellow]")
1009
+ return
1010
+
1011
+ cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
1012
+
1013
+ while True:
1014
+ # Compose and display grid
1015
+ grid = self._compose_grid()
1016
+ cv2.imshow(self.window_name, grid)
1017
+
1018
+ # Wait for input
1019
+ key = cv2.waitKey(30) & 0xFF
1020
+
1021
+ # Handle keyboard input
1022
+ if key == ord("q") or key == 27: # Q or ESC
1023
+ break
1024
+ elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
1025
+ self._next_images()
1026
+ elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
1027
+ self._prev_images()
1028
+ elif key == ord("r"): # R to reset
1029
+ self._reset_indices()
1030
+
1031
+ cv2.destroyAllWindows()
1032
+
1033
+
1034
+ class _MaskViewer:
1035
+ """Interactive viewer for semantic mask datasets with colored overlay."""
1036
+
1037
+ def __init__(
1038
+ self,
1039
+ image_paths: list[Path],
1040
+ dataset: MaskDataset,
1041
+ class_colors: dict[str, tuple[int, int, int]],
1042
+ window_name: str,
1043
+ opacity: float = 0.5,
1044
+ ):
1045
+ self.image_paths = image_paths
1046
+ self.dataset = dataset
1047
+ self.class_colors = class_colors
1048
+ self.window_name = window_name
1049
+ self.opacity = opacity
1050
+
1051
+ self.current_idx = 0
1052
+ self.zoom = 1.0
1053
+ self.pan_x = 0.0
1054
+ self.pan_y = 0.0
1055
+
1056
+ # Mouse state for panning
1057
+ self.dragging = False
1058
+ self.drag_start_x = 0
1059
+ self.drag_start_y = 0
1060
+ self.pan_start_x = 0.0
1061
+ self.pan_start_y = 0.0
1062
+
1063
+ # Current image cache
1064
+ self.current_img: np.ndarray | None = None
1065
+ self.overlay_img: np.ndarray | None = None
1066
+
1067
+ # Overlay visibility toggle
1068
+ self.show_overlay = True
1069
+
1070
+ # Build class_id to color mapping
1071
+ self._id_to_color: dict[int, tuple[int, int, int]] = {}
1072
+ class_mapping = dataset.get_class_mapping()
1073
+ for class_id, class_name in class_mapping.items():
1074
+ if class_name in class_colors:
1075
+ self._id_to_color[class_id] = class_colors[class_name]
1076
+
1077
+ def _load_current_image(self) -> bool:
1078
+ """Load current image and create mask overlay."""
1079
+ image_path = self.image_paths[self.current_idx]
1080
+
1081
+ img = cv2.imread(str(image_path))
1082
+ if img is None:
1083
+ return False
1084
+
1085
+ mask = self.dataset.load_mask(image_path)
1086
+ if mask is None:
1087
+ console.print(f"[yellow]Warning: No mask for {image_path}[/yellow]")
1088
+ self.current_img = img
1089
+ self.overlay_img = img.copy()
1090
+ return True
1091
+
1092
+ # Validate dimensions
1093
+ if img.shape[:2] != mask.shape[:2]:
1094
+ console.print(
1095
+ f"[red]Error: Dimension mismatch for {image_path.name}: "
1096
+ f"image={img.shape[:2]}, mask={mask.shape[:2]}[/red]"
1097
+ )
1098
+ return False
1099
+
1100
+ self.current_img = img
1101
+ self.overlay_img = self._create_overlay(img, mask)
1102
+ return True
1103
+
1104
+ def _create_overlay(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
1105
+ """Create colored overlay from mask.
1106
+
1107
+ Args:
1108
+ img: Original image (BGR).
1109
+ mask: Grayscale mask with class IDs.
1110
+
1111
+ Returns:
1112
+ Image with colored mask overlay.
1113
+ """
1114
+ # Create colored mask
1115
+ h, w = mask.shape
1116
+ colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
1117
+
1118
+ for class_id, color in self._id_to_color.items():
1119
+ colored_mask[mask == class_id] = color
1120
+
1121
+ # Blend with original image
1122
+ # Ignore pixels are fully transparent (not blended)
1123
+ ignore_mask = mask == self.dataset.ignore_index
1124
+ alpha = np.ones((h, w, 1), dtype=np.float32) * self.opacity
1125
+ alpha[ignore_mask] = 0.0
1126
+
1127
+ # Blend: result = img * (1 - alpha) + colored_mask * alpha
1128
+ blended = (
1129
+ img.astype(np.float32) * (1 - alpha)
1130
+ + colored_mask.astype(np.float32) * alpha
1131
+ )
1132
+ return blended.astype(np.uint8)
1133
+
1134
+ def _get_display_image(self) -> np.ndarray:
1135
+ """Get the image transformed for current zoom/pan."""
1136
+ if self.overlay_img is None:
1137
+ return np.zeros((480, 640, 3), dtype=np.uint8)
1138
+
1139
+ if self.show_overlay:
1140
+ img = self.overlay_img
1141
+ elif self.current_img is not None:
1142
+ img = self.current_img
1143
+ else:
1144
+ img = self.overlay_img
1145
+
1146
+ h, w = img.shape[:2]
1147
+
1148
+ if self.zoom == 1.0 and self.pan_x == 0.0 and self.pan_y == 0.0:
1149
+ display = img.copy()
1150
+ else:
1151
+ # Calculate the visible region
1152
+ view_w = int(w / self.zoom)
1153
+ view_h = int(h / self.zoom)
1154
+
1155
+ # Center point with pan offset
1156
+ cx = w / 2 + self.pan_x
1157
+ cy = h / 2 + self.pan_y
1158
+
1159
+ # Calculate crop bounds
1160
+ x1 = int(max(0, cx - view_w / 2))
1161
+ y1 = int(max(0, cy - view_h / 2))
1162
+ x2 = int(min(w, x1 + view_w))
1163
+ y2 = int(min(h, y1 + view_h))
1164
+
1165
+ # Adjust if we hit boundaries
1166
+ if x2 - x1 < view_w:
1167
+ x1 = max(0, x2 - view_w)
1168
+ if y2 - y1 < view_h:
1169
+ y1 = max(0, y2 - view_h)
1170
+
1171
+ # Crop and resize
1172
+ cropped = img[y1:y2, x1:x2]
1173
+ display = cv2.resize(cropped, (w, h), interpolation=cv2.INTER_LINEAR)
1174
+
1175
+ # Add info overlay
1176
+ image_path = self.image_paths[self.current_idx]
1177
+ idx = self.current_idx + 1
1178
+ total = len(self.image_paths)
1179
+ info_text = f"[{idx}/{total}] {image_path.name}"
1180
+ if self.zoom > 1.0:
1181
+ info_text += f" (Zoom: {self.zoom:.1f}x)"
1182
+ if not self.show_overlay:
1183
+ info_text += " [Overlay: OFF]"
1184
+
1185
+ cv2.putText(
1186
+ display,
1187
+ info_text,
1188
+ (10, 30),
1189
+ cv2.FONT_HERSHEY_SIMPLEX,
1190
+ 0.7,
1191
+ (255, 255, 255),
1192
+ 2,
1193
+ )
1194
+ cv2.putText(
1195
+ display, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
1196
+ )
1197
+
1198
+ return display
1199
+
1200
+ def _mouse_callback(
1201
+ self, event: int, x: int, y: int, flags: int, param: None
1202
+ ) -> None:
1203
+ """Handle mouse events for zoom and pan."""
1204
+ if event == cv2.EVENT_MOUSEWHEEL:
1205
+ # Zoom in/out
1206
+ if flags > 0:
1207
+ self.zoom = min(10.0, self.zoom * 1.2)
1208
+ else:
1209
+ self.zoom = max(1.0, self.zoom / 1.2)
1210
+
1211
+ # Reset pan if zoomed out to 1x
1212
+ if self.zoom == 1.0:
1213
+ self.pan_x = 0.0
1214
+ self.pan_y = 0.0
1215
+
1216
+ elif event == cv2.EVENT_LBUTTONDOWN:
1217
+ self.dragging = True
1218
+ self.drag_start_x = x
1219
+ self.drag_start_y = y
1220
+ self.pan_start_x = self.pan_x
1221
+ self.pan_start_y = self.pan_y
1222
+
1223
+ elif event == cv2.EVENT_MOUSEMOVE and self.dragging:
1224
+ if self.zoom > 1.0 and self.overlay_img is not None:
1225
+ h, w = self.overlay_img.shape[:2]
1226
+ # Calculate pan delta (inverted for natural feel)
1227
+ dx = (self.drag_start_x - x) / self.zoom
1228
+ dy = (self.drag_start_y - y) / self.zoom
1229
+
1230
+ # Update pan with limits
1231
+ max_pan_x = w * (1 - 1 / self.zoom) / 2
1232
+ max_pan_y = h * (1 - 1 / self.zoom) / 2
1233
+
1234
+ self.pan_x = max(-max_pan_x, min(max_pan_x, self.pan_start_x + dx))
1235
+ self.pan_y = max(-max_pan_y, min(max_pan_y, self.pan_start_y + dy))
1236
+
1237
+ elif event == cv2.EVENT_LBUTTONUP:
1238
+ self.dragging = False
1239
+
1240
+ def _reset_view(self) -> None:
1241
+ """Reset zoom and pan to default."""
1242
+ self.zoom = 1.0
1243
+ self.pan_x = 0.0
1244
+ self.pan_y = 0.0
1245
+
1246
+ def _next_image(self) -> None:
1247
+ """Go to next image."""
1248
+ self.current_idx = (self.current_idx + 1) % len(self.image_paths)
1249
+ self._reset_view()
1250
+
1251
+ def _prev_image(self) -> None:
1252
+ """Go to previous image."""
1253
+ self.current_idx = (self.current_idx - 1) % len(self.image_paths)
1254
+ self._reset_view()
1255
+
1256
+ def run(self) -> None:
1257
+ """Run the interactive viewer."""
1258
+ cv2.namedWindow(self.window_name, cv2.WINDOW_AUTOSIZE)
1259
+ cv2.setMouseCallback(self.window_name, self._mouse_callback)
1260
+
1261
+ while True:
1262
+ # Load image if needed
1263
+ if self.overlay_img is None and not self._load_current_image():
1264
+ console.print(
1265
+ f"[yellow]Warning: Could not load "
1266
+ f"{self.image_paths[self.current_idx]}[/yellow]"
1267
+ )
1268
+ self._next_image()
1269
+ continue
1270
+
1271
+ # Display image
1272
+ display = self._get_display_image()
1273
+ cv2.imshow(self.window_name, display)
1274
+
1275
+ # Wait for input (short timeout for smooth panning)
1276
+ key = cv2.waitKey(30) & 0xFF
1277
+
1278
+ # Handle keyboard input
1279
+ if key == ord("q") or key == 27: # Q or ESC
1280
+ break
1281
+ elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
1282
+ self.overlay_img = None
1283
+ self._next_image()
1284
+ elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
1285
+ self.overlay_img = None
1286
+ self._prev_image()
1287
+ elif key == ord("r"): # R to reset zoom
1288
+ self._reset_view()
1289
+ elif key == ord("t"): # T to toggle overlay
1290
+ self.show_overlay = not self.show_overlay
1291
+
1292
+ cv2.destroyAllWindows()
1293
+
1294
+
659
1295
  def _generate_class_colors(class_names: list[str]) -> dict[str, tuple[int, int, int]]:
660
1296
  """Generate consistent colors for each class name.
661
1297
 
@@ -738,9 +1374,13 @@ def _draw_annotations(
738
1374
  )
739
1375
  # Draw label text
740
1376
  cv2.putText(
741
- img, label,
1377
+ img,
1378
+ label,
742
1379
  (x1 + 2, y1 - baseline - 2),
743
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
1380
+ cv2.FONT_HERSHEY_SIMPLEX,
1381
+ 0.5,
1382
+ (255, 255, 255),
1383
+ 1,
744
1384
  )
745
1385
 
746
1386
  return img
@@ -799,17 +1439,25 @@ def _discover_datasets(root_path: Path, max_depth: int) -> list[Dataset]:
799
1439
 
800
1440
 
801
1441
  def _detect_dataset(path: Path) -> Dataset | None:
802
- """Try to detect a dataset at the given path."""
803
- # Try YOLO first (more specific patterns)
1442
+ """Try to detect a dataset at the given path.
1443
+
1444
+ Detection priority: YOLO -> COCO -> MaskDataset
1445
+ """
1446
+ # Try YOLO first (more specific patterns - requires data.yaml)
804
1447
  dataset = YOLODataset.detect(path)
805
1448
  if dataset:
806
1449
  return dataset
807
1450
 
808
- # Try COCO
1451
+ # Try COCO (requires annotations/*.json)
809
1452
  dataset = COCODataset.detect(path)
810
1453
  if dataset:
811
1454
  return dataset
812
1455
 
1456
+ # Try MaskDataset (directory structure based)
1457
+ dataset = MaskDataset.detect(path)
1458
+ if dataset:
1459
+ return dataset
1460
+
813
1461
  return None
814
1462
 
815
1463