argus-cv 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of argus-cv might be problematic. Click here for more details.
- argus/__init__.py +1 -1
- argus/cli.py +448 -33
- argus/core/__init__.py +3 -0
- argus/core/base.py +1 -0
- argus/core/coco.py +8 -6
- argus/core/mask.py +648 -0
- argus/core/yolo.py +21 -12
- {argus_cv-1.2.0.dist-info → argus_cv-1.4.0.dist-info}/METADATA +9 -2
- argus_cv-1.4.0.dist-info/RECORD +14 -0
- argus_cv-1.2.0.dist-info/RECORD +0 -13
- {argus_cv-1.2.0.dist-info → argus_cv-1.4.0.dist-info}/WHEEL +0 -0
- {argus_cv-1.2.0.dist-info → argus_cv-1.4.0.dist-info}/entry_points.txt +0 -0
argus/core/mask.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
1
|
+
"""Mask dataset detection and handling for semantic segmentation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
import numpy as np
|
|
9
|
+
import yaml
|
|
10
|
+
|
|
11
|
+
from argus.core.base import Dataset, DatasetFormat, TaskType
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConfigurationError(Exception):
|
|
17
|
+
"""Error raised when dataset configuration is invalid."""
|
|
18
|
+
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Directory patterns for mask dataset detection (checked in order)
|
|
23
|
+
DIRECTORY_PATTERNS = [
|
|
24
|
+
("images", "masks"),
|
|
25
|
+
("img", "gt"),
|
|
26
|
+
("leftImg8bit", "gtFine"),
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
# Standard image extensions
|
|
30
|
+
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class MaskDataset(Dataset):
|
|
35
|
+
"""Dataset format for folder-based semantic segmentation masks.
|
|
36
|
+
|
|
37
|
+
Supports directory structures like:
|
|
38
|
+
- images/ + masks/
|
|
39
|
+
- img/ + gt/
|
|
40
|
+
- leftImg8bit/ + gtFine/ (Cityscapes-style)
|
|
41
|
+
|
|
42
|
+
Each pattern expects parallel split subdirectories:
|
|
43
|
+
dataset/
|
|
44
|
+
├── images/
|
|
45
|
+
│ ├── train/
|
|
46
|
+
│ └── val/
|
|
47
|
+
├── masks/
|
|
48
|
+
│ ├── train/
|
|
49
|
+
│ └── val/
|
|
50
|
+
└── classes.yaml # Optional for grayscale, required for RGB
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
images_dir: str = ""
|
|
54
|
+
masks_dir: str = ""
|
|
55
|
+
class_config: dict | None = None
|
|
56
|
+
ignore_index: int | None = 255
|
|
57
|
+
format: DatasetFormat = field(default=DatasetFormat.MASK, init=False)
|
|
58
|
+
task: TaskType = field(default=TaskType.SEGMENTATION, init=False)
|
|
59
|
+
|
|
60
|
+
# Internal: maps class_id -> class_name
|
|
61
|
+
_class_mapping: dict[int, str] = field(default_factory=dict, repr=False)
|
|
62
|
+
# Internal: maps RGB tuple -> (class_id, class_name) for palette masks
|
|
63
|
+
_palette_mapping: dict[tuple[int, int, int], tuple[int, str]] = field(
|
|
64
|
+
default_factory=dict, repr=False
|
|
65
|
+
)
|
|
66
|
+
# Internal: flag for RGB palette mode
|
|
67
|
+
_is_rgb_palette: bool = field(default=False, repr=False)
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def detect(cls, path: Path) -> "MaskDataset | None":
|
|
71
|
+
"""Detect if the given path contains a mask dataset.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
path: Directory path to check for dataset.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
MaskDataset instance if detected, None otherwise.
|
|
78
|
+
"""
|
|
79
|
+
path = Path(path)
|
|
80
|
+
|
|
81
|
+
if not path.is_dir():
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
# Try each directory pattern
|
|
85
|
+
for images_name, masks_name in DIRECTORY_PATTERNS:
|
|
86
|
+
images_root = path / images_name
|
|
87
|
+
masks_root = path / masks_name
|
|
88
|
+
|
|
89
|
+
if not (images_root.is_dir() and masks_root.is_dir()):
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# Check for split subdirectories with matching structure
|
|
93
|
+
splits = cls._detect_splits(images_root, masks_root)
|
|
94
|
+
|
|
95
|
+
# If no splits found, check if images and masks are directly in root
|
|
96
|
+
if not splits:
|
|
97
|
+
has_images = any(
|
|
98
|
+
f.suffix.lower() in IMAGE_EXTENSIONS for f in images_root.iterdir()
|
|
99
|
+
)
|
|
100
|
+
has_masks = any(
|
|
101
|
+
f.suffix.lower() == ".png" for f in masks_root.iterdir()
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if has_images and has_masks:
|
|
105
|
+
# Valid unsplit structure
|
|
106
|
+
splits = []
|
|
107
|
+
else:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
# Load class configuration if available
|
|
111
|
+
class_config = cls._load_class_config(path)
|
|
112
|
+
|
|
113
|
+
# Determine if masks are grayscale or RGB palette
|
|
114
|
+
is_rgb, palette_mapping = cls._detect_mask_type(path, masks_root, splits)
|
|
115
|
+
|
|
116
|
+
# RGB masks require configuration
|
|
117
|
+
if is_rgb and not class_config:
|
|
118
|
+
raise ConfigurationError(
|
|
119
|
+
f"RGB palette masks detected in {path} but no classes.yaml found. "
|
|
120
|
+
"RGB masks require a classes.yaml config file with palette mapping."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Build class mapping
|
|
124
|
+
class_mapping, ignore_idx = cls._build_class_mapping(
|
|
125
|
+
path, masks_root, splits, class_config, is_rgb, palette_mapping
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Extract class names from mapping
|
|
129
|
+
class_names = [
|
|
130
|
+
class_mapping[i]
|
|
131
|
+
for i in sorted(class_mapping.keys())
|
|
132
|
+
if i != ignore_idx
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
dataset = cls(
|
|
136
|
+
path=path,
|
|
137
|
+
num_classes=len(class_names),
|
|
138
|
+
class_names=class_names,
|
|
139
|
+
splits=splits,
|
|
140
|
+
images_dir=images_name,
|
|
141
|
+
masks_dir=masks_name,
|
|
142
|
+
class_config=class_config,
|
|
143
|
+
ignore_index=ignore_idx,
|
|
144
|
+
)
|
|
145
|
+
dataset._class_mapping = class_mapping
|
|
146
|
+
dataset._palette_mapping = palette_mapping if is_rgb else {}
|
|
147
|
+
dataset._is_rgb_palette = is_rgb
|
|
148
|
+
|
|
149
|
+
return dataset
|
|
150
|
+
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def _detect_splits(cls, images_root: Path, masks_root: Path) -> list[str]:
|
|
155
|
+
"""Detect available splits from directory structure.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
images_root: Root directory containing images.
|
|
159
|
+
masks_root: Root directory containing masks.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
List of split names found in both images and masks directories.
|
|
163
|
+
"""
|
|
164
|
+
splits = []
|
|
165
|
+
|
|
166
|
+
for split_name in ["train", "val", "test"]:
|
|
167
|
+
images_split = images_root / split_name
|
|
168
|
+
masks_split = masks_root / split_name
|
|
169
|
+
|
|
170
|
+
if images_split.is_dir() and masks_split.is_dir():
|
|
171
|
+
# Verify there are actual files
|
|
172
|
+
has_images = any(
|
|
173
|
+
f.suffix.lower() in IMAGE_EXTENSIONS for f in images_split.iterdir()
|
|
174
|
+
)
|
|
175
|
+
has_masks = any(
|
|
176
|
+
f.suffix.lower() == ".png" for f in masks_split.iterdir()
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if has_images and has_masks:
|
|
180
|
+
splits.append(split_name)
|
|
181
|
+
|
|
182
|
+
return splits
|
|
183
|
+
|
|
184
|
+
@classmethod
|
|
185
|
+
def _load_class_config(cls, path: Path) -> dict | None:
|
|
186
|
+
"""Load classes.yaml configuration if present.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
path: Dataset root path.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Parsed config dict or None if not found.
|
|
193
|
+
"""
|
|
194
|
+
config_path = path / "classes.yaml"
|
|
195
|
+
if not config_path.exists():
|
|
196
|
+
config_path = path / "classes.yml"
|
|
197
|
+
|
|
198
|
+
if not config_path.exists():
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
with open(config_path, encoding="utf-8") as f:
|
|
203
|
+
return yaml.safe_load(f)
|
|
204
|
+
except (yaml.YAMLError, OSError) as e:
|
|
205
|
+
raise ConfigurationError(f"Failed to parse {config_path}: {e}") from e
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def _detect_mask_type(
|
|
209
|
+
cls,
|
|
210
|
+
path: Path,
|
|
211
|
+
masks_root: Path,
|
|
212
|
+
splits: list[str],
|
|
213
|
+
) -> tuple[bool, dict[tuple[int, int, int], tuple[int, str]]]:
|
|
214
|
+
"""Detect if masks are grayscale or RGB palette format.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
path: Dataset root path.
|
|
218
|
+
masks_root: Masks directory root.
|
|
219
|
+
splits: List of split names.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Tuple of (is_rgb_palette, palette_mapping).
|
|
223
|
+
"""
|
|
224
|
+
# Find sample mask files
|
|
225
|
+
sample_masks: list[Path] = []
|
|
226
|
+
|
|
227
|
+
if splits:
|
|
228
|
+
for split in splits[:1]: # Just check first split
|
|
229
|
+
split_dir = masks_root / split
|
|
230
|
+
sample_masks.extend(list(split_dir.glob("*.png"))[:5])
|
|
231
|
+
else:
|
|
232
|
+
sample_masks.extend(list(masks_root.glob("*.png"))[:5])
|
|
233
|
+
|
|
234
|
+
if not sample_masks:
|
|
235
|
+
return False, {}
|
|
236
|
+
|
|
237
|
+
# Check first mask
|
|
238
|
+
mask = cv2.imread(str(sample_masks[0]), cv2.IMREAD_UNCHANGED)
|
|
239
|
+
if mask is None:
|
|
240
|
+
return False, {}
|
|
241
|
+
|
|
242
|
+
# Check if grayscale (single channel) or RGB (3 channels)
|
|
243
|
+
if len(mask.shape) == 2 or mask.shape[2] == 1:
|
|
244
|
+
return False, {}
|
|
245
|
+
elif mask.shape[2] >= 3:
|
|
246
|
+
return True, {}
|
|
247
|
+
|
|
248
|
+
return False, {}
|
|
249
|
+
|
|
250
|
+
@classmethod
|
|
251
|
+
def _build_class_mapping(
|
|
252
|
+
cls,
|
|
253
|
+
path: Path,
|
|
254
|
+
masks_root: Path,
|
|
255
|
+
splits: list[str],
|
|
256
|
+
class_config: dict | None,
|
|
257
|
+
is_rgb: bool,
|
|
258
|
+
palette_mapping: dict,
|
|
259
|
+
) -> tuple[dict[int, str], int]:
|
|
260
|
+
"""Build class ID to name mapping.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
path: Dataset root path.
|
|
264
|
+
masks_root: Masks directory root.
|
|
265
|
+
splits: List of split names.
|
|
266
|
+
class_config: Parsed classes.yaml or None.
|
|
267
|
+
is_rgb: Whether masks use RGB palette encoding.
|
|
268
|
+
palette_mapping: RGB to class mapping (if RGB).
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Tuple of (class_id_to_name dict, ignore_index or None).
|
|
272
|
+
"""
|
|
273
|
+
ignore_index: int | None = 255
|
|
274
|
+
|
|
275
|
+
if class_config:
|
|
276
|
+
# Get ignore index from config (can be null/None to disable)
|
|
277
|
+
ignore_index = class_config.get("ignore_index", 255)
|
|
278
|
+
|
|
279
|
+
# Try to get names from config
|
|
280
|
+
if "names" in class_config:
|
|
281
|
+
names = class_config["names"]
|
|
282
|
+
if isinstance(names, dict):
|
|
283
|
+
return {int(k): v for k, v in names.items()}, ignore_index
|
|
284
|
+
elif isinstance(names, list):
|
|
285
|
+
return {i: name for i, name in enumerate(names)}, ignore_index
|
|
286
|
+
|
|
287
|
+
# If RGB palette, build from palette config
|
|
288
|
+
if is_rgb and "palette" in class_config:
|
|
289
|
+
mapping = {}
|
|
290
|
+
for entry in class_config["palette"]:
|
|
291
|
+
class_id = entry["id"]
|
|
292
|
+
class_name = entry["name"]
|
|
293
|
+
mapping[class_id] = class_name
|
|
294
|
+
return mapping, ignore_index
|
|
295
|
+
|
|
296
|
+
# Auto-detect classes from grayscale masks
|
|
297
|
+
if not is_rgb:
|
|
298
|
+
return cls._auto_detect_classes(masks_root, splits, ignore_index)
|
|
299
|
+
|
|
300
|
+
# RGB without config - error should have been raised earlier
|
|
301
|
+
return {}, ignore_index
|
|
302
|
+
|
|
303
|
+
@classmethod
|
|
304
|
+
def _auto_detect_classes(
|
|
305
|
+
cls,
|
|
306
|
+
masks_root: Path,
|
|
307
|
+
splits: list[str],
|
|
308
|
+
ignore_index: int | None,
|
|
309
|
+
) -> tuple[dict[int, str], int | None]:
|
|
310
|
+
"""Auto-detect class IDs from grayscale mask values.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
masks_root: Masks directory root.
|
|
314
|
+
splits: List of split names.
|
|
315
|
+
ignore_index: Index to treat as ignored/void, or None.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
Tuple of (class_id_to_name dict, ignore_index or None).
|
|
319
|
+
"""
|
|
320
|
+
unique_values: set[int] = set()
|
|
321
|
+
|
|
322
|
+
# Sample masks to find unique values
|
|
323
|
+
sample_masks: list[Path] = []
|
|
324
|
+
|
|
325
|
+
if splits:
|
|
326
|
+
for split in splits:
|
|
327
|
+
split_dir = masks_root / split
|
|
328
|
+
sample_masks.extend(list(split_dir.glob("*.png"))[:20])
|
|
329
|
+
else:
|
|
330
|
+
sample_masks.extend(list(masks_root.glob("*.png"))[:20])
|
|
331
|
+
|
|
332
|
+
for mask_path in sample_masks[:50]: # Limit sampling
|
|
333
|
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
|
334
|
+
if mask is not None:
|
|
335
|
+
unique_values.update(np.unique(mask).tolist())
|
|
336
|
+
|
|
337
|
+
# Build mapping with auto-generated names
|
|
338
|
+
mapping = {}
|
|
339
|
+
for val in sorted(unique_values):
|
|
340
|
+
if val == ignore_index:
|
|
341
|
+
continue
|
|
342
|
+
mapping[val] = f"class_{val}"
|
|
343
|
+
|
|
344
|
+
return mapping, ignore_index
|
|
345
|
+
|
|
346
|
+
def get_image_paths(self, split: str | None = None) -> list[Path]:
|
|
347
|
+
"""Get all image file paths for a split or the entire dataset.
|
|
348
|
+
|
|
349
|
+
Only returns images that have corresponding masks.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
split: Specific split to get images from. If None, returns all images.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
List of image file paths sorted alphabetically.
|
|
356
|
+
"""
|
|
357
|
+
images_root = self.path / self.images_dir
|
|
358
|
+
image_paths: list[Path] = []
|
|
359
|
+
|
|
360
|
+
splits_to_search = (
|
|
361
|
+
[split] if split else (self.splits if self.splits else [None])
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
for s in splits_to_search:
|
|
365
|
+
image_dir = images_root / s if s else images_root
|
|
366
|
+
|
|
367
|
+
if not image_dir.is_dir():
|
|
368
|
+
continue
|
|
369
|
+
|
|
370
|
+
for img_file in image_dir.iterdir():
|
|
371
|
+
if img_file.suffix.lower() not in IMAGE_EXTENSIONS:
|
|
372
|
+
continue
|
|
373
|
+
|
|
374
|
+
# Only include if mask exists
|
|
375
|
+
mask_path = self.get_mask_path(img_file)
|
|
376
|
+
if mask_path and mask_path.exists():
|
|
377
|
+
image_paths.append(img_file)
|
|
378
|
+
else:
|
|
379
|
+
logger.warning(f"No mask found for image: {img_file}")
|
|
380
|
+
|
|
381
|
+
return sorted(image_paths, key=lambda p: p.name)
|
|
382
|
+
|
|
383
|
+
def get_mask_path(self, image_path: Path) -> Path | None:
|
|
384
|
+
"""Return corresponding mask path for an image.
|
|
385
|
+
|
|
386
|
+
Tries multiple naming conventions in order:
|
|
387
|
+
1. Exact stem match: image.jpg -> image.png
|
|
388
|
+
2. With _mask suffix: image.jpg -> image_mask.png
|
|
389
|
+
3. With _gt suffix: image.jpg -> image_gt.png
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
image_path: Path to the image file.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
Path to corresponding mask, or None if not found.
|
|
396
|
+
"""
|
|
397
|
+
stem = image_path.stem
|
|
398
|
+
|
|
399
|
+
# Determine the split from image path
|
|
400
|
+
image_parts = image_path.parts
|
|
401
|
+
images_dir_idx = None
|
|
402
|
+
for i, part in enumerate(image_parts):
|
|
403
|
+
if part == self.images_dir:
|
|
404
|
+
images_dir_idx = i
|
|
405
|
+
break
|
|
406
|
+
|
|
407
|
+
if images_dir_idx is None:
|
|
408
|
+
# Fallback: look in masks root
|
|
409
|
+
masks_dir = self.path / self.masks_dir
|
|
410
|
+
return self._find_mask_with_patterns(masks_dir, stem)
|
|
411
|
+
|
|
412
|
+
# Build mask directory path with same structure
|
|
413
|
+
mask_parts = list(image_parts[:-1]) # Exclude filename
|
|
414
|
+
mask_parts[images_dir_idx] = self.masks_dir
|
|
415
|
+
masks_dir = Path(*mask_parts)
|
|
416
|
+
|
|
417
|
+
return self._find_mask_with_patterns(masks_dir, stem)
|
|
418
|
+
|
|
419
|
+
def _find_mask_with_patterns(self, masks_dir: Path, stem: str) -> Path | None:
|
|
420
|
+
"""Find mask file trying multiple naming patterns.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
masks_dir: Directory containing masks.
|
|
424
|
+
stem: Image filename stem (without extension).
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Path to mask if found, None otherwise.
|
|
428
|
+
"""
|
|
429
|
+
# Try different naming patterns in order of preference
|
|
430
|
+
patterns = [
|
|
431
|
+
f"{stem}.png", # Exact match
|
|
432
|
+
f"{stem}_mask.png", # Common _mask suffix
|
|
433
|
+
f"{stem}_gt.png", # Ground truth suffix
|
|
434
|
+
f"{stem}_label.png", # Label suffix
|
|
435
|
+
]
|
|
436
|
+
|
|
437
|
+
for pattern in patterns:
|
|
438
|
+
mask_path = masks_dir / pattern
|
|
439
|
+
if mask_path.exists():
|
|
440
|
+
return mask_path
|
|
441
|
+
|
|
442
|
+
return None
|
|
443
|
+
|
|
444
|
+
def get_class_mapping(self) -> dict[int, str]:
|
|
445
|
+
"""Return class ID to name mapping.
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
Dictionary mapping class IDs to class names.
|
|
449
|
+
"""
|
|
450
|
+
return self._class_mapping.copy()
|
|
451
|
+
|
|
452
|
+
def get_instance_counts(self) -> dict[str, dict[str, int]]:
|
|
453
|
+
"""Get pixel counts per class, per split.
|
|
454
|
+
|
|
455
|
+
For mask datasets, this returns the total number of pixels
|
|
456
|
+
for each class across all masks in each split.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
Dictionary mapping split name to dict of class name to pixel count.
|
|
460
|
+
"""
|
|
461
|
+
counts: dict[str, dict[str, int]] = {}
|
|
462
|
+
|
|
463
|
+
splits_to_process = self.splits if self.splits else ["unsplit"]
|
|
464
|
+
masks_root = self.path / self.masks_dir
|
|
465
|
+
|
|
466
|
+
for split in splits_to_process:
|
|
467
|
+
split_counts: dict[str, int] = {name: 0 for name in self.class_names}
|
|
468
|
+
|
|
469
|
+
mask_dir = masks_root if split == "unsplit" else masks_root / split
|
|
470
|
+
|
|
471
|
+
if not mask_dir.is_dir():
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
for mask_path in mask_dir.glob("*.png"):
|
|
475
|
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
|
476
|
+
if mask is None:
|
|
477
|
+
logger.warning(f"Could not read mask: {mask_path}")
|
|
478
|
+
continue
|
|
479
|
+
|
|
480
|
+
# Count pixels for each class
|
|
481
|
+
unique, pixel_counts = np.unique(mask, return_counts=True)
|
|
482
|
+
|
|
483
|
+
for class_id, count in zip(unique, pixel_counts, strict=True):
|
|
484
|
+
if class_id == self.ignore_index:
|
|
485
|
+
continue
|
|
486
|
+
class_name = self._class_mapping.get(class_id)
|
|
487
|
+
if class_name and class_name in split_counts:
|
|
488
|
+
split_counts[class_name] += int(count)
|
|
489
|
+
|
|
490
|
+
counts[split] = split_counts
|
|
491
|
+
|
|
492
|
+
return counts
|
|
493
|
+
|
|
494
|
+
def get_image_counts(self) -> dict[str, dict[str, int]]:
|
|
495
|
+
"""Get image counts per split.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
Dictionary mapping split name to dict with "total" and "background" counts.
|
|
499
|
+
For mask datasets, "background" is always 0.
|
|
500
|
+
"""
|
|
501
|
+
counts: dict[str, dict[str, int]] = {}
|
|
502
|
+
|
|
503
|
+
splits_to_process = self.splits if self.splits else ["unsplit"]
|
|
504
|
+
|
|
505
|
+
for split in splits_to_process:
|
|
506
|
+
image_paths = self.get_image_paths(split if split != "unsplit" else None)
|
|
507
|
+
counts[split] = {"total": len(image_paths), "background": 0}
|
|
508
|
+
|
|
509
|
+
return counts
|
|
510
|
+
|
|
511
|
+
def get_image_class_presence(self, split: str | None = None) -> dict[int, int]:
|
|
512
|
+
"""Return count of images containing each class.
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
split: Specific split to analyze. If None, analyzes all splits.
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
Dictionary mapping class ID to count of images containing that class.
|
|
519
|
+
"""
|
|
520
|
+
presence: dict[int, int] = {class_id: 0 for class_id in self._class_mapping}
|
|
521
|
+
|
|
522
|
+
splits_to_process = (
|
|
523
|
+
[split] if split else (self.splits if self.splits else [None])
|
|
524
|
+
)
|
|
525
|
+
masks_root = self.path / self.masks_dir
|
|
526
|
+
|
|
527
|
+
for s in splits_to_process:
|
|
528
|
+
mask_dir = masks_root / s if s else masks_root
|
|
529
|
+
|
|
530
|
+
if not mask_dir.is_dir():
|
|
531
|
+
continue
|
|
532
|
+
|
|
533
|
+
for mask_path in mask_dir.glob("*.png"):
|
|
534
|
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
|
535
|
+
if mask is None:
|
|
536
|
+
continue
|
|
537
|
+
|
|
538
|
+
# Find which classes are present in this mask
|
|
539
|
+
unique_values = np.unique(mask)
|
|
540
|
+
for class_id in unique_values:
|
|
541
|
+
if class_id != self.ignore_index and class_id in presence:
|
|
542
|
+
presence[class_id] += 1
|
|
543
|
+
|
|
544
|
+
return presence
|
|
545
|
+
|
|
546
|
+
def get_pixel_counts(self, split: str | None = None) -> dict[int, int]:
|
|
547
|
+
"""Return total pixel count per class ID.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
split: Specific split to analyze. If None, analyzes all splits.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Dictionary mapping class ID to total pixel count.
|
|
554
|
+
"""
|
|
555
|
+
pixel_counts: dict[int | None, int] = {
|
|
556
|
+
class_id: 0 for class_id in self._class_mapping
|
|
557
|
+
}
|
|
558
|
+
# Track ignored pixels if ignore_index is set
|
|
559
|
+
if self.ignore_index is not None:
|
|
560
|
+
pixel_counts[self.ignore_index] = 0
|
|
561
|
+
|
|
562
|
+
splits_to_process = (
|
|
563
|
+
[split] if split else (self.splits if self.splits else [None])
|
|
564
|
+
)
|
|
565
|
+
masks_root = self.path / self.masks_dir
|
|
566
|
+
|
|
567
|
+
for s in splits_to_process:
|
|
568
|
+
mask_dir = masks_root / s if s else masks_root
|
|
569
|
+
|
|
570
|
+
if not mask_dir.is_dir():
|
|
571
|
+
continue
|
|
572
|
+
|
|
573
|
+
for mask_path in mask_dir.glob("*.png"):
|
|
574
|
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
|
575
|
+
if mask is None:
|
|
576
|
+
continue
|
|
577
|
+
|
|
578
|
+
unique, counts = np.unique(mask, return_counts=True)
|
|
579
|
+
for class_id, count in zip(unique, counts, strict=True):
|
|
580
|
+
class_id_int = int(class_id)
|
|
581
|
+
if class_id_int in pixel_counts:
|
|
582
|
+
pixel_counts[class_id_int] += int(count)
|
|
583
|
+
elif (
|
|
584
|
+
self.ignore_index is not None
|
|
585
|
+
and class_id_int == self.ignore_index
|
|
586
|
+
):
|
|
587
|
+
pixel_counts[self.ignore_index] += int(count)
|
|
588
|
+
|
|
589
|
+
return pixel_counts
|
|
590
|
+
|
|
591
|
+
def get_annotations_for_image(self, image_path: Path) -> list[dict]:
|
|
592
|
+
"""Get annotations for a specific image.
|
|
593
|
+
|
|
594
|
+
For mask datasets, this returns an empty list since annotations
|
|
595
|
+
are stored as pixel masks rather than discrete objects.
|
|
596
|
+
Use get_mask_path() to get the mask file directly.
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
image_path: Path to the image file.
|
|
600
|
+
|
|
601
|
+
Returns:
|
|
602
|
+
Empty list (masks don't have discrete annotations).
|
|
603
|
+
"""
|
|
604
|
+
# Mask datasets don't have discrete annotations like detection/segmentation
|
|
605
|
+
# The mask itself IS the annotation
|
|
606
|
+
return []
|
|
607
|
+
|
|
608
|
+
def load_mask(self, image_path: Path) -> np.ndarray | None:
|
|
609
|
+
"""Load the mask for a given image.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
image_path: Path to the image file.
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
Mask as numpy array (grayscale), or None if not found.
|
|
616
|
+
"""
|
|
617
|
+
mask_path = self.get_mask_path(image_path)
|
|
618
|
+
if mask_path is None or not mask_path.exists():
|
|
619
|
+
return None
|
|
620
|
+
|
|
621
|
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
|
622
|
+
return mask
|
|
623
|
+
|
|
624
|
+
def validate_dimensions(
|
|
625
|
+
self, image_path: Path
|
|
626
|
+
) -> tuple[bool, tuple[int, int] | None, tuple[int, int] | None]:
|
|
627
|
+
"""Check if image and mask dimensions match.
|
|
628
|
+
|
|
629
|
+
Args:
|
|
630
|
+
image_path: Path to the image file.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
Tuple of (dimensions_match, image_shape, mask_shape).
|
|
634
|
+
"""
|
|
635
|
+
mask_path = self.get_mask_path(image_path)
|
|
636
|
+
if mask_path is None or not mask_path.exists():
|
|
637
|
+
return False, None, None
|
|
638
|
+
|
|
639
|
+
img = cv2.imread(str(image_path))
|
|
640
|
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
|
641
|
+
|
|
642
|
+
if img is None or mask is None:
|
|
643
|
+
return False, None, None
|
|
644
|
+
|
|
645
|
+
img_shape = (img.shape[0], img.shape[1])
|
|
646
|
+
mask_shape = (mask.shape[0], mask.shape[1])
|
|
647
|
+
|
|
648
|
+
return img_shape == mask_shape, img_shape, mask_shape
|
argus/core/yolo.py
CHANGED
|
@@ -89,6 +89,11 @@ class YOLODataset(Dataset):
|
|
|
89
89
|
if "names" not in config:
|
|
90
90
|
continue
|
|
91
91
|
|
|
92
|
+
# Skip if this looks like a mask dataset config
|
|
93
|
+
# (has ignore_index or palette keys which are mask-specific)
|
|
94
|
+
if "ignore_index" in config or "palette" in config:
|
|
95
|
+
continue
|
|
96
|
+
|
|
92
97
|
names = config["names"]
|
|
93
98
|
|
|
94
99
|
# Extract class names
|
|
@@ -692,12 +697,14 @@ class YOLODataset(Dataset):
|
|
|
692
697
|
x = x_center - width / 2
|
|
693
698
|
y = y_center - height / 2
|
|
694
699
|
|
|
695
|
-
annotations.append(
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
700
|
+
annotations.append(
|
|
701
|
+
{
|
|
702
|
+
"class_name": class_name,
|
|
703
|
+
"class_id": class_id,
|
|
704
|
+
"bbox": (x, y, width, height),
|
|
705
|
+
"polygon": None,
|
|
706
|
+
}
|
|
707
|
+
)
|
|
701
708
|
else:
|
|
702
709
|
# Segmentation: class x1 y1 x2 y2 ... xn yn
|
|
703
710
|
coords = [float(p) for p in parts[1:]]
|
|
@@ -715,12 +722,14 @@ class YOLODataset(Dataset):
|
|
|
715
722
|
width = max(xs) - x
|
|
716
723
|
height = max(ys) - y
|
|
717
724
|
|
|
718
|
-
annotations.append(
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
725
|
+
annotations.append(
|
|
726
|
+
{
|
|
727
|
+
"class_name": class_name,
|
|
728
|
+
"class_id": class_id,
|
|
729
|
+
"bbox": (x, y, width, height),
|
|
730
|
+
"polygon": polygon,
|
|
731
|
+
}
|
|
732
|
+
)
|
|
724
733
|
|
|
725
734
|
except (ValueError, IndexError):
|
|
726
735
|
continue
|