argus-cv 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of argus-cv might be problematic. Click here for more details.

argus/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  """Argus - Vision AI dataset toolkit."""
2
2
 
3
- __version__ = "1.2.0"
3
+ __version__ = "1.4.0"
argus/cli.py CHANGED
@@ -11,8 +11,8 @@ from rich.console import Console
11
11
  from rich.progress import Progress, SpinnerColumn, TextColumn
12
12
  from rich.table import Table
13
13
 
14
- from argus.core import COCODataset, Dataset, YOLODataset
15
- from argus.core.base import TaskType
14
+ from argus.core import COCODataset, Dataset, MaskDataset, YOLODataset
15
+ from argus.core.base import DatasetFormat, TaskType
16
16
  from argus.core.split import (
17
17
  is_coco_unsplit,
18
18
  parse_ratio,
@@ -128,12 +128,19 @@ def stats(
128
128
  dataset = _detect_dataset(dataset_path)
129
129
  if not dataset:
130
130
  console.print(
131
- f"[red]Error: No YOLO or COCO dataset found at {dataset_path}[/red]\n"
131
+ f"[red]Error: No dataset found at {dataset_path}[/red]\n"
132
132
  "[yellow]Ensure the path points to a dataset root containing "
133
- "data.yaml (YOLO) or annotations/ folder (COCO).[/yellow]"
133
+ "data.yaml (YOLO), annotations/ folder (COCO), or "
134
+ "images/ + masks/ directories (Mask).[/yellow]"
134
135
  )
135
136
  raise typer.Exit(1)
136
137
 
138
+ # Handle mask datasets with pixel statistics
139
+ if dataset.format == DatasetFormat.MASK:
140
+ assert isinstance(dataset, MaskDataset)
141
+ _show_mask_stats(dataset, dataset_path)
142
+ return
143
+
137
144
  # Get instance counts with progress indicator
138
145
  with Progress(
139
146
  SpinnerColumn(),
@@ -212,10 +219,100 @@ def stats(
212
219
  else:
213
220
  image_parts.append(f"{split}: {img_total}")
214
221
 
215
- console.print(f"\n[green]Dataset: {dataset.format.value.upper()} | "
216
- f"Task: {dataset.task.value} | "
217
- f"Classes: {len(sorted_classes)} | "
218
- f"Total instances: {grand_total}[/green]")
222
+ console.print(
223
+ f"\n[green]Dataset: {dataset.format.value.upper()} | "
224
+ f"Task: {dataset.task.value} | "
225
+ f"Classes: {len(sorted_classes)} | "
226
+ f"Total instances: {grand_total}[/green]"
227
+ )
228
+
229
+ if image_parts:
230
+ console.print(f"[blue]Images: {' | '.join(image_parts)}[/blue]")
231
+
232
+
233
+ def _show_mask_stats(dataset: MaskDataset, dataset_path: Path) -> None:
234
+ """Show statistics for mask datasets with pixel-level information.
235
+
236
+ Args:
237
+ dataset: The MaskDataset instance.
238
+ dataset_path: Path to the dataset root.
239
+ """
240
+ with Progress(
241
+ SpinnerColumn(),
242
+ TextColumn("[progress.description]{task.description}"),
243
+ console=console,
244
+ transient=True,
245
+ ) as progress:
246
+ progress.add_task("Analyzing mask dataset...", total=None)
247
+ pixel_counts = dataset.get_pixel_counts()
248
+ image_presence = dataset.get_image_class_presence()
249
+ image_counts = dataset.get_image_counts()
250
+
251
+ # Get class mapping
252
+ class_mapping = dataset.get_class_mapping()
253
+
254
+ # Calculate total non-ignored pixels
255
+ total_pixels = sum(
256
+ count
257
+ for class_id, count in pixel_counts.items()
258
+ if class_id != dataset.ignore_index
259
+ )
260
+ ignored_pixels = pixel_counts.get(dataset.ignore_index, 0)
261
+
262
+ # Calculate total images
263
+ total_images = sum(ic["total"] for ic in image_counts.values())
264
+
265
+ # Create table
266
+ splits_str = ", ".join(dataset.splits) if dataset.splits else "unsplit"
267
+ title = f"Class Statistics: {dataset_path.name} ({splits_str})"
268
+ table = Table(title=title)
269
+ table.add_column("Class", style="cyan")
270
+ table.add_column("Total Pixels", justify="right", style="green")
271
+ table.add_column("% Coverage", justify="right", style="magenta")
272
+ table.add_column("Images With", justify="right", style="yellow")
273
+
274
+ # Sort classes by class_id
275
+ sorted_class_ids = sorted(class_mapping.keys())
276
+
277
+ for class_id in sorted_class_ids:
278
+ class_name = class_mapping[class_id]
279
+ pixels = pixel_counts.get(class_id, 0)
280
+ presence = image_presence.get(class_id, 0)
281
+
282
+ # Calculate coverage percentage
283
+ coverage = (pixels / total_pixels * 100) if total_pixels > 0 else 0.0
284
+
285
+ table.add_row(
286
+ class_name,
287
+ f"{pixels:,}",
288
+ f"{coverage:.1f}%",
289
+ str(presence),
290
+ )
291
+
292
+ # Add ignored row if there are ignored pixels
293
+ if ignored_pixels > 0:
294
+ table.add_section()
295
+ table.add_row(
296
+ "[dim](ignored)[/dim]",
297
+ f"[dim]{ignored_pixels:,}[/dim]",
298
+ "[dim]-[/dim]",
299
+ f"[dim]{total_images}[/dim]",
300
+ )
301
+
302
+ console.print(table)
303
+
304
+ # Summary line
305
+ console.print(f"\n[green]Dataset: {dataset_path}[/green]")
306
+ console.print(
307
+ f"[green]Format: {dataset.format.value.upper()} | "
308
+ f"Task: {dataset.task.value}[/green]"
309
+ )
310
+
311
+ # Image counts per split
312
+ image_parts = []
313
+ for split in dataset.splits if dataset.splits else ["unsplit"]:
314
+ if split in image_counts:
315
+ image_parts.append(f"{split}: {image_counts[split]['total']}")
219
316
 
220
317
  if image_parts:
221
318
  console.print(f"[blue]Images: {' | '.join(image_parts)}[/blue]")
@@ -247,6 +344,16 @@ def view(
247
344
  help="Maximum classes to show in grid (classification only).",
248
345
  ),
249
346
  ] = None,
347
+ opacity: Annotated[
348
+ float,
349
+ typer.Option(
350
+ "--opacity",
351
+ "-o",
352
+ help="Mask overlay opacity (0.0-1.0, mask datasets only).",
353
+ min=0.0,
354
+ max=1.0,
355
+ ),
356
+ ] = 0.5,
250
357
  ) -> None:
251
358
  """View annotated images in a dataset.
252
359
 
@@ -295,6 +402,41 @@ def view(
295
402
  # Generate consistent colors for each class
296
403
  class_colors = _generate_class_colors(dataset.class_names)
297
404
 
405
+ # Handle mask datasets with overlay viewer
406
+ if dataset.format == DatasetFormat.MASK:
407
+ assert isinstance(dataset, MaskDataset)
408
+ with Progress(
409
+ SpinnerColumn(),
410
+ TextColumn("[progress.description]{task.description}"),
411
+ console=console,
412
+ transient=True,
413
+ ) as progress:
414
+ progress.add_task("Loading images...", total=None)
415
+ image_paths = dataset.get_image_paths(split)
416
+
417
+ if not image_paths:
418
+ console.print("[yellow]No images found in the dataset.[/yellow]")
419
+ return
420
+
421
+ console.print(
422
+ f"[green]Found {len(image_paths)} images. "
423
+ f"Opening mask viewer...[/green]\n"
424
+ "[dim]Controls: \u2190 / \u2192 or P / N to navigate, "
425
+ "Mouse wheel to zoom, Drag to pan, R to reset, T to toggle overlay, "
426
+ "Q / ESC to quit[/dim]"
427
+ )
428
+
429
+ viewer = _MaskViewer(
430
+ image_paths=image_paths,
431
+ dataset=dataset,
432
+ class_colors=class_colors,
433
+ window_name=f"Argus Mask Viewer - {dataset_path.name}",
434
+ opacity=opacity,
435
+ )
436
+ viewer.run()
437
+ console.print("[green]Viewer closed.[/green]")
438
+ return
439
+
298
440
  # Handle classification datasets with grid viewer
299
441
  if dataset.task == TaskType.CLASSIFICATION:
300
442
  # Use first split if specified, otherwise let get_images_by_class handle it
@@ -392,13 +534,6 @@ def split_dataset(
392
534
  help="Train/val/test ratio (e.g. 0.8,0.1,0.1).",
393
535
  ),
394
536
  ] = "0.8,0.1,0.1",
395
- stratify: Annotated[
396
- bool,
397
- typer.Option(
398
- "--stratify/--no-stratify",
399
- help="Stratify by class distribution when splitting.",
400
- ),
401
- ] = True,
402
537
  seed: Annotated[
403
538
  int,
404
539
  typer.Option(
@@ -451,9 +586,7 @@ def split_dataset(
451
586
  ) as progress:
452
587
  progress.add_task("Creating YOLO splits...", total=None)
453
588
  try:
454
- counts = split_yolo_dataset(
455
- dataset, output_path, ratios, stratify, seed
456
- )
589
+ counts = split_yolo_dataset(dataset, output_path, ratios, True, seed)
457
590
  except ValueError as exc:
458
591
  console.print(f"[red]Error: {exc}[/red]")
459
592
  raise typer.Exit(1) from exc
@@ -486,7 +619,7 @@ def split_dataset(
486
619
  annotation_file,
487
620
  output_path,
488
621
  ratios,
489
- stratify,
622
+ True,
490
623
  seed,
491
624
  )
492
625
  except ValueError as exc:
@@ -599,12 +732,16 @@ class _ImageViewer:
599
732
  info_text += " [Annotations: OFF]"
600
733
 
601
734
  cv2.putText(
602
- display, info_text, (10, 30),
603
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2
735
+ display,
736
+ info_text,
737
+ (10, 30),
738
+ cv2.FONT_HERSHEY_SIMPLEX,
739
+ 0.7,
740
+ (255, 255, 255),
741
+ 2,
604
742
  )
605
743
  cv2.putText(
606
- display, info_text, (10, 30),
607
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
744
+ display, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
608
745
  )
609
746
 
610
747
  return display
@@ -733,9 +870,11 @@ class _ClassificationGridViewer:
733
870
  self.current_index = 0
734
871
 
735
872
  # Calculate max images across all classes
736
- self.max_images = max(
737
- len(imgs) for imgs in self.images_by_class.values()
738
- ) if self.images_by_class else 0
873
+ self.max_images = (
874
+ max(len(imgs) for imgs in self.images_by_class.values())
875
+ if self.images_by_class
876
+ else 0
877
+ )
739
878
 
740
879
  # Calculate grid layout
741
880
  self.cols, self.rows = self._calculate_grid_layout()
@@ -883,6 +1022,267 @@ class _ClassificationGridViewer:
883
1022
  cv2.destroyAllWindows()
884
1023
 
885
1024
 
1025
+ class _MaskViewer:
1026
+ """Interactive viewer for semantic mask datasets with colored overlay."""
1027
+
1028
+ def __init__(
1029
+ self,
1030
+ image_paths: list[Path],
1031
+ dataset: MaskDataset,
1032
+ class_colors: dict[str, tuple[int, int, int]],
1033
+ window_name: str,
1034
+ opacity: float = 0.5,
1035
+ ):
1036
+ self.image_paths = image_paths
1037
+ self.dataset = dataset
1038
+ self.class_colors = class_colors
1039
+ self.window_name = window_name
1040
+ self.opacity = opacity
1041
+
1042
+ self.current_idx = 0
1043
+ self.zoom = 1.0
1044
+ self.pan_x = 0.0
1045
+ self.pan_y = 0.0
1046
+
1047
+ # Mouse state for panning
1048
+ self.dragging = False
1049
+ self.drag_start_x = 0
1050
+ self.drag_start_y = 0
1051
+ self.pan_start_x = 0.0
1052
+ self.pan_start_y = 0.0
1053
+
1054
+ # Current image cache
1055
+ self.current_img: np.ndarray | None = None
1056
+ self.overlay_img: np.ndarray | None = None
1057
+
1058
+ # Overlay visibility toggle
1059
+ self.show_overlay = True
1060
+
1061
+ # Build class_id to color mapping
1062
+ self._id_to_color: dict[int, tuple[int, int, int]] = {}
1063
+ class_mapping = dataset.get_class_mapping()
1064
+ for class_id, class_name in class_mapping.items():
1065
+ if class_name in class_colors:
1066
+ self._id_to_color[class_id] = class_colors[class_name]
1067
+
1068
+ def _load_current_image(self) -> bool:
1069
+ """Load current image and create mask overlay."""
1070
+ image_path = self.image_paths[self.current_idx]
1071
+
1072
+ img = cv2.imread(str(image_path))
1073
+ if img is None:
1074
+ return False
1075
+
1076
+ mask = self.dataset.load_mask(image_path)
1077
+ if mask is None:
1078
+ console.print(f"[yellow]Warning: No mask for {image_path}[/yellow]")
1079
+ self.current_img = img
1080
+ self.overlay_img = img.copy()
1081
+ return True
1082
+
1083
+ # Validate dimensions
1084
+ if img.shape[:2] != mask.shape[:2]:
1085
+ console.print(
1086
+ f"[red]Error: Dimension mismatch for {image_path.name}: "
1087
+ f"image={img.shape[:2]}, mask={mask.shape[:2]}[/red]"
1088
+ )
1089
+ return False
1090
+
1091
+ self.current_img = img
1092
+ self.overlay_img = self._create_overlay(img, mask)
1093
+ return True
1094
+
1095
+ def _create_overlay(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
1096
+ """Create colored overlay from mask.
1097
+
1098
+ Args:
1099
+ img: Original image (BGR).
1100
+ mask: Grayscale mask with class IDs.
1101
+
1102
+ Returns:
1103
+ Image with colored mask overlay.
1104
+ """
1105
+ # Create colored mask
1106
+ h, w = mask.shape
1107
+ colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
1108
+
1109
+ for class_id, color in self._id_to_color.items():
1110
+ colored_mask[mask == class_id] = color
1111
+
1112
+ # Blend with original image
1113
+ # Ignore pixels are fully transparent (not blended)
1114
+ ignore_mask = mask == self.dataset.ignore_index
1115
+ alpha = np.ones((h, w, 1), dtype=np.float32) * self.opacity
1116
+ alpha[ignore_mask] = 0.0
1117
+
1118
+ # Blend: result = img * (1 - alpha) + colored_mask * alpha
1119
+ blended = (
1120
+ img.astype(np.float32) * (1 - alpha)
1121
+ + colored_mask.astype(np.float32) * alpha
1122
+ )
1123
+ return blended.astype(np.uint8)
1124
+
1125
+ def _get_display_image(self) -> np.ndarray:
1126
+ """Get the image transformed for current zoom/pan."""
1127
+ if self.overlay_img is None:
1128
+ return np.zeros((480, 640, 3), dtype=np.uint8)
1129
+
1130
+ if self.show_overlay:
1131
+ img = self.overlay_img
1132
+ elif self.current_img is not None:
1133
+ img = self.current_img
1134
+ else:
1135
+ img = self.overlay_img
1136
+
1137
+ h, w = img.shape[:2]
1138
+
1139
+ if self.zoom == 1.0 and self.pan_x == 0.0 and self.pan_y == 0.0:
1140
+ display = img.copy()
1141
+ else:
1142
+ # Calculate the visible region
1143
+ view_w = int(w / self.zoom)
1144
+ view_h = int(h / self.zoom)
1145
+
1146
+ # Center point with pan offset
1147
+ cx = w / 2 + self.pan_x
1148
+ cy = h / 2 + self.pan_y
1149
+
1150
+ # Calculate crop bounds
1151
+ x1 = int(max(0, cx - view_w / 2))
1152
+ y1 = int(max(0, cy - view_h / 2))
1153
+ x2 = int(min(w, x1 + view_w))
1154
+ y2 = int(min(h, y1 + view_h))
1155
+
1156
+ # Adjust if we hit boundaries
1157
+ if x2 - x1 < view_w:
1158
+ x1 = max(0, x2 - view_w)
1159
+ if y2 - y1 < view_h:
1160
+ y1 = max(0, y2 - view_h)
1161
+
1162
+ # Crop and resize
1163
+ cropped = img[y1:y2, x1:x2]
1164
+ display = cv2.resize(cropped, (w, h), interpolation=cv2.INTER_LINEAR)
1165
+
1166
+ # Add info overlay
1167
+ image_path = self.image_paths[self.current_idx]
1168
+ idx = self.current_idx + 1
1169
+ total = len(self.image_paths)
1170
+ info_text = f"[{idx}/{total}] {image_path.name}"
1171
+ if self.zoom > 1.0:
1172
+ info_text += f" (Zoom: {self.zoom:.1f}x)"
1173
+ if not self.show_overlay:
1174
+ info_text += " [Overlay: OFF]"
1175
+
1176
+ cv2.putText(
1177
+ display,
1178
+ info_text,
1179
+ (10, 30),
1180
+ cv2.FONT_HERSHEY_SIMPLEX,
1181
+ 0.7,
1182
+ (255, 255, 255),
1183
+ 2,
1184
+ )
1185
+ cv2.putText(
1186
+ display, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1
1187
+ )
1188
+
1189
+ return display
1190
+
1191
+ def _mouse_callback(
1192
+ self, event: int, x: int, y: int, flags: int, param: None
1193
+ ) -> None:
1194
+ """Handle mouse events for zoom and pan."""
1195
+ if event == cv2.EVENT_MOUSEWHEEL:
1196
+ # Zoom in/out
1197
+ if flags > 0:
1198
+ self.zoom = min(10.0, self.zoom * 1.2)
1199
+ else:
1200
+ self.zoom = max(1.0, self.zoom / 1.2)
1201
+
1202
+ # Reset pan if zoomed out to 1x
1203
+ if self.zoom == 1.0:
1204
+ self.pan_x = 0.0
1205
+ self.pan_y = 0.0
1206
+
1207
+ elif event == cv2.EVENT_LBUTTONDOWN:
1208
+ self.dragging = True
1209
+ self.drag_start_x = x
1210
+ self.drag_start_y = y
1211
+ self.pan_start_x = self.pan_x
1212
+ self.pan_start_y = self.pan_y
1213
+
1214
+ elif event == cv2.EVENT_MOUSEMOVE and self.dragging:
1215
+ if self.zoom > 1.0 and self.overlay_img is not None:
1216
+ h, w = self.overlay_img.shape[:2]
1217
+ # Calculate pan delta (inverted for natural feel)
1218
+ dx = (self.drag_start_x - x) / self.zoom
1219
+ dy = (self.drag_start_y - y) / self.zoom
1220
+
1221
+ # Update pan with limits
1222
+ max_pan_x = w * (1 - 1 / self.zoom) / 2
1223
+ max_pan_y = h * (1 - 1 / self.zoom) / 2
1224
+
1225
+ self.pan_x = max(-max_pan_x, min(max_pan_x, self.pan_start_x + dx))
1226
+ self.pan_y = max(-max_pan_y, min(max_pan_y, self.pan_start_y + dy))
1227
+
1228
+ elif event == cv2.EVENT_LBUTTONUP:
1229
+ self.dragging = False
1230
+
1231
+ def _reset_view(self) -> None:
1232
+ """Reset zoom and pan to default."""
1233
+ self.zoom = 1.0
1234
+ self.pan_x = 0.0
1235
+ self.pan_y = 0.0
1236
+
1237
+ def _next_image(self) -> None:
1238
+ """Go to next image."""
1239
+ self.current_idx = (self.current_idx + 1) % len(self.image_paths)
1240
+ self._reset_view()
1241
+
1242
+ def _prev_image(self) -> None:
1243
+ """Go to previous image."""
1244
+ self.current_idx = (self.current_idx - 1) % len(self.image_paths)
1245
+ self._reset_view()
1246
+
1247
+ def run(self) -> None:
1248
+ """Run the interactive viewer."""
1249
+ cv2.namedWindow(self.window_name, cv2.WINDOW_AUTOSIZE)
1250
+ cv2.setMouseCallback(self.window_name, self._mouse_callback)
1251
+
1252
+ while True:
1253
+ # Load image if needed
1254
+ if self.overlay_img is None and not self._load_current_image():
1255
+ console.print(
1256
+ f"[yellow]Warning: Could not load "
1257
+ f"{self.image_paths[self.current_idx]}[/yellow]"
1258
+ )
1259
+ self._next_image()
1260
+ continue
1261
+
1262
+ # Display image
1263
+ display = self._get_display_image()
1264
+ cv2.imshow(self.window_name, display)
1265
+
1266
+ # Wait for input (short timeout for smooth panning)
1267
+ key = cv2.waitKey(30) & 0xFF
1268
+
1269
+ # Handle keyboard input
1270
+ if key == ord("q") or key == 27: # Q or ESC
1271
+ break
1272
+ elif key == ord("n") or key == 83 or key == 3: # N or Right arrow
1273
+ self.overlay_img = None
1274
+ self._next_image()
1275
+ elif key == ord("p") or key == 81 or key == 2: # P or Left arrow
1276
+ self.overlay_img = None
1277
+ self._prev_image()
1278
+ elif key == ord("r"): # R to reset zoom
1279
+ self._reset_view()
1280
+ elif key == ord("t"): # T to toggle overlay
1281
+ self.show_overlay = not self.show_overlay
1282
+
1283
+ cv2.destroyAllWindows()
1284
+
1285
+
886
1286
  def _generate_class_colors(class_names: list[str]) -> dict[str, tuple[int, int, int]]:
887
1287
  """Generate consistent colors for each class name.
888
1288
 
@@ -943,9 +1343,12 @@ def _draw_annotations(
943
1343
  overlay = img.copy()
944
1344
  cv2.fillPoly(overlay, [pts], color)
945
1345
  cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
1346
+ # Draw small points at polygon vertices
1347
+ for pt in pts:
1348
+ cv2.circle(img, tuple(pt), radius=3, color=color, thickness=-1)
946
1349
 
947
- # Draw bounding box
948
- if bbox:
1350
+ # Draw bounding box (only for detection, not segmentation)
1351
+ if bbox and not polygon:
949
1352
  x, y, w, h = bbox
950
1353
  x1, y1 = int(x), int(y)
951
1354
  x2, y2 = int(x + w), int(y + h)
@@ -965,9 +1368,13 @@ def _draw_annotations(
965
1368
  )
966
1369
  # Draw label text
967
1370
  cv2.putText(
968
- img, label,
1371
+ img,
1372
+ label,
969
1373
  (x1 + 2, y1 - baseline - 2),
970
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
1374
+ cv2.FONT_HERSHEY_SIMPLEX,
1375
+ 0.5,
1376
+ (255, 255, 255),
1377
+ 1,
971
1378
  )
972
1379
 
973
1380
  return img
@@ -1026,17 +1433,25 @@ def _discover_datasets(root_path: Path, max_depth: int) -> list[Dataset]:
1026
1433
 
1027
1434
 
1028
1435
  def _detect_dataset(path: Path) -> Dataset | None:
1029
- """Try to detect a dataset at the given path."""
1030
- # Try YOLO first (more specific patterns)
1436
+ """Try to detect a dataset at the given path.
1437
+
1438
+ Detection priority: YOLO -> COCO -> MaskDataset
1439
+ """
1440
+ # Try YOLO first (more specific patterns - requires data.yaml)
1031
1441
  dataset = YOLODataset.detect(path)
1032
1442
  if dataset:
1033
1443
  return dataset
1034
1444
 
1035
- # Try COCO
1445
+ # Try COCO (requires annotations/*.json)
1036
1446
  dataset = COCODataset.detect(path)
1037
1447
  if dataset:
1038
1448
  return dataset
1039
1449
 
1450
+ # Try MaskDataset (directory structure based)
1451
+ dataset = MaskDataset.detect(path)
1452
+ if dataset:
1453
+ return dataset
1454
+
1040
1455
  return None
1041
1456
 
1042
1457
 
argus/core/__init__.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  from argus.core.base import Dataset
4
4
  from argus.core.coco import COCODataset
5
+ from argus.core.mask import ConfigurationError, MaskDataset
5
6
  from argus.core.split import split_coco_dataset, split_yolo_dataset
6
7
  from argus.core.yolo import YOLODataset
7
8
 
@@ -9,6 +10,8 @@ __all__ = [
9
10
  "Dataset",
10
11
  "YOLODataset",
11
12
  "COCODataset",
13
+ "MaskDataset",
14
+ "ConfigurationError",
12
15
  "split_coco_dataset",
13
16
  "split_yolo_dataset",
14
17
  ]
argus/core/base.py CHANGED
@@ -11,6 +11,7 @@ class DatasetFormat(str, Enum):
11
11
 
12
12
  YOLO = "yolo"
13
13
  COCO = "coco"
14
+ MASK = "mask"
14
15
 
15
16
 
16
17
  class TaskType(str, Enum):
argus/core/coco.py CHANGED
@@ -453,12 +453,14 @@ class COCODataset(Dataset):
453
453
  for i in range(0, len(coords), 2):
454
454
  polygon.append((float(coords[i]), float(coords[i + 1])))
455
455
 
456
- annotations.append({
457
- "class_name": class_name,
458
- "class_id": cat_id,
459
- "bbox": bbox_tuple,
460
- "polygon": polygon,
461
- })
456
+ annotations.append(
457
+ {
458
+ "class_name": class_name,
459
+ "class_id": cat_id,
460
+ "bbox": bbox_tuple,
461
+ "polygon": polygon,
462
+ }
463
+ )
462
464
 
463
465
  except (json.JSONDecodeError, OSError):
464
466
  continue