argus-cv 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of argus-cv might be problematic. Click here for more details.
- argus/__init__.py +1 -1
- argus/cli.py +698 -50
- argus/core/__init__.py +3 -0
- argus/core/base.py +1 -0
- argus/core/coco.py +8 -6
- argus/core/mask.py +648 -0
- argus/core/yolo.py +376 -21
- {argus_cv-1.1.0.dist-info → argus_cv-1.3.0.dist-info}/METADATA +1 -1
- argus_cv-1.3.0.dist-info/RECORD +14 -0
- argus_cv-1.1.0.dist-info/RECORD +0 -13
- {argus_cv-1.1.0.dist-info → argus_cv-1.3.0.dist-info}/WHEEL +0 -0
- {argus_cv-1.1.0.dist-info → argus_cv-1.3.0.dist-info}/entry_points.txt +0 -0
argus/cli.py
CHANGED
|
@@ -11,7 +11,8 @@ from rich.console import Console
|
|
|
11
11
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
12
12
|
from rich.table import Table
|
|
13
13
|
|
|
14
|
-
from argus.core import COCODataset, Dataset, YOLODataset
|
|
14
|
+
from argus.core import COCODataset, Dataset, MaskDataset, YOLODataset
|
|
15
|
+
from argus.core.base import DatasetFormat, TaskType
|
|
15
16
|
from argus.core.split import (
|
|
16
17
|
is_coco_unsplit,
|
|
17
18
|
parse_ratio,
|
|
@@ -127,12 +128,19 @@ def stats(
|
|
|
127
128
|
dataset = _detect_dataset(dataset_path)
|
|
128
129
|
if not dataset:
|
|
129
130
|
console.print(
|
|
130
|
-
f"[red]Error: No
|
|
131
|
+
f"[red]Error: No dataset found at {dataset_path}[/red]\n"
|
|
131
132
|
"[yellow]Ensure the path points to a dataset root containing "
|
|
132
|
-
"data.yaml (YOLO)
|
|
133
|
+
"data.yaml (YOLO), annotations/ folder (COCO), or "
|
|
134
|
+
"images/ + masks/ directories (Mask).[/yellow]"
|
|
133
135
|
)
|
|
134
136
|
raise typer.Exit(1)
|
|
135
137
|
|
|
138
|
+
# Handle mask datasets with pixel statistics
|
|
139
|
+
if dataset.format == DatasetFormat.MASK:
|
|
140
|
+
assert isinstance(dataset, MaskDataset)
|
|
141
|
+
_show_mask_stats(dataset, dataset_path)
|
|
142
|
+
return
|
|
143
|
+
|
|
136
144
|
# Get instance counts with progress indicator
|
|
137
145
|
with Progress(
|
|
138
146
|
SpinnerColumn(),
|
|
@@ -211,10 +219,100 @@ def stats(
|
|
|
211
219
|
else:
|
|
212
220
|
image_parts.append(f"{split}: {img_total}")
|
|
213
221
|
|
|
214
|
-
console.print(
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
222
|
+
console.print(
|
|
223
|
+
f"\n[green]Dataset: {dataset.format.value.upper()} | "
|
|
224
|
+
f"Task: {dataset.task.value} | "
|
|
225
|
+
f"Classes: {len(sorted_classes)} | "
|
|
226
|
+
f"Total instances: {grand_total}[/green]"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if image_parts:
|
|
230
|
+
console.print(f"[blue]Images: {' | '.join(image_parts)}[/blue]")
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _show_mask_stats(dataset: MaskDataset, dataset_path: Path) -> None:
|
|
234
|
+
"""Show statistics for mask datasets with pixel-level information.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
dataset: The MaskDataset instance.
|
|
238
|
+
dataset_path: Path to the dataset root.
|
|
239
|
+
"""
|
|
240
|
+
with Progress(
|
|
241
|
+
SpinnerColumn(),
|
|
242
|
+
TextColumn("[progress.description]{task.description}"),
|
|
243
|
+
console=console,
|
|
244
|
+
transient=True,
|
|
245
|
+
) as progress:
|
|
246
|
+
progress.add_task("Analyzing mask dataset...", total=None)
|
|
247
|
+
pixel_counts = dataset.get_pixel_counts()
|
|
248
|
+
image_presence = dataset.get_image_class_presence()
|
|
249
|
+
image_counts = dataset.get_image_counts()
|
|
250
|
+
|
|
251
|
+
# Get class mapping
|
|
252
|
+
class_mapping = dataset.get_class_mapping()
|
|
253
|
+
|
|
254
|
+
# Calculate total non-ignored pixels
|
|
255
|
+
total_pixels = sum(
|
|
256
|
+
count
|
|
257
|
+
for class_id, count in pixel_counts.items()
|
|
258
|
+
if class_id != dataset.ignore_index
|
|
259
|
+
)
|
|
260
|
+
ignored_pixels = pixel_counts.get(dataset.ignore_index, 0)
|
|
261
|
+
|
|
262
|
+
# Calculate total images
|
|
263
|
+
total_images = sum(ic["total"] for ic in image_counts.values())
|
|
264
|
+
|
|
265
|
+
# Create table
|
|
266
|
+
splits_str = ", ".join(dataset.splits) if dataset.splits else "unsplit"
|
|
267
|
+
title = f"Class Statistics: {dataset_path.name} ({splits_str})"
|
|
268
|
+
table = Table(title=title)
|
|
269
|
+
table.add_column("Class", style="cyan")
|
|
270
|
+
table.add_column("Total Pixels", justify="right", style="green")
|
|
271
|
+
table.add_column("% Coverage", justify="right", style="magenta")
|
|
272
|
+
table.add_column("Images With", justify="right", style="yellow")
|
|
273
|
+
|
|
274
|
+
# Sort classes by class_id
|
|
275
|
+
sorted_class_ids = sorted(class_mapping.keys())
|
|
276
|
+
|
|
277
|
+
for class_id in sorted_class_ids:
|
|
278
|
+
class_name = class_mapping[class_id]
|
|
279
|
+
pixels = pixel_counts.get(class_id, 0)
|
|
280
|
+
presence = image_presence.get(class_id, 0)
|
|
281
|
+
|
|
282
|
+
# Calculate coverage percentage
|
|
283
|
+
coverage = (pixels / total_pixels * 100) if total_pixels > 0 else 0.0
|
|
284
|
+
|
|
285
|
+
table.add_row(
|
|
286
|
+
class_name,
|
|
287
|
+
f"{pixels:,}",
|
|
288
|
+
f"{coverage:.1f}%",
|
|
289
|
+
str(presence),
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Add ignored row if there are ignored pixels
|
|
293
|
+
if ignored_pixels > 0:
|
|
294
|
+
table.add_section()
|
|
295
|
+
table.add_row(
|
|
296
|
+
"[dim](ignored)[/dim]",
|
|
297
|
+
f"[dim]{ignored_pixels:,}[/dim]",
|
|
298
|
+
"[dim]-[/dim]",
|
|
299
|
+
f"[dim]{total_images}[/dim]",
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
console.print(table)
|
|
303
|
+
|
|
304
|
+
# Summary line
|
|
305
|
+
console.print(f"\n[green]Dataset: {dataset_path}[/green]")
|
|
306
|
+
console.print(
|
|
307
|
+
f"[green]Format: {dataset.format.value.upper()} | "
|
|
308
|
+
f"Task: {dataset.task.value}[/green]"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Image counts per split
|
|
312
|
+
image_parts = []
|
|
313
|
+
for split in dataset.splits if dataset.splits else ["unsplit"]:
|
|
314
|
+
if split in image_counts:
|
|
315
|
+
image_parts.append(f"{split}: {image_counts[split]['total']}")
|
|
218
316
|
|
|
219
317
|
if image_parts:
|
|
220
318
|
console.print(f"[blue]Images: {' | '.join(image_parts)}[/blue]")
|
|
@@ -238,19 +336,39 @@ def view(
|
|
|
238
336
|
help="Specific split to view (train, val, test).",
|
|
239
337
|
),
|
|
240
338
|
] = None,
|
|
339
|
+
max_classes: Annotated[
|
|
340
|
+
int | None,
|
|
341
|
+
typer.Option(
|
|
342
|
+
"--max-classes",
|
|
343
|
+
"-m",
|
|
344
|
+
help="Maximum classes to show in grid (classification only).",
|
|
345
|
+
),
|
|
346
|
+
] = None,
|
|
347
|
+
opacity: Annotated[
|
|
348
|
+
float,
|
|
349
|
+
typer.Option(
|
|
350
|
+
"--opacity",
|
|
351
|
+
"-o",
|
|
352
|
+
help="Mask overlay opacity (0.0-1.0, mask datasets only).",
|
|
353
|
+
min=0.0,
|
|
354
|
+
max=1.0,
|
|
355
|
+
),
|
|
356
|
+
] = 0.5,
|
|
241
357
|
) -> None:
|
|
242
358
|
"""View annotated images in a dataset.
|
|
243
359
|
|
|
244
360
|
Opens an interactive viewer to browse images with their annotations
|
|
245
361
|
(bounding boxes and segmentation masks) overlaid.
|
|
246
362
|
|
|
363
|
+
For classification datasets, shows a grid view with one image per class.
|
|
364
|
+
|
|
247
365
|
Controls:
|
|
248
|
-
- Right Arrow / N: Next image
|
|
249
|
-
- Left Arrow / P: Previous image
|
|
250
|
-
- Mouse Wheel: Zoom in/out
|
|
251
|
-
- Mouse Drag: Pan when zoomed
|
|
252
|
-
- R: Reset zoom
|
|
253
|
-
- T: Toggle annotations
|
|
366
|
+
- Right Arrow / N: Next image(s)
|
|
367
|
+
- Left Arrow / P: Previous image(s)
|
|
368
|
+
- Mouse Wheel: Zoom in/out (detection/segmentation only)
|
|
369
|
+
- Mouse Drag: Pan when zoomed (detection/segmentation only)
|
|
370
|
+
- R: Reset zoom / Reset to first images
|
|
371
|
+
- T: Toggle annotations (detection/segmentation only)
|
|
254
372
|
- Q / ESC: Quit viewer
|
|
255
373
|
"""
|
|
256
374
|
# Resolve path and validate
|
|
@@ -281,39 +399,111 @@ def view(
|
|
|
281
399
|
)
|
|
282
400
|
raise typer.Exit(1)
|
|
283
401
|
|
|
284
|
-
#
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
402
|
+
# Generate consistent colors for each class
|
|
403
|
+
class_colors = _generate_class_colors(dataset.class_names)
|
|
404
|
+
|
|
405
|
+
# Handle mask datasets with overlay viewer
|
|
406
|
+
if dataset.format == DatasetFormat.MASK:
|
|
407
|
+
assert isinstance(dataset, MaskDataset)
|
|
408
|
+
with Progress(
|
|
409
|
+
SpinnerColumn(),
|
|
410
|
+
TextColumn("[progress.description]{task.description}"),
|
|
411
|
+
console=console,
|
|
412
|
+
transient=True,
|
|
413
|
+
) as progress:
|
|
414
|
+
progress.add_task("Loading images...", total=None)
|
|
415
|
+
image_paths = dataset.get_image_paths(split)
|
|
416
|
+
|
|
417
|
+
if not image_paths:
|
|
418
|
+
console.print("[yellow]No images found in the dataset.[/yellow]")
|
|
419
|
+
return
|
|
293
420
|
|
|
294
|
-
|
|
295
|
-
|
|
421
|
+
console.print(
|
|
422
|
+
f"[green]Found {len(image_paths)} images. "
|
|
423
|
+
f"Opening mask viewer...[/green]\n"
|
|
424
|
+
"[dim]Controls: \u2190 / \u2192 or P / N to navigate, "
|
|
425
|
+
"Mouse wheel to zoom, Drag to pan, R to reset, T to toggle overlay, "
|
|
426
|
+
"Q / ESC to quit[/dim]"
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
viewer = _MaskViewer(
|
|
430
|
+
image_paths=image_paths,
|
|
431
|
+
dataset=dataset,
|
|
432
|
+
class_colors=class_colors,
|
|
433
|
+
window_name=f"Argus Mask Viewer - {dataset_path.name}",
|
|
434
|
+
opacity=opacity,
|
|
435
|
+
)
|
|
436
|
+
viewer.run()
|
|
437
|
+
console.print("[green]Viewer closed.[/green]")
|
|
296
438
|
return
|
|
297
439
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
"Mouse wheel to zoom, Drag to pan, R to reset, T to toggle annotations, "
|
|
303
|
-
"Q / ESC to quit[/dim]"
|
|
304
|
-
)
|
|
440
|
+
# Handle classification datasets with grid viewer
|
|
441
|
+
if dataset.task == TaskType.CLASSIFICATION:
|
|
442
|
+
# Use first split if specified, otherwise let get_images_by_class handle it
|
|
443
|
+
view_split = split if split else (dataset.splits[0] if dataset.splits else None)
|
|
305
444
|
|
|
306
|
-
|
|
307
|
-
|
|
445
|
+
with Progress(
|
|
446
|
+
SpinnerColumn(),
|
|
447
|
+
TextColumn("[progress.description]{task.description}"),
|
|
448
|
+
console=console,
|
|
449
|
+
transient=True,
|
|
450
|
+
) as progress:
|
|
451
|
+
progress.add_task("Loading images by class...", total=None)
|
|
452
|
+
images_by_class = dataset.get_images_by_class(view_split)
|
|
308
453
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
454
|
+
total_images = sum(len(imgs) for imgs in images_by_class.values())
|
|
455
|
+
if total_images == 0:
|
|
456
|
+
console.print("[yellow]No images found in the dataset.[/yellow]")
|
|
457
|
+
return
|
|
458
|
+
|
|
459
|
+
num_classes = len(dataset.class_names)
|
|
460
|
+
display_classes = min(num_classes, max_classes) if max_classes else num_classes
|
|
461
|
+
|
|
462
|
+
console.print(
|
|
463
|
+
f"[green]Found {total_images} images across {num_classes} classes "
|
|
464
|
+
f"(showing {display_classes}). Opening grid viewer...[/green]\n"
|
|
465
|
+
"[dim]Controls: ← / → or P / N to navigate all classes, "
|
|
466
|
+
"R to reset, Q / ESC to quit[/dim]"
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
viewer = _ClassificationGridViewer(
|
|
470
|
+
images_by_class=images_by_class,
|
|
471
|
+
class_names=dataset.class_names,
|
|
472
|
+
class_colors=class_colors,
|
|
473
|
+
window_name=f"Argus Classification Viewer - {dataset_path.name}",
|
|
474
|
+
max_classes=max_classes,
|
|
475
|
+
)
|
|
476
|
+
viewer.run()
|
|
477
|
+
else:
|
|
478
|
+
# Detection/Segmentation viewer
|
|
479
|
+
with Progress(
|
|
480
|
+
SpinnerColumn(),
|
|
481
|
+
TextColumn("[progress.description]{task.description}"),
|
|
482
|
+
console=console,
|
|
483
|
+
transient=True,
|
|
484
|
+
) as progress:
|
|
485
|
+
progress.add_task("Loading images...", total=None)
|
|
486
|
+
image_paths = dataset.get_image_paths(split)
|
|
487
|
+
|
|
488
|
+
if not image_paths:
|
|
489
|
+
console.print("[yellow]No images found in the dataset.[/yellow]")
|
|
490
|
+
return
|
|
491
|
+
|
|
492
|
+
console.print(
|
|
493
|
+
f"[green]Found {len(image_paths)} images. "
|
|
494
|
+
f"Opening viewer...[/green]\n"
|
|
495
|
+
"[dim]Controls: ← / → or P / N to navigate, "
|
|
496
|
+
"Mouse wheel to zoom, Drag to pan, R to reset, T to toggle annotations, "
|
|
497
|
+
"Q / ESC to quit[/dim]"
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
viewer = _ImageViewer(
|
|
501
|
+
image_paths=image_paths,
|
|
502
|
+
dataset=dataset,
|
|
503
|
+
class_colors=class_colors,
|
|
504
|
+
window_name=f"Argus Viewer - {dataset_path.name}",
|
|
505
|
+
)
|
|
506
|
+
viewer.run()
|
|
317
507
|
|
|
318
508
|
console.print("[green]Viewer closed.[/green]")
|
|
319
509
|
|
|
@@ -551,12 +741,16 @@ class _ImageViewer:
|
|
|
551
741
|
info_text += " [Annotations: OFF]"
|
|
552
742
|
|
|
553
743
|
cv2.putText(
|
|
554
|
-
display,
|
|
555
|
-
|
|
744
|
+
display,
|
|
745
|
+
info_text,
|
|
746
|
+
(10, 30),
|
|
747
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
748
|
+
0.7,
|
|
749
|
+
(255, 255, 255),
|
|
750
|
+
2,
|
|
556
751
|
)
|
|
557
752
|
cv2.putText(
|
|
558
|
-
display, info_text, (10, 30),
|
|
559
|
-
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
|
|
753
|
+
display, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
|
|
560
754
|
)
|
|
561
755
|
|
|
562
756
|
return display
|
|
@@ -656,6 +850,448 @@ class _ImageViewer:
|
|
|
656
850
|
cv2.destroyAllWindows()
|
|
657
851
|
|
|
658
852
|
|
|
853
|
+
class _ClassificationGridViewer:
|
|
854
|
+
"""Grid viewer for classification datasets showing one image per class."""
|
|
855
|
+
|
|
856
|
+
def __init__(
|
|
857
|
+
self,
|
|
858
|
+
images_by_class: dict[str, list[Path]],
|
|
859
|
+
class_names: list[str],
|
|
860
|
+
class_colors: dict[str, tuple[int, int, int]],
|
|
861
|
+
window_name: str,
|
|
862
|
+
max_classes: int | None = None,
|
|
863
|
+
tile_size: int = 300,
|
|
864
|
+
):
|
|
865
|
+
# Limit classes if max_classes specified
|
|
866
|
+
if max_classes and len(class_names) > max_classes:
|
|
867
|
+
self.class_names = class_names[:max_classes]
|
|
868
|
+
else:
|
|
869
|
+
self.class_names = class_names
|
|
870
|
+
|
|
871
|
+
self.images_by_class = {
|
|
872
|
+
cls: images_by_class.get(cls, []) for cls in self.class_names
|
|
873
|
+
}
|
|
874
|
+
self.class_colors = class_colors
|
|
875
|
+
self.window_name = window_name
|
|
876
|
+
self.tile_size = tile_size
|
|
877
|
+
|
|
878
|
+
# Global image index (same for all classes)
|
|
879
|
+
self.current_index = 0
|
|
880
|
+
|
|
881
|
+
# Calculate max images across all classes
|
|
882
|
+
self.max_images = (
|
|
883
|
+
max(len(imgs) for imgs in self.images_by_class.values())
|
|
884
|
+
if self.images_by_class
|
|
885
|
+
else 0
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
# Calculate grid layout
|
|
889
|
+
self.cols, self.rows = self._calculate_grid_layout()
|
|
890
|
+
|
|
891
|
+
def _calculate_grid_layout(self) -> tuple[int, int]:
|
|
892
|
+
"""Calculate optimal grid layout based on number of classes."""
|
|
893
|
+
n = len(self.class_names)
|
|
894
|
+
if n <= 0:
|
|
895
|
+
return 1, 1
|
|
896
|
+
|
|
897
|
+
# Try to make a roughly square grid
|
|
898
|
+
import math
|
|
899
|
+
|
|
900
|
+
cols = int(math.ceil(math.sqrt(n)))
|
|
901
|
+
rows = int(math.ceil(n / cols))
|
|
902
|
+
return cols, rows
|
|
903
|
+
|
|
904
|
+
def _create_tile(
|
|
905
|
+
self, class_name: str, image_path: Path | None, index: int, total: int
|
|
906
|
+
) -> np.ndarray:
|
|
907
|
+
"""Create a single tile for a class."""
|
|
908
|
+
tile = np.zeros((self.tile_size, self.tile_size, 3), dtype=np.uint8)
|
|
909
|
+
|
|
910
|
+
if image_path is not None and image_path.exists():
|
|
911
|
+
# Load and resize image
|
|
912
|
+
img = cv2.imread(str(image_path))
|
|
913
|
+
if img is not None:
|
|
914
|
+
# Resize maintaining aspect ratio
|
|
915
|
+
h, w = img.shape[:2]
|
|
916
|
+
scale = min(self.tile_size / w, self.tile_size / h)
|
|
917
|
+
new_w = int(w * scale)
|
|
918
|
+
new_h = int(h * scale)
|
|
919
|
+
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
|
920
|
+
|
|
921
|
+
# Center in tile
|
|
922
|
+
x_offset = (self.tile_size - new_w) // 2
|
|
923
|
+
y_offset = (self.tile_size - new_h) // 2
|
|
924
|
+
tile[y_offset : y_offset + new_h, x_offset : x_offset + new_w] = resized
|
|
925
|
+
|
|
926
|
+
# Draw label at top: "class_name (N/M)"
|
|
927
|
+
if image_path is not None:
|
|
928
|
+
label = f"{class_name} ({index + 1}/{total})"
|
|
929
|
+
else:
|
|
930
|
+
label = f"{class_name} (-/{total})"
|
|
931
|
+
|
|
932
|
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
933
|
+
font_scale = 0.5
|
|
934
|
+
thickness = 1
|
|
935
|
+
(label_w, label_h), baseline = cv2.getTextSize(
|
|
936
|
+
label, font, font_scale, thickness
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# Semi-transparent background for label
|
|
940
|
+
overlay = tile.copy()
|
|
941
|
+
label_bg_height = label_h + baseline + 10
|
|
942
|
+
cv2.rectangle(overlay, (0, 0), (self.tile_size, label_bg_height), (0, 0, 0), -1)
|
|
943
|
+
cv2.addWeighted(overlay, 0.6, tile, 0.4, 0, tile)
|
|
944
|
+
|
|
945
|
+
cv2.putText(
|
|
946
|
+
tile,
|
|
947
|
+
label,
|
|
948
|
+
(5, label_h + 5),
|
|
949
|
+
font,
|
|
950
|
+
font_scale,
|
|
951
|
+
(255, 255, 255),
|
|
952
|
+
thickness,
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
# Draw thin border
|
|
956
|
+
border_end = self.tile_size - 1
|
|
957
|
+
cv2.rectangle(tile, (0, 0), (border_end, border_end), (80, 80, 80), 1)
|
|
958
|
+
|
|
959
|
+
return tile
|
|
960
|
+
|
|
961
|
+
def _compose_grid(self) -> np.ndarray:
|
|
962
|
+
"""Compose all tiles into a single grid image."""
|
|
963
|
+
grid_h = self.rows * self.tile_size
|
|
964
|
+
grid_w = self.cols * self.tile_size
|
|
965
|
+
grid = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
|
|
966
|
+
|
|
967
|
+
for i, class_name in enumerate(self.class_names):
|
|
968
|
+
row = i // self.cols
|
|
969
|
+
col = i % self.cols
|
|
970
|
+
|
|
971
|
+
images = self.images_by_class[class_name]
|
|
972
|
+
total = len(images)
|
|
973
|
+
|
|
974
|
+
# Use global index - show black tile if class doesn't have this image
|
|
975
|
+
if self.current_index < total:
|
|
976
|
+
image_path = images[self.current_index]
|
|
977
|
+
display_index = self.current_index
|
|
978
|
+
else:
|
|
979
|
+
image_path = None
|
|
980
|
+
display_index = self.current_index
|
|
981
|
+
|
|
982
|
+
tile = self._create_tile(class_name, image_path, display_index, total)
|
|
983
|
+
|
|
984
|
+
y_start = row * self.tile_size
|
|
985
|
+
x_start = col * self.tile_size
|
|
986
|
+
y_end = y_start + self.tile_size
|
|
987
|
+
x_end = x_start + self.tile_size
|
|
988
|
+
grid[y_start:y_end, x_start:x_end] = tile
|
|
989
|
+
|
|
990
|
+
return grid
|
|
991
|
+
|
|
992
|
+
def _next_images(self) -> None:
|
|
993
|
+
"""Advance to next image index."""
|
|
994
|
+
if self.max_images > 0:
|
|
995
|
+
self.current_index = min(self.current_index + 1, self.max_images - 1)
|
|
996
|
+
|
|
997
|
+
def _prev_images(self) -> None:
|
|
998
|
+
"""Go back to previous image index."""
|
|
999
|
+
self.current_index = max(self.current_index - 1, 0)
|
|
1000
|
+
|
|
1001
|
+
def _reset_indices(self) -> None:
|
|
1002
|
+
"""Reset to first image."""
|
|
1003
|
+
self.current_index = 0
|
|
1004
|
+
|
|
1005
|
+
def run(self) -> None:
|
|
1006
|
+
"""Run the interactive grid viewer."""
|
|
1007
|
+
if not self.class_names:
|
|
1008
|
+
console.print("[yellow]No classes to display.[/yellow]")
|
|
1009
|
+
return
|
|
1010
|
+
|
|
1011
|
+
cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
|
|
1012
|
+
|
|
1013
|
+
while True:
|
|
1014
|
+
# Compose and display grid
|
|
1015
|
+
grid = self._compose_grid()
|
|
1016
|
+
cv2.imshow(self.window_name, grid)
|
|
1017
|
+
|
|
1018
|
+
# Wait for input
|
|
1019
|
+
key = cv2.waitKey(30) & 0xFF
|
|
1020
|
+
|
|
1021
|
+
# Handle keyboard input
|
|
1022
|
+
if key == ord("q") or key == 27: # Q or ESC
|
|
1023
|
+
break
|
|
1024
|
+
elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
|
|
1025
|
+
self._next_images()
|
|
1026
|
+
elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
|
|
1027
|
+
self._prev_images()
|
|
1028
|
+
elif key == ord("r"): # R to reset
|
|
1029
|
+
self._reset_indices()
|
|
1030
|
+
|
|
1031
|
+
cv2.destroyAllWindows()
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
class _MaskViewer:
|
|
1035
|
+
"""Interactive viewer for semantic mask datasets with colored overlay."""
|
|
1036
|
+
|
|
1037
|
+
def __init__(
|
|
1038
|
+
self,
|
|
1039
|
+
image_paths: list[Path],
|
|
1040
|
+
dataset: MaskDataset,
|
|
1041
|
+
class_colors: dict[str, tuple[int, int, int]],
|
|
1042
|
+
window_name: str,
|
|
1043
|
+
opacity: float = 0.5,
|
|
1044
|
+
):
|
|
1045
|
+
self.image_paths = image_paths
|
|
1046
|
+
self.dataset = dataset
|
|
1047
|
+
self.class_colors = class_colors
|
|
1048
|
+
self.window_name = window_name
|
|
1049
|
+
self.opacity = opacity
|
|
1050
|
+
|
|
1051
|
+
self.current_idx = 0
|
|
1052
|
+
self.zoom = 1.0
|
|
1053
|
+
self.pan_x = 0.0
|
|
1054
|
+
self.pan_y = 0.0
|
|
1055
|
+
|
|
1056
|
+
# Mouse state for panning
|
|
1057
|
+
self.dragging = False
|
|
1058
|
+
self.drag_start_x = 0
|
|
1059
|
+
self.drag_start_y = 0
|
|
1060
|
+
self.pan_start_x = 0.0
|
|
1061
|
+
self.pan_start_y = 0.0
|
|
1062
|
+
|
|
1063
|
+
# Current image cache
|
|
1064
|
+
self.current_img: np.ndarray | None = None
|
|
1065
|
+
self.overlay_img: np.ndarray | None = None
|
|
1066
|
+
|
|
1067
|
+
# Overlay visibility toggle
|
|
1068
|
+
self.show_overlay = True
|
|
1069
|
+
|
|
1070
|
+
# Build class_id to color mapping
|
|
1071
|
+
self._id_to_color: dict[int, tuple[int, int, int]] = {}
|
|
1072
|
+
class_mapping = dataset.get_class_mapping()
|
|
1073
|
+
for class_id, class_name in class_mapping.items():
|
|
1074
|
+
if class_name in class_colors:
|
|
1075
|
+
self._id_to_color[class_id] = class_colors[class_name]
|
|
1076
|
+
|
|
1077
|
+
def _load_current_image(self) -> bool:
|
|
1078
|
+
"""Load current image and create mask overlay."""
|
|
1079
|
+
image_path = self.image_paths[self.current_idx]
|
|
1080
|
+
|
|
1081
|
+
img = cv2.imread(str(image_path))
|
|
1082
|
+
if img is None:
|
|
1083
|
+
return False
|
|
1084
|
+
|
|
1085
|
+
mask = self.dataset.load_mask(image_path)
|
|
1086
|
+
if mask is None:
|
|
1087
|
+
console.print(f"[yellow]Warning: No mask for {image_path}[/yellow]")
|
|
1088
|
+
self.current_img = img
|
|
1089
|
+
self.overlay_img = img.copy()
|
|
1090
|
+
return True
|
|
1091
|
+
|
|
1092
|
+
# Validate dimensions
|
|
1093
|
+
if img.shape[:2] != mask.shape[:2]:
|
|
1094
|
+
console.print(
|
|
1095
|
+
f"[red]Error: Dimension mismatch for {image_path.name}: "
|
|
1096
|
+
f"image={img.shape[:2]}, mask={mask.shape[:2]}[/red]"
|
|
1097
|
+
)
|
|
1098
|
+
return False
|
|
1099
|
+
|
|
1100
|
+
self.current_img = img
|
|
1101
|
+
self.overlay_img = self._create_overlay(img, mask)
|
|
1102
|
+
return True
|
|
1103
|
+
|
|
1104
|
+
def _create_overlay(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
|
1105
|
+
"""Create colored overlay from mask.
|
|
1106
|
+
|
|
1107
|
+
Args:
|
|
1108
|
+
img: Original image (BGR).
|
|
1109
|
+
mask: Grayscale mask with class IDs.
|
|
1110
|
+
|
|
1111
|
+
Returns:
|
|
1112
|
+
Image with colored mask overlay.
|
|
1113
|
+
"""
|
|
1114
|
+
# Create colored mask
|
|
1115
|
+
h, w = mask.shape
|
|
1116
|
+
colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
|
|
1117
|
+
|
|
1118
|
+
for class_id, color in self._id_to_color.items():
|
|
1119
|
+
colored_mask[mask == class_id] = color
|
|
1120
|
+
|
|
1121
|
+
# Blend with original image
|
|
1122
|
+
# Ignore pixels are fully transparent (not blended)
|
|
1123
|
+
ignore_mask = mask == self.dataset.ignore_index
|
|
1124
|
+
alpha = np.ones((h, w, 1), dtype=np.float32) * self.opacity
|
|
1125
|
+
alpha[ignore_mask] = 0.0
|
|
1126
|
+
|
|
1127
|
+
# Blend: result = img * (1 - alpha) + colored_mask * alpha
|
|
1128
|
+
blended = (
|
|
1129
|
+
img.astype(np.float32) * (1 - alpha)
|
|
1130
|
+
+ colored_mask.astype(np.float32) * alpha
|
|
1131
|
+
)
|
|
1132
|
+
return blended.astype(np.uint8)
|
|
1133
|
+
|
|
1134
|
+
def _get_display_image(self) -> np.ndarray:
|
|
1135
|
+
"""Get the image transformed for current zoom/pan."""
|
|
1136
|
+
if self.overlay_img is None:
|
|
1137
|
+
return np.zeros((480, 640, 3), dtype=np.uint8)
|
|
1138
|
+
|
|
1139
|
+
if self.show_overlay:
|
|
1140
|
+
img = self.overlay_img
|
|
1141
|
+
elif self.current_img is not None:
|
|
1142
|
+
img = self.current_img
|
|
1143
|
+
else:
|
|
1144
|
+
img = self.overlay_img
|
|
1145
|
+
|
|
1146
|
+
h, w = img.shape[:2]
|
|
1147
|
+
|
|
1148
|
+
if self.zoom == 1.0 and self.pan_x == 0.0 and self.pan_y == 0.0:
|
|
1149
|
+
display = img.copy()
|
|
1150
|
+
else:
|
|
1151
|
+
# Calculate the visible region
|
|
1152
|
+
view_w = int(w / self.zoom)
|
|
1153
|
+
view_h = int(h / self.zoom)
|
|
1154
|
+
|
|
1155
|
+
# Center point with pan offset
|
|
1156
|
+
cx = w / 2 + self.pan_x
|
|
1157
|
+
cy = h / 2 + self.pan_y
|
|
1158
|
+
|
|
1159
|
+
# Calculate crop bounds
|
|
1160
|
+
x1 = int(max(0, cx - view_w / 2))
|
|
1161
|
+
y1 = int(max(0, cy - view_h / 2))
|
|
1162
|
+
x2 = int(min(w, x1 + view_w))
|
|
1163
|
+
y2 = int(min(h, y1 + view_h))
|
|
1164
|
+
|
|
1165
|
+
# Adjust if we hit boundaries
|
|
1166
|
+
if x2 - x1 < view_w:
|
|
1167
|
+
x1 = max(0, x2 - view_w)
|
|
1168
|
+
if y2 - y1 < view_h:
|
|
1169
|
+
y1 = max(0, y2 - view_h)
|
|
1170
|
+
|
|
1171
|
+
# Crop and resize
|
|
1172
|
+
cropped = img[y1:y2, x1:x2]
|
|
1173
|
+
display = cv2.resize(cropped, (w, h), interpolation=cv2.INTER_LINEAR)
|
|
1174
|
+
|
|
1175
|
+
# Add info overlay
|
|
1176
|
+
image_path = self.image_paths[self.current_idx]
|
|
1177
|
+
idx = self.current_idx + 1
|
|
1178
|
+
total = len(self.image_paths)
|
|
1179
|
+
info_text = f"[{idx}/{total}] {image_path.name}"
|
|
1180
|
+
if self.zoom > 1.0:
|
|
1181
|
+
info_text += f" (Zoom: {self.zoom:.1f}x)"
|
|
1182
|
+
if not self.show_overlay:
|
|
1183
|
+
info_text += " [Overlay: OFF]"
|
|
1184
|
+
|
|
1185
|
+
cv2.putText(
|
|
1186
|
+
display,
|
|
1187
|
+
info_text,
|
|
1188
|
+
(10, 30),
|
|
1189
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1190
|
+
0.7,
|
|
1191
|
+
(255, 255, 255),
|
|
1192
|
+
2,
|
|
1193
|
+
)
|
|
1194
|
+
cv2.putText(
|
|
1195
|
+
display, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
|
|
1196
|
+
)
|
|
1197
|
+
|
|
1198
|
+
return display
|
|
1199
|
+
|
|
1200
|
+
def _mouse_callback(
|
|
1201
|
+
self, event: int, x: int, y: int, flags: int, param: None
|
|
1202
|
+
) -> None:
|
|
1203
|
+
"""Handle mouse events for zoom and pan."""
|
|
1204
|
+
if event == cv2.EVENT_MOUSEWHEEL:
|
|
1205
|
+
# Zoom in/out
|
|
1206
|
+
if flags > 0:
|
|
1207
|
+
self.zoom = min(10.0, self.zoom * 1.2)
|
|
1208
|
+
else:
|
|
1209
|
+
self.zoom = max(1.0, self.zoom / 1.2)
|
|
1210
|
+
|
|
1211
|
+
# Reset pan if zoomed out to 1x
|
|
1212
|
+
if self.zoom == 1.0:
|
|
1213
|
+
self.pan_x = 0.0
|
|
1214
|
+
self.pan_y = 0.0
|
|
1215
|
+
|
|
1216
|
+
elif event == cv2.EVENT_LBUTTONDOWN:
|
|
1217
|
+
self.dragging = True
|
|
1218
|
+
self.drag_start_x = x
|
|
1219
|
+
self.drag_start_y = y
|
|
1220
|
+
self.pan_start_x = self.pan_x
|
|
1221
|
+
self.pan_start_y = self.pan_y
|
|
1222
|
+
|
|
1223
|
+
elif event == cv2.EVENT_MOUSEMOVE and self.dragging:
|
|
1224
|
+
if self.zoom > 1.0 and self.overlay_img is not None:
|
|
1225
|
+
h, w = self.overlay_img.shape[:2]
|
|
1226
|
+
# Calculate pan delta (inverted for natural feel)
|
|
1227
|
+
dx = (self.drag_start_x - x) / self.zoom
|
|
1228
|
+
dy = (self.drag_start_y - y) / self.zoom
|
|
1229
|
+
|
|
1230
|
+
# Update pan with limits
|
|
1231
|
+
max_pan_x = w * (1 - 1 / self.zoom) / 2
|
|
1232
|
+
max_pan_y = h * (1 - 1 / self.zoom) / 2
|
|
1233
|
+
|
|
1234
|
+
self.pan_x = max(-max_pan_x, min(max_pan_x, self.pan_start_x + dx))
|
|
1235
|
+
self.pan_y = max(-max_pan_y, min(max_pan_y, self.pan_start_y + dy))
|
|
1236
|
+
|
|
1237
|
+
elif event == cv2.EVENT_LBUTTONUP:
|
|
1238
|
+
self.dragging = False
|
|
1239
|
+
|
|
1240
|
+
def _reset_view(self) -> None:
|
|
1241
|
+
"""Reset zoom and pan to default."""
|
|
1242
|
+
self.zoom = 1.0
|
|
1243
|
+
self.pan_x = 0.0
|
|
1244
|
+
self.pan_y = 0.0
|
|
1245
|
+
|
|
1246
|
+
def _next_image(self) -> None:
|
|
1247
|
+
"""Go to next image."""
|
|
1248
|
+
self.current_idx = (self.current_idx + 1) % len(self.image_paths)
|
|
1249
|
+
self._reset_view()
|
|
1250
|
+
|
|
1251
|
+
def _prev_image(self) -> None:
|
|
1252
|
+
"""Go to previous image."""
|
|
1253
|
+
self.current_idx = (self.current_idx - 1) % len(self.image_paths)
|
|
1254
|
+
self._reset_view()
|
|
1255
|
+
|
|
1256
|
+
def run(self) -> None:
|
|
1257
|
+
"""Run the interactive viewer."""
|
|
1258
|
+
cv2.namedWindow(self.window_name, cv2.WINDOW_AUTOSIZE)
|
|
1259
|
+
cv2.setMouseCallback(self.window_name, self._mouse_callback)
|
|
1260
|
+
|
|
1261
|
+
while True:
|
|
1262
|
+
# Load image if needed
|
|
1263
|
+
if self.overlay_img is None and not self._load_current_image():
|
|
1264
|
+
console.print(
|
|
1265
|
+
f"[yellow]Warning: Could not load "
|
|
1266
|
+
f"{self.image_paths[self.current_idx]}[/yellow]"
|
|
1267
|
+
)
|
|
1268
|
+
self._next_image()
|
|
1269
|
+
continue
|
|
1270
|
+
|
|
1271
|
+
# Display image
|
|
1272
|
+
display = self._get_display_image()
|
|
1273
|
+
cv2.imshow(self.window_name, display)
|
|
1274
|
+
|
|
1275
|
+
# Wait for input (short timeout for smooth panning)
|
|
1276
|
+
key = cv2.waitKey(30) & 0xFF
|
|
1277
|
+
|
|
1278
|
+
# Handle keyboard input
|
|
1279
|
+
if key == ord("q") or key == 27: # Q or ESC
|
|
1280
|
+
break
|
|
1281
|
+
elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
|
|
1282
|
+
self.overlay_img = None
|
|
1283
|
+
self._next_image()
|
|
1284
|
+
elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
|
|
1285
|
+
self.overlay_img = None
|
|
1286
|
+
self._prev_image()
|
|
1287
|
+
elif key == ord("r"): # R to reset zoom
|
|
1288
|
+
self._reset_view()
|
|
1289
|
+
elif key == ord("t"): # T to toggle overlay
|
|
1290
|
+
self.show_overlay = not self.show_overlay
|
|
1291
|
+
|
|
1292
|
+
cv2.destroyAllWindows()
|
|
1293
|
+
|
|
1294
|
+
|
|
659
1295
|
def _generate_class_colors(class_names: list[str]) -> dict[str, tuple[int, int, int]]:
|
|
660
1296
|
"""Generate consistent colors for each class name.
|
|
661
1297
|
|
|
@@ -738,9 +1374,13 @@ def _draw_annotations(
|
|
|
738
1374
|
)
|
|
739
1375
|
# Draw label text
|
|
740
1376
|
cv2.putText(
|
|
741
|
-
img,
|
|
1377
|
+
img,
|
|
1378
|
+
label,
|
|
742
1379
|
(x1 + 2, y1 - baseline - 2),
|
|
743
|
-
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1380
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1381
|
+
0.5,
|
|
1382
|
+
(255, 255, 255),
|
|
1383
|
+
1,
|
|
744
1384
|
)
|
|
745
1385
|
|
|
746
1386
|
return img
|
|
@@ -799,17 +1439,25 @@ def _discover_datasets(root_path: Path, max_depth: int) -> list[Dataset]:
|
|
|
799
1439
|
|
|
800
1440
|
|
|
801
1441
|
def _detect_dataset(path: Path) -> Dataset | None:
|
|
802
|
-
"""Try to detect a dataset at the given path.
|
|
803
|
-
|
|
1442
|
+
"""Try to detect a dataset at the given path.
|
|
1443
|
+
|
|
1444
|
+
Detection priority: YOLO -> COCO -> MaskDataset
|
|
1445
|
+
"""
|
|
1446
|
+
# Try YOLO first (more specific patterns - requires data.yaml)
|
|
804
1447
|
dataset = YOLODataset.detect(path)
|
|
805
1448
|
if dataset:
|
|
806
1449
|
return dataset
|
|
807
1450
|
|
|
808
|
-
# Try COCO
|
|
1451
|
+
# Try COCO (requires annotations/*.json)
|
|
809
1452
|
dataset = COCODataset.detect(path)
|
|
810
1453
|
if dataset:
|
|
811
1454
|
return dataset
|
|
812
1455
|
|
|
1456
|
+
# Try MaskDataset (directory structure based)
|
|
1457
|
+
dataset = MaskDataset.detect(path)
|
|
1458
|
+
if dataset:
|
|
1459
|
+
return dataset
|
|
1460
|
+
|
|
813
1461
|
return None
|
|
814
1462
|
|
|
815
1463
|
|