bplusplus 0.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +7 -3
- bplusplus/{collect_images.py → collect.py} +71 -7
- bplusplus/hierarchical/test.py +670 -0
- bplusplus/hierarchical/train.py +676 -0
- bplusplus/prepare.py +737 -0
- bplusplus/resnet/test.py +473 -0
- bplusplus/resnet/train.py +329 -0
- bplusplus/train_validate.py +8 -64
- bplusplus-1.2.0.dist-info/METADATA +249 -0
- bplusplus-1.2.0.dist-info/RECORD +12 -0
- bplusplus/build_model.py +0 -38
- bplusplus-0.1.1.dist-info/METADATA +0 -97
- bplusplus-0.1.1.dist-info/RECORD +0 -8
- {bplusplus-0.1.1.dist-info → bplusplus-1.2.0.dist-info}/LICENSE +0 -0
- {bplusplus-0.1.1.dist-info → bplusplus-1.2.0.dist-info}/WHEEL +0 -0
bplusplus/prepare.py
ADDED
|
@@ -0,0 +1,737 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
import requests
|
|
5
|
+
import tempfile
|
|
6
|
+
from .collect import Group, collect
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from ultralytics import YOLO
|
|
9
|
+
import shutil
|
|
10
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from prettytable import PrettyTable
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import requests
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
import yaml
|
|
17
|
+
import torch
|
|
18
|
+
from torch import serialization
|
|
19
|
+
from ultralytics.nn.tasks import DetectionModel
|
|
20
|
+
from torch.nn.modules.container import Sequential
|
|
21
|
+
from ultralytics.nn.modules.conv import Conv
|
|
22
|
+
from torch.nn.modules.conv import Conv2d
|
|
23
|
+
# Add more modules to prevent further errors
|
|
24
|
+
from torch.nn.modules.batchnorm import BatchNorm2d
|
|
25
|
+
from torch.nn.modules.activation import SiLU, ReLU, LeakyReLU
|
|
26
|
+
from torch.nn.modules.pooling import MaxPool2d
|
|
27
|
+
from torch.nn.modules.linear import Linear
|
|
28
|
+
from torch.nn.modules.dropout import Dropout
|
|
29
|
+
from torch.nn.modules.upsampling import Upsample
|
|
30
|
+
from torch.nn import Module, ModuleList, ModuleDict
|
|
31
|
+
from ultralytics.nn.modules import (
|
|
32
|
+
Bottleneck, C2f, SPPF, Detect, Concat
|
|
33
|
+
)
|
|
34
|
+
from ultralytics.nn.modules.block import DFL
|
|
35
|
+
import numpy as np
|
|
36
|
+
|
|
37
|
+
def prepare(input_directory: str, output_directory: str, one_stage: bool = False, with_background: bool = False, size_filter: bool = False, sizes: list = None):
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
Prepares the dataset for training by performing the following steps:
|
|
41
|
+
1. Copies images from the input directory to a temporary directory.
|
|
42
|
+
2. Deletes corrupted images.
|
|
43
|
+
3. Downloads YOLOv5 weights if not already present.
|
|
44
|
+
4. Runs YOLOv5 inference to generate labels for the images.
|
|
45
|
+
5. Deletes orphaned images and inferences.
|
|
46
|
+
6. Updates labels based on class mapping.
|
|
47
|
+
7. Splits the data into train, test, and validation sets.
|
|
48
|
+
8. Counts the total number of images across all splits.
|
|
49
|
+
9. Makes a YAML configuration file for YOLOv8.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
input_directory (str): The path to the input directory containing the images.
|
|
53
|
+
output_directory (str): The path to the output directory where the prepared dataset will be saved.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
input_directory = Path(input_directory)
|
|
57
|
+
output_directory = Path(output_directory)
|
|
58
|
+
|
|
59
|
+
class_mapping={}
|
|
60
|
+
|
|
61
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
62
|
+
|
|
63
|
+
temp_dir_path = Path(temp_dir)
|
|
64
|
+
images_path = temp_dir_path / "images"
|
|
65
|
+
|
|
66
|
+
images_path.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
|
|
68
|
+
for folder_directory in input_directory.iterdir():
|
|
69
|
+
images_names = []
|
|
70
|
+
if folder_directory.is_dir():
|
|
71
|
+
folder_name = folder_directory.name
|
|
72
|
+
for image_file in folder_directory.glob("*.jpg"):
|
|
73
|
+
shutil.copy(image_file, images_path)
|
|
74
|
+
image_name = image_file.name
|
|
75
|
+
images_names.append(image_name)
|
|
76
|
+
|
|
77
|
+
class_mapping[folder_name] = images_names
|
|
78
|
+
|
|
79
|
+
original_image_count = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.jpeg")))
|
|
80
|
+
|
|
81
|
+
__delete_corrupted_images(images_path)
|
|
82
|
+
|
|
83
|
+
current_dir = Path(__file__).resolve().parent
|
|
84
|
+
|
|
85
|
+
weights_path = current_dir / 'small-generic.pt'
|
|
86
|
+
|
|
87
|
+
github_release_url = 'https://github.com/orlandocloss/TwoStageInsectDetection/releases/download/models/small-generic.pt'
|
|
88
|
+
|
|
89
|
+
if not weights_path.exists():
|
|
90
|
+
__download_file_from_github_release(github_release_url, weights_path)
|
|
91
|
+
|
|
92
|
+
# Add all required classes to safe globals
|
|
93
|
+
serialization.add_safe_globals([
|
|
94
|
+
DetectionModel, Sequential, Conv, Conv2d, BatchNorm2d,
|
|
95
|
+
SiLU, ReLU, LeakyReLU, MaxPool2d, Linear, Dropout, Upsample,
|
|
96
|
+
Module, ModuleList, ModuleDict,
|
|
97
|
+
Bottleneck, C2f, SPPF, Detect, Concat, DFL
|
|
98
|
+
])
|
|
99
|
+
|
|
100
|
+
model = YOLO(weights_path)
|
|
101
|
+
model.predict(images_path, conf=0.25, save=True, save_txt=True, project=temp_dir_path)
|
|
102
|
+
labels_path = temp_dir_path / "predict" / "labels"
|
|
103
|
+
|
|
104
|
+
if size_filter and len(sizes) <= 2:
|
|
105
|
+
filtered=filter_by_size(images_path, labels_path, sizes)
|
|
106
|
+
print(f"\nFiltered {len(list(images_path.glob('*.jpg')))} images by size out of {original_image_count} input images.\n NOTE: Some images may be filtered due to corruption or inaccurate labels.")
|
|
107
|
+
|
|
108
|
+
if one_stage:
|
|
109
|
+
|
|
110
|
+
__delete_orphaned_images_and_inferences(images_path, labels_path)
|
|
111
|
+
__delete_invalid_txt_files(images_path, labels_path)
|
|
112
|
+
class_idxs = update_labels(class_mapping, labels_path)
|
|
113
|
+
__split_data(class_mapping, temp_dir_path, output_directory)
|
|
114
|
+
|
|
115
|
+
# __save_class_idx_to_file(class_idxs, output_directory)
|
|
116
|
+
final_image_count = count_images_across_splits(output_directory)
|
|
117
|
+
print(f"\nOut of {original_image_count} input images, {final_image_count} are eligible for detection. \nThese are saved across train, test and valid split in {output_directory}.")
|
|
118
|
+
__generate_sample_images_with_detections(output_directory, class_idxs)
|
|
119
|
+
|
|
120
|
+
if with_background:
|
|
121
|
+
print("\nCollecting and splitting background images.")
|
|
122
|
+
|
|
123
|
+
bg_images=int(final_image_count*0.06)
|
|
124
|
+
|
|
125
|
+
search: dict[str, Any] = {
|
|
126
|
+
"scientificName": ["Plantae"]
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
collect(
|
|
130
|
+
group_by_key=Group.scientificName,
|
|
131
|
+
search_parameters=search,
|
|
132
|
+
images_per_group=bg_images,
|
|
133
|
+
output_directory=temp_dir_path
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
__delete_corrupted_images(temp_dir_path / "Plantae")
|
|
137
|
+
|
|
138
|
+
__split_background_images(temp_dir_path / "Plantae", output_directory)
|
|
139
|
+
|
|
140
|
+
__count_classes_and_output_table(output_directory, class_idxs)
|
|
141
|
+
|
|
142
|
+
__make_yaml_file(output_directory, class_idxs)
|
|
143
|
+
else:
|
|
144
|
+
try:
|
|
145
|
+
sized_dir = temp_dir_path / "sized"
|
|
146
|
+
sized_dir.mkdir(parents=True, exist_ok=True)
|
|
147
|
+
__two_stage_update(class_mapping, filtered, sized_dir, images_path)
|
|
148
|
+
__classification_split(sized_dir, output_directory)
|
|
149
|
+
__count_classification_split(output_directory, class_mapping)
|
|
150
|
+
except:
|
|
151
|
+
__classification_split(images_path, output_directory)
|
|
152
|
+
__count_classification_split(output_directory, class_mapping)
|
|
153
|
+
|
|
154
|
+
def __count_classification_split(output_directory: str, class_mapping: dict):
|
|
155
|
+
"""
|
|
156
|
+
Counts the number of images in the train and valid splits for each class.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
output_directory (str): Path to the output directory containing train and valid splits.
|
|
160
|
+
class_mapping (dict): Dictionary mapping class names to image file names.
|
|
161
|
+
"""
|
|
162
|
+
class_counts = {}
|
|
163
|
+
train_counts = {}
|
|
164
|
+
valid_counts = {}
|
|
165
|
+
|
|
166
|
+
for class_name in class_mapping.keys():
|
|
167
|
+
train_dir = output_directory / 'train' / class_name
|
|
168
|
+
valid_dir = output_directory / 'valid' / class_name
|
|
169
|
+
|
|
170
|
+
train_count = len(list(train_dir.glob("*.jpg"))) if train_dir.exists() else 0
|
|
171
|
+
valid_count = len(list(valid_dir.glob("*.jpg"))) if valid_dir.exists() else 0
|
|
172
|
+
total_count = train_count + valid_count
|
|
173
|
+
|
|
174
|
+
class_counts[class_name] = total_count
|
|
175
|
+
train_counts[class_name] = train_count
|
|
176
|
+
valid_counts[class_name] = valid_count
|
|
177
|
+
|
|
178
|
+
table = PrettyTable()
|
|
179
|
+
table.field_names = ["Class", "Train", "Valid", "Total"]
|
|
180
|
+
for class_name in class_mapping.keys():
|
|
181
|
+
table.add_row([
|
|
182
|
+
class_name,
|
|
183
|
+
train_counts[class_name],
|
|
184
|
+
valid_counts[class_name],
|
|
185
|
+
class_counts[class_name]
|
|
186
|
+
])
|
|
187
|
+
print(table)
|
|
188
|
+
print(f"Saved in {output_directory}")
|
|
189
|
+
|
|
190
|
+
def __classification_split(input_directory: str, output_directory: str):
|
|
191
|
+
"""
|
|
192
|
+
Splits the data into train and validation sets for classification tasks.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
input_directory (str): Path to the input directory containing subdirectories of class names.
|
|
196
|
+
output_directory (str): Path to the output directory where train and valid splits will be created.
|
|
197
|
+
"""
|
|
198
|
+
input_directory = Path(input_directory)
|
|
199
|
+
output_directory = Path(output_directory)
|
|
200
|
+
|
|
201
|
+
# Create train and valid directories
|
|
202
|
+
train_dir = output_directory / 'train'
|
|
203
|
+
valid_dir = output_directory / 'valid'
|
|
204
|
+
|
|
205
|
+
train_dir.mkdir(parents=True, exist_ok=True)
|
|
206
|
+
valid_dir.mkdir(parents=True, exist_ok=True)
|
|
207
|
+
|
|
208
|
+
# Process each class directory
|
|
209
|
+
for class_dir in input_directory.iterdir():
|
|
210
|
+
if not class_dir.is_dir():
|
|
211
|
+
continue
|
|
212
|
+
|
|
213
|
+
class_name = class_dir.name
|
|
214
|
+
print(f"Processing class: {class_name}")
|
|
215
|
+
|
|
216
|
+
# Create corresponding class directories in train and valid
|
|
217
|
+
(train_dir / class_name).mkdir(exist_ok=True)
|
|
218
|
+
(valid_dir / class_name).mkdir(exist_ok=True)
|
|
219
|
+
|
|
220
|
+
# Get all image files
|
|
221
|
+
image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.jpeg')) + list(class_dir.glob('*.png'))
|
|
222
|
+
|
|
223
|
+
if not image_files:
|
|
224
|
+
print(f"Warning: No images found in {class_dir}")
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
# Shuffle the files to ensure random distribution
|
|
228
|
+
np.random.shuffle(image_files)
|
|
229
|
+
|
|
230
|
+
# Split into train (90%) and valid (10%)
|
|
231
|
+
split_idx = int(len(image_files) * 0.9)
|
|
232
|
+
train_files = image_files[:split_idx]
|
|
233
|
+
valid_files = image_files[split_idx:]
|
|
234
|
+
|
|
235
|
+
# Copy files to respective directories
|
|
236
|
+
for img_file in train_files:
|
|
237
|
+
shutil.copy(img_file, train_dir / class_name / img_file.name)
|
|
238
|
+
|
|
239
|
+
for img_file in valid_files:
|
|
240
|
+
shutil.copy(img_file, valid_dir / class_name / img_file.name)
|
|
241
|
+
|
|
242
|
+
print(f" - {len(train_files)} images in train, {len(valid_files)} images in valid")
|
|
243
|
+
|
|
244
|
+
print(f"\nData split complete. Train and validation sets created in {output_directory}")
|
|
245
|
+
|
|
246
|
+
def filter_by_size(images_path: Path, labels_path: Path, sizes: list):
|
|
247
|
+
"""
|
|
248
|
+
Filters images by size and updates labels accordingly.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
images_path (Path): The path to the directory containing images.
|
|
252
|
+
labels_path (Path): The path to the directory containing labels.
|
|
253
|
+
sizes (list): A list of sizes to filter by.
|
|
254
|
+
"""
|
|
255
|
+
size_map={
|
|
256
|
+
"small": [0, 0.15],
|
|
257
|
+
"medium": [0.15, 0.3],
|
|
258
|
+
"large": [0.3, 1],
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
filtered_images = []
|
|
262
|
+
for image_file in images_path.glob("*.jpg"):
|
|
263
|
+
label_file = labels_path / (image_file.stem + ".txt")
|
|
264
|
+
image_name = image_file.name
|
|
265
|
+
|
|
266
|
+
if label_file.exists():
|
|
267
|
+
with open(label_file, 'r') as file:
|
|
268
|
+
lines = file.readlines()
|
|
269
|
+
if len(lines) != 1:
|
|
270
|
+
continue
|
|
271
|
+
else:
|
|
272
|
+
parts = lines[0].split()
|
|
273
|
+
_, _, width, height = map(float, parts[1:])
|
|
274
|
+
for size in sizes:
|
|
275
|
+
if width < size_map[size][1] and width >= size_map[size][0] and height < size_map[size][1] and height >= size_map[size][0]:
|
|
276
|
+
filtered_images.append(image_name)
|
|
277
|
+
|
|
278
|
+
for image_file in images_path.glob("*.jpg"):
|
|
279
|
+
label_file = labels_path / (image_file.stem + ".txt")
|
|
280
|
+
image_name = image_file.name
|
|
281
|
+
if image_name not in filtered_images:
|
|
282
|
+
image_file.unlink()
|
|
283
|
+
try:
|
|
284
|
+
label_file.unlink()
|
|
285
|
+
except FileNotFoundError:
|
|
286
|
+
pass
|
|
287
|
+
return filtered_images
|
|
288
|
+
|
|
289
|
+
def __two_stage_update(class_mapping: dict, filtered_images: Path, output_directory: Path, images_path: Path):
|
|
290
|
+
"""
|
|
291
|
+
Prepares folders with class name containing filtered images.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
for class_name, images in class_mapping.items():
|
|
295
|
+
for image_name in images:
|
|
296
|
+
if image_name in filtered_images:
|
|
297
|
+
(output_directory / class_name).mkdir(parents=True, exist_ok=True)
|
|
298
|
+
shutil.copy(images_path / image_name, output_directory / class_name / image_name)
|
|
299
|
+
|
|
300
|
+
def __delete_corrupted_images(images_path: Path):
|
|
301
|
+
|
|
302
|
+
"""
|
|
303
|
+
Deletes corrupted images from the specified directory.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
images_path (Path): The path to the directory containing images.
|
|
307
|
+
|
|
308
|
+
This function iterates through all the image files in the specified directory
|
|
309
|
+
and attempts to open each one. If an image file is found to be corrupted (i.e.,
|
|
310
|
+
it cannot be opened), the function deletes the corrupted image file.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
for image_file in images_path.glob("*.jpg"):
|
|
314
|
+
try:
|
|
315
|
+
Image.open(image_file)
|
|
316
|
+
except IOError:
|
|
317
|
+
image_file.unlink()
|
|
318
|
+
|
|
319
|
+
def __download_file_from_github_release(url, dest_path):
|
|
320
|
+
|
|
321
|
+
"""
|
|
322
|
+
Downloads a file from a given GitHub release URL and saves it to the specified destination path,
|
|
323
|
+
with a progress bar displayed in the terminal.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
url (str): The URL of the file to download.
|
|
327
|
+
dest_path (Path): The destination path where the file will be saved.
|
|
328
|
+
|
|
329
|
+
Raises:
|
|
330
|
+
Exception: If the file download fails.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
response = requests.get(url, stream=True)
|
|
334
|
+
total_size = int(response.headers.get('content-length', 0))
|
|
335
|
+
block_size = 1024 # 1 Kibibyte
|
|
336
|
+
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
|
|
337
|
+
|
|
338
|
+
if response.status_code == 200:
|
|
339
|
+
with open(dest_path, 'wb') as f:
|
|
340
|
+
for chunk in response.iter_content(chunk_size=block_size):
|
|
341
|
+
progress_bar.update(len(chunk))
|
|
342
|
+
f.write(chunk)
|
|
343
|
+
progress_bar.close()
|
|
344
|
+
else:
|
|
345
|
+
progress_bar.close()
|
|
346
|
+
raise Exception(f"Failed to download file from {url}")
|
|
347
|
+
|
|
348
|
+
def __delete_orphaned_images_and_inferences(images_path: Path, labels_path: Path):
|
|
349
|
+
|
|
350
|
+
"""
|
|
351
|
+
Deletes orphaned images and their corresponding inference files if they do not have a label file.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
images_path (Path): The path to the directory containing images.
|
|
355
|
+
inference_path (Path): The path to the directory containing inference files.
|
|
356
|
+
labels_path (Path): The path to the directory containing label files.
|
|
357
|
+
|
|
358
|
+
This function iterates through all the image files in the specified directory
|
|
359
|
+
and checks if there is a corresponding label file. If an image file does not
|
|
360
|
+
have a corresponding label file, the function deletes the orphaned image file
|
|
361
|
+
and its corresponding inference file.
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
for txt_file in labels_path.glob("*.txt"):
|
|
365
|
+
image_file_jpg = images_path / (txt_file.stem + ".jpg")
|
|
366
|
+
image_file_jpeg = images_path / (txt_file.stem + ".jpeg")
|
|
367
|
+
|
|
368
|
+
if not (image_file_jpg.exists() or image_file_jpeg.exists()):
|
|
369
|
+
print(f"Deleting {txt_file.name} - No corresponding image file")
|
|
370
|
+
txt_file.unlink()
|
|
371
|
+
|
|
372
|
+
label_stems = {txt_file.stem for txt_file in labels_path.glob("*.txt")}
|
|
373
|
+
image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.jpeg"))
|
|
374
|
+
|
|
375
|
+
for image_file in image_files:
|
|
376
|
+
if image_file.stem not in label_stems:
|
|
377
|
+
print(f"Deleting orphaned image: {image_file.name}")
|
|
378
|
+
image_file.unlink()
|
|
379
|
+
|
|
380
|
+
print("Orphaned images files without corresponding labels have been deleted.")
|
|
381
|
+
|
|
382
|
+
def __delete_invalid_txt_files(images_path: Path, labels_path: Path):
|
|
383
|
+
|
|
384
|
+
"""
|
|
385
|
+
Deletes invalid text files and their corresponding image and inference files.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
images_path (Path): The path to the directory containing images.
|
|
389
|
+
inference_path (Path): The path to the directory containing inference files.
|
|
390
|
+
labels_path (Path): The path to the directory containing label files.
|
|
391
|
+
|
|
392
|
+
This function iterates through all the text files in the specified directory
|
|
393
|
+
and checks if they have 0 or more than one detections. If a text file is invalid,
|
|
394
|
+
the function deletes the invalid text file and its corresponding image and inference files.
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
for txt_file in labels_path.glob("*.txt"):
|
|
398
|
+
with open(txt_file, 'r') as file:
|
|
399
|
+
lines = file.readlines()
|
|
400
|
+
|
|
401
|
+
if len(lines) == 0 or len(lines) > 1:
|
|
402
|
+
print(f"Deleting {txt_file.name} - Invalid file")
|
|
403
|
+
txt_file.unlink()
|
|
404
|
+
|
|
405
|
+
image_file_jpg = images_path / (txt_file.stem + ".jpg")
|
|
406
|
+
image_file_jpeg = images_path / (txt_file.stem + ".jpeg")
|
|
407
|
+
|
|
408
|
+
if image_file_jpg.exists():
|
|
409
|
+
image_file_jpg.unlink()
|
|
410
|
+
print(f"Deleted corresponding image file: {image_file_jpg.name}")
|
|
411
|
+
elif image_file_jpeg.exists():
|
|
412
|
+
image_file_jpeg.unlink()
|
|
413
|
+
print(f"Deleted corresponding image file: {image_file_jpeg.name}")
|
|
414
|
+
|
|
415
|
+
print("Invalid text files and their corresponding images files have been deleted.")
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def __split_data(class_mapping: dict, temp_dir_path: Path, output_directory: Path):
|
|
419
|
+
"""
|
|
420
|
+
Splits the data into train, test, and validation sets.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
class_mapping (dict): A dictionary mapping class names to image file names.
|
|
424
|
+
temp_dir_path (Path): The path to the temporary directory containing the images.
|
|
425
|
+
output_directory (Path): The path to the output directory where the split data will be saved.
|
|
426
|
+
"""
|
|
427
|
+
images_dir = temp_dir_path / "images"
|
|
428
|
+
labels_dir = temp_dir_path / "predict" / "labels"
|
|
429
|
+
|
|
430
|
+
def create_dirs(split):
|
|
431
|
+
(output_directory / split).mkdir(parents=True, exist_ok=True)
|
|
432
|
+
(output_directory / split / "images").mkdir(parents=True, exist_ok=True)
|
|
433
|
+
(output_directory / split / "labels").mkdir(parents=True, exist_ok=True)
|
|
434
|
+
|
|
435
|
+
def copy_files(file_list, split):
|
|
436
|
+
for image_file in file_list:
|
|
437
|
+
image_file_path = images_dir / image_file
|
|
438
|
+
|
|
439
|
+
if not image_file_path.exists():
|
|
440
|
+
continue
|
|
441
|
+
|
|
442
|
+
shutil.copy(image_file_path, output_directory / split / "images" / image_file_path.name)
|
|
443
|
+
|
|
444
|
+
label_file = labels_dir / (image_file_path.stem + ".txt")
|
|
445
|
+
if label_file.exists():
|
|
446
|
+
shutil.copy(label_file, output_directory / split / "labels" / label_file.name)
|
|
447
|
+
|
|
448
|
+
for split in ["train", "test", "valid"]:
|
|
449
|
+
create_dirs(split)
|
|
450
|
+
|
|
451
|
+
for _, files in class_mapping.items():
|
|
452
|
+
random.shuffle(files)
|
|
453
|
+
num_files = len(files)
|
|
454
|
+
|
|
455
|
+
train_count = int(0.8 * num_files)
|
|
456
|
+
test_count = int(0.1 * num_files)
|
|
457
|
+
valid_count = num_files - train_count - test_count
|
|
458
|
+
|
|
459
|
+
train_files = files[:train_count]
|
|
460
|
+
test_files = files[train_count:train_count + test_count]
|
|
461
|
+
valid_files = files[train_count + test_count:]
|
|
462
|
+
|
|
463
|
+
copy_files(train_files, "train")
|
|
464
|
+
copy_files(test_files, "test")
|
|
465
|
+
copy_files(valid_files, "valid")
|
|
466
|
+
|
|
467
|
+
print("Data has been split into train, test, and valid.")
|
|
468
|
+
|
|
469
|
+
def __save_class_idx_to_file(class_idxs: dict, output_directory: Path):
|
|
470
|
+
"""
|
|
471
|
+
Saves the class indices to a file.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
class_idxs (dict): A dictionary mapping class names to class indices.
|
|
475
|
+
output_directory (Path): The path to the output directory where the class index file will be saved.
|
|
476
|
+
"""
|
|
477
|
+
class_idx_file = output_directory / "class_idx.txt"
|
|
478
|
+
with open(class_idx_file, 'w') as f:
|
|
479
|
+
for class_name, idx in class_idxs.items():
|
|
480
|
+
f.write(f"{class_name}: {idx}\n")
|
|
481
|
+
print(f"Class indices have been saved to {class_idx_file}")
|
|
482
|
+
|
|
483
|
+
def __generate_sample_images_with_detections(main_dir: Path, class_idxs: dict):
|
|
484
|
+
|
|
485
|
+
"""
|
|
486
|
+
Generates one sample image with multiple detections for each of train, test, valid, combining up to 6 images in one output.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
main_dir (str): The main directory containing the train, test, and valid splits.
|
|
490
|
+
"""
|
|
491
|
+
|
|
492
|
+
def resize_and_contain(image, target_size):
|
|
493
|
+
image.thumbnail(target_size, Image.LANCZOS)
|
|
494
|
+
new_image = Image.new("RGB", target_size, (0, 0, 0))
|
|
495
|
+
new_image.paste(image, ((target_size[0] - image.width) // 2, (target_size[1] - image.height) // 2))
|
|
496
|
+
return new_image
|
|
497
|
+
|
|
498
|
+
def draw_bounding_boxes(image, labels_path, class_mapping, color_map):
|
|
499
|
+
draw = ImageDraw.Draw(image)
|
|
500
|
+
img_width, img_height = image.size
|
|
501
|
+
try:
|
|
502
|
+
font = ImageFont.truetype("DejaVuSans-Bold.ttf", 20)
|
|
503
|
+
except IOError:
|
|
504
|
+
font = ImageFont.load_default()
|
|
505
|
+
|
|
506
|
+
if labels_path.exists():
|
|
507
|
+
with open(labels_path, 'r') as label_file:
|
|
508
|
+
for line in label_file.readlines():
|
|
509
|
+
parts = line.strip().split()
|
|
510
|
+
class_idx = int(parts[0])
|
|
511
|
+
center_x, center_y, width, height = map(float, parts[1:])
|
|
512
|
+
x_min = int((center_x - width / 2) * img_width)
|
|
513
|
+
y_min = int((center_y - height / 2) * img_height)
|
|
514
|
+
x_max = int((center_x + width / 2) * img_width)
|
|
515
|
+
y_max = int((center_y + height / 2) * img_height)
|
|
516
|
+
class_name = class_mapping.get(class_idx, str(class_idx))
|
|
517
|
+
color = color_map[class_idx]
|
|
518
|
+
draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=3)
|
|
519
|
+
draw.text((x_min, y_min - 20), class_name, fill=color, font=font)
|
|
520
|
+
return image
|
|
521
|
+
|
|
522
|
+
def combine_images(images, grid_size=(3, 2), target_size=(416, 416)):
|
|
523
|
+
resized_images = [resize_and_contain(img, target_size) for img in images]
|
|
524
|
+
width, height = target_size
|
|
525
|
+
combined_image = Image.new('RGB', (width * grid_size[0], height * grid_size[1]))
|
|
526
|
+
|
|
527
|
+
for i, img in enumerate(resized_images):
|
|
528
|
+
row = i // grid_size[0]
|
|
529
|
+
col = i % grid_size[0]
|
|
530
|
+
combined_image.paste(img, (col * width, row * height))
|
|
531
|
+
|
|
532
|
+
return combined_image
|
|
533
|
+
|
|
534
|
+
def generate_color_map(class_mapping):
|
|
535
|
+
colors = ['red', 'blue', 'green', 'purple', 'orange', 'yellow', 'pink', 'cyan', 'magenta']
|
|
536
|
+
color_map = {idx: random.choice(colors) for idx in class_mapping.keys()}
|
|
537
|
+
return color_map
|
|
538
|
+
|
|
539
|
+
splits = ['train', 'test', 'valid']
|
|
540
|
+
class_mapping = class_idxs
|
|
541
|
+
color_map = generate_color_map(class_mapping)
|
|
542
|
+
|
|
543
|
+
for split in splits:
|
|
544
|
+
images_dir = Path(main_dir) / split / 'images'
|
|
545
|
+
labels_dir = Path(main_dir) / split / 'labels'
|
|
546
|
+
image_files = list(images_dir.glob("*.jpg"))
|
|
547
|
+
if not image_files:
|
|
548
|
+
continue
|
|
549
|
+
|
|
550
|
+
sample_images = []
|
|
551
|
+
for image_file in image_files[:6]:
|
|
552
|
+
label_file = labels_dir / (image_file.stem + '.txt')
|
|
553
|
+
image = Image.open(image_file)
|
|
554
|
+
image_with_boxes = draw_bounding_boxes(image, label_file, class_mapping, color_map)
|
|
555
|
+
sample_images.append(image_with_boxes)
|
|
556
|
+
|
|
557
|
+
if sample_images:
|
|
558
|
+
combined_image = combine_images(sample_images, grid_size=(3, 2), target_size=(416, 416))
|
|
559
|
+
combined_image_path = Path(main_dir) / split / f"{split}_sample_with_detections.jpg"
|
|
560
|
+
combined_image.save(combined_image_path)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def __split_background_images(background_dir: Path, output_directory: Path):
|
|
564
|
+
"""
|
|
565
|
+
Splits the background images into train, test, and validation sets.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
temp_dir_path (Path): The path to the temporary directory containing the background images.
|
|
569
|
+
output_directory (Path): The path to the output directory where the split background images will be saved.
|
|
570
|
+
"""
|
|
571
|
+
|
|
572
|
+
image_files = list(Path(background_dir).glob("*.jpg"))
|
|
573
|
+
random.shuffle(image_files)
|
|
574
|
+
|
|
575
|
+
num_images = len(image_files)
|
|
576
|
+
train_split = int(0.8 * num_images)
|
|
577
|
+
valid_split = int(0.1 * num_images)
|
|
578
|
+
|
|
579
|
+
train_files = image_files[:train_split]
|
|
580
|
+
valid_files = image_files[train_split:train_split + valid_split]
|
|
581
|
+
test_files = image_files[train_split + valid_split:]
|
|
582
|
+
|
|
583
|
+
def copy_files(image_list, split):
|
|
584
|
+
for image_file in image_list:
|
|
585
|
+
shutil.copy(image_file, Path(output_directory) / split / 'images' / image_file.name)
|
|
586
|
+
|
|
587
|
+
label_file = Path(output_directory) / split / 'labels' / (image_file.stem + ".txt")
|
|
588
|
+
label_file.touch()
|
|
589
|
+
|
|
590
|
+
copy_files(train_files, 'train')
|
|
591
|
+
copy_files(valid_files, 'valid')
|
|
592
|
+
copy_files(test_files, 'test')
|
|
593
|
+
|
|
594
|
+
print(f"Background data has been split: {len(train_files)} train, {len(valid_files)} valid, {len(test_files)} test")
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def __count_classes_and_output_table(output_directory: Path, class_idxs: dict):
|
|
598
|
+
"""
|
|
599
|
+
Counts the number of images per class and outputs a table.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
output_directory (Path): The path to the output directory containing the split data.
|
|
603
|
+
class_idxs (dict): A dictionary mapping class indices to class names.
|
|
604
|
+
"""
|
|
605
|
+
|
|
606
|
+
def count_classes_in_split(labels_dir):
|
|
607
|
+
class_counts = defaultdict(int)
|
|
608
|
+
for label_file in os.listdir(labels_dir):
|
|
609
|
+
if label_file.endswith(".txt"):
|
|
610
|
+
label_path = os.path.join(labels_dir, label_file)
|
|
611
|
+
with open(label_path, 'r') as f:
|
|
612
|
+
lines = f.readlines()
|
|
613
|
+
if not lines:
|
|
614
|
+
# Count empty files as 'null' class (background images)
|
|
615
|
+
class_counts['null'] += 1
|
|
616
|
+
else:
|
|
617
|
+
for line in lines:
|
|
618
|
+
class_index = int(line.split()[0])
|
|
619
|
+
class_counts[class_index] += 1
|
|
620
|
+
return class_counts
|
|
621
|
+
|
|
622
|
+
splits = ['train', 'test', 'valid']
|
|
623
|
+
total_counts = defaultdict(int)
|
|
624
|
+
|
|
625
|
+
table = PrettyTable()
|
|
626
|
+
table.field_names = ["Class", "Class Index", "Train Count", "Test Count", "Valid Count", "Total"]
|
|
627
|
+
|
|
628
|
+
split_counts = {split: defaultdict(int) for split in splits}
|
|
629
|
+
|
|
630
|
+
for split in splits:
|
|
631
|
+
labels_dir = output_directory / split / 'labels'
|
|
632
|
+
if not os.path.exists(labels_dir):
|
|
633
|
+
print(f"Warning: {labels_dir} does not exist, skipping {split}.")
|
|
634
|
+
continue
|
|
635
|
+
|
|
636
|
+
class_counts = count_classes_in_split(labels_dir)
|
|
637
|
+
for class_index, count in class_counts.items():
|
|
638
|
+
split_counts[split][class_index] = count
|
|
639
|
+
total_counts[class_index] += count
|
|
640
|
+
|
|
641
|
+
for class_index, total in total_counts.items():
|
|
642
|
+
class_name = class_idxs.get(class_index, "Background" if class_index == 'null' else f"Class {class_index}")
|
|
643
|
+
train_count = split_counts['train'].get(class_index, 0)
|
|
644
|
+
test_count = split_counts['test'].get(class_index, 0)
|
|
645
|
+
valid_count = split_counts['valid'].get(class_index, 0)
|
|
646
|
+
table.add_row([class_name, class_index, train_count, test_count, valid_count, total])
|
|
647
|
+
|
|
648
|
+
print(table)
|
|
649
|
+
|
|
650
|
+
def update_labels(class_mapping: dict, labels_path: Path) -> dict:
|
|
651
|
+
"""
|
|
652
|
+
Updates the labels based on the class mapping.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
class_mapping (dict): A dictionary mapping class names to image file names.
|
|
656
|
+
labels_path (Path): The path to the directory containing the label files.
|
|
657
|
+
|
|
658
|
+
Returns:
|
|
659
|
+
dict: A dictionary mapping class names to class indices.
|
|
660
|
+
"""
|
|
661
|
+
class_index_mapping = {}
|
|
662
|
+
class_index_definition = {}
|
|
663
|
+
|
|
664
|
+
for idx, (class_name, images) in enumerate(class_mapping.items()):
|
|
665
|
+
class_index_definition[idx] = class_name
|
|
666
|
+
for image_name in images:
|
|
667
|
+
class_index_mapping[image_name] = idx
|
|
668
|
+
|
|
669
|
+
for txt_file in labels_path.glob("*.txt"):
|
|
670
|
+
image_name_jpg = txt_file.stem + ".jpg"
|
|
671
|
+
image_name_jpeg = txt_file.stem + ".jpeg"
|
|
672
|
+
|
|
673
|
+
if image_name_jpg in class_index_mapping:
|
|
674
|
+
class_index = class_index_mapping[image_name_jpg]
|
|
675
|
+
elif image_name_jpeg in class_index_mapping:
|
|
676
|
+
class_index = class_index_mapping[image_name_jpeg]
|
|
677
|
+
else:
|
|
678
|
+
print(f"Warning: No corresponding image found for {txt_file.name}")
|
|
679
|
+
continue
|
|
680
|
+
|
|
681
|
+
with open(txt_file, 'r') as file:
|
|
682
|
+
lines = file.readlines()
|
|
683
|
+
|
|
684
|
+
updated_lines = []
|
|
685
|
+
for line in lines:
|
|
686
|
+
parts = line.split()
|
|
687
|
+
if len(parts) > 0:
|
|
688
|
+
parts[0] = str(class_index)
|
|
689
|
+
updated_lines.append(" ".join(parts))
|
|
690
|
+
|
|
691
|
+
with open(txt_file, 'w') as file:
|
|
692
|
+
file.write("\n".join(updated_lines))
|
|
693
|
+
|
|
694
|
+
print(f"Labels updated successfully")
|
|
695
|
+
return class_index_definition
|
|
696
|
+
|
|
697
|
+
def count_images_across_splits(output_directory: Path) -> int:
|
|
698
|
+
"""
|
|
699
|
+
Counts the total number of images across train, test, and validation splits.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
output_directory (Path): The path to the output directory containing the split data.
|
|
703
|
+
|
|
704
|
+
Returns:
|
|
705
|
+
int: The total number of images across all splits.
|
|
706
|
+
"""
|
|
707
|
+
total_images = 0
|
|
708
|
+
for split in ['train', 'test', 'valid']:
|
|
709
|
+
split_dir = output_directory / split / 'images'
|
|
710
|
+
total_images += len(list(split_dir.glob("*.jpg"))) + len(list(split_dir.glob("*.jpeg")))
|
|
711
|
+
|
|
712
|
+
return total_images
|
|
713
|
+
|
|
714
|
+
def __make_yaml_file(output_directory: Path, class_idxs: dict):
|
|
715
|
+
"""
|
|
716
|
+
Creates a YAML configuration file for YOLOv8.
|
|
717
|
+
|
|
718
|
+
Args:
|
|
719
|
+
output_directory (Path): The path to the output directory where the YAML file will be saved.
|
|
720
|
+
class_idxs (dict): A dictionary mapping class indices to class names.
|
|
721
|
+
"""
|
|
722
|
+
|
|
723
|
+
# Define the structure of the YAML file
|
|
724
|
+
yaml_content = {
|
|
725
|
+
'path': str(output_directory.resolve()),
|
|
726
|
+
'train': 'train/images',
|
|
727
|
+
'val': 'valid/images',
|
|
728
|
+
'test': 'test/images',
|
|
729
|
+
'names': {idx: name for idx, name in class_idxs.items()}
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
# Write the YAML content to a file
|
|
733
|
+
yaml_file_path = output_directory / 'dataset.yaml'
|
|
734
|
+
with open(yaml_file_path, 'w') as yaml_file:
|
|
735
|
+
yaml.dump(yaml_content, yaml_file, default_flow_style=False, sort_keys=False)
|
|
736
|
+
|
|
737
|
+
print(f"YOLOv8 YAML file created at {yaml_file_path}")
|