argus-cv 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of argus-cv might be problematic. Click here for more details.
- argus/__init__.py +1 -1
- argus/cli.py +345 -1
- argus/core/__init__.py +20 -0
- argus/core/coco.py +46 -8
- argus/core/convert.py +277 -0
- argus/core/filter.py +670 -0
- argus/core/yolo.py +29 -0
- {argus_cv-1.4.0.dist-info → argus_cv-1.5.1.dist-info}/METADATA +1 -1
- argus_cv-1.5.1.dist-info/RECORD +16 -0
- argus_cv-1.4.0.dist-info/RECORD +0 -14
- {argus_cv-1.4.0.dist-info → argus_cv-1.5.1.dist-info}/WHEEL +0 -0
- {argus_cv-1.4.0.dist-info → argus_cv-1.5.1.dist-info}/entry_points.txt +0 -0
argus/core/filter.py
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
1
|
+
"""Dataset filtering utilities."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import shutil
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
import yaml
|
|
11
|
+
|
|
12
|
+
from argus.core.base import TaskType
|
|
13
|
+
from argus.core.coco import COCODataset
|
|
14
|
+
from argus.core.mask import MaskDataset
|
|
15
|
+
from argus.core.yolo import YOLODataset
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def filter_yolo_dataset(
|
|
19
|
+
dataset: YOLODataset,
|
|
20
|
+
output_path: Path,
|
|
21
|
+
classes: list[str],
|
|
22
|
+
no_background: bool = False,
|
|
23
|
+
use_symlinks: bool = False,
|
|
24
|
+
progress_callback: Callable[[int, int], None] | None = None,
|
|
25
|
+
) -> dict[str, int]:
|
|
26
|
+
"""Filter a YOLO dataset by class names.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
dataset: Source YOLODataset to filter.
|
|
30
|
+
output_path: Directory to write filtered dataset.
|
|
31
|
+
classes: List of class names to keep.
|
|
32
|
+
no_background: If True, exclude images with no annotations after filtering.
|
|
33
|
+
use_symlinks: If True, create symlinks instead of copying images.
|
|
34
|
+
progress_callback: Optional callback for progress updates (current, total).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Dictionary with statistics: images, labels, skipped.
|
|
38
|
+
"""
|
|
39
|
+
if dataset.task == TaskType.CLASSIFICATION:
|
|
40
|
+
return _filter_yolo_classification(
|
|
41
|
+
dataset, output_path, classes, use_symlinks, progress_callback
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
return _filter_yolo_detection_segmentation(
|
|
45
|
+
dataset,
|
|
46
|
+
output_path,
|
|
47
|
+
classes,
|
|
48
|
+
no_background,
|
|
49
|
+
use_symlinks,
|
|
50
|
+
progress_callback,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _filter_yolo_detection_segmentation(
|
|
55
|
+
dataset: YOLODataset,
|
|
56
|
+
output_path: Path,
|
|
57
|
+
classes: list[str],
|
|
58
|
+
no_background: bool,
|
|
59
|
+
use_symlinks: bool,
|
|
60
|
+
progress_callback: Callable[[int, int], None] | None,
|
|
61
|
+
) -> dict[str, int]:
|
|
62
|
+
"""Filter YOLO detection/segmentation dataset."""
|
|
63
|
+
# Build class ID mapping: old_id -> new_id
|
|
64
|
+
# New IDs are sequential starting from 0
|
|
65
|
+
old_to_new: dict[int, int] = {}
|
|
66
|
+
new_class_names: list[str] = []
|
|
67
|
+
|
|
68
|
+
for i, name in enumerate(dataset.class_names):
|
|
69
|
+
if name in classes:
|
|
70
|
+
new_id = len(new_class_names)
|
|
71
|
+
old_to_new[i] = new_id
|
|
72
|
+
new_class_names.append(name)
|
|
73
|
+
|
|
74
|
+
if not new_class_names:
|
|
75
|
+
raise ValueError(f"No matching classes found. Available: {dataset.class_names}")
|
|
76
|
+
|
|
77
|
+
# Create output structure
|
|
78
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
79
|
+
|
|
80
|
+
# Determine splits
|
|
81
|
+
splits = dataset.splits if dataset.splits else [""]
|
|
82
|
+
has_splits = bool(dataset.splits)
|
|
83
|
+
|
|
84
|
+
stats = {"images": 0, "labels": 0, "skipped": 0}
|
|
85
|
+
|
|
86
|
+
# Collect all image/label pairs
|
|
87
|
+
all_pairs: list[tuple[Path, Path, str]] = []
|
|
88
|
+
labels_root = dataset.path / "labels"
|
|
89
|
+
|
|
90
|
+
for split in splits:
|
|
91
|
+
if has_splits:
|
|
92
|
+
images_dir = dataset.path / "images" / split
|
|
93
|
+
labels_dir = labels_root / split
|
|
94
|
+
else:
|
|
95
|
+
images_dir = dataset.path / "images"
|
|
96
|
+
labels_dir = labels_root
|
|
97
|
+
|
|
98
|
+
if not images_dir.is_dir():
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
for img_file in images_dir.iterdir():
|
|
102
|
+
if img_file.suffix.lower() not in {
|
|
103
|
+
".jpg",
|
|
104
|
+
".jpeg",
|
|
105
|
+
".png",
|
|
106
|
+
".bmp",
|
|
107
|
+
".tiff",
|
|
108
|
+
".webp",
|
|
109
|
+
}:
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
label_file = labels_dir / f"{img_file.stem}.txt"
|
|
113
|
+
all_pairs.append((img_file, label_file, split))
|
|
114
|
+
|
|
115
|
+
total = len(all_pairs)
|
|
116
|
+
|
|
117
|
+
for idx, (img_file, label_file, split) in enumerate(all_pairs):
|
|
118
|
+
if progress_callback:
|
|
119
|
+
progress_callback(idx, total)
|
|
120
|
+
|
|
121
|
+
# Read and filter label file
|
|
122
|
+
filtered_lines: list[str] = []
|
|
123
|
+
|
|
124
|
+
if label_file.exists():
|
|
125
|
+
with open(label_file, encoding="utf-8") as f:
|
|
126
|
+
for line in f:
|
|
127
|
+
line = line.strip()
|
|
128
|
+
if not line:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
parts = line.split()
|
|
132
|
+
if len(parts) < 5:
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
old_class_id = int(parts[0])
|
|
137
|
+
except ValueError:
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
if old_class_id in old_to_new:
|
|
141
|
+
new_class_id = old_to_new[old_class_id]
|
|
142
|
+
parts[0] = str(new_class_id)
|
|
143
|
+
filtered_lines.append(" ".join(parts))
|
|
144
|
+
|
|
145
|
+
# Skip if no annotations and no_background is True
|
|
146
|
+
if no_background and not filtered_lines:
|
|
147
|
+
stats["skipped"] += 1
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
# Create output directories
|
|
151
|
+
if has_splits:
|
|
152
|
+
out_images_dir = output_path / "images" / split
|
|
153
|
+
out_labels_dir = output_path / "labels" / split
|
|
154
|
+
else:
|
|
155
|
+
out_images_dir = output_path / "images"
|
|
156
|
+
out_labels_dir = output_path / "labels"
|
|
157
|
+
|
|
158
|
+
out_images_dir.mkdir(parents=True, exist_ok=True)
|
|
159
|
+
out_labels_dir.mkdir(parents=True, exist_ok=True)
|
|
160
|
+
|
|
161
|
+
# Copy/symlink image
|
|
162
|
+
out_img = out_images_dir / img_file.name
|
|
163
|
+
if use_symlinks:
|
|
164
|
+
if not out_img.exists():
|
|
165
|
+
out_img.symlink_to(img_file.resolve())
|
|
166
|
+
else:
|
|
167
|
+
if not out_img.exists():
|
|
168
|
+
shutil.copy2(img_file, out_img)
|
|
169
|
+
|
|
170
|
+
# Write filtered label
|
|
171
|
+
out_label = out_labels_dir / f"{img_file.stem}.txt"
|
|
172
|
+
with open(out_label, "w", encoding="utf-8") as f:
|
|
173
|
+
f.write("\n".join(filtered_lines))
|
|
174
|
+
if filtered_lines:
|
|
175
|
+
f.write("\n")
|
|
176
|
+
|
|
177
|
+
stats["images"] += 1
|
|
178
|
+
stats["labels"] += 1
|
|
179
|
+
|
|
180
|
+
if progress_callback:
|
|
181
|
+
progress_callback(total, total)
|
|
182
|
+
|
|
183
|
+
# Create data.yaml
|
|
184
|
+
_create_yolo_yaml(output_path, new_class_names, splits if has_splits else [])
|
|
185
|
+
|
|
186
|
+
return stats
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _filter_yolo_classification(
|
|
190
|
+
dataset: YOLODataset,
|
|
191
|
+
output_path: Path,
|
|
192
|
+
classes: list[str],
|
|
193
|
+
use_symlinks: bool,
|
|
194
|
+
progress_callback: Callable[[int, int], None] | None,
|
|
195
|
+
) -> dict[str, int]:
|
|
196
|
+
"""Filter YOLO classification dataset."""
|
|
197
|
+
# Filter to only requested classes that exist
|
|
198
|
+
new_class_names = [name for name in dataset.class_names if name in classes]
|
|
199
|
+
|
|
200
|
+
if not new_class_names:
|
|
201
|
+
raise ValueError(f"No matching classes found. Available: {dataset.class_names}")
|
|
202
|
+
|
|
203
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
204
|
+
|
|
205
|
+
stats = {"images": 0, "labels": 0, "skipped": 0}
|
|
206
|
+
|
|
207
|
+
# Count total images for progress
|
|
208
|
+
total = 0
|
|
209
|
+
if dataset.splits:
|
|
210
|
+
for split in dataset.splits:
|
|
211
|
+
for class_name in new_class_names:
|
|
212
|
+
class_dir = dataset.path / "images" / split / class_name
|
|
213
|
+
if class_dir.is_dir():
|
|
214
|
+
total += sum(
|
|
215
|
+
1
|
|
216
|
+
for f in class_dir.iterdir()
|
|
217
|
+
if f.suffix.lower()
|
|
218
|
+
in {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
# Flat structure
|
|
222
|
+
for class_name in new_class_names:
|
|
223
|
+
class_dir = dataset.path / class_name
|
|
224
|
+
if class_dir.is_dir():
|
|
225
|
+
total += sum(
|
|
226
|
+
1
|
|
227
|
+
for f in class_dir.iterdir()
|
|
228
|
+
if f.suffix.lower()
|
|
229
|
+
in {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
current = 0
|
|
233
|
+
|
|
234
|
+
if dataset.splits:
|
|
235
|
+
for split in dataset.splits:
|
|
236
|
+
for class_name in new_class_names:
|
|
237
|
+
src_dir = dataset.path / "images" / split / class_name
|
|
238
|
+
dst_dir = output_path / "images" / split / class_name
|
|
239
|
+
|
|
240
|
+
if not src_dir.is_dir():
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
dst_dir.mkdir(parents=True, exist_ok=True)
|
|
244
|
+
|
|
245
|
+
for img_file in src_dir.iterdir():
|
|
246
|
+
if img_file.suffix.lower() not in {
|
|
247
|
+
".jpg",
|
|
248
|
+
".jpeg",
|
|
249
|
+
".png",
|
|
250
|
+
".bmp",
|
|
251
|
+
".tiff",
|
|
252
|
+
".webp",
|
|
253
|
+
}:
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
if progress_callback:
|
|
257
|
+
progress_callback(current, total)
|
|
258
|
+
current += 1
|
|
259
|
+
|
|
260
|
+
dst_file = dst_dir / img_file.name
|
|
261
|
+
if use_symlinks:
|
|
262
|
+
if not dst_file.exists():
|
|
263
|
+
dst_file.symlink_to(img_file.resolve())
|
|
264
|
+
else:
|
|
265
|
+
if not dst_file.exists():
|
|
266
|
+
shutil.copy2(img_file, dst_file)
|
|
267
|
+
|
|
268
|
+
stats["images"] += 1
|
|
269
|
+
else:
|
|
270
|
+
# Flat structure
|
|
271
|
+
for class_name in new_class_names:
|
|
272
|
+
src_dir = dataset.path / class_name
|
|
273
|
+
dst_dir = output_path / class_name
|
|
274
|
+
|
|
275
|
+
if not src_dir.is_dir():
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
dst_dir.mkdir(parents=True, exist_ok=True)
|
|
279
|
+
|
|
280
|
+
for img_file in src_dir.iterdir():
|
|
281
|
+
if img_file.suffix.lower() not in {
|
|
282
|
+
".jpg",
|
|
283
|
+
".jpeg",
|
|
284
|
+
".png",
|
|
285
|
+
".bmp",
|
|
286
|
+
".tiff",
|
|
287
|
+
".webp",
|
|
288
|
+
}:
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
if progress_callback:
|
|
292
|
+
progress_callback(current, total)
|
|
293
|
+
current += 1
|
|
294
|
+
|
|
295
|
+
dst_file = dst_dir / img_file.name
|
|
296
|
+
if use_symlinks:
|
|
297
|
+
if not dst_file.exists():
|
|
298
|
+
dst_file.symlink_to(img_file.resolve())
|
|
299
|
+
else:
|
|
300
|
+
if not dst_file.exists():
|
|
301
|
+
shutil.copy2(img_file, dst_file)
|
|
302
|
+
|
|
303
|
+
stats["images"] += 1
|
|
304
|
+
|
|
305
|
+
if progress_callback:
|
|
306
|
+
progress_callback(total, total)
|
|
307
|
+
|
|
308
|
+
return stats
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _create_yolo_yaml(
|
|
312
|
+
output_path: Path, class_names: list[str], splits: list[str]
|
|
313
|
+
) -> None:
|
|
314
|
+
"""Create data.yaml for YOLO dataset."""
|
|
315
|
+
config: dict = {
|
|
316
|
+
"path": ".",
|
|
317
|
+
"names": {i: name for i, name in enumerate(class_names)},
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
if splits:
|
|
321
|
+
if "train" in splits:
|
|
322
|
+
config["train"] = "images/train"
|
|
323
|
+
if "val" in splits:
|
|
324
|
+
config["val"] = "images/val"
|
|
325
|
+
if "test" in splits:
|
|
326
|
+
config["test"] = "images/test"
|
|
327
|
+
|
|
328
|
+
with open(output_path / "data.yaml", "w", encoding="utf-8") as f:
|
|
329
|
+
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def filter_coco_dataset(
|
|
333
|
+
dataset: COCODataset,
|
|
334
|
+
output_path: Path,
|
|
335
|
+
classes: list[str],
|
|
336
|
+
no_background: bool = False,
|
|
337
|
+
use_symlinks: bool = False,
|
|
338
|
+
progress_callback: Callable[[int, int], None] | None = None,
|
|
339
|
+
) -> dict[str, int]:
|
|
340
|
+
"""Filter a COCO dataset by class names.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
dataset: Source COCODataset to filter.
|
|
344
|
+
output_path: Directory to write filtered dataset.
|
|
345
|
+
classes: List of class names to keep.
|
|
346
|
+
no_background: If True, exclude images with no annotations after filtering.
|
|
347
|
+
use_symlinks: If True, create symlinks instead of copying images.
|
|
348
|
+
progress_callback: Optional callback for progress updates (current, total).
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
Dictionary with statistics: images, annotations, skipped.
|
|
352
|
+
"""
|
|
353
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
354
|
+
|
|
355
|
+
stats = {"images": 0, "annotations": 0, "skipped": 0}
|
|
356
|
+
|
|
357
|
+
# Process each annotation file
|
|
358
|
+
for ann_file in dataset.annotation_files:
|
|
359
|
+
with open(ann_file, encoding="utf-8") as f:
|
|
360
|
+
data = json.load(f)
|
|
361
|
+
|
|
362
|
+
# Build category mappings
|
|
363
|
+
old_categories = data.get("categories", [])
|
|
364
|
+
old_id_to_name: dict[int, str] = {}
|
|
365
|
+
for cat in old_categories:
|
|
366
|
+
if isinstance(cat, dict) and "id" in cat and "name" in cat:
|
|
367
|
+
old_id_to_name[cat["id"]] = cat["name"]
|
|
368
|
+
|
|
369
|
+
# Create new category list with remapped IDs
|
|
370
|
+
old_to_new: dict[int, int] = {}
|
|
371
|
+
new_categories: list[dict] = []
|
|
372
|
+
new_id = 1 # COCO IDs typically start at 1
|
|
373
|
+
|
|
374
|
+
for cat in old_categories:
|
|
375
|
+
if isinstance(cat, dict) and "name" in cat and cat["name"] in classes:
|
|
376
|
+
old_id = cat["id"]
|
|
377
|
+
old_to_new[old_id] = new_id
|
|
378
|
+
new_cat = cat.copy()
|
|
379
|
+
new_cat["id"] = new_id
|
|
380
|
+
new_categories.append(new_cat)
|
|
381
|
+
new_id += 1
|
|
382
|
+
|
|
383
|
+
if not new_categories:
|
|
384
|
+
raise ValueError(
|
|
385
|
+
f"No matching classes found. Available: {list(old_id_to_name.values())}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Filter annotations
|
|
389
|
+
old_annotations = data.get("annotations", [])
|
|
390
|
+
new_annotations: list[dict] = []
|
|
391
|
+
images_with_annotations: set[int] = set()
|
|
392
|
+
new_ann_id = 1
|
|
393
|
+
|
|
394
|
+
for ann in old_annotations:
|
|
395
|
+
if not isinstance(ann, dict):
|
|
396
|
+
continue
|
|
397
|
+
|
|
398
|
+
old_cat_id = ann.get("category_id")
|
|
399
|
+
if old_cat_id not in old_to_new:
|
|
400
|
+
continue
|
|
401
|
+
|
|
402
|
+
new_ann = ann.copy()
|
|
403
|
+
new_ann["id"] = new_ann_id
|
|
404
|
+
new_ann["category_id"] = old_to_new[old_cat_id]
|
|
405
|
+
new_annotations.append(new_ann)
|
|
406
|
+
new_ann_id += 1
|
|
407
|
+
stats["annotations"] += 1
|
|
408
|
+
|
|
409
|
+
image_id = ann.get("image_id")
|
|
410
|
+
if image_id is not None:
|
|
411
|
+
images_with_annotations.add(image_id)
|
|
412
|
+
|
|
413
|
+
# Filter images
|
|
414
|
+
old_images = data.get("images", [])
|
|
415
|
+
new_images: list[dict] = []
|
|
416
|
+
included_image_ids: set[int] = set()
|
|
417
|
+
new_img_id = 1
|
|
418
|
+
|
|
419
|
+
# Build image ID mapping for annotation update
|
|
420
|
+
old_to_new_img_id: dict[int, int] = {}
|
|
421
|
+
|
|
422
|
+
for img in old_images:
|
|
423
|
+
if not isinstance(img, dict) or "id" not in img:
|
|
424
|
+
continue
|
|
425
|
+
|
|
426
|
+
old_img_id = img["id"]
|
|
427
|
+
|
|
428
|
+
# Skip if no_background and no annotations
|
|
429
|
+
if no_background and old_img_id not in images_with_annotations:
|
|
430
|
+
stats["skipped"] += 1
|
|
431
|
+
continue
|
|
432
|
+
|
|
433
|
+
old_to_new_img_id[old_img_id] = new_img_id
|
|
434
|
+
new_img = img.copy()
|
|
435
|
+
new_img["id"] = new_img_id
|
|
436
|
+
new_images.append(new_img)
|
|
437
|
+
included_image_ids.add(old_img_id)
|
|
438
|
+
new_img_id += 1
|
|
439
|
+
stats["images"] += 1
|
|
440
|
+
|
|
441
|
+
# Update annotation image IDs and filter out annotations for excluded images
|
|
442
|
+
final_annotations: list[dict] = []
|
|
443
|
+
for ann in new_annotations:
|
|
444
|
+
old_img_id = ann.get("image_id")
|
|
445
|
+
if old_img_id in old_to_new_img_id:
|
|
446
|
+
ann["image_id"] = old_to_new_img_id[old_img_id]
|
|
447
|
+
final_annotations.append(ann)
|
|
448
|
+
|
|
449
|
+
# Determine split from annotation file
|
|
450
|
+
split = COCODataset._get_split_from_filename(
|
|
451
|
+
ann_file.stem, ann_file.parent.name
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
# Check if this is Roboflow format (annotation in split directory)
|
|
455
|
+
is_roboflow = ann_file.parent.name.lower() in ("train", "valid", "val", "test")
|
|
456
|
+
|
|
457
|
+
# Create output annotation
|
|
458
|
+
new_data = data.copy()
|
|
459
|
+
new_data["categories"] = new_categories
|
|
460
|
+
new_data["annotations"] = final_annotations
|
|
461
|
+
new_data["images"] = new_images
|
|
462
|
+
|
|
463
|
+
# Write annotation file
|
|
464
|
+
if is_roboflow:
|
|
465
|
+
# Roboflow format: annotations in split directories
|
|
466
|
+
out_ann_dir = output_path / split
|
|
467
|
+
out_ann_dir.mkdir(parents=True, exist_ok=True)
|
|
468
|
+
out_ann_file = out_ann_dir / ann_file.name
|
|
469
|
+
else:
|
|
470
|
+
# Standard format: annotations in annotations/ directory
|
|
471
|
+
out_ann_dir = output_path / "annotations"
|
|
472
|
+
out_ann_dir.mkdir(parents=True, exist_ok=True)
|
|
473
|
+
out_ann_file = out_ann_dir / ann_file.name
|
|
474
|
+
|
|
475
|
+
with open(out_ann_file, "w", encoding="utf-8") as f:
|
|
476
|
+
json.dump(new_data, f, indent=2)
|
|
477
|
+
|
|
478
|
+
# Copy/symlink images
|
|
479
|
+
for img in old_images:
|
|
480
|
+
if not isinstance(img, dict) or "id" not in img:
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
if img["id"] not in included_image_ids:
|
|
484
|
+
continue
|
|
485
|
+
|
|
486
|
+
file_name = img.get("file_name")
|
|
487
|
+
if not file_name:
|
|
488
|
+
continue
|
|
489
|
+
|
|
490
|
+
# Find source image
|
|
491
|
+
possible_paths = [
|
|
492
|
+
dataset.path / "images" / split / file_name,
|
|
493
|
+
dataset.path / "images" / file_name,
|
|
494
|
+
dataset.path / split / file_name,
|
|
495
|
+
dataset.path / file_name,
|
|
496
|
+
ann_file.parent / file_name, # Roboflow format
|
|
497
|
+
]
|
|
498
|
+
|
|
499
|
+
src_path = None
|
|
500
|
+
for p in possible_paths:
|
|
501
|
+
if p.exists():
|
|
502
|
+
src_path = p
|
|
503
|
+
break
|
|
504
|
+
|
|
505
|
+
if src_path is None:
|
|
506
|
+
continue
|
|
507
|
+
|
|
508
|
+
# Determine output directory
|
|
509
|
+
if is_roboflow:
|
|
510
|
+
out_img_dir = output_path / split
|
|
511
|
+
else:
|
|
512
|
+
out_img_dir = output_path / "images" / split
|
|
513
|
+
out_img_dir.mkdir(parents=True, exist_ok=True)
|
|
514
|
+
|
|
515
|
+
out_img = out_img_dir / file_name
|
|
516
|
+
if use_symlinks:
|
|
517
|
+
if not out_img.exists():
|
|
518
|
+
out_img.symlink_to(src_path.resolve())
|
|
519
|
+
else:
|
|
520
|
+
if not out_img.exists():
|
|
521
|
+
shutil.copy2(src_path, out_img)
|
|
522
|
+
|
|
523
|
+
return stats
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def filter_mask_dataset(
|
|
527
|
+
dataset: MaskDataset,
|
|
528
|
+
output_path: Path,
|
|
529
|
+
classes: list[str],
|
|
530
|
+
no_background: bool = False,
|
|
531
|
+
use_symlinks: bool = False,
|
|
532
|
+
progress_callback: Callable[[int, int], None] | None = None,
|
|
533
|
+
) -> dict[str, int]:
|
|
534
|
+
"""Filter a mask dataset by class names.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
dataset: Source MaskDataset to filter.
|
|
538
|
+
output_path: Directory to write filtered dataset.
|
|
539
|
+
classes: List of class names to keep.
|
|
540
|
+
no_background: If True, exclude images with no annotations after filtering.
|
|
541
|
+
use_symlinks: If True, create symlinks instead of copying images.
|
|
542
|
+
progress_callback: Optional callback for progress updates (current, total).
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
Dictionary with statistics: images, masks, skipped.
|
|
546
|
+
"""
|
|
547
|
+
# Build class ID mapping
|
|
548
|
+
old_mapping = dataset.get_class_mapping()
|
|
549
|
+
old_name_to_id: dict[str, int] = {name: id for id, name in old_mapping.items()}
|
|
550
|
+
|
|
551
|
+
# Create new mapping: old_id -> new_id
|
|
552
|
+
old_to_new: dict[int, int] = {}
|
|
553
|
+
new_class_names: list[str] = []
|
|
554
|
+
|
|
555
|
+
# Start from 0 for background, then 1, 2, ... for other classes
|
|
556
|
+
# If "background" is in classes, include it
|
|
557
|
+
new_id = 0
|
|
558
|
+
for name in classes:
|
|
559
|
+
if name in old_name_to_id:
|
|
560
|
+
old_id = old_name_to_id[name]
|
|
561
|
+
old_to_new[old_id] = new_id
|
|
562
|
+
new_class_names.append(name)
|
|
563
|
+
new_id += 1
|
|
564
|
+
|
|
565
|
+
if not new_class_names:
|
|
566
|
+
raise ValueError(
|
|
567
|
+
f"No matching classes found. Available: {list(old_mapping.values())}"
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
571
|
+
|
|
572
|
+
stats = {"images": 0, "masks": 0, "skipped": 0}
|
|
573
|
+
|
|
574
|
+
# Get all image paths
|
|
575
|
+
image_paths = dataset.get_image_paths()
|
|
576
|
+
total = len(image_paths)
|
|
577
|
+
|
|
578
|
+
for idx, img_path in enumerate(image_paths):
|
|
579
|
+
if progress_callback:
|
|
580
|
+
progress_callback(idx, total)
|
|
581
|
+
|
|
582
|
+
# Load mask
|
|
583
|
+
mask = dataset.load_mask(img_path)
|
|
584
|
+
if mask is None:
|
|
585
|
+
stats["skipped"] += 1
|
|
586
|
+
continue
|
|
587
|
+
|
|
588
|
+
# Create filtered mask
|
|
589
|
+
# Set all pixels to ignore_index first, then fill in kept classes
|
|
590
|
+
new_ignore_index = 255
|
|
591
|
+
new_mask = np.full(mask.shape, new_ignore_index, dtype=np.uint8)
|
|
592
|
+
|
|
593
|
+
has_annotations = False
|
|
594
|
+
for old_id, new_id in old_to_new.items():
|
|
595
|
+
mask_pixels = mask == old_id
|
|
596
|
+
if np.any(mask_pixels):
|
|
597
|
+
has_annotations = True
|
|
598
|
+
new_mask[mask_pixels] = new_id
|
|
599
|
+
|
|
600
|
+
# Skip if no_background and no kept annotations
|
|
601
|
+
if no_background and not has_annotations:
|
|
602
|
+
stats["skipped"] += 1
|
|
603
|
+
continue
|
|
604
|
+
|
|
605
|
+
# Determine split from image path
|
|
606
|
+
img_parts = img_path.parts
|
|
607
|
+
images_dir_idx = None
|
|
608
|
+
for i, part in enumerate(img_parts):
|
|
609
|
+
if part == dataset.images_dir:
|
|
610
|
+
images_dir_idx = i
|
|
611
|
+
break
|
|
612
|
+
|
|
613
|
+
if images_dir_idx is not None and images_dir_idx + 1 < len(img_parts) - 1:
|
|
614
|
+
split = img_parts[images_dir_idx + 1]
|
|
615
|
+
if split not in dataset.splits:
|
|
616
|
+
split = None
|
|
617
|
+
else:
|
|
618
|
+
split = None
|
|
619
|
+
|
|
620
|
+
# Create output directories
|
|
621
|
+
if split:
|
|
622
|
+
out_images_dir = output_path / dataset.images_dir / split
|
|
623
|
+
out_masks_dir = output_path / dataset.masks_dir / split
|
|
624
|
+
else:
|
|
625
|
+
out_images_dir = output_path / dataset.images_dir
|
|
626
|
+
out_masks_dir = output_path / dataset.masks_dir
|
|
627
|
+
|
|
628
|
+
out_images_dir.mkdir(parents=True, exist_ok=True)
|
|
629
|
+
out_masks_dir.mkdir(parents=True, exist_ok=True)
|
|
630
|
+
|
|
631
|
+
# Copy/symlink image
|
|
632
|
+
out_img = out_images_dir / img_path.name
|
|
633
|
+
if use_symlinks:
|
|
634
|
+
if not out_img.exists():
|
|
635
|
+
out_img.symlink_to(img_path.resolve())
|
|
636
|
+
else:
|
|
637
|
+
if not out_img.exists():
|
|
638
|
+
shutil.copy2(img_path, out_img)
|
|
639
|
+
|
|
640
|
+
# Write filtered mask
|
|
641
|
+
mask_path = dataset.get_mask_path(img_path)
|
|
642
|
+
if mask_path:
|
|
643
|
+
out_mask = out_masks_dir / mask_path.name
|
|
644
|
+
cv2.imwrite(str(out_mask), new_mask)
|
|
645
|
+
|
|
646
|
+
stats["images"] += 1
|
|
647
|
+
stats["masks"] += 1
|
|
648
|
+
|
|
649
|
+
if progress_callback:
|
|
650
|
+
progress_callback(total, total)
|
|
651
|
+
|
|
652
|
+
# Create classes.yaml
|
|
653
|
+
_create_mask_classes_yaml(output_path, new_class_names, dataset.ignore_index)
|
|
654
|
+
|
|
655
|
+
return stats
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def _create_mask_classes_yaml(
|
|
659
|
+
output_path: Path, class_names: list[str], ignore_index: int | None
|
|
660
|
+
) -> None:
|
|
661
|
+
"""Create classes.yaml for mask dataset."""
|
|
662
|
+
config: dict = {
|
|
663
|
+
"names": {i: name for i, name in enumerate(class_names)},
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
if ignore_index is not None:
|
|
667
|
+
config["ignore_index"] = 255 # Use standard ignore index
|
|
668
|
+
|
|
669
|
+
with open(output_path / "classes.yaml", "w", encoding="utf-8") as f:
|
|
670
|
+
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|