argus-cv 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of argus-cv might be problematic. Click here for more details.

argus/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  """Argus - Vision AI dataset toolkit."""
2
2
 
3
- __version__ = "1.5.0"
3
+ __version__ = "1.5.1"
argus/cli.py CHANGED
@@ -20,6 +20,11 @@ from rich.table import Table
20
20
  from argus.core import COCODataset, Dataset, MaskDataset, YOLODataset
21
21
  from argus.core.base import DatasetFormat, TaskType
22
22
  from argus.core.convert import convert_mask_to_yolo_seg
23
+ from argus.core.filter import (
24
+ filter_coco_dataset,
25
+ filter_mask_dataset,
26
+ filter_yolo_dataset,
27
+ )
23
28
  from argus.core.split import (
24
29
  is_coco_unsplit,
25
30
  parse_ratio,
@@ -781,6 +786,196 @@ def convert_dataset(
781
786
  console.print(f"\n[cyan]Output dataset: {output_path}[/cyan]")
782
787
 
783
788
 
789
+ @app.command(name="filter")
790
+ def filter_dataset(
791
+ dataset_path: Annotated[
792
+ Path,
793
+ typer.Option(
794
+ "--dataset-path",
795
+ "-d",
796
+ help="Path to the dataset root directory.",
797
+ ),
798
+ ] = Path("."),
799
+ output_path: Annotated[
800
+ Path,
801
+ typer.Option(
802
+ "--output",
803
+ "-o",
804
+ help="Output directory for filtered dataset.",
805
+ ),
806
+ ] = Path("filtered"),
807
+ classes: Annotated[
808
+ str,
809
+ typer.Option(
810
+ "--classes",
811
+ "-c",
812
+ help="Comma-separated list of class names to keep.",
813
+ ),
814
+ ] = "",
815
+ no_background: Annotated[
816
+ bool,
817
+ typer.Option(
818
+ "--no-background",
819
+ help="Exclude images with no annotations after filtering.",
820
+ ),
821
+ ] = False,
822
+ use_symlinks: Annotated[
823
+ bool,
824
+ typer.Option(
825
+ "--symlinks",
826
+ help="Use symlinks instead of copying images.",
827
+ ),
828
+ ] = False,
829
+ ) -> None:
830
+ """Filter a dataset by class names.
831
+
832
+ Creates a filtered copy of the dataset containing only the specified classes.
833
+ Class IDs are remapped to sequential values (0, 1, 2, ...).
834
+
835
+ Examples:
836
+ argus-cv filter -d dataset -o output --classes ball --no-background
837
+ argus-cv filter -d dataset -o output --classes ball,player
838
+ argus-cv filter -d dataset -o output --classes ball --symlinks
839
+ """
840
+ # Resolve path and validate
841
+ dataset_path = dataset_path.resolve()
842
+ if not dataset_path.exists():
843
+ console.print(f"[red]Error: Path does not exist: {dataset_path}[/red]")
844
+ raise typer.Exit(1)
845
+ if not dataset_path.is_dir():
846
+ console.print(f"[red]Error: Path is not a directory: {dataset_path}[/red]")
847
+ raise typer.Exit(1)
848
+
849
+ # Parse classes
850
+ if not classes:
851
+ console.print(
852
+ "[red]Error: No classes specified. "
853
+ "Use --classes to specify classes to keep.[/red]"
854
+ )
855
+ raise typer.Exit(1)
856
+
857
+ class_list = [c.strip() for c in classes.split(",") if c.strip()]
858
+ if not class_list:
859
+ console.print("[red]Error: No valid class names provided.[/red]")
860
+ raise typer.Exit(1)
861
+
862
+ # Detect dataset
863
+ dataset = _detect_dataset(dataset_path)
864
+ if not dataset:
865
+ console.print(
866
+ f"[red]Error: No dataset found at {dataset_path}[/red]\n"
867
+ "[yellow]Ensure the path points to a dataset root containing "
868
+ "data.yaml (YOLO), annotations/ folder (COCO), or "
869
+ "images/ + masks/ directories (Mask).[/yellow]"
870
+ )
871
+ raise typer.Exit(1)
872
+
873
+ # Validate classes exist in dataset
874
+ missing_classes = [c for c in class_list if c not in dataset.class_names]
875
+ if missing_classes:
876
+ available = ", ".join(dataset.class_names)
877
+ missing = ", ".join(missing_classes)
878
+ console.print(
879
+ f"[red]Error: Classes not found in dataset: {missing}[/red]\n"
880
+ f"[yellow]Available classes: {available}[/yellow]"
881
+ )
882
+ raise typer.Exit(1)
883
+
884
+ # Resolve output path
885
+ if not output_path.is_absolute():
886
+ output_path = dataset_path.parent / output_path
887
+ output_path = output_path.resolve()
888
+
889
+ # Check if output already exists
890
+ if output_path.exists() and any(output_path.iterdir()):
891
+ console.print(
892
+ f"[red]Error: Output directory already exists and is not empty: "
893
+ f"{output_path}[/red]"
894
+ )
895
+ raise typer.Exit(1)
896
+
897
+ # Show filter info
898
+ console.print(f"[cyan]Filtering {dataset.format.value.upper()} dataset[/cyan]")
899
+ console.print(f" Source: {dataset_path}")
900
+ console.print(f" Output: {output_path}")
901
+ console.print(f" Classes to keep: {', '.join(class_list)}")
902
+ console.print(f" Exclude background: {no_background}")
903
+ console.print(f" Use symlinks: {use_symlinks}")
904
+ console.print()
905
+
906
+ # Run filtering with progress bar
907
+ with Progress(
908
+ SpinnerColumn(),
909
+ TextColumn("[progress.description]{task.description}"),
910
+ BarColumn(),
911
+ TaskProgressColumn(),
912
+ console=console,
913
+ ) as progress:
914
+ task = progress.add_task("Filtering dataset...", total=None)
915
+
916
+ def update_progress(current: int, total: int) -> None:
917
+ progress.update(task, completed=current, total=total)
918
+
919
+ try:
920
+ if dataset.format == DatasetFormat.YOLO:
921
+ assert isinstance(dataset, YOLODataset)
922
+ stats = filter_yolo_dataset(
923
+ dataset=dataset,
924
+ output_path=output_path,
925
+ classes=class_list,
926
+ no_background=no_background,
927
+ use_symlinks=use_symlinks,
928
+ progress_callback=update_progress,
929
+ )
930
+ elif dataset.format == DatasetFormat.COCO:
931
+ assert isinstance(dataset, COCODataset)
932
+ stats = filter_coco_dataset(
933
+ dataset=dataset,
934
+ output_path=output_path,
935
+ classes=class_list,
936
+ no_background=no_background,
937
+ use_symlinks=use_symlinks,
938
+ progress_callback=update_progress,
939
+ )
940
+ elif dataset.format == DatasetFormat.MASK:
941
+ assert isinstance(dataset, MaskDataset)
942
+ stats = filter_mask_dataset(
943
+ dataset=dataset,
944
+ output_path=output_path,
945
+ classes=class_list,
946
+ no_background=no_background,
947
+ use_symlinks=use_symlinks,
948
+ progress_callback=update_progress,
949
+ )
950
+ else:
951
+ console.print(
952
+ f"[red]Error: Unsupported dataset format: {dataset.format}[/red]"
953
+ )
954
+ raise typer.Exit(1)
955
+ except ValueError as exc:
956
+ console.print(f"[red]Error: {exc}[/red]")
957
+ raise typer.Exit(1) from exc
958
+ except Exception as exc:
959
+ console.print(f"[red]Error during filtering: {exc}[/red]")
960
+ raise typer.Exit(1) from exc
961
+
962
+ # Show results
963
+ console.print()
964
+ console.print("[green]Filtering complete![/green]")
965
+ console.print(f" Images: {stats.get('images', 0)}")
966
+ if "labels" in stats:
967
+ console.print(f" Labels: {stats['labels']}")
968
+ if "annotations" in stats:
969
+ console.print(f" Annotations: {stats['annotations']}")
970
+ if "masks" in stats:
971
+ console.print(f" Masks: {stats['masks']}")
972
+ if stats.get("skipped", 0) > 0:
973
+ skipped = stats["skipped"]
974
+ console.print(f" [yellow]Skipped: {skipped} (background images)[/yellow]")
975
+
976
+ console.print(f"\n[cyan]Output dataset: {output_path}[/cyan]")
977
+
978
+
784
979
  class _ImageViewer:
785
980
  """Interactive image viewer with zoom and pan support."""
786
981
 
argus/core/__init__.py CHANGED
@@ -9,6 +9,11 @@ from argus.core.convert import (
9
9
  convert_mask_to_yolo_seg,
10
10
  mask_to_polygons,
11
11
  )
12
+ from argus.core.filter import (
13
+ filter_coco_dataset,
14
+ filter_mask_dataset,
15
+ filter_yolo_dataset,
16
+ )
12
17
  from argus.core.mask import ConfigurationError, MaskDataset
13
18
  from argus.core.split import split_coco_dataset, split_yolo_dataset
14
19
  from argus.core.yolo import YOLODataset
@@ -21,6 +26,9 @@ __all__ = [
21
26
  "ConfigurationError",
22
27
  "split_coco_dataset",
23
28
  "split_yolo_dataset",
29
+ "filter_yolo_dataset",
30
+ "filter_coco_dataset",
31
+ "filter_mask_dataset",
24
32
  "ConversionParams",
25
33
  "Polygon",
26
34
  "mask_to_polygons",
argus/core/filter.py ADDED
@@ -0,0 +1,670 @@
1
+ """Dataset filtering utilities."""
2
+
3
+ import json
4
+ import shutil
5
+ from collections.abc import Callable
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import yaml
11
+
12
+ from argus.core.base import TaskType
13
+ from argus.core.coco import COCODataset
14
+ from argus.core.mask import MaskDataset
15
+ from argus.core.yolo import YOLODataset
16
+
17
+
18
+ def filter_yolo_dataset(
19
+ dataset: YOLODataset,
20
+ output_path: Path,
21
+ classes: list[str],
22
+ no_background: bool = False,
23
+ use_symlinks: bool = False,
24
+ progress_callback: Callable[[int, int], None] | None = None,
25
+ ) -> dict[str, int]:
26
+ """Filter a YOLO dataset by class names.
27
+
28
+ Args:
29
+ dataset: Source YOLODataset to filter.
30
+ output_path: Directory to write filtered dataset.
31
+ classes: List of class names to keep.
32
+ no_background: If True, exclude images with no annotations after filtering.
33
+ use_symlinks: If True, create symlinks instead of copying images.
34
+ progress_callback: Optional callback for progress updates (current, total).
35
+
36
+ Returns:
37
+ Dictionary with statistics: images, labels, skipped.
38
+ """
39
+ if dataset.task == TaskType.CLASSIFICATION:
40
+ return _filter_yolo_classification(
41
+ dataset, output_path, classes, use_symlinks, progress_callback
42
+ )
43
+ else:
44
+ return _filter_yolo_detection_segmentation(
45
+ dataset,
46
+ output_path,
47
+ classes,
48
+ no_background,
49
+ use_symlinks,
50
+ progress_callback,
51
+ )
52
+
53
+
54
+ def _filter_yolo_detection_segmentation(
55
+ dataset: YOLODataset,
56
+ output_path: Path,
57
+ classes: list[str],
58
+ no_background: bool,
59
+ use_symlinks: bool,
60
+ progress_callback: Callable[[int, int], None] | None,
61
+ ) -> dict[str, int]:
62
+ """Filter YOLO detection/segmentation dataset."""
63
+ # Build class ID mapping: old_id -> new_id
64
+ # New IDs are sequential starting from 0
65
+ old_to_new: dict[int, int] = {}
66
+ new_class_names: list[str] = []
67
+
68
+ for i, name in enumerate(dataset.class_names):
69
+ if name in classes:
70
+ new_id = len(new_class_names)
71
+ old_to_new[i] = new_id
72
+ new_class_names.append(name)
73
+
74
+ if not new_class_names:
75
+ raise ValueError(f"No matching classes found. Available: {dataset.class_names}")
76
+
77
+ # Create output structure
78
+ output_path.mkdir(parents=True, exist_ok=True)
79
+
80
+ # Determine splits
81
+ splits = dataset.splits if dataset.splits else [""]
82
+ has_splits = bool(dataset.splits)
83
+
84
+ stats = {"images": 0, "labels": 0, "skipped": 0}
85
+
86
+ # Collect all image/label pairs
87
+ all_pairs: list[tuple[Path, Path, str]] = []
88
+ labels_root = dataset.path / "labels"
89
+
90
+ for split in splits:
91
+ if has_splits:
92
+ images_dir = dataset.path / "images" / split
93
+ labels_dir = labels_root / split
94
+ else:
95
+ images_dir = dataset.path / "images"
96
+ labels_dir = labels_root
97
+
98
+ if not images_dir.is_dir():
99
+ continue
100
+
101
+ for img_file in images_dir.iterdir():
102
+ if img_file.suffix.lower() not in {
103
+ ".jpg",
104
+ ".jpeg",
105
+ ".png",
106
+ ".bmp",
107
+ ".tiff",
108
+ ".webp",
109
+ }:
110
+ continue
111
+
112
+ label_file = labels_dir / f"{img_file.stem}.txt"
113
+ all_pairs.append((img_file, label_file, split))
114
+
115
+ total = len(all_pairs)
116
+
117
+ for idx, (img_file, label_file, split) in enumerate(all_pairs):
118
+ if progress_callback:
119
+ progress_callback(idx, total)
120
+
121
+ # Read and filter label file
122
+ filtered_lines: list[str] = []
123
+
124
+ if label_file.exists():
125
+ with open(label_file, encoding="utf-8") as f:
126
+ for line in f:
127
+ line = line.strip()
128
+ if not line:
129
+ continue
130
+
131
+ parts = line.split()
132
+ if len(parts) < 5:
133
+ continue
134
+
135
+ try:
136
+ old_class_id = int(parts[0])
137
+ except ValueError:
138
+ continue
139
+
140
+ if old_class_id in old_to_new:
141
+ new_class_id = old_to_new[old_class_id]
142
+ parts[0] = str(new_class_id)
143
+ filtered_lines.append(" ".join(parts))
144
+
145
+ # Skip if no annotations and no_background is True
146
+ if no_background and not filtered_lines:
147
+ stats["skipped"] += 1
148
+ continue
149
+
150
+ # Create output directories
151
+ if has_splits:
152
+ out_images_dir = output_path / "images" / split
153
+ out_labels_dir = output_path / "labels" / split
154
+ else:
155
+ out_images_dir = output_path / "images"
156
+ out_labels_dir = output_path / "labels"
157
+
158
+ out_images_dir.mkdir(parents=True, exist_ok=True)
159
+ out_labels_dir.mkdir(parents=True, exist_ok=True)
160
+
161
+ # Copy/symlink image
162
+ out_img = out_images_dir / img_file.name
163
+ if use_symlinks:
164
+ if not out_img.exists():
165
+ out_img.symlink_to(img_file.resolve())
166
+ else:
167
+ if not out_img.exists():
168
+ shutil.copy2(img_file, out_img)
169
+
170
+ # Write filtered label
171
+ out_label = out_labels_dir / f"{img_file.stem}.txt"
172
+ with open(out_label, "w", encoding="utf-8") as f:
173
+ f.write("\n".join(filtered_lines))
174
+ if filtered_lines:
175
+ f.write("\n")
176
+
177
+ stats["images"] += 1
178
+ stats["labels"] += 1
179
+
180
+ if progress_callback:
181
+ progress_callback(total, total)
182
+
183
+ # Create data.yaml
184
+ _create_yolo_yaml(output_path, new_class_names, splits if has_splits else [])
185
+
186
+ return stats
187
+
188
+
189
+ def _filter_yolo_classification(
190
+ dataset: YOLODataset,
191
+ output_path: Path,
192
+ classes: list[str],
193
+ use_symlinks: bool,
194
+ progress_callback: Callable[[int, int], None] | None,
195
+ ) -> dict[str, int]:
196
+ """Filter YOLO classification dataset."""
197
+ # Filter to only requested classes that exist
198
+ new_class_names = [name for name in dataset.class_names if name in classes]
199
+
200
+ if not new_class_names:
201
+ raise ValueError(f"No matching classes found. Available: {dataset.class_names}")
202
+
203
+ output_path.mkdir(parents=True, exist_ok=True)
204
+
205
+ stats = {"images": 0, "labels": 0, "skipped": 0}
206
+
207
+ # Count total images for progress
208
+ total = 0
209
+ if dataset.splits:
210
+ for split in dataset.splits:
211
+ for class_name in new_class_names:
212
+ class_dir = dataset.path / "images" / split / class_name
213
+ if class_dir.is_dir():
214
+ total += sum(
215
+ 1
216
+ for f in class_dir.iterdir()
217
+ if f.suffix.lower()
218
+ in {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
219
+ )
220
+ else:
221
+ # Flat structure
222
+ for class_name in new_class_names:
223
+ class_dir = dataset.path / class_name
224
+ if class_dir.is_dir():
225
+ total += sum(
226
+ 1
227
+ for f in class_dir.iterdir()
228
+ if f.suffix.lower()
229
+ in {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
230
+ )
231
+
232
+ current = 0
233
+
234
+ if dataset.splits:
235
+ for split in dataset.splits:
236
+ for class_name in new_class_names:
237
+ src_dir = dataset.path / "images" / split / class_name
238
+ dst_dir = output_path / "images" / split / class_name
239
+
240
+ if not src_dir.is_dir():
241
+ continue
242
+
243
+ dst_dir.mkdir(parents=True, exist_ok=True)
244
+
245
+ for img_file in src_dir.iterdir():
246
+ if img_file.suffix.lower() not in {
247
+ ".jpg",
248
+ ".jpeg",
249
+ ".png",
250
+ ".bmp",
251
+ ".tiff",
252
+ ".webp",
253
+ }:
254
+ continue
255
+
256
+ if progress_callback:
257
+ progress_callback(current, total)
258
+ current += 1
259
+
260
+ dst_file = dst_dir / img_file.name
261
+ if use_symlinks:
262
+ if not dst_file.exists():
263
+ dst_file.symlink_to(img_file.resolve())
264
+ else:
265
+ if not dst_file.exists():
266
+ shutil.copy2(img_file, dst_file)
267
+
268
+ stats["images"] += 1
269
+ else:
270
+ # Flat structure
271
+ for class_name in new_class_names:
272
+ src_dir = dataset.path / class_name
273
+ dst_dir = output_path / class_name
274
+
275
+ if not src_dir.is_dir():
276
+ continue
277
+
278
+ dst_dir.mkdir(parents=True, exist_ok=True)
279
+
280
+ for img_file in src_dir.iterdir():
281
+ if img_file.suffix.lower() not in {
282
+ ".jpg",
283
+ ".jpeg",
284
+ ".png",
285
+ ".bmp",
286
+ ".tiff",
287
+ ".webp",
288
+ }:
289
+ continue
290
+
291
+ if progress_callback:
292
+ progress_callback(current, total)
293
+ current += 1
294
+
295
+ dst_file = dst_dir / img_file.name
296
+ if use_symlinks:
297
+ if not dst_file.exists():
298
+ dst_file.symlink_to(img_file.resolve())
299
+ else:
300
+ if not dst_file.exists():
301
+ shutil.copy2(img_file, dst_file)
302
+
303
+ stats["images"] += 1
304
+
305
+ if progress_callback:
306
+ progress_callback(total, total)
307
+
308
+ return stats
309
+
310
+
311
+ def _create_yolo_yaml(
312
+ output_path: Path, class_names: list[str], splits: list[str]
313
+ ) -> None:
314
+ """Create data.yaml for YOLO dataset."""
315
+ config: dict = {
316
+ "path": ".",
317
+ "names": {i: name for i, name in enumerate(class_names)},
318
+ }
319
+
320
+ if splits:
321
+ if "train" in splits:
322
+ config["train"] = "images/train"
323
+ if "val" in splits:
324
+ config["val"] = "images/val"
325
+ if "test" in splits:
326
+ config["test"] = "images/test"
327
+
328
+ with open(output_path / "data.yaml", "w", encoding="utf-8") as f:
329
+ yaml.dump(config, f, default_flow_style=False, sort_keys=False)
330
+
331
+
332
+ def filter_coco_dataset(
333
+ dataset: COCODataset,
334
+ output_path: Path,
335
+ classes: list[str],
336
+ no_background: bool = False,
337
+ use_symlinks: bool = False,
338
+ progress_callback: Callable[[int, int], None] | None = None,
339
+ ) -> dict[str, int]:
340
+ """Filter a COCO dataset by class names.
341
+
342
+ Args:
343
+ dataset: Source COCODataset to filter.
344
+ output_path: Directory to write filtered dataset.
345
+ classes: List of class names to keep.
346
+ no_background: If True, exclude images with no annotations after filtering.
347
+ use_symlinks: If True, create symlinks instead of copying images.
348
+ progress_callback: Optional callback for progress updates (current, total).
349
+
350
+ Returns:
351
+ Dictionary with statistics: images, annotations, skipped.
352
+ """
353
+ output_path.mkdir(parents=True, exist_ok=True)
354
+
355
+ stats = {"images": 0, "annotations": 0, "skipped": 0}
356
+
357
+ # Process each annotation file
358
+ for ann_file in dataset.annotation_files:
359
+ with open(ann_file, encoding="utf-8") as f:
360
+ data = json.load(f)
361
+
362
+ # Build category mappings
363
+ old_categories = data.get("categories", [])
364
+ old_id_to_name: dict[int, str] = {}
365
+ for cat in old_categories:
366
+ if isinstance(cat, dict) and "id" in cat and "name" in cat:
367
+ old_id_to_name[cat["id"]] = cat["name"]
368
+
369
+ # Create new category list with remapped IDs
370
+ old_to_new: dict[int, int] = {}
371
+ new_categories: list[dict] = []
372
+ new_id = 1 # COCO IDs typically start at 1
373
+
374
+ for cat in old_categories:
375
+ if isinstance(cat, dict) and "name" in cat and cat["name"] in classes:
376
+ old_id = cat["id"]
377
+ old_to_new[old_id] = new_id
378
+ new_cat = cat.copy()
379
+ new_cat["id"] = new_id
380
+ new_categories.append(new_cat)
381
+ new_id += 1
382
+
383
+ if not new_categories:
384
+ raise ValueError(
385
+ f"No matching classes found. Available: {list(old_id_to_name.values())}"
386
+ )
387
+
388
+ # Filter annotations
389
+ old_annotations = data.get("annotations", [])
390
+ new_annotations: list[dict] = []
391
+ images_with_annotations: set[int] = set()
392
+ new_ann_id = 1
393
+
394
+ for ann in old_annotations:
395
+ if not isinstance(ann, dict):
396
+ continue
397
+
398
+ old_cat_id = ann.get("category_id")
399
+ if old_cat_id not in old_to_new:
400
+ continue
401
+
402
+ new_ann = ann.copy()
403
+ new_ann["id"] = new_ann_id
404
+ new_ann["category_id"] = old_to_new[old_cat_id]
405
+ new_annotations.append(new_ann)
406
+ new_ann_id += 1
407
+ stats["annotations"] += 1
408
+
409
+ image_id = ann.get("image_id")
410
+ if image_id is not None:
411
+ images_with_annotations.add(image_id)
412
+
413
+ # Filter images
414
+ old_images = data.get("images", [])
415
+ new_images: list[dict] = []
416
+ included_image_ids: set[int] = set()
417
+ new_img_id = 1
418
+
419
+ # Build image ID mapping for annotation update
420
+ old_to_new_img_id: dict[int, int] = {}
421
+
422
+ for img in old_images:
423
+ if not isinstance(img, dict) or "id" not in img:
424
+ continue
425
+
426
+ old_img_id = img["id"]
427
+
428
+ # Skip if no_background and no annotations
429
+ if no_background and old_img_id not in images_with_annotations:
430
+ stats["skipped"] += 1
431
+ continue
432
+
433
+ old_to_new_img_id[old_img_id] = new_img_id
434
+ new_img = img.copy()
435
+ new_img["id"] = new_img_id
436
+ new_images.append(new_img)
437
+ included_image_ids.add(old_img_id)
438
+ new_img_id += 1
439
+ stats["images"] += 1
440
+
441
+ # Update annotation image IDs and filter out annotations for excluded images
442
+ final_annotations: list[dict] = []
443
+ for ann in new_annotations:
444
+ old_img_id = ann.get("image_id")
445
+ if old_img_id in old_to_new_img_id:
446
+ ann["image_id"] = old_to_new_img_id[old_img_id]
447
+ final_annotations.append(ann)
448
+
449
+ # Determine split from annotation file
450
+ split = COCODataset._get_split_from_filename(
451
+ ann_file.stem, ann_file.parent.name
452
+ )
453
+
454
+ # Check if this is Roboflow format (annotation in split directory)
455
+ is_roboflow = ann_file.parent.name.lower() in ("train", "valid", "val", "test")
456
+
457
+ # Create output annotation
458
+ new_data = data.copy()
459
+ new_data["categories"] = new_categories
460
+ new_data["annotations"] = final_annotations
461
+ new_data["images"] = new_images
462
+
463
+ # Write annotation file
464
+ if is_roboflow:
465
+ # Roboflow format: annotations in split directories
466
+ out_ann_dir = output_path / split
467
+ out_ann_dir.mkdir(parents=True, exist_ok=True)
468
+ out_ann_file = out_ann_dir / ann_file.name
469
+ else:
470
+ # Standard format: annotations in annotations/ directory
471
+ out_ann_dir = output_path / "annotations"
472
+ out_ann_dir.mkdir(parents=True, exist_ok=True)
473
+ out_ann_file = out_ann_dir / ann_file.name
474
+
475
+ with open(out_ann_file, "w", encoding="utf-8") as f:
476
+ json.dump(new_data, f, indent=2)
477
+
478
+ # Copy/symlink images
479
+ for img in old_images:
480
+ if not isinstance(img, dict) or "id" not in img:
481
+ continue
482
+
483
+ if img["id"] not in included_image_ids:
484
+ continue
485
+
486
+ file_name = img.get("file_name")
487
+ if not file_name:
488
+ continue
489
+
490
+ # Find source image
491
+ possible_paths = [
492
+ dataset.path / "images" / split / file_name,
493
+ dataset.path / "images" / file_name,
494
+ dataset.path / split / file_name,
495
+ dataset.path / file_name,
496
+ ann_file.parent / file_name, # Roboflow format
497
+ ]
498
+
499
+ src_path = None
500
+ for p in possible_paths:
501
+ if p.exists():
502
+ src_path = p
503
+ break
504
+
505
+ if src_path is None:
506
+ continue
507
+
508
+ # Determine output directory
509
+ if is_roboflow:
510
+ out_img_dir = output_path / split
511
+ else:
512
+ out_img_dir = output_path / "images" / split
513
+ out_img_dir.mkdir(parents=True, exist_ok=True)
514
+
515
+ out_img = out_img_dir / file_name
516
+ if use_symlinks:
517
+ if not out_img.exists():
518
+ out_img.symlink_to(src_path.resolve())
519
+ else:
520
+ if not out_img.exists():
521
+ shutil.copy2(src_path, out_img)
522
+
523
+ return stats
524
+
525
+
526
+ def filter_mask_dataset(
527
+ dataset: MaskDataset,
528
+ output_path: Path,
529
+ classes: list[str],
530
+ no_background: bool = False,
531
+ use_symlinks: bool = False,
532
+ progress_callback: Callable[[int, int], None] | None = None,
533
+ ) -> dict[str, int]:
534
+ """Filter a mask dataset by class names.
535
+
536
+ Args:
537
+ dataset: Source MaskDataset to filter.
538
+ output_path: Directory to write filtered dataset.
539
+ classes: List of class names to keep.
540
+ no_background: If True, exclude images with no annotations after filtering.
541
+ use_symlinks: If True, create symlinks instead of copying images.
542
+ progress_callback: Optional callback for progress updates (current, total).
543
+
544
+ Returns:
545
+ Dictionary with statistics: images, masks, skipped.
546
+ """
547
+ # Build class ID mapping
548
+ old_mapping = dataset.get_class_mapping()
549
+ old_name_to_id: dict[str, int] = {name: id for id, name in old_mapping.items()}
550
+
551
+ # Create new mapping: old_id -> new_id
552
+ old_to_new: dict[int, int] = {}
553
+ new_class_names: list[str] = []
554
+
555
+ # Start from 0 for background, then 1, 2, ... for other classes
556
+ # If "background" is in classes, include it
557
+ new_id = 0
558
+ for name in classes:
559
+ if name in old_name_to_id:
560
+ old_id = old_name_to_id[name]
561
+ old_to_new[old_id] = new_id
562
+ new_class_names.append(name)
563
+ new_id += 1
564
+
565
+ if not new_class_names:
566
+ raise ValueError(
567
+ f"No matching classes found. Available: {list(old_mapping.values())}"
568
+ )
569
+
570
+ output_path.mkdir(parents=True, exist_ok=True)
571
+
572
+ stats = {"images": 0, "masks": 0, "skipped": 0}
573
+
574
+ # Get all image paths
575
+ image_paths = dataset.get_image_paths()
576
+ total = len(image_paths)
577
+
578
+ for idx, img_path in enumerate(image_paths):
579
+ if progress_callback:
580
+ progress_callback(idx, total)
581
+
582
+ # Load mask
583
+ mask = dataset.load_mask(img_path)
584
+ if mask is None:
585
+ stats["skipped"] += 1
586
+ continue
587
+
588
+ # Create filtered mask
589
+ # Set all pixels to ignore_index first, then fill in kept classes
590
+ new_ignore_index = 255
591
+ new_mask = np.full(mask.shape, new_ignore_index, dtype=np.uint8)
592
+
593
+ has_annotations = False
594
+ for old_id, new_id in old_to_new.items():
595
+ mask_pixels = mask == old_id
596
+ if np.any(mask_pixels):
597
+ has_annotations = True
598
+ new_mask[mask_pixels] = new_id
599
+
600
+ # Skip if no_background and no kept annotations
601
+ if no_background and not has_annotations:
602
+ stats["skipped"] += 1
603
+ continue
604
+
605
+ # Determine split from image path
606
+ img_parts = img_path.parts
607
+ images_dir_idx = None
608
+ for i, part in enumerate(img_parts):
609
+ if part == dataset.images_dir:
610
+ images_dir_idx = i
611
+ break
612
+
613
+ if images_dir_idx is not None and images_dir_idx + 1 < len(img_parts) - 1:
614
+ split = img_parts[images_dir_idx + 1]
615
+ if split not in dataset.splits:
616
+ split = None
617
+ else:
618
+ split = None
619
+
620
+ # Create output directories
621
+ if split:
622
+ out_images_dir = output_path / dataset.images_dir / split
623
+ out_masks_dir = output_path / dataset.masks_dir / split
624
+ else:
625
+ out_images_dir = output_path / dataset.images_dir
626
+ out_masks_dir = output_path / dataset.masks_dir
627
+
628
+ out_images_dir.mkdir(parents=True, exist_ok=True)
629
+ out_masks_dir.mkdir(parents=True, exist_ok=True)
630
+
631
+ # Copy/symlink image
632
+ out_img = out_images_dir / img_path.name
633
+ if use_symlinks:
634
+ if not out_img.exists():
635
+ out_img.symlink_to(img_path.resolve())
636
+ else:
637
+ if not out_img.exists():
638
+ shutil.copy2(img_path, out_img)
639
+
640
+ # Write filtered mask
641
+ mask_path = dataset.get_mask_path(img_path)
642
+ if mask_path:
643
+ out_mask = out_masks_dir / mask_path.name
644
+ cv2.imwrite(str(out_mask), new_mask)
645
+
646
+ stats["images"] += 1
647
+ stats["masks"] += 1
648
+
649
+ if progress_callback:
650
+ progress_callback(total, total)
651
+
652
+ # Create classes.yaml
653
+ _create_mask_classes_yaml(output_path, new_class_names, dataset.ignore_index)
654
+
655
+ return stats
656
+
657
+
658
+ def _create_mask_classes_yaml(
659
+ output_path: Path, class_names: list[str], ignore_index: int | None
660
+ ) -> None:
661
+ """Create classes.yaml for mask dataset."""
662
+ config: dict = {
663
+ "names": {i: name for i, name in enumerate(class_names)},
664
+ }
665
+
666
+ if ignore_index is not None:
667
+ config["ignore_index"] = 255 # Use standard ignore index
668
+
669
+ with open(output_path / "classes.yaml", "w", encoding="utf-8") as f:
670
+ yaml.dump(config, f, default_flow_style=False, sort_keys=False)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: argus-cv
3
- Version: 1.5.0
3
+ Version: 1.5.1
4
4
  Summary: CLI tool for working with vision AI datasets
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: numpy>=1.24.0
@@ -1,15 +1,16 @@
1
- argus/__init__.py,sha256=-NDJwMF-NWlPd0dIFWTu3SjgVWZy8SJxBD9g3YQXfrY,64
1
+ argus/__init__.py,sha256=Yo7UDKujDodxid1b2g022IqmD1bwc9POtFSl8iolq5c,64
2
2
  argus/__main__.py,sha256=63ezHx8eL_lCMoZrCbKhmpao0fmdvYVw1chbknGg-oI,104
3
- argus/cli.py,sha256=hQ4t69E-clFvn9ZIeQ4Rf7cAqC0TgPtz1HEAFqNajcg,52706
3
+ argus/cli.py,sha256=Xri1KFVOMS-YhbwkRE0eB5HOf49kfQuftu0IJU4gAjA,59605
4
4
  argus/commands/__init__.py,sha256=i2oor9hpVpF-_1qZWCGDLwwi1pZGJfZnUKJZ_NMBG18,30
5
- argus/core/__init__.py,sha256=sP206E44GdnnjKwyWNvuWntvO7D8oy0qs1yUUaPDThI,738
5
+ argus/core/__init__.py,sha256=L5Onny8UjJtok5hOBKftqwnOhPvgQS5MZT_kgLpes1o,928
6
6
  argus/core/base.py,sha256=WBrB7XWz125YZ1UQfHQwsYAuIFY_XGEhG_0ybgPhn6s,3696
7
7
  argus/core/coco.py,sha256=V3Ifh6KUbifBTLefUuMxQkejgkwsPZNfKLn0newDZJ4,17539
8
8
  argus/core/convert.py,sha256=cHuw1E9B4vyozpikS2PJnFfiJ_eRMPIHblizyeZz1Ps,8471
9
+ argus/core/filter.py,sha256=7BRefzYcKIxU0GkFNHiJJAijc9UIhrvNKdYgXE_22ig,21945
9
10
  argus/core/mask.py,sha256=m7Ztf4lAZx5ITpk3F3mETcvCC6hGydlxK0-2nCjeTfU,21835
10
11
  argus/core/split.py,sha256=kEWtbdg6bH-WiNFf83HkqZD90EL4gsavw6JiefuAETs,10776
11
12
  argus/core/yolo.py,sha256=Vtw2sga40VooaRE8bmjKtr_aYhfoV7ZcVijFjg1DVwo,29644
12
- argus_cv-1.5.0.dist-info/METADATA,sha256=9iwY-3C6t-vzZOA9wBvrvIY10YBUaHgsDRN5x5Uk_8c,1353
13
- argus_cv-1.5.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
14
- argus_cv-1.5.0.dist-info/entry_points.txt,sha256=dvJFH7BkrOxJnifSjPhwq1YCafPaqdngWyBuFYE73yY,43
15
- argus_cv-1.5.0.dist-info/RECORD,,
13
+ argus_cv-1.5.1.dist-info/METADATA,sha256=71pcUGzx6s0uCiJjZNUh-4p-z50xCL5xqQC8-JXhNaI,1353
14
+ argus_cv-1.5.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
15
+ argus_cv-1.5.1.dist-info/entry_points.txt,sha256=dvJFH7BkrOxJnifSjPhwq1YCafPaqdngWyBuFYE73yY,43
16
+ argus_cv-1.5.1.dist-info/RECORD,,