bplusplus 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bplusplus/prepare.py ADDED
@@ -0,0 +1,706 @@
1
+ import os
2
+ import random
3
+ import shutil
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import requests
9
+ import torch
10
+ from PIL import Image, ImageFilter
11
+ from torch import serialization
12
+ from torch.nn import Module, ModuleDict, ModuleList
13
+ from torch.nn.modules.activation import LeakyReLU, ReLU, SiLU
14
+ # Add more modules to prevent further errors
15
+ from torch.nn.modules.batchnorm import BatchNorm2d
16
+ from torch.nn.modules.container import Sequential
17
+ from torch.nn.modules.conv import Conv2d
18
+ from torch.nn.modules.dropout import Dropout
19
+ from torch.nn.modules.linear import Linear
20
+ from torch.nn.modules.pooling import MaxPool2d
21
+ from torch.nn.modules.upsampling import Upsample
22
+ from tqdm import tqdm
23
+ from ultralytics import YOLO
24
+ from ultralytics.nn.modules import SPPF, Bottleneck, C2f, Concat, Detect
25
+ from ultralytics.nn.modules.block import DFL
26
+ from ultralytics.nn.modules.conv import Conv
27
+ from ultralytics.nn.tasks import DetectionModel
28
+
29
+
30
+ def prepare(input_directory: str, output_directory: str, img_size: int = 40, conf: float = 0.35, valid: float = 0.1, blur: Optional[float] = None):
31
+ """
32
+ Prepares a YOLO classification dataset by performing the following steps:
33
+ 1. Copies images from input directory to temporary directory and creates class mapping.
34
+ 2. Deletes corrupted images and downloads YOLO model weights if not present.
35
+ 3. Runs YOLO inference to generate detection labels (bounding boxes) for the images.
36
+ 4. Cleans up orphaned images, invalid labels, and updates labels with class indices.
37
+ 5. Crops detected objects from images based on bounding boxes and resizes them.
38
+ 6. Splits data into train/valid sets with classification folder structure (train/class_name/image.jpg).
39
+
40
+ Args:
41
+ input_directory (str): The path to the input directory containing the images.
42
+ output_directory (str): The path to the output directory where the prepared classification dataset will be saved.
43
+ img_size (int, optional): The target size for the smallest dimension of cropped images. Defaults to 40.
44
+ conf (float, optional): YOLO detection confidence threshold. Defaults to 0.35.
45
+ valid (float, optional): Fraction of data for validation (0.0 to 1.0).
46
+ 0 = no validation split, 0.1 = 10% validation. Defaults to 0.1.
47
+ blur (float, optional): Gaussian blur as fraction of image size (0.0 to 1.0).
48
+ Applied before resizing. 0.01 = 1% of smallest dimension.
49
+ None or 0 means no blur. Defaults to None.
50
+ """
51
+ # Validate the valid parameter
52
+ if not 0 <= valid <= 1:
53
+ raise ValueError(f"valid must be between 0 and 1, got {valid}")
54
+ # Validate the blur parameter
55
+ if blur is not None and not 0 <= blur <= 1:
56
+ raise ValueError(f"blur must be between 0 and 1, got {blur}")
57
+ input_directory = Path(input_directory)
58
+ output_directory = Path(output_directory)
59
+
60
+ print("="*60)
61
+ print("STARTING BPLUSPLUS DATASET PREPARATION")
62
+ print("="*60)
63
+ print(f"Input directory: {input_directory}")
64
+ print(f"Output directory: {output_directory}")
65
+ print(f"Target image size: {img_size}px (smallest dimension)")
66
+ print(f"YOLO confidence threshold: {conf}")
67
+ if valid > 0:
68
+ print(f"Validation split: {valid*100:.0f}% validation, {(1-valid)*100:.0f}% training")
69
+ else:
70
+ print("Validation split: disabled (all images to training)")
71
+ if blur and blur > 0:
72
+ print(f"Gaussian blur: {blur*100:.1f}% of image size")
73
+ else:
74
+ print("Gaussian blur: disabled")
75
+ print()
76
+
77
+ with tempfile.TemporaryDirectory() as temp_dir:
78
+ temp_dir_path = Path(temp_dir)
79
+ print(f"Using temporary directory: {temp_dir_path}")
80
+ print()
81
+
82
+ # Step 1: Setup directories and copy images
83
+ print("STEP 1: Setting up directories and copying images...")
84
+ print("-" * 50)
85
+ class_mapping, original_image_count = _setup_directories_and_copy_images(
86
+ input_directory, temp_dir_path
87
+ )
88
+ print(f"✓ Step 1 completed: {original_image_count} images copied from {len(class_mapping)} classes")
89
+ print()
90
+
91
+ # Step 2-3: Clean images and setup model
92
+ print("STEP 2: Cleaning images and setting up YOLO model...")
93
+ print("-" * 50)
94
+ weights_path = _prepare_model_and_clean_images(temp_dir_path)
95
+ print(f"✓ Step 2 completed: Model ready at {weights_path}")
96
+ print()
97
+
98
+ # Step 4: Run YOLO inference
99
+ print("STEP 3: Running YOLO inference to detect objects...")
100
+ print("-" * 50)
101
+ labels_path = _run_yolo_inference(temp_dir_path, weights_path, conf)
102
+ print(f"✓ Step 3 completed: Labels generated at {labels_path}")
103
+ print()
104
+
105
+ # Step 5-6: Clean up labels and update class mapping
106
+ print("STEP 4: Cleaning up orphaned files and processing labels...")
107
+ print("-" * 50)
108
+ class_idxs = _cleanup_and_process_labels(
109
+ temp_dir_path, labels_path, class_mapping
110
+ )
111
+ print(f"✓ Step 4 completed: Processed {len(class_idxs)} classes")
112
+ print()
113
+
114
+ # Step 7-9: Finalize dataset
115
+ print("STEP 5: Creating classification dataset with cropped images...")
116
+ print("-" * 50)
117
+ _finalize_dataset(
118
+ class_mapping, temp_dir_path, output_directory,
119
+ class_idxs, original_image_count, img_size, valid, blur
120
+ )
121
+ print("✓ Step 5 completed: Classification dataset ready!")
122
+ print()
123
+
124
+ print("="*60)
125
+ print("BPLUSPLUS DATASET PREPARATION COMPLETED SUCCESSFULLY!")
126
+ print("="*60)
127
+
128
+ def _setup_directories_and_copy_images(input_directory: Path, temp_dir_path: Path):
129
+ """
130
+ Sets up temporary directories and copies images from input directory.
131
+
132
+ Returns:
133
+ tuple: (class_mapping dict, original_image_count int)
134
+ """
135
+ images_path = temp_dir_path / "images"
136
+ images_path.mkdir(parents=True, exist_ok=True)
137
+ print(f" Created temporary images directory: {images_path}")
138
+
139
+ class_mapping = {}
140
+ total_copied = 0
141
+
142
+ print(" Scanning input directory for class folders...")
143
+ class_folders = [d for d in input_directory.iterdir() if d.is_dir()]
144
+ print(f" Found {len(class_folders)} class folders")
145
+
146
+ for folder_directory in class_folders:
147
+ images_names = []
148
+ if folder_directory.is_dir():
149
+ folder_name = folder_directory.name
150
+ image_files = list(folder_directory.glob("*.jpg")) + list(folder_directory.glob("*.png"))
151
+ print(f" Copying {len(image_files)} images from class '{folder_name}'...")
152
+
153
+ for image_file in image_files:
154
+ shutil.copy(image_file, images_path)
155
+ image_name = image_file.name
156
+ images_names.append(image_name)
157
+ total_copied += 1
158
+
159
+ class_mapping[folder_name] = images_names
160
+ print(f" ✓ {len(images_names)} images copied for class '{folder_name}'")
161
+
162
+ original_image_count = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.jpeg"))) + len(list(images_path.glob("*.png")))
163
+ print(f" Total images in temporary directory: {original_image_count}")
164
+
165
+ return class_mapping, original_image_count
166
+
167
+ def _prepare_model_and_clean_images(temp_dir_path: Path):
168
+ """
169
+ Cleans corrupted images and downloads/prepares the YOLO model.
170
+
171
+ Returns:
172
+ Path: weights_path for the YOLO model
173
+ """
174
+ images_path = temp_dir_path / "images"
175
+
176
+ # Clean corrupted images
177
+ print(" Checking for corrupted images...")
178
+ images_before = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
179
+ __delete_corrupted_images(images_path)
180
+ images_after = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
181
+ deleted_count = images_before - images_after
182
+ print(f" ✓ Cleaned {deleted_count} corrupted images ({images_after} images remain)")
183
+
184
+ # Setup model weights
185
+ current_dir = Path(__file__).resolve().parent
186
+ weights_path = current_dir / 'gbif-generic.pt'
187
+ github_release_url = 'https://github.com/Tvenver/Bplusplus/releases/download/weights/gbif-generic.pt'
188
+
189
+ print(f" Checking for YOLO model weights at: {weights_path}")
190
+ if not weights_path.exists():
191
+ print(" Model weights not found, downloading from GitHub...")
192
+ __download_file_from_github_release(github_release_url, weights_path)
193
+ print(f" ✓ Model weights downloaded successfully")
194
+ else:
195
+ print(" ✓ Model weights already exist")
196
+
197
+ # Add all required classes to safe globals
198
+ if hasattr(serialization, 'add_safe_globals'):
199
+ serialization.add_safe_globals([
200
+ DetectionModel, Sequential, Conv, Conv2d, BatchNorm2d,
201
+ SiLU, ReLU, LeakyReLU, MaxPool2d, Linear, Dropout, Upsample,
202
+ Module, ModuleList, ModuleDict,
203
+ Bottleneck, C2f, SPPF, Detect, Concat, DFL,
204
+ # Add torch internal classes
205
+ torch.nn.parameter.Parameter,
206
+ torch.Tensor,
207
+ torch._utils._rebuild_tensor_v2,
208
+ torch._utils._rebuild_parameter
209
+ ])
210
+
211
+ return weights_path
212
+
213
+ def _run_yolo_inference(temp_dir_path: Path, weights_path: Path, conf: float):
214
+ """
215
+ Runs YOLO inference on all images to generate labels.
216
+
217
+ Args:
218
+ temp_dir_path (Path): Path to the working temp directory.
219
+ weights_path (Path): Path to YOLO weights.
220
+ conf (float): YOLO detection confidence threshold.
221
+
222
+ Returns:
223
+ Path: labels_path where the generated labels are stored
224
+ """
225
+ images_path = temp_dir_path / "images"
226
+ labels_path = temp_dir_path / "predict" / "labels"
227
+
228
+ try:
229
+ print(f" Loading YOLO model from: {weights_path}")
230
+ model = YOLO(weights_path)
231
+ print(" ✓ YOLO model loaded successfully")
232
+
233
+ # Get list of all image files
234
+ image_files = list(images_path.glob('*.jpg')) + list(images_path.glob('*.png'))
235
+ print(f" Found {len(image_files)} images to process with YOLO")
236
+
237
+ # Ensure predict directory exists
238
+ predict_dir = temp_dir_path / "predict"
239
+ predict_dir.mkdir(exist_ok=True)
240
+ labels_path.mkdir(parents=True, exist_ok=True)
241
+ print(f" Created prediction output directory: {predict_dir}")
242
+
243
+ result_count = 0
244
+ error_count = 0
245
+
246
+ print(" Starting YOLO inference...")
247
+ print(f" Progress: 0/{len(image_files)} images processed", end="", flush=True)
248
+
249
+ for i, img_path in enumerate(image_files, 1):
250
+ try:
251
+ results = model.predict(
252
+ source=str(img_path),
253
+ conf=conf,
254
+ save=True,
255
+ save_txt=True,
256
+ project=temp_dir_path,
257
+ name="predict",
258
+ exist_ok=True,
259
+ verbose=False # Set to False to reduce YOLO's own output
260
+ )
261
+
262
+ result_count += 1
263
+
264
+ # Update progress every 10% or every 100 images, whichever is smaller
265
+ update_interval = max(1, min(100, len(image_files) // 10))
266
+ if i % update_interval == 0 or i == len(image_files):
267
+ print(f"\r Progress: {i}/{len(image_files)} images processed", end="", flush=True)
268
+
269
+ except Exception as e:
270
+ error_count += 1
271
+ print(f"\n Error processing {img_path.name}: {e}")
272
+ continue
273
+
274
+ print() # New line after progress
275
+ print(f" ✓ YOLO inference completed: {result_count} successful, {error_count} failed")
276
+
277
+ # Verify labels were created
278
+ label_files = list(labels_path.glob("*.txt"))
279
+ print(f" Generated {len(label_files)} label files")
280
+
281
+ if len(label_files) == 0:
282
+ print("WARNING: No label files were created by the model prediction!")
283
+
284
+ except Exception as e:
285
+ print(f"Error during model prediction setup: {e}")
286
+ import traceback
287
+ traceback.print_exc()
288
+
289
+ return labels_path
290
+
291
+ def _cleanup_and_process_labels(temp_dir_path: Path, labels_path: Path, class_mapping: dict):
292
+ """
293
+ Cleans up orphaned images and invalid labels, then creates class index mapping.
294
+
295
+ Returns:
296
+ dict: class_idxs mapping class indices to class names
297
+ """
298
+ images_path = temp_dir_path / "images"
299
+
300
+ print(" Cleaning up orphaned images and labels...")
301
+ images_before = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
302
+ labels_before = len(list(labels_path.glob("*.txt")))
303
+
304
+ __delete_orphaned_images_and_inferences(images_path, labels_path)
305
+ __delete_invalid_txt_files(images_path, labels_path)
306
+
307
+ images_after = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
308
+ labels_after = len(list(labels_path.glob("*.txt")))
309
+
310
+ deleted_images = images_before - images_after
311
+ deleted_labels = labels_before - labels_after
312
+ print(f" ✓ Cleaned up {deleted_images} orphaned images and {deleted_labels} invalid labels")
313
+ print(f" Final counts: {images_after} images, {labels_after} valid labels")
314
+
315
+ # Create class index mapping for classification
316
+ class_idxs = {}
317
+ for idx, class_name in enumerate(class_mapping.keys()):
318
+ class_idxs[idx] = class_name
319
+
320
+ print(f" Created class mapping for {len(class_idxs)} classes: {list(class_idxs.values())}")
321
+
322
+ return class_idxs
323
+
324
+ def _finalize_dataset(class_mapping: dict, temp_dir_path: Path, output_directory: Path,
325
+ class_idxs: dict, original_image_count: int, img_size: int,
326
+ valid_fraction: float = 0.1, blur: Optional[float] = None):
327
+ """
328
+ Finalizes the dataset by creating cropped classification images and splitting into train/valid sets.
329
+
330
+ Args:
331
+ valid_fraction: Fraction of data for validation (0.0 to 1.0). 0 = no validation split.
332
+ blur: Gaussian blur as fraction of image size (0-1). None or 0 means no blur.
333
+ """
334
+ # Split data into train/valid with cropped classification images
335
+ __classification_split(class_mapping, temp_dir_path, output_directory, img_size, valid_fraction, blur)
336
+
337
+ # Generate final report
338
+ print(" Generating final statistics...")
339
+ final_image_count = count_images_across_splits(output_directory)
340
+ print(f" Dataset Statistics:")
341
+ print(f" - Original images: {original_image_count}")
342
+ print(f" - Final cropped images: {final_image_count}")
343
+ print(f" - Success rate: {final_image_count/original_image_count*100:.1f}%")
344
+ print(f" - Output directory: {output_directory}")
345
+
346
+ def __delete_corrupted_images(images_path: Path):
347
+
348
+ """
349
+ Deletes corrupted images from the specified directory.
350
+
351
+ Args:
352
+ images_path (Path): The path to the directory containing images.
353
+
354
+ This function iterates through all the image files in the specified directory
355
+ and attempts to open each one. If an image file is found to be corrupted (i.e.,
356
+ it cannot be opened), the function deletes the corrupted image file.
357
+ """
358
+
359
+ for pattern in ["*.jpg", "*.png"]:
360
+ for image_file in images_path.glob(pattern):
361
+ try:
362
+ Image.open(image_file)
363
+ except IOError:
364
+ image_file.unlink()
365
+
366
+ def __download_file_from_github_release(url, dest_path):
367
+
368
+ """
369
+ Downloads a file from a given GitHub release URL and saves it to the specified destination path,
370
+ with a progress bar displayed in the terminal.
371
+
372
+ Args:
373
+ url (str): The URL of the file to download.
374
+ dest_path (Path): The destination path where the file will be saved.
375
+
376
+ Raises:
377
+ Exception: If the file download fails.
378
+ """
379
+
380
+ response = requests.get(url, stream=True)
381
+ total_size = int(response.headers.get('content-length', 0))
382
+ block_size = 1024 # 1 Kibibyte
383
+ progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
384
+
385
+ if response.status_code == 200:
386
+ with open(dest_path, 'wb') as f:
387
+ for chunk in response.iter_content(chunk_size=block_size):
388
+ progress_bar.update(len(chunk))
389
+ f.write(chunk)
390
+ progress_bar.close()
391
+ else:
392
+ progress_bar.close()
393
+ raise Exception(f"Failed to download file from {url}")
394
+
395
+ def __delete_orphaned_images_and_inferences(images_path: Path, labels_path: Path):
396
+
397
+ """
398
+ Deletes orphaned images and their corresponding inference files if they do not have a label file.
399
+
400
+ Args:
401
+ images_path (Path): The path to the directory containing images.
402
+ inference_path (Path): The path to the directory containing inference files.
403
+ labels_path (Path): The path to the directory containing label files.
404
+
405
+ This function iterates through all the image files in the specified directory
406
+ and checks if there is a corresponding label file. If an image file does not
407
+ have a corresponding label file, the function deletes the orphaned image file
408
+ and its corresponding inference file.
409
+ """
410
+
411
+ for txt_file in labels_path.glob("*.txt"):
412
+ image_file_jpg = images_path / (txt_file.stem + ".jpg")
413
+ image_file_jpeg = images_path / (txt_file.stem + ".jpeg")
414
+ image_file_png = images_path / (txt_file.stem + ".png")
415
+
416
+ if not (image_file_jpg.exists() or image_file_jpeg.exists() or image_file_png.exists()):
417
+ # print(f"Deleting {txt_file.name} - No corresponding image file")
418
+ txt_file.unlink()
419
+
420
+ label_stems = {txt_file.stem for txt_file in labels_path.glob("*.txt")}
421
+ image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.jpeg")) + list(images_path.glob("*.png"))
422
+
423
+ for image_file in image_files:
424
+ if image_file.stem not in label_stems:
425
+ # print(f"Deleting orphaned image: {image_file.name}")
426
+ image_file.unlink()
427
+
428
+
429
+
430
+ def __delete_invalid_txt_files(images_path: Path, labels_path: Path):
431
+
432
+ """
433
+ Deletes invalid text files and their corresponding image and inference files.
434
+
435
+ Args:
436
+ images_path (Path): The path to the directory containing images.
437
+ inference_path (Path): The path to the directory containing inference files.
438
+ labels_path (Path): The path to the directory containing label files.
439
+
440
+ This function iterates through all the text files in the specified directory
441
+ and checks if they have 0 or more than one detections. If a text file is invalid,
442
+ the function deletes the invalid text file and its corresponding image and inference files.
443
+ """
444
+
445
+ for txt_file in labels_path.glob("*.txt"):
446
+ with open(txt_file, 'r') as file:
447
+ lines = file.readlines()
448
+
449
+ if len(lines) == 0 or len(lines) > 1:
450
+ # print(f"Deleting {txt_file.name} - Invalid file")
451
+ txt_file.unlink()
452
+
453
+ image_file_jpg = images_path / (txt_file.stem + ".jpg")
454
+ image_file_jpeg = images_path / (txt_file.stem + ".jpeg")
455
+ image_file_png = images_path / (txt_file.stem + ".png")
456
+
457
+ if image_file_jpg.exists():
458
+ image_file_jpg.unlink()
459
+ # print(f"Deleted corresponding image file: {image_file_jpg.name}")
460
+ elif image_file_jpeg.exists():
461
+ image_file_jpeg.unlink()
462
+ # print(f"Deleted corresponding image file: {image_file_jpeg.name}")
463
+ elif image_file_png.exists():
464
+ image_file_png.unlink()
465
+ # print(f"Deleted corresponding image file: {image_file_png.name}")
466
+
467
+
468
+
469
+
470
+ def __classification_split(class_mapping: dict, temp_dir_path: Path, output_directory: Path, img_size: int, valid_fraction: float = 0.1, blur: Optional[float] = None):
471
+ """
472
+ Splits the data into train and validation sets for classification tasks,
473
+ cropping images according to their YOLO labels but preserving original class structure.
474
+
475
+ Args:
476
+ class_mapping (dict): A dictionary mapping class names to image file names.
477
+ temp_dir_path (Path): The path to the temporary directory containing the images.
478
+ output_directory (Path): The path to the output directory where train and valid splits will be created.
479
+ img_size (int): The target size for the smallest dimension of cropped images.
480
+ valid_fraction (float): Fraction of data for validation (0.0 to 1.0). 0 = no validation split.
481
+ blur (float, optional): Gaussian blur as fraction of image size (0-1). None or 0 means no blur.
482
+ """
483
+ images_dir = temp_dir_path / "images"
484
+ labels_dir = temp_dir_path / "predict" / "labels"
485
+
486
+ create_valid = valid_fraction > 0
487
+
488
+ # Create train directory (and optionally valid)
489
+ train_dir = output_directory / 'train'
490
+ train_dir.mkdir(parents=True, exist_ok=True)
491
+
492
+ if create_valid:
493
+ valid_dir = output_directory / 'valid'
494
+ valid_dir.mkdir(parents=True, exist_ok=True)
495
+ print(f" Creating train and validation directories for {len(class_mapping)} classes...")
496
+ else:
497
+ valid_dir = None
498
+ print(f" Creating train directory for {len(class_mapping)} classes (no validation split)...")
499
+
500
+ # Create class directories based on class_mapping
501
+ for class_name in class_mapping:
502
+ (train_dir / class_name).mkdir(exist_ok=True)
503
+ if create_valid:
504
+ (valid_dir / class_name).mkdir(exist_ok=True)
505
+ print(f" ✓ Created directories for class: {class_name}")
506
+
507
+ # Process each class folder and its images
508
+ valid_images = []
509
+
510
+ # First, collect all valid label files
511
+ valid_label_stems = {label_file.stem for label_file in labels_dir.glob("*.txt")
512
+ if label_file.exists() and os.path.getsize(label_file) > 0}
513
+
514
+ print(f" Found {len(valid_label_stems)} valid label files for cropping")
515
+
516
+ print(" Starting image cropping and resizing...")
517
+ total_processed = 0
518
+ total_valid = 0
519
+
520
+ for class_name, image_names in class_mapping.items():
521
+ print(f" Processing class '{class_name}' ({len(image_names)} images)...")
522
+ class_processed = 0
523
+ class_valid = 0
524
+
525
+ for image_name in image_names:
526
+ # Check if the image exists in the images directory
527
+ image_path = images_dir / image_name
528
+ class_processed += 1
529
+ total_processed += 1
530
+
531
+ if not image_path.exists():
532
+ continue
533
+
534
+ # Skip images that don't have a valid label
535
+ if image_path.stem not in valid_label_stems:
536
+ continue
537
+
538
+ label_file = labels_dir / (image_path.stem + '.txt')
539
+
540
+ try:
541
+ img = Image.open(image_path)
542
+
543
+ if label_file.exists():
544
+ # If label exists, crop the image
545
+ with open(label_file, 'r') as f:
546
+ lines = f.readlines()
547
+ if lines:
548
+ parts = lines[0].strip().split()
549
+ if len(parts) >= 5:
550
+ x_center, y_center, width, height = map(float, parts[1:5])
551
+
552
+ img_width, img_height = img.size
553
+ x_min = int((x_center - width/2) * img_width)
554
+ y_min = int((y_center - height/2) * img_height)
555
+ x_max = int((x_center + width/2) * img_width)
556
+ y_max = int((y_center + height/2) * img_height)
557
+
558
+ x_min = max(0, x_min)
559
+ y_min = max(0, y_min)
560
+ x_max = min(img_width, x_max)
561
+ y_max = min(img_height, y_max)
562
+
563
+ img = img.crop((x_min, y_min, x_max, y_max))
564
+
565
+ # Apply Gaussian blur if specified (blur is fraction of smallest dimension)
566
+ if blur and blur > 0:
567
+ img_width, img_height = img.size
568
+ blur_radius = blur * min(img_width, img_height)
569
+ img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
570
+
571
+ img_width, img_height = img.size
572
+ if img_width < img_height:
573
+ # Width is smaller, set to img_size
574
+ new_width = img_size
575
+ new_height = int((img_height / img_width) * img_size)
576
+ else:
577
+ # Height is smaller, set to img_size
578
+ new_height = img_size
579
+ new_width = int((img_width / img_height) * img_size)
580
+
581
+ # Resize the image
582
+ img = img.resize((new_width, new_height), Image.LANCZOS)
583
+
584
+ valid_images.append((image_path, img, class_name))
585
+ class_valid += 1
586
+ total_valid += 1
587
+ except Exception as e:
588
+ print(f" Error processing {image_path}: {e}")
589
+
590
+ print(f" ✓ Class '{class_name}': {class_valid} valid images from {class_processed} processed")
591
+
592
+ print(f" ✓ Successfully processed {total_valid} valid images from {total_processed} total images")
593
+
594
+ # Shuffle images
595
+ random.shuffle(valid_images)
596
+
597
+ # Split into train/valid or put all in train
598
+ if create_valid:
599
+ train_fraction = 1.0 - valid_fraction
600
+ print(f" Shuffling and splitting images ({train_fraction*100:.0f}% train, {valid_fraction*100:.0f}% validation)...")
601
+ split_idx = int(len(valid_images) * train_fraction)
602
+ train_images = valid_images[:split_idx]
603
+ valid_images_split = valid_images[split_idx:]
604
+ print(f" Split: {len(train_images)} training images, {len(valid_images_split)} validation images")
605
+ else:
606
+ print(" Shuffling images (no validation split)...")
607
+ train_images = valid_images
608
+ valid_images_split = []
609
+ print(f" All {len(train_images)} images will be used for training")
610
+
611
+ # Save images to train/valid directories
612
+ print(" Saving cropped and resized images...")
613
+ saved_train = 0
614
+ saved_valid = 0
615
+
616
+ # Build list of (image_set, dest_dir, split_name) tuples
617
+ save_tasks = [(train_images, train_dir, "train")]
618
+ if create_valid and valid_images_split:
619
+ save_tasks.append((valid_images_split, valid_dir, "valid"))
620
+
621
+ for image_set, dest_dir, split_name in save_tasks:
622
+ print(f" Saving {len(image_set)} images to {split_name} set...")
623
+ for orig_file, img, class_name in image_set:
624
+ output_path = dest_dir / class_name / (orig_file.stem + '.jpg')
625
+
626
+ # Convert any non-RGB mode to RGB before saving
627
+ if img.mode != 'RGB':
628
+ img = img.convert('RGB')
629
+
630
+ img.save(output_path, format='JPEG', quality=95)
631
+
632
+ if split_name == "train":
633
+ saved_train += 1
634
+ else:
635
+ saved_valid += 1
636
+
637
+ if create_valid:
638
+ print(f" ✓ Saved {saved_train} train images and {saved_valid} validation images")
639
+ else:
640
+ print(f" ✓ Saved {saved_train} training images (no validation split)")
641
+
642
+ # Print detailed summary table
643
+ print(f" Final dataset summary:")
644
+ print()
645
+
646
+ # Calculate column widths for proper alignment
647
+ max_class_name_length = max(len(class_name) for class_name in class_mapping.keys())
648
+ class_col_width = max(max_class_name_length, len("Class"))
649
+
650
+ # Print table header (with or without Valid column)
651
+ if create_valid:
652
+ print(f" {'Class':<{class_col_width}} | {'Train':<7} | {'Valid':<7} | {'Total':<7}")
653
+ print(f" {'-' * class_col_width}-+-{'-' * 7}-+-{'-' * 7}-+-{'-' * 7}")
654
+ else:
655
+ print(f" {'Class':<{class_col_width}} | {'Train':<7}")
656
+ print(f" {'-' * class_col_width}-+-{'-' * 7}")
657
+
658
+ # Print data for each class and calculate totals
659
+ total_train = 0
660
+ total_valid = 0
661
+
662
+ for class_name in sorted(class_mapping.keys()): # Sort for consistent output
663
+ train_count = len(list((train_dir / class_name).glob('*.*')))
664
+
665
+ if create_valid:
666
+ valid_count = len(list((valid_dir / class_name).glob('*.*')))
667
+ class_total = train_count + valid_count
668
+ print(f" {class_name:<{class_col_width}} | {train_count:<7} | {valid_count:<7} | {class_total:<7}")
669
+ total_valid += valid_count
670
+ else:
671
+ print(f" {class_name:<{class_col_width}} | {train_count:<7}")
672
+
673
+ total_train += train_count
674
+
675
+ # Print totals row
676
+ if create_valid:
677
+ total_overall = total_train + total_valid
678
+ print(f" {'-' * class_col_width}-+-{'-' * 7}-+-{'-' * 7}-+-{'-' * 7}")
679
+ print(f" {'TOTAL':<{class_col_width}} | {total_train:<7} | {total_valid:<7} | {total_overall:<7}")
680
+ else:
681
+ print(f" {'-' * class_col_width}-+-{'-' * 7}")
682
+ print(f" {'TOTAL':<{class_col_width}} | {total_train:<7}")
683
+ print()
684
+
685
+ print(f" ✓ Classification dataset created successfully at: {output_directory}")
686
+
687
+ def count_images_across_splits(output_directory: Path) -> int:
688
+ """
689
+ Counts the total number of images across train and validation splits for classification dataset.
690
+
691
+ Args:
692
+ output_directory (Path): The path to the output directory containing the split data.
693
+
694
+ Returns:
695
+ int: The total number of images across all splits.
696
+ """
697
+ total_images = 0
698
+ for split in ['train', 'valid']:
699
+ split_dir = output_directory / split
700
+ if split_dir.exists():
701
+ # Count all images in all class subdirectories
702
+ for class_dir in split_dir.iterdir():
703
+ if class_dir.is_dir():
704
+ total_images += len(list(class_dir.glob("*.jpg"))) + len(list(class_dir.glob("*.jpeg"))) + len(list(class_dir.glob("*.png")))
705
+
706
+ return total_images