argus-cv 1.2.0__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of argus-cv might be problematic. Click here for more details.

Files changed (38) hide show
  1. {argus_cv-1.2.0 → argus_cv-1.3.0}/.github/workflows/ci.yml +3 -14
  2. argus_cv-1.3.0/.pre-commit-config.yaml +7 -0
  3. {argus_cv-1.2.0 → argus_cv-1.3.0}/CHANGELOG.md +8 -0
  4. {argus_cv-1.2.0 → argus_cv-1.3.0}/PKG-INFO +1 -1
  5. {argus_cv-1.2.0 → argus_cv-1.3.0}/pyproject.toml +2 -1
  6. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/__init__.py +1 -1
  7. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/cli.py +441 -20
  8. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/core/__init__.py +3 -0
  9. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/core/base.py +1 -0
  10. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/core/coco.py +8 -6
  11. argus_cv-1.3.0/src/argus/core/mask.py +648 -0
  12. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/core/yolo.py +21 -12
  13. {argus_cv-1.2.0 → argus_cv-1.3.0}/tests/conftest.py +271 -3
  14. {argus_cv-1.2.0 → argus_cv-1.3.0}/tests/test_classification.py +2 -6
  15. argus_cv-1.3.0/tests/test_mask.py +400 -0
  16. {argus_cv-1.2.0 → argus_cv-1.3.0}/tests/test_split_command.py +1 -3
  17. {argus_cv-1.2.0 → argus_cv-1.3.0}/tests/test_stats_command.py +1 -1
  18. {argus_cv-1.2.0 → argus_cv-1.3.0}/uv.lock +27 -1
  19. {argus_cv-1.2.0 → argus_cv-1.3.0}/.github/workflows/docs.yml +0 -0
  20. {argus_cv-1.2.0 → argus_cv-1.3.0}/.github/workflows/release.yml +0 -0
  21. {argus_cv-1.2.0 → argus_cv-1.3.0}/.gitignore +0 -0
  22. {argus_cv-1.2.0 → argus_cv-1.3.0}/README.md +0 -0
  23. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/assets/javascripts/extra.js +0 -0
  24. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/assets/stylesheets/extra.css +0 -0
  25. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/getting-started/installation.md +0 -0
  26. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/getting-started/quickstart.md +0 -0
  27. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/guides/datasets.md +0 -0
  28. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/guides/listing.md +0 -0
  29. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/guides/splitting.md +0 -0
  30. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/guides/stats.md +0 -0
  31. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/guides/viewer.md +0 -0
  32. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/index.md +0 -0
  33. {argus_cv-1.2.0 → argus_cv-1.3.0}/docs/reference/cli.md +0 -0
  34. {argus_cv-1.2.0 → argus_cv-1.3.0}/mkdocs.yml +0 -0
  35. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/__main__.py +0 -0
  36. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/commands/__init__.py +0 -0
  37. {argus_cv-1.2.0 → argus_cv-1.3.0}/src/argus/core/split.py +0 -0
  38. {argus_cv-1.2.0 → argus_cv-1.3.0}/tests/test_list_command.py +0 -0
@@ -11,23 +11,12 @@ on:
11
11
  jobs:
12
12
  lint:
13
13
  runs-on: ubuntu-latest
14
-
14
+
15
15
  steps:
16
16
  - name: Checkout code
17
17
  uses: actions/checkout@v4
18
-
19
- - name: Install uv
20
- uses: astral-sh/setup-uv@v4
21
- with:
22
- enable-cache: true
23
-
24
- - name: Set up Python
25
- run: uv python install 3.12
26
-
27
- - name: Run linting
28
- run: |
29
- uv sync --only-group dev
30
- uv run ruff check .
18
+
19
+ - uses: j178/prek-action@v1
31
20
 
32
21
  test:
33
22
  runs-on: ubuntu-latest
@@ -0,0 +1,7 @@
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ rev: v0.11.13
4
+ hooks:
5
+ - id: ruff
6
+ args: [--fix]
7
+ - id: ruff-format
@@ -2,6 +2,14 @@
2
2
 
3
3
  <!-- version list -->
4
4
 
5
+ ## v1.3.0 (2026-01-24)
6
+
7
+ ### Features
8
+
9
+ - Add MaskDataset class for semantic segmentation masks
10
+ ([`bbba4f2`](https://github.com/pirnerjonas/argus/commit/bbba4f2c9d476d426b6c805ad2120eca9c38c855))
11
+
12
+
5
13
  ## v1.2.0 (2026-01-15)
6
14
 
7
15
  ### Code Style
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: argus-cv
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary: CLI tool for working with vision AI datasets
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: numpy>=1.24.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "argus-cv"
3
- version = "1.2.0"
3
+ version = "1.3.0"
4
4
  description = "CLI tool for working with vision AI datasets"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -19,6 +19,7 @@ dev = [
19
19
  "ruff>=0.9.9",
20
20
  "python-semantic-release>=9.0.0",
21
21
  "build>=1.0.0",
22
+ "prek>=0.2.30",
22
23
  ]
23
24
  docs = [
24
25
  "mkdocs-material>=9.5.0",
@@ -1,3 +1,3 @@
1
1
  """Argus - Vision AI dataset toolkit."""
2
2
 
3
- __version__ = "1.2.0"
3
+ __version__ = "1.3.0"
@@ -11,8 +11,8 @@ from rich.console import Console
11
11
  from rich.progress import Progress, SpinnerColumn, TextColumn
12
12
  from rich.table import Table
13
13
 
14
- from argus.core import COCODataset, Dataset, YOLODataset
15
- from argus.core.base import TaskType
14
+ from argus.core import COCODataset, Dataset, MaskDataset, YOLODataset
15
+ from argus.core.base import DatasetFormat, TaskType
16
16
  from argus.core.split import (
17
17
  is_coco_unsplit,
18
18
  parse_ratio,
@@ -128,12 +128,19 @@ def stats(
128
128
  dataset = _detect_dataset(dataset_path)
129
129
  if not dataset:
130
130
  console.print(
131
- f"[red]Error: No YOLO or COCO dataset found at {dataset_path}[/red]\n"
131
+ f"[red]Error: No dataset found at {dataset_path}[/red]\n"
132
132
  "[yellow]Ensure the path points to a dataset root containing "
133
- "data.yaml (YOLO) or annotations/ folder (COCO).[/yellow]"
133
+ "data.yaml (YOLO), annotations/ folder (COCO), or "
134
+ "images/ + masks/ directories (Mask).[/yellow]"
134
135
  )
135
136
  raise typer.Exit(1)
136
137
 
138
+ # Handle mask datasets with pixel statistics
139
+ if dataset.format == DatasetFormat.MASK:
140
+ assert isinstance(dataset, MaskDataset)
141
+ _show_mask_stats(dataset, dataset_path)
142
+ return
143
+
137
144
  # Get instance counts with progress indicator
138
145
  with Progress(
139
146
  SpinnerColumn(),
@@ -212,10 +219,100 @@ def stats(
212
219
  else:
213
220
  image_parts.append(f"{split}: {img_total}")
214
221
 
215
- console.print(f"\n[green]Dataset: {dataset.format.value.upper()} | "
216
- f"Task: {dataset.task.value} | "
217
- f"Classes: {len(sorted_classes)} | "
218
- f"Total instances: {grand_total}[/green]")
222
+ console.print(
223
+ f"\n[green]Dataset: {dataset.format.value.upper()} | "
224
+ f"Task: {dataset.task.value} | "
225
+ f"Classes: {len(sorted_classes)} | "
226
+ f"Total instances: {grand_total}[/green]"
227
+ )
228
+
229
+ if image_parts:
230
+ console.print(f"[blue]Images: {' | '.join(image_parts)}[/blue]")
231
+
232
+
233
+ def _show_mask_stats(dataset: MaskDataset, dataset_path: Path) -> None:
234
+ """Show statistics for mask datasets with pixel-level information.
235
+
236
+ Args:
237
+ dataset: The MaskDataset instance.
238
+ dataset_path: Path to the dataset root.
239
+ """
240
+ with Progress(
241
+ SpinnerColumn(),
242
+ TextColumn("[progress.description]{task.description}"),
243
+ console=console,
244
+ transient=True,
245
+ ) as progress:
246
+ progress.add_task("Analyzing mask dataset...", total=None)
247
+ pixel_counts = dataset.get_pixel_counts()
248
+ image_presence = dataset.get_image_class_presence()
249
+ image_counts = dataset.get_image_counts()
250
+
251
+ # Get class mapping
252
+ class_mapping = dataset.get_class_mapping()
253
+
254
+ # Calculate total non-ignored pixels
255
+ total_pixels = sum(
256
+ count
257
+ for class_id, count in pixel_counts.items()
258
+ if class_id != dataset.ignore_index
259
+ )
260
+ ignored_pixels = pixel_counts.get(dataset.ignore_index, 0)
261
+
262
+ # Calculate total images
263
+ total_images = sum(ic["total"] for ic in image_counts.values())
264
+
265
+ # Create table
266
+ splits_str = ", ".join(dataset.splits) if dataset.splits else "unsplit"
267
+ title = f"Class Statistics: {dataset_path.name} ({splits_str})"
268
+ table = Table(title=title)
269
+ table.add_column("Class", style="cyan")
270
+ table.add_column("Total Pixels", justify="right", style="green")
271
+ table.add_column("% Coverage", justify="right", style="magenta")
272
+ table.add_column("Images With", justify="right", style="yellow")
273
+
274
+ # Sort classes by class_id
275
+ sorted_class_ids = sorted(class_mapping.keys())
276
+
277
+ for class_id in sorted_class_ids:
278
+ class_name = class_mapping[class_id]
279
+ pixels = pixel_counts.get(class_id, 0)
280
+ presence = image_presence.get(class_id, 0)
281
+
282
+ # Calculate coverage percentage
283
+ coverage = (pixels / total_pixels * 100) if total_pixels > 0 else 0.0
284
+
285
+ table.add_row(
286
+ class_name,
287
+ f"{pixels:,}",
288
+ f"{coverage:.1f}%",
289
+ str(presence),
290
+ )
291
+
292
+ # Add ignored row if there are ignored pixels
293
+ if ignored_pixels > 0:
294
+ table.add_section()
295
+ table.add_row(
296
+ "[dim](ignored)[/dim]",
297
+ f"[dim]{ignored_pixels:,}[/dim]",
298
+ "[dim]-[/dim]",
299
+ f"[dim]{total_images}[/dim]",
300
+ )
301
+
302
+ console.print(table)
303
+
304
+ # Summary line
305
+ console.print(f"\n[green]Dataset: {dataset_path}[/green]")
306
+ console.print(
307
+ f"[green]Format: {dataset.format.value.upper()} | "
308
+ f"Task: {dataset.task.value}[/green]"
309
+ )
310
+
311
+ # Image counts per split
312
+ image_parts = []
313
+ for split in dataset.splits if dataset.splits else ["unsplit"]:
314
+ if split in image_counts:
315
+ image_parts.append(f"{split}: {image_counts[split]['total']}")
219
316
 
220
317
  if image_parts:
221
318
  console.print(f"[blue]Images: {' | '.join(image_parts)}[/blue]")
@@ -247,6 +344,16 @@ def view(
247
344
  help="Maximum classes to show in grid (classification only).",
248
345
  ),
249
346
  ] = None,
347
+ opacity: Annotated[
348
+ float,
349
+ typer.Option(
350
+ "--opacity",
351
+ "-o",
352
+ help="Mask overlay opacity (0.0-1.0, mask datasets only).",
353
+ min=0.0,
354
+ max=1.0,
355
+ ),
356
+ ] = 0.5,
250
357
  ) -> None:
251
358
  """View annotated images in a dataset.
252
359
 
@@ -295,6 +402,41 @@ def view(
295
402
  # Generate consistent colors for each class
296
403
  class_colors = _generate_class_colors(dataset.class_names)
297
404
 
405
+ # Handle mask datasets with overlay viewer
406
+ if dataset.format == DatasetFormat.MASK:
407
+ assert isinstance(dataset, MaskDataset)
408
+ with Progress(
409
+ SpinnerColumn(),
410
+ TextColumn("[progress.description]{task.description}"),
411
+ console=console,
412
+ transient=True,
413
+ ) as progress:
414
+ progress.add_task("Loading images...", total=None)
415
+ image_paths = dataset.get_image_paths(split)
416
+
417
+ if not image_paths:
418
+ console.print("[yellow]No images found in the dataset.[/yellow]")
419
+ return
420
+
421
+ console.print(
422
+ f"[green]Found {len(image_paths)} images. "
423
+ f"Opening mask viewer...[/green]\n"
424
+ "[dim]Controls: \u2190 / \u2192 or P / N to navigate, "
425
+ "Mouse wheel to zoom, Drag to pan, R to reset, T to toggle overlay, "
426
+ "Q / ESC to quit[/dim]"
427
+ )
428
+
429
+ viewer = _MaskViewer(
430
+ image_paths=image_paths,
431
+ dataset=dataset,
432
+ class_colors=class_colors,
433
+ window_name=f"Argus Mask Viewer - {dataset_path.name}",
434
+ opacity=opacity,
435
+ )
436
+ viewer.run()
437
+ console.print("[green]Viewer closed.[/green]")
438
+ return
439
+
298
440
  # Handle classification datasets with grid viewer
299
441
  if dataset.task == TaskType.CLASSIFICATION:
300
442
  # Use first split if specified, otherwise let get_images_by_class handle it
@@ -599,12 +741,16 @@ class _ImageViewer:
599
741
  info_text += " [Annotations: OFF]"
600
742
 
601
743
  cv2.putText(
602
- display, info_text, (10, 30),
603
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2
744
+ display,
745
+ info_text,
746
+ (10, 30),
747
+ cv2.FONT_HERSHEY_SIMPLEX,
748
+ 0.7,
749
+ (255, 255, 255),
750
+ 2,
604
751
  )
605
752
  cv2.putText(
606
- display, info_text, (10, 30),
607
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
753
+ display, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
608
754
  )
609
755
 
610
756
  return display
@@ -733,9 +879,11 @@ class _ClassificationGridViewer:
733
879
  self.current_index = 0
734
880
 
735
881
  # Calculate max images across all classes
736
- self.max_images = max(
737
- len(imgs) for imgs in self.images_by_class.values()
738
- ) if self.images_by_class else 0
882
+ self.max_images = (
883
+ max(len(imgs) for imgs in self.images_by_class.values())
884
+ if self.images_by_class
885
+ else 0
886
+ )
739
887
 
740
888
  # Calculate grid layout
741
889
  self.cols, self.rows = self._calculate_grid_layout()
@@ -883,6 +1031,267 @@ class _ClassificationGridViewer:
883
1031
  cv2.destroyAllWindows()
884
1032
 
885
1033
 
1034
+ class _MaskViewer:
1035
+ """Interactive viewer for semantic mask datasets with colored overlay."""
1036
+
1037
+ def __init__(
1038
+ self,
1039
+ image_paths: list[Path],
1040
+ dataset: MaskDataset,
1041
+ class_colors: dict[str, tuple[int, int, int]],
1042
+ window_name: str,
1043
+ opacity: float = 0.5,
1044
+ ):
1045
+ self.image_paths = image_paths
1046
+ self.dataset = dataset
1047
+ self.class_colors = class_colors
1048
+ self.window_name = window_name
1049
+ self.opacity = opacity
1050
+
1051
+ self.current_idx = 0
1052
+ self.zoom = 1.0
1053
+ self.pan_x = 0.0
1054
+ self.pan_y = 0.0
1055
+
1056
+ # Mouse state for panning
1057
+ self.dragging = False
1058
+ self.drag_start_x = 0
1059
+ self.drag_start_y = 0
1060
+ self.pan_start_x = 0.0
1061
+ self.pan_start_y = 0.0
1062
+
1063
+ # Current image cache
1064
+ self.current_img: np.ndarray | None = None
1065
+ self.overlay_img: np.ndarray | None = None
1066
+
1067
+ # Overlay visibility toggle
1068
+ self.show_overlay = True
1069
+
1070
+ # Build class_id to color mapping
1071
+ self._id_to_color: dict[int, tuple[int, int, int]] = {}
1072
+ class_mapping = dataset.get_class_mapping()
1073
+ for class_id, class_name in class_mapping.items():
1074
+ if class_name in class_colors:
1075
+ self._id_to_color[class_id] = class_colors[class_name]
1076
+
1077
+ def _load_current_image(self) -> bool:
1078
+ """Load current image and create mask overlay."""
1079
+ image_path = self.image_paths[self.current_idx]
1080
+
1081
+ img = cv2.imread(str(image_path))
1082
+ if img is None:
1083
+ return False
1084
+
1085
+ mask = self.dataset.load_mask(image_path)
1086
+ if mask is None:
1087
+ console.print(f"[yellow]Warning: No mask for {image_path}[/yellow]")
1088
+ self.current_img = img
1089
+ self.overlay_img = img.copy()
1090
+ return True
1091
+
1092
+ # Validate dimensions
1093
+ if img.shape[:2] != mask.shape[:2]:
1094
+ console.print(
1095
+ f"[red]Error: Dimension mismatch for {image_path.name}: "
1096
+ f"image={img.shape[:2]}, mask={mask.shape[:2]}[/red]"
1097
+ )
1098
+ return False
1099
+
1100
+ self.current_img = img
1101
+ self.overlay_img = self._create_overlay(img, mask)
1102
+ return True
1103
+
1104
+ def _create_overlay(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
1105
+ """Create colored overlay from mask.
1106
+
1107
+ Args:
1108
+ img: Original image (BGR).
1109
+ mask: Grayscale mask with class IDs.
1110
+
1111
+ Returns:
1112
+ Image with colored mask overlay.
1113
+ """
1114
+ # Create colored mask
1115
+ h, w = mask.shape
1116
+ colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
1117
+
1118
+ for class_id, color in self._id_to_color.items():
1119
+ colored_mask[mask == class_id] = color
1120
+
1121
+ # Blend with original image
1122
+ # Ignore pixels are fully transparent (not blended)
1123
+ ignore_mask = mask == self.dataset.ignore_index
1124
+ alpha = np.ones((h, w, 1), dtype=np.float32) * self.opacity
1125
+ alpha[ignore_mask] = 0.0
1126
+
1127
+ # Blend: result = img * (1 - alpha) + colored_mask * alpha
1128
+ blended = (
1129
+ img.astype(np.float32) * (1 - alpha)
1130
+ + colored_mask.astype(np.float32) * alpha
1131
+ )
1132
+ return blended.astype(np.uint8)
1133
+
1134
+ def _get_display_image(self) -> np.ndarray:
1135
+ """Get the image transformed for current zoom/pan."""
1136
+ if self.overlay_img is None:
1137
+ return np.zeros((480, 640, 3), dtype=np.uint8)
1138
+
1139
+ if self.show_overlay:
1140
+ img = self.overlay_img
1141
+ elif self.current_img is not None:
1142
+ img = self.current_img
1143
+ else:
1144
+ img = self.overlay_img
1145
+
1146
+ h, w = img.shape[:2]
1147
+
1148
+ if self.zoom == 1.0 and self.pan_x == 0.0 and self.pan_y == 0.0:
1149
+ display = img.copy()
1150
+ else:
1151
+ # Calculate the visible region
1152
+ view_w = int(w / self.zoom)
1153
+ view_h = int(h / self.zoom)
1154
+
1155
+ # Center point with pan offset
1156
+ cx = w / 2 + self.pan_x
1157
+ cy = h / 2 + self.pan_y
1158
+
1159
+ # Calculate crop bounds
1160
+ x1 = int(max(0, cx - view_w / 2))
1161
+ y1 = int(max(0, cy - view_h / 2))
1162
+ x2 = int(min(w, x1 + view_w))
1163
+ y2 = int(min(h, y1 + view_h))
1164
+
1165
+ # Adjust if we hit boundaries
1166
+ if x2 - x1 < view_w:
1167
+ x1 = max(0, x2 - view_w)
1168
+ if y2 - y1 < view_h:
1169
+ y1 = max(0, y2 - view_h)
1170
+
1171
+ # Crop and resize
1172
+ cropped = img[y1:y2, x1:x2]
1173
+ display = cv2.resize(cropped, (w, h), interpolation=cv2.INTER_LINEAR)
1174
+
1175
+ # Add info overlay
1176
+ image_path = self.image_paths[self.current_idx]
1177
+ idx = self.current_idx + 1
1178
+ total = len(self.image_paths)
1179
+ info_text = f"[{idx}/{total}] {image_path.name}"
1180
+ if self.zoom > 1.0:
1181
+ info_text += f" (Zoom: {self.zoom:.1f}x)"
1182
+ if not self.show_overlay:
1183
+ info_text += " [Overlay: OFF]"
1184
+
1185
+ cv2.putText(
1186
+ display,
1187
+ info_text,
1188
+ (10, 30),
1189
+ cv2.FONT_HERSHEY_SIMPLEX,
1190
+ 0.7,
1191
+ (255, 255, 255),
1192
+ 2,
1193
+ )
1194
+ cv2.putText(
1195
+ display, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
1196
+ )
1197
+
1198
+ return display
1199
+
1200
+ def _mouse_callback(
1201
+ self, event: int, x: int, y: int, flags: int, param: None
1202
+ ) -> None:
1203
+ """Handle mouse events for zoom and pan."""
1204
+ if event == cv2.EVENT_MOUSEWHEEL:
1205
+ # Zoom in/out
1206
+ if flags > 0:
1207
+ self.zoom = min(10.0, self.zoom * 1.2)
1208
+ else:
1209
+ self.zoom = max(1.0, self.zoom / 1.2)
1210
+
1211
+ # Reset pan if zoomed out to 1x
1212
+ if self.zoom == 1.0:
1213
+ self.pan_x = 0.0
1214
+ self.pan_y = 0.0
1215
+
1216
+ elif event == cv2.EVENT_LBUTTONDOWN:
1217
+ self.dragging = True
1218
+ self.drag_start_x = x
1219
+ self.drag_start_y = y
1220
+ self.pan_start_x = self.pan_x
1221
+ self.pan_start_y = self.pan_y
1222
+
1223
+ elif event == cv2.EVENT_MOUSEMOVE and self.dragging:
1224
+ if self.zoom > 1.0 and self.overlay_img is not None:
1225
+ h, w = self.overlay_img.shape[:2]
1226
+ # Calculate pan delta (inverted for natural feel)
1227
+ dx = (self.drag_start_x - x) / self.zoom
1228
+ dy = (self.drag_start_y - y) / self.zoom
1229
+
1230
+ # Update pan with limits
1231
+ max_pan_x = w * (1 - 1 / self.zoom) / 2
1232
+ max_pan_y = h * (1 - 1 / self.zoom) / 2
1233
+
1234
+ self.pan_x = max(-max_pan_x, min(max_pan_x, self.pan_start_x + dx))
1235
+ self.pan_y = max(-max_pan_y, min(max_pan_y, self.pan_start_y + dy))
1236
+
1237
+ elif event == cv2.EVENT_LBUTTONUP:
1238
+ self.dragging = False
1239
+
1240
+ def _reset_view(self) -> None:
1241
+ """Reset zoom and pan to default."""
1242
+ self.zoom = 1.0
1243
+ self.pan_x = 0.0
1244
+ self.pan_y = 0.0
1245
+
1246
+ def _next_image(self) -> None:
1247
+ """Go to next image."""
1248
+ self.current_idx = (self.current_idx + 1) % len(self.image_paths)
1249
+ self._reset_view()
1250
+
1251
+ def _prev_image(self) -> None:
1252
+ """Go to previous image."""
1253
+ self.current_idx = (self.current_idx - 1) % len(self.image_paths)
1254
+ self._reset_view()
1255
+
1256
+ def run(self) -> None:
1257
+ """Run the interactive viewer."""
1258
+ cv2.namedWindow(self.window_name, cv2.WINDOW_AUTOSIZE)
1259
+ cv2.setMouseCallback(self.window_name, self._mouse_callback)
1260
+
1261
+ while True:
1262
+ # Load image if needed
1263
+ if self.overlay_img is None and not self._load_current_image():
1264
+ console.print(
1265
+ f"[yellow]Warning: Could not load "
1266
+ f"{self.image_paths[self.current_idx]}[/yellow]"
1267
+ )
1268
+ self._next_image()
1269
+ continue
1270
+
1271
+ # Display image
1272
+ display = self._get_display_image()
1273
+ cv2.imshow(self.window_name, display)
1274
+
1275
+ # Wait for input (short timeout for smooth panning)
1276
+ key = cv2.waitKey(30) & 0xFF
1277
+
1278
+ # Handle keyboard input
1279
+ if key == ord("q") or key == 27: # Q or ESC
1280
+ break
1281
+ elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
1282
+ self.overlay_img = None
1283
+ self._next_image()
1284
+ elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
1285
+ self.overlay_img = None
1286
+ self._prev_image()
1287
+ elif key == ord("r"): # R to reset zoom
1288
+ self._reset_view()
1289
+ elif key == ord("t"): # T to toggle overlay
1290
+ self.show_overlay = not self.show_overlay
1291
+
1292
+ cv2.destroyAllWindows()
1293
+
1294
+
886
1295
  def _generate_class_colors(class_names: list[str]) -> dict[str, tuple[int, int, int]]:
887
1296
  """Generate consistent colors for each class name.
888
1297
 
@@ -965,9 +1374,13 @@ def _draw_annotations(
965
1374
  )
966
1375
  # Draw label text
967
1376
  cv2.putText(
968
- img, label,
1377
+ img,
1378
+ label,
969
1379
  (x1 + 2, y1 - baseline - 2),
970
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
1380
+ cv2.FONT_HERSHEY_SIMPLEX,
1381
+ 0.5,
1382
+ (255, 255, 255),
1383
+ 1,
971
1384
  )
972
1385
 
973
1386
  return img
@@ -1026,17 +1439,25 @@ def _discover_datasets(root_path: Path, max_depth: int) -> list[Dataset]:
1026
1439
 
1027
1440
 
1028
1441
  def _detect_dataset(path: Path) -> Dataset | None:
1029
- """Try to detect a dataset at the given path."""
1030
- # Try YOLO first (more specific patterns)
1442
+ """Try to detect a dataset at the given path.
1443
+
1444
+ Detection priority: YOLO -> COCO -> MaskDataset
1445
+ """
1446
+ # Try YOLO first (more specific patterns - requires data.yaml)
1031
1447
  dataset = YOLODataset.detect(path)
1032
1448
  if dataset:
1033
1449
  return dataset
1034
1450
 
1035
- # Try COCO
1451
+ # Try COCO (requires annotations/*.json)
1036
1452
  dataset = COCODataset.detect(path)
1037
1453
  if dataset:
1038
1454
  return dataset
1039
1455
 
1456
+ # Try MaskDataset (directory structure based)
1457
+ dataset = MaskDataset.detect(path)
1458
+ if dataset:
1459
+ return dataset
1460
+
1040
1461
  return None
1041
1462
 
1042
1463
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  from argus.core.base import Dataset
4
4
  from argus.core.coco import COCODataset
5
+ from argus.core.mask import ConfigurationError, MaskDataset
5
6
  from argus.core.split import split_coco_dataset, split_yolo_dataset
6
7
  from argus.core.yolo import YOLODataset
7
8
 
@@ -9,6 +10,8 @@ __all__ = [
9
10
  "Dataset",
10
11
  "YOLODataset",
11
12
  "COCODataset",
13
+ "MaskDataset",
14
+ "ConfigurationError",
12
15
  "split_coco_dataset",
13
16
  "split_yolo_dataset",
14
17
  ]
@@ -11,6 +11,7 @@ class DatasetFormat(str, Enum):
11
11
 
12
12
  YOLO = "yolo"
13
13
  COCO = "coco"
14
+ MASK = "mask"
14
15
 
15
16
 
16
17
  class TaskType(str, Enum):
@@ -453,12 +453,14 @@ class COCODataset(Dataset):
453
453
  for i in range(0, len(coords), 2):
454
454
  polygon.append((float(coords[i]), float(coords[i + 1])))
455
455
 
456
- annotations.append({
457
- "class_name": class_name,
458
- "class_id": cat_id,
459
- "bbox": bbox_tuple,
460
- "polygon": polygon,
461
- })
456
+ annotations.append(
457
+ {
458
+ "class_name": class_name,
459
+ "class_id": cat_id,
460
+ "bbox": bbox_tuple,
461
+ "polygon": polygon,
462
+ }
463
+ )
462
464
 
463
465
  except (json.JSONDecodeError, OSError):
464
466
  continue