bplusplus 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bplusplus/__init__.py +15 -0
- bplusplus/collect.py +523 -0
- bplusplus/detector.py +376 -0
- bplusplus/inference.py +1337 -0
- bplusplus/prepare.py +706 -0
- bplusplus/tracker.py +261 -0
- bplusplus/train.py +913 -0
- bplusplus/validation.py +580 -0
- bplusplus-2.0.4.dist-info/LICENSE +21 -0
- bplusplus-2.0.4.dist-info/METADATA +259 -0
- bplusplus-2.0.4.dist-info/RECORD +12 -0
- bplusplus-2.0.4.dist-info/WHEEL +4 -0
bplusplus/prepare.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import shutil
|
|
4
|
+
import tempfile
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
import torch
|
|
10
|
+
from PIL import Image, ImageFilter
|
|
11
|
+
from torch import serialization
|
|
12
|
+
from torch.nn import Module, ModuleDict, ModuleList
|
|
13
|
+
from torch.nn.modules.activation import LeakyReLU, ReLU, SiLU
|
|
14
|
+
# Add more modules to prevent further errors
|
|
15
|
+
from torch.nn.modules.batchnorm import BatchNorm2d
|
|
16
|
+
from torch.nn.modules.container import Sequential
|
|
17
|
+
from torch.nn.modules.conv import Conv2d
|
|
18
|
+
from torch.nn.modules.dropout import Dropout
|
|
19
|
+
from torch.nn.modules.linear import Linear
|
|
20
|
+
from torch.nn.modules.pooling import MaxPool2d
|
|
21
|
+
from torch.nn.modules.upsampling import Upsample
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
from ultralytics import YOLO
|
|
24
|
+
from ultralytics.nn.modules import SPPF, Bottleneck, C2f, Concat, Detect
|
|
25
|
+
from ultralytics.nn.modules.block import DFL
|
|
26
|
+
from ultralytics.nn.modules.conv import Conv
|
|
27
|
+
from ultralytics.nn.tasks import DetectionModel
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def prepare(input_directory: str, output_directory: str, img_size: int = 40, conf: float = 0.35, valid: float = 0.1, blur: Optional[float] = None):
|
|
31
|
+
"""
|
|
32
|
+
Prepares a YOLO classification dataset by performing the following steps:
|
|
33
|
+
1. Copies images from input directory to temporary directory and creates class mapping.
|
|
34
|
+
2. Deletes corrupted images and downloads YOLO model weights if not present.
|
|
35
|
+
3. Runs YOLO inference to generate detection labels (bounding boxes) for the images.
|
|
36
|
+
4. Cleans up orphaned images, invalid labels, and updates labels with class indices.
|
|
37
|
+
5. Crops detected objects from images based on bounding boxes and resizes them.
|
|
38
|
+
6. Splits data into train/valid sets with classification folder structure (train/class_name/image.jpg).
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
input_directory (str): The path to the input directory containing the images.
|
|
42
|
+
output_directory (str): The path to the output directory where the prepared classification dataset will be saved.
|
|
43
|
+
img_size (int, optional): The target size for the smallest dimension of cropped images. Defaults to 40.
|
|
44
|
+
conf (float, optional): YOLO detection confidence threshold. Defaults to 0.35.
|
|
45
|
+
valid (float, optional): Fraction of data for validation (0.0 to 1.0).
|
|
46
|
+
0 = no validation split, 0.1 = 10% validation. Defaults to 0.1.
|
|
47
|
+
blur (float, optional): Gaussian blur as fraction of image size (0.0 to 1.0).
|
|
48
|
+
Applied before resizing. 0.01 = 1% of smallest dimension.
|
|
49
|
+
None or 0 means no blur. Defaults to None.
|
|
50
|
+
"""
|
|
51
|
+
# Validate the valid parameter
|
|
52
|
+
if not 0 <= valid <= 1:
|
|
53
|
+
raise ValueError(f"valid must be between 0 and 1, got {valid}")
|
|
54
|
+
# Validate the blur parameter
|
|
55
|
+
if blur is not None and not 0 <= blur <= 1:
|
|
56
|
+
raise ValueError(f"blur must be between 0 and 1, got {blur}")
|
|
57
|
+
input_directory = Path(input_directory)
|
|
58
|
+
output_directory = Path(output_directory)
|
|
59
|
+
|
|
60
|
+
print("="*60)
|
|
61
|
+
print("STARTING BPLUSPLUS DATASET PREPARATION")
|
|
62
|
+
print("="*60)
|
|
63
|
+
print(f"Input directory: {input_directory}")
|
|
64
|
+
print(f"Output directory: {output_directory}")
|
|
65
|
+
print(f"Target image size: {img_size}px (smallest dimension)")
|
|
66
|
+
print(f"YOLO confidence threshold: {conf}")
|
|
67
|
+
if valid > 0:
|
|
68
|
+
print(f"Validation split: {valid*100:.0f}% validation, {(1-valid)*100:.0f}% training")
|
|
69
|
+
else:
|
|
70
|
+
print("Validation split: disabled (all images to training)")
|
|
71
|
+
if blur and blur > 0:
|
|
72
|
+
print(f"Gaussian blur: {blur*100:.1f}% of image size")
|
|
73
|
+
else:
|
|
74
|
+
print("Gaussian blur: disabled")
|
|
75
|
+
print()
|
|
76
|
+
|
|
77
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
78
|
+
temp_dir_path = Path(temp_dir)
|
|
79
|
+
print(f"Using temporary directory: {temp_dir_path}")
|
|
80
|
+
print()
|
|
81
|
+
|
|
82
|
+
# Step 1: Setup directories and copy images
|
|
83
|
+
print("STEP 1: Setting up directories and copying images...")
|
|
84
|
+
print("-" * 50)
|
|
85
|
+
class_mapping, original_image_count = _setup_directories_and_copy_images(
|
|
86
|
+
input_directory, temp_dir_path
|
|
87
|
+
)
|
|
88
|
+
print(f"✓ Step 1 completed: {original_image_count} images copied from {len(class_mapping)} classes")
|
|
89
|
+
print()
|
|
90
|
+
|
|
91
|
+
# Step 2-3: Clean images and setup model
|
|
92
|
+
print("STEP 2: Cleaning images and setting up YOLO model...")
|
|
93
|
+
print("-" * 50)
|
|
94
|
+
weights_path = _prepare_model_and_clean_images(temp_dir_path)
|
|
95
|
+
print(f"✓ Step 2 completed: Model ready at {weights_path}")
|
|
96
|
+
print()
|
|
97
|
+
|
|
98
|
+
# Step 4: Run YOLO inference
|
|
99
|
+
print("STEP 3: Running YOLO inference to detect objects...")
|
|
100
|
+
print("-" * 50)
|
|
101
|
+
labels_path = _run_yolo_inference(temp_dir_path, weights_path, conf)
|
|
102
|
+
print(f"✓ Step 3 completed: Labels generated at {labels_path}")
|
|
103
|
+
print()
|
|
104
|
+
|
|
105
|
+
# Step 5-6: Clean up labels and update class mapping
|
|
106
|
+
print("STEP 4: Cleaning up orphaned files and processing labels...")
|
|
107
|
+
print("-" * 50)
|
|
108
|
+
class_idxs = _cleanup_and_process_labels(
|
|
109
|
+
temp_dir_path, labels_path, class_mapping
|
|
110
|
+
)
|
|
111
|
+
print(f"✓ Step 4 completed: Processed {len(class_idxs)} classes")
|
|
112
|
+
print()
|
|
113
|
+
|
|
114
|
+
# Step 7-9: Finalize dataset
|
|
115
|
+
print("STEP 5: Creating classification dataset with cropped images...")
|
|
116
|
+
print("-" * 50)
|
|
117
|
+
_finalize_dataset(
|
|
118
|
+
class_mapping, temp_dir_path, output_directory,
|
|
119
|
+
class_idxs, original_image_count, img_size, valid, blur
|
|
120
|
+
)
|
|
121
|
+
print("✓ Step 5 completed: Classification dataset ready!")
|
|
122
|
+
print()
|
|
123
|
+
|
|
124
|
+
print("="*60)
|
|
125
|
+
print("BPLUSPLUS DATASET PREPARATION COMPLETED SUCCESSFULLY!")
|
|
126
|
+
print("="*60)
|
|
127
|
+
|
|
128
|
+
def _setup_directories_and_copy_images(input_directory: Path, temp_dir_path: Path):
|
|
129
|
+
"""
|
|
130
|
+
Sets up temporary directories and copies images from input directory.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
tuple: (class_mapping dict, original_image_count int)
|
|
134
|
+
"""
|
|
135
|
+
images_path = temp_dir_path / "images"
|
|
136
|
+
images_path.mkdir(parents=True, exist_ok=True)
|
|
137
|
+
print(f" Created temporary images directory: {images_path}")
|
|
138
|
+
|
|
139
|
+
class_mapping = {}
|
|
140
|
+
total_copied = 0
|
|
141
|
+
|
|
142
|
+
print(" Scanning input directory for class folders...")
|
|
143
|
+
class_folders = [d for d in input_directory.iterdir() if d.is_dir()]
|
|
144
|
+
print(f" Found {len(class_folders)} class folders")
|
|
145
|
+
|
|
146
|
+
for folder_directory in class_folders:
|
|
147
|
+
images_names = []
|
|
148
|
+
if folder_directory.is_dir():
|
|
149
|
+
folder_name = folder_directory.name
|
|
150
|
+
image_files = list(folder_directory.glob("*.jpg")) + list(folder_directory.glob("*.png"))
|
|
151
|
+
print(f" Copying {len(image_files)} images from class '{folder_name}'...")
|
|
152
|
+
|
|
153
|
+
for image_file in image_files:
|
|
154
|
+
shutil.copy(image_file, images_path)
|
|
155
|
+
image_name = image_file.name
|
|
156
|
+
images_names.append(image_name)
|
|
157
|
+
total_copied += 1
|
|
158
|
+
|
|
159
|
+
class_mapping[folder_name] = images_names
|
|
160
|
+
print(f" ✓ {len(images_names)} images copied for class '{folder_name}'")
|
|
161
|
+
|
|
162
|
+
original_image_count = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.jpeg"))) + len(list(images_path.glob("*.png")))
|
|
163
|
+
print(f" Total images in temporary directory: {original_image_count}")
|
|
164
|
+
|
|
165
|
+
return class_mapping, original_image_count
|
|
166
|
+
|
|
167
|
+
def _prepare_model_and_clean_images(temp_dir_path: Path):
|
|
168
|
+
"""
|
|
169
|
+
Cleans corrupted images and downloads/prepares the YOLO model.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Path: weights_path for the YOLO model
|
|
173
|
+
"""
|
|
174
|
+
images_path = temp_dir_path / "images"
|
|
175
|
+
|
|
176
|
+
# Clean corrupted images
|
|
177
|
+
print(" Checking for corrupted images...")
|
|
178
|
+
images_before = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
|
|
179
|
+
__delete_corrupted_images(images_path)
|
|
180
|
+
images_after = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
|
|
181
|
+
deleted_count = images_before - images_after
|
|
182
|
+
print(f" ✓ Cleaned {deleted_count} corrupted images ({images_after} images remain)")
|
|
183
|
+
|
|
184
|
+
# Setup model weights
|
|
185
|
+
current_dir = Path(__file__).resolve().parent
|
|
186
|
+
weights_path = current_dir / 'gbif-generic.pt'
|
|
187
|
+
github_release_url = 'https://github.com/Tvenver/Bplusplus/releases/download/weights/gbif-generic.pt'
|
|
188
|
+
|
|
189
|
+
print(f" Checking for YOLO model weights at: {weights_path}")
|
|
190
|
+
if not weights_path.exists():
|
|
191
|
+
print(" Model weights not found, downloading from GitHub...")
|
|
192
|
+
__download_file_from_github_release(github_release_url, weights_path)
|
|
193
|
+
print(f" ✓ Model weights downloaded successfully")
|
|
194
|
+
else:
|
|
195
|
+
print(" ✓ Model weights already exist")
|
|
196
|
+
|
|
197
|
+
# Add all required classes to safe globals
|
|
198
|
+
if hasattr(serialization, 'add_safe_globals'):
|
|
199
|
+
serialization.add_safe_globals([
|
|
200
|
+
DetectionModel, Sequential, Conv, Conv2d, BatchNorm2d,
|
|
201
|
+
SiLU, ReLU, LeakyReLU, MaxPool2d, Linear, Dropout, Upsample,
|
|
202
|
+
Module, ModuleList, ModuleDict,
|
|
203
|
+
Bottleneck, C2f, SPPF, Detect, Concat, DFL,
|
|
204
|
+
# Add torch internal classes
|
|
205
|
+
torch.nn.parameter.Parameter,
|
|
206
|
+
torch.Tensor,
|
|
207
|
+
torch._utils._rebuild_tensor_v2,
|
|
208
|
+
torch._utils._rebuild_parameter
|
|
209
|
+
])
|
|
210
|
+
|
|
211
|
+
return weights_path
|
|
212
|
+
|
|
213
|
+
def _run_yolo_inference(temp_dir_path: Path, weights_path: Path, conf: float):
|
|
214
|
+
"""
|
|
215
|
+
Runs YOLO inference on all images to generate labels.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
temp_dir_path (Path): Path to the working temp directory.
|
|
219
|
+
weights_path (Path): Path to YOLO weights.
|
|
220
|
+
conf (float): YOLO detection confidence threshold.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Path: labels_path where the generated labels are stored
|
|
224
|
+
"""
|
|
225
|
+
images_path = temp_dir_path / "images"
|
|
226
|
+
labels_path = temp_dir_path / "predict" / "labels"
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
print(f" Loading YOLO model from: {weights_path}")
|
|
230
|
+
model = YOLO(weights_path)
|
|
231
|
+
print(" ✓ YOLO model loaded successfully")
|
|
232
|
+
|
|
233
|
+
# Get list of all image files
|
|
234
|
+
image_files = list(images_path.glob('*.jpg')) + list(images_path.glob('*.png'))
|
|
235
|
+
print(f" Found {len(image_files)} images to process with YOLO")
|
|
236
|
+
|
|
237
|
+
# Ensure predict directory exists
|
|
238
|
+
predict_dir = temp_dir_path / "predict"
|
|
239
|
+
predict_dir.mkdir(exist_ok=True)
|
|
240
|
+
labels_path.mkdir(parents=True, exist_ok=True)
|
|
241
|
+
print(f" Created prediction output directory: {predict_dir}")
|
|
242
|
+
|
|
243
|
+
result_count = 0
|
|
244
|
+
error_count = 0
|
|
245
|
+
|
|
246
|
+
print(" Starting YOLO inference...")
|
|
247
|
+
print(f" Progress: 0/{len(image_files)} images processed", end="", flush=True)
|
|
248
|
+
|
|
249
|
+
for i, img_path in enumerate(image_files, 1):
|
|
250
|
+
try:
|
|
251
|
+
results = model.predict(
|
|
252
|
+
source=str(img_path),
|
|
253
|
+
conf=conf,
|
|
254
|
+
save=True,
|
|
255
|
+
save_txt=True,
|
|
256
|
+
project=temp_dir_path,
|
|
257
|
+
name="predict",
|
|
258
|
+
exist_ok=True,
|
|
259
|
+
verbose=False # Set to False to reduce YOLO's own output
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
result_count += 1
|
|
263
|
+
|
|
264
|
+
# Update progress every 10% or every 100 images, whichever is smaller
|
|
265
|
+
update_interval = max(1, min(100, len(image_files) // 10))
|
|
266
|
+
if i % update_interval == 0 or i == len(image_files):
|
|
267
|
+
print(f"\r Progress: {i}/{len(image_files)} images processed", end="", flush=True)
|
|
268
|
+
|
|
269
|
+
except Exception as e:
|
|
270
|
+
error_count += 1
|
|
271
|
+
print(f"\n Error processing {img_path.name}: {e}")
|
|
272
|
+
continue
|
|
273
|
+
|
|
274
|
+
print() # New line after progress
|
|
275
|
+
print(f" ✓ YOLO inference completed: {result_count} successful, {error_count} failed")
|
|
276
|
+
|
|
277
|
+
# Verify labels were created
|
|
278
|
+
label_files = list(labels_path.glob("*.txt"))
|
|
279
|
+
print(f" Generated {len(label_files)} label files")
|
|
280
|
+
|
|
281
|
+
if len(label_files) == 0:
|
|
282
|
+
print("WARNING: No label files were created by the model prediction!")
|
|
283
|
+
|
|
284
|
+
except Exception as e:
|
|
285
|
+
print(f"Error during model prediction setup: {e}")
|
|
286
|
+
import traceback
|
|
287
|
+
traceback.print_exc()
|
|
288
|
+
|
|
289
|
+
return labels_path
|
|
290
|
+
|
|
291
|
+
def _cleanup_and_process_labels(temp_dir_path: Path, labels_path: Path, class_mapping: dict):
|
|
292
|
+
"""
|
|
293
|
+
Cleans up orphaned images and invalid labels, then creates class index mapping.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
dict: class_idxs mapping class indices to class names
|
|
297
|
+
"""
|
|
298
|
+
images_path = temp_dir_path / "images"
|
|
299
|
+
|
|
300
|
+
print(" Cleaning up orphaned images and labels...")
|
|
301
|
+
images_before = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
|
|
302
|
+
labels_before = len(list(labels_path.glob("*.txt")))
|
|
303
|
+
|
|
304
|
+
__delete_orphaned_images_and_inferences(images_path, labels_path)
|
|
305
|
+
__delete_invalid_txt_files(images_path, labels_path)
|
|
306
|
+
|
|
307
|
+
images_after = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
|
|
308
|
+
labels_after = len(list(labels_path.glob("*.txt")))
|
|
309
|
+
|
|
310
|
+
deleted_images = images_before - images_after
|
|
311
|
+
deleted_labels = labels_before - labels_after
|
|
312
|
+
print(f" ✓ Cleaned up {deleted_images} orphaned images and {deleted_labels} invalid labels")
|
|
313
|
+
print(f" Final counts: {images_after} images, {labels_after} valid labels")
|
|
314
|
+
|
|
315
|
+
# Create class index mapping for classification
|
|
316
|
+
class_idxs = {}
|
|
317
|
+
for idx, class_name in enumerate(class_mapping.keys()):
|
|
318
|
+
class_idxs[idx] = class_name
|
|
319
|
+
|
|
320
|
+
print(f" Created class mapping for {len(class_idxs)} classes: {list(class_idxs.values())}")
|
|
321
|
+
|
|
322
|
+
return class_idxs
|
|
323
|
+
|
|
324
|
+
def _finalize_dataset(class_mapping: dict, temp_dir_path: Path, output_directory: Path,
|
|
325
|
+
class_idxs: dict, original_image_count: int, img_size: int,
|
|
326
|
+
valid_fraction: float = 0.1, blur: Optional[float] = None):
|
|
327
|
+
"""
|
|
328
|
+
Finalizes the dataset by creating cropped classification images and splitting into train/valid sets.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
valid_fraction: Fraction of data for validation (0.0 to 1.0). 0 = no validation split.
|
|
332
|
+
blur: Gaussian blur as fraction of image size (0-1). None or 0 means no blur.
|
|
333
|
+
"""
|
|
334
|
+
# Split data into train/valid with cropped classification images
|
|
335
|
+
__classification_split(class_mapping, temp_dir_path, output_directory, img_size, valid_fraction, blur)
|
|
336
|
+
|
|
337
|
+
# Generate final report
|
|
338
|
+
print(" Generating final statistics...")
|
|
339
|
+
final_image_count = count_images_across_splits(output_directory)
|
|
340
|
+
print(f" Dataset Statistics:")
|
|
341
|
+
print(f" - Original images: {original_image_count}")
|
|
342
|
+
print(f" - Final cropped images: {final_image_count}")
|
|
343
|
+
print(f" - Success rate: {final_image_count/original_image_count*100:.1f}%")
|
|
344
|
+
print(f" - Output directory: {output_directory}")
|
|
345
|
+
|
|
346
|
+
def __delete_corrupted_images(images_path: Path):
|
|
347
|
+
|
|
348
|
+
"""
|
|
349
|
+
Deletes corrupted images from the specified directory.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
images_path (Path): The path to the directory containing images.
|
|
353
|
+
|
|
354
|
+
This function iterates through all the image files in the specified directory
|
|
355
|
+
and attempts to open each one. If an image file is found to be corrupted (i.e.,
|
|
356
|
+
it cannot be opened), the function deletes the corrupted image file.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
for pattern in ["*.jpg", "*.png"]:
|
|
360
|
+
for image_file in images_path.glob(pattern):
|
|
361
|
+
try:
|
|
362
|
+
Image.open(image_file)
|
|
363
|
+
except IOError:
|
|
364
|
+
image_file.unlink()
|
|
365
|
+
|
|
366
|
+
def __download_file_from_github_release(url, dest_path):
|
|
367
|
+
|
|
368
|
+
"""
|
|
369
|
+
Downloads a file from a given GitHub release URL and saves it to the specified destination path,
|
|
370
|
+
with a progress bar displayed in the terminal.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
url (str): The URL of the file to download.
|
|
374
|
+
dest_path (Path): The destination path where the file will be saved.
|
|
375
|
+
|
|
376
|
+
Raises:
|
|
377
|
+
Exception: If the file download fails.
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
response = requests.get(url, stream=True)
|
|
381
|
+
total_size = int(response.headers.get('content-length', 0))
|
|
382
|
+
block_size = 1024 # 1 Kibibyte
|
|
383
|
+
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
|
|
384
|
+
|
|
385
|
+
if response.status_code == 200:
|
|
386
|
+
with open(dest_path, 'wb') as f:
|
|
387
|
+
for chunk in response.iter_content(chunk_size=block_size):
|
|
388
|
+
progress_bar.update(len(chunk))
|
|
389
|
+
f.write(chunk)
|
|
390
|
+
progress_bar.close()
|
|
391
|
+
else:
|
|
392
|
+
progress_bar.close()
|
|
393
|
+
raise Exception(f"Failed to download file from {url}")
|
|
394
|
+
|
|
395
|
+
def __delete_orphaned_images_and_inferences(images_path: Path, labels_path: Path):
|
|
396
|
+
|
|
397
|
+
"""
|
|
398
|
+
Deletes orphaned images and their corresponding inference files if they do not have a label file.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
images_path (Path): The path to the directory containing images.
|
|
402
|
+
inference_path (Path): The path to the directory containing inference files.
|
|
403
|
+
labels_path (Path): The path to the directory containing label files.
|
|
404
|
+
|
|
405
|
+
This function iterates through all the image files in the specified directory
|
|
406
|
+
and checks if there is a corresponding label file. If an image file does not
|
|
407
|
+
have a corresponding label file, the function deletes the orphaned image file
|
|
408
|
+
and its corresponding inference file.
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
for txt_file in labels_path.glob("*.txt"):
|
|
412
|
+
image_file_jpg = images_path / (txt_file.stem + ".jpg")
|
|
413
|
+
image_file_jpeg = images_path / (txt_file.stem + ".jpeg")
|
|
414
|
+
image_file_png = images_path / (txt_file.stem + ".png")
|
|
415
|
+
|
|
416
|
+
if not (image_file_jpg.exists() or image_file_jpeg.exists() or image_file_png.exists()):
|
|
417
|
+
# print(f"Deleting {txt_file.name} - No corresponding image file")
|
|
418
|
+
txt_file.unlink()
|
|
419
|
+
|
|
420
|
+
label_stems = {txt_file.stem for txt_file in labels_path.glob("*.txt")}
|
|
421
|
+
image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.jpeg")) + list(images_path.glob("*.png"))
|
|
422
|
+
|
|
423
|
+
for image_file in image_files:
|
|
424
|
+
if image_file.stem not in label_stems:
|
|
425
|
+
# print(f"Deleting orphaned image: {image_file.name}")
|
|
426
|
+
image_file.unlink()
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def __delete_invalid_txt_files(images_path: Path, labels_path: Path):
|
|
431
|
+
|
|
432
|
+
"""
|
|
433
|
+
Deletes invalid text files and their corresponding image and inference files.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
images_path (Path): The path to the directory containing images.
|
|
437
|
+
inference_path (Path): The path to the directory containing inference files.
|
|
438
|
+
labels_path (Path): The path to the directory containing label files.
|
|
439
|
+
|
|
440
|
+
This function iterates through all the text files in the specified directory
|
|
441
|
+
and checks if they have 0 or more than one detections. If a text file is invalid,
|
|
442
|
+
the function deletes the invalid text file and its corresponding image and inference files.
|
|
443
|
+
"""
|
|
444
|
+
|
|
445
|
+
for txt_file in labels_path.glob("*.txt"):
|
|
446
|
+
with open(txt_file, 'r') as file:
|
|
447
|
+
lines = file.readlines()
|
|
448
|
+
|
|
449
|
+
if len(lines) == 0 or len(lines) > 1:
|
|
450
|
+
# print(f"Deleting {txt_file.name} - Invalid file")
|
|
451
|
+
txt_file.unlink()
|
|
452
|
+
|
|
453
|
+
image_file_jpg = images_path / (txt_file.stem + ".jpg")
|
|
454
|
+
image_file_jpeg = images_path / (txt_file.stem + ".jpeg")
|
|
455
|
+
image_file_png = images_path / (txt_file.stem + ".png")
|
|
456
|
+
|
|
457
|
+
if image_file_jpg.exists():
|
|
458
|
+
image_file_jpg.unlink()
|
|
459
|
+
# print(f"Deleted corresponding image file: {image_file_jpg.name}")
|
|
460
|
+
elif image_file_jpeg.exists():
|
|
461
|
+
image_file_jpeg.unlink()
|
|
462
|
+
# print(f"Deleted corresponding image file: {image_file_jpeg.name}")
|
|
463
|
+
elif image_file_png.exists():
|
|
464
|
+
image_file_png.unlink()
|
|
465
|
+
# print(f"Deleted corresponding image file: {image_file_png.name}")
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def __classification_split(class_mapping: dict, temp_dir_path: Path, output_directory: Path, img_size: int, valid_fraction: float = 0.1, blur: Optional[float] = None):
|
|
471
|
+
"""
|
|
472
|
+
Splits the data into train and validation sets for classification tasks,
|
|
473
|
+
cropping images according to their YOLO labels but preserving original class structure.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
class_mapping (dict): A dictionary mapping class names to image file names.
|
|
477
|
+
temp_dir_path (Path): The path to the temporary directory containing the images.
|
|
478
|
+
output_directory (Path): The path to the output directory where train and valid splits will be created.
|
|
479
|
+
img_size (int): The target size for the smallest dimension of cropped images.
|
|
480
|
+
valid_fraction (float): Fraction of data for validation (0.0 to 1.0). 0 = no validation split.
|
|
481
|
+
blur (float, optional): Gaussian blur as fraction of image size (0-1). None or 0 means no blur.
|
|
482
|
+
"""
|
|
483
|
+
images_dir = temp_dir_path / "images"
|
|
484
|
+
labels_dir = temp_dir_path / "predict" / "labels"
|
|
485
|
+
|
|
486
|
+
create_valid = valid_fraction > 0
|
|
487
|
+
|
|
488
|
+
# Create train directory (and optionally valid)
|
|
489
|
+
train_dir = output_directory / 'train'
|
|
490
|
+
train_dir.mkdir(parents=True, exist_ok=True)
|
|
491
|
+
|
|
492
|
+
if create_valid:
|
|
493
|
+
valid_dir = output_directory / 'valid'
|
|
494
|
+
valid_dir.mkdir(parents=True, exist_ok=True)
|
|
495
|
+
print(f" Creating train and validation directories for {len(class_mapping)} classes...")
|
|
496
|
+
else:
|
|
497
|
+
valid_dir = None
|
|
498
|
+
print(f" Creating train directory for {len(class_mapping)} classes (no validation split)...")
|
|
499
|
+
|
|
500
|
+
# Create class directories based on class_mapping
|
|
501
|
+
for class_name in class_mapping:
|
|
502
|
+
(train_dir / class_name).mkdir(exist_ok=True)
|
|
503
|
+
if create_valid:
|
|
504
|
+
(valid_dir / class_name).mkdir(exist_ok=True)
|
|
505
|
+
print(f" ✓ Created directories for class: {class_name}")
|
|
506
|
+
|
|
507
|
+
# Process each class folder and its images
|
|
508
|
+
valid_images = []
|
|
509
|
+
|
|
510
|
+
# First, collect all valid label files
|
|
511
|
+
valid_label_stems = {label_file.stem for label_file in labels_dir.glob("*.txt")
|
|
512
|
+
if label_file.exists() and os.path.getsize(label_file) > 0}
|
|
513
|
+
|
|
514
|
+
print(f" Found {len(valid_label_stems)} valid label files for cropping")
|
|
515
|
+
|
|
516
|
+
print(" Starting image cropping and resizing...")
|
|
517
|
+
total_processed = 0
|
|
518
|
+
total_valid = 0
|
|
519
|
+
|
|
520
|
+
for class_name, image_names in class_mapping.items():
|
|
521
|
+
print(f" Processing class '{class_name}' ({len(image_names)} images)...")
|
|
522
|
+
class_processed = 0
|
|
523
|
+
class_valid = 0
|
|
524
|
+
|
|
525
|
+
for image_name in image_names:
|
|
526
|
+
# Check if the image exists in the images directory
|
|
527
|
+
image_path = images_dir / image_name
|
|
528
|
+
class_processed += 1
|
|
529
|
+
total_processed += 1
|
|
530
|
+
|
|
531
|
+
if not image_path.exists():
|
|
532
|
+
continue
|
|
533
|
+
|
|
534
|
+
# Skip images that don't have a valid label
|
|
535
|
+
if image_path.stem not in valid_label_stems:
|
|
536
|
+
continue
|
|
537
|
+
|
|
538
|
+
label_file = labels_dir / (image_path.stem + '.txt')
|
|
539
|
+
|
|
540
|
+
try:
|
|
541
|
+
img = Image.open(image_path)
|
|
542
|
+
|
|
543
|
+
if label_file.exists():
|
|
544
|
+
# If label exists, crop the image
|
|
545
|
+
with open(label_file, 'r') as f:
|
|
546
|
+
lines = f.readlines()
|
|
547
|
+
if lines:
|
|
548
|
+
parts = lines[0].strip().split()
|
|
549
|
+
if len(parts) >= 5:
|
|
550
|
+
x_center, y_center, width, height = map(float, parts[1:5])
|
|
551
|
+
|
|
552
|
+
img_width, img_height = img.size
|
|
553
|
+
x_min = int((x_center - width/2) * img_width)
|
|
554
|
+
y_min = int((y_center - height/2) * img_height)
|
|
555
|
+
x_max = int((x_center + width/2) * img_width)
|
|
556
|
+
y_max = int((y_center + height/2) * img_height)
|
|
557
|
+
|
|
558
|
+
x_min = max(0, x_min)
|
|
559
|
+
y_min = max(0, y_min)
|
|
560
|
+
x_max = min(img_width, x_max)
|
|
561
|
+
y_max = min(img_height, y_max)
|
|
562
|
+
|
|
563
|
+
img = img.crop((x_min, y_min, x_max, y_max))
|
|
564
|
+
|
|
565
|
+
# Apply Gaussian blur if specified (blur is fraction of smallest dimension)
|
|
566
|
+
if blur and blur > 0:
|
|
567
|
+
img_width, img_height = img.size
|
|
568
|
+
blur_radius = blur * min(img_width, img_height)
|
|
569
|
+
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
|
|
570
|
+
|
|
571
|
+
img_width, img_height = img.size
|
|
572
|
+
if img_width < img_height:
|
|
573
|
+
# Width is smaller, set to img_size
|
|
574
|
+
new_width = img_size
|
|
575
|
+
new_height = int((img_height / img_width) * img_size)
|
|
576
|
+
else:
|
|
577
|
+
# Height is smaller, set to img_size
|
|
578
|
+
new_height = img_size
|
|
579
|
+
new_width = int((img_width / img_height) * img_size)
|
|
580
|
+
|
|
581
|
+
# Resize the image
|
|
582
|
+
img = img.resize((new_width, new_height), Image.LANCZOS)
|
|
583
|
+
|
|
584
|
+
valid_images.append((image_path, img, class_name))
|
|
585
|
+
class_valid += 1
|
|
586
|
+
total_valid += 1
|
|
587
|
+
except Exception as e:
|
|
588
|
+
print(f" Error processing {image_path}: {e}")
|
|
589
|
+
|
|
590
|
+
print(f" ✓ Class '{class_name}': {class_valid} valid images from {class_processed} processed")
|
|
591
|
+
|
|
592
|
+
print(f" ✓ Successfully processed {total_valid} valid images from {total_processed} total images")
|
|
593
|
+
|
|
594
|
+
# Shuffle images
|
|
595
|
+
random.shuffle(valid_images)
|
|
596
|
+
|
|
597
|
+
# Split into train/valid or put all in train
|
|
598
|
+
if create_valid:
|
|
599
|
+
train_fraction = 1.0 - valid_fraction
|
|
600
|
+
print(f" Shuffling and splitting images ({train_fraction*100:.0f}% train, {valid_fraction*100:.0f}% validation)...")
|
|
601
|
+
split_idx = int(len(valid_images) * train_fraction)
|
|
602
|
+
train_images = valid_images[:split_idx]
|
|
603
|
+
valid_images_split = valid_images[split_idx:]
|
|
604
|
+
print(f" Split: {len(train_images)} training images, {len(valid_images_split)} validation images")
|
|
605
|
+
else:
|
|
606
|
+
print(" Shuffling images (no validation split)...")
|
|
607
|
+
train_images = valid_images
|
|
608
|
+
valid_images_split = []
|
|
609
|
+
print(f" All {len(train_images)} images will be used for training")
|
|
610
|
+
|
|
611
|
+
# Save images to train/valid directories
|
|
612
|
+
print(" Saving cropped and resized images...")
|
|
613
|
+
saved_train = 0
|
|
614
|
+
saved_valid = 0
|
|
615
|
+
|
|
616
|
+
# Build list of (image_set, dest_dir, split_name) tuples
|
|
617
|
+
save_tasks = [(train_images, train_dir, "train")]
|
|
618
|
+
if create_valid and valid_images_split:
|
|
619
|
+
save_tasks.append((valid_images_split, valid_dir, "valid"))
|
|
620
|
+
|
|
621
|
+
for image_set, dest_dir, split_name in save_tasks:
|
|
622
|
+
print(f" Saving {len(image_set)} images to {split_name} set...")
|
|
623
|
+
for orig_file, img, class_name in image_set:
|
|
624
|
+
output_path = dest_dir / class_name / (orig_file.stem + '.jpg')
|
|
625
|
+
|
|
626
|
+
# Convert any non-RGB mode to RGB before saving
|
|
627
|
+
if img.mode != 'RGB':
|
|
628
|
+
img = img.convert('RGB')
|
|
629
|
+
|
|
630
|
+
img.save(output_path, format='JPEG', quality=95)
|
|
631
|
+
|
|
632
|
+
if split_name == "train":
|
|
633
|
+
saved_train += 1
|
|
634
|
+
else:
|
|
635
|
+
saved_valid += 1
|
|
636
|
+
|
|
637
|
+
if create_valid:
|
|
638
|
+
print(f" ✓ Saved {saved_train} train images and {saved_valid} validation images")
|
|
639
|
+
else:
|
|
640
|
+
print(f" ✓ Saved {saved_train} training images (no validation split)")
|
|
641
|
+
|
|
642
|
+
# Print detailed summary table
|
|
643
|
+
print(f" Final dataset summary:")
|
|
644
|
+
print()
|
|
645
|
+
|
|
646
|
+
# Calculate column widths for proper alignment
|
|
647
|
+
max_class_name_length = max(len(class_name) for class_name in class_mapping.keys())
|
|
648
|
+
class_col_width = max(max_class_name_length, len("Class"))
|
|
649
|
+
|
|
650
|
+
# Print table header (with or without Valid column)
|
|
651
|
+
if create_valid:
|
|
652
|
+
print(f" {'Class':<{class_col_width}} | {'Train':<7} | {'Valid':<7} | {'Total':<7}")
|
|
653
|
+
print(f" {'-' * class_col_width}-+-{'-' * 7}-+-{'-' * 7}-+-{'-' * 7}")
|
|
654
|
+
else:
|
|
655
|
+
print(f" {'Class':<{class_col_width}} | {'Train':<7}")
|
|
656
|
+
print(f" {'-' * class_col_width}-+-{'-' * 7}")
|
|
657
|
+
|
|
658
|
+
# Print data for each class and calculate totals
|
|
659
|
+
total_train = 0
|
|
660
|
+
total_valid = 0
|
|
661
|
+
|
|
662
|
+
for class_name in sorted(class_mapping.keys()): # Sort for consistent output
|
|
663
|
+
train_count = len(list((train_dir / class_name).glob('*.*')))
|
|
664
|
+
|
|
665
|
+
if create_valid:
|
|
666
|
+
valid_count = len(list((valid_dir / class_name).glob('*.*')))
|
|
667
|
+
class_total = train_count + valid_count
|
|
668
|
+
print(f" {class_name:<{class_col_width}} | {train_count:<7} | {valid_count:<7} | {class_total:<7}")
|
|
669
|
+
total_valid += valid_count
|
|
670
|
+
else:
|
|
671
|
+
print(f" {class_name:<{class_col_width}} | {train_count:<7}")
|
|
672
|
+
|
|
673
|
+
total_train += train_count
|
|
674
|
+
|
|
675
|
+
# Print totals row
|
|
676
|
+
if create_valid:
|
|
677
|
+
total_overall = total_train + total_valid
|
|
678
|
+
print(f" {'-' * class_col_width}-+-{'-' * 7}-+-{'-' * 7}-+-{'-' * 7}")
|
|
679
|
+
print(f" {'TOTAL':<{class_col_width}} | {total_train:<7} | {total_valid:<7} | {total_overall:<7}")
|
|
680
|
+
else:
|
|
681
|
+
print(f" {'-' * class_col_width}-+-{'-' * 7}")
|
|
682
|
+
print(f" {'TOTAL':<{class_col_width}} | {total_train:<7}")
|
|
683
|
+
print()
|
|
684
|
+
|
|
685
|
+
print(f" ✓ Classification dataset created successfully at: {output_directory}")
|
|
686
|
+
|
|
687
|
+
def count_images_across_splits(output_directory: Path) -> int:
|
|
688
|
+
"""
|
|
689
|
+
Counts the total number of images across train and validation splits for classification dataset.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
output_directory (Path): The path to the output directory containing the split data.
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
int: The total number of images across all splits.
|
|
696
|
+
"""
|
|
697
|
+
total_images = 0
|
|
698
|
+
for split in ['train', 'valid']:
|
|
699
|
+
split_dir = output_directory / split
|
|
700
|
+
if split_dir.exists():
|
|
701
|
+
# Count all images in all class subdirectories
|
|
702
|
+
for class_dir in split_dir.iterdir():
|
|
703
|
+
if class_dir.is_dir():
|
|
704
|
+
total_images += len(list(class_dir.glob("*.jpg"))) + len(list(class_dir.glob("*.jpeg"))) + len(list(class_dir.glob("*.png")))
|
|
705
|
+
|
|
706
|
+
return total_images
|