bplusplus 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

Files changed (95) hide show
  1. bplusplus/__init__.py +5 -3
  2. bplusplus/{collect_images.py → collect.py} +3 -3
  3. bplusplus/prepare.py +573 -0
  4. bplusplus/train_validate.py +8 -64
  5. bplusplus/yolov5detect/__init__.py +1 -0
  6. bplusplus/yolov5detect/detect.py +444 -0
  7. bplusplus/yolov5detect/export.py +1530 -0
  8. bplusplus/yolov5detect/insect.yaml +8 -0
  9. bplusplus/yolov5detect/models/__init__.py +0 -0
  10. bplusplus/yolov5detect/models/common.py +1109 -0
  11. bplusplus/yolov5detect/models/experimental.py +130 -0
  12. bplusplus/yolov5detect/models/hub/anchors.yaml +56 -0
  13. bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +52 -0
  14. bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +42 -0
  15. bplusplus/yolov5detect/models/hub/yolov3.yaml +52 -0
  16. bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +49 -0
  17. bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +43 -0
  18. bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +55 -0
  19. bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +42 -0
  20. bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +57 -0
  21. bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +68 -0
  22. bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +49 -0
  23. bplusplus/yolov5detect/models/hub/yolov5l6.yaml +61 -0
  24. bplusplus/yolov5detect/models/hub/yolov5m6.yaml +61 -0
  25. bplusplus/yolov5detect/models/hub/yolov5n6.yaml +61 -0
  26. bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +50 -0
  27. bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +49 -0
  28. bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +49 -0
  29. bplusplus/yolov5detect/models/hub/yolov5s6.yaml +61 -0
  30. bplusplus/yolov5detect/models/hub/yolov5x6.yaml +61 -0
  31. bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +49 -0
  32. bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +49 -0
  33. bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +49 -0
  34. bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +49 -0
  35. bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +49 -0
  36. bplusplus/yolov5detect/models/tf.py +797 -0
  37. bplusplus/yolov5detect/models/yolo.py +495 -0
  38. bplusplus/yolov5detect/models/yolov5l.yaml +49 -0
  39. bplusplus/yolov5detect/models/yolov5m.yaml +49 -0
  40. bplusplus/yolov5detect/models/yolov5n.yaml +49 -0
  41. bplusplus/yolov5detect/models/yolov5s.yaml +49 -0
  42. bplusplus/yolov5detect/models/yolov5x.yaml +49 -0
  43. bplusplus/yolov5detect/utils/__init__.py +97 -0
  44. bplusplus/yolov5detect/utils/activations.py +134 -0
  45. bplusplus/yolov5detect/utils/augmentations.py +448 -0
  46. bplusplus/yolov5detect/utils/autoanchor.py +175 -0
  47. bplusplus/yolov5detect/utils/autobatch.py +70 -0
  48. bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
  49. bplusplus/yolov5detect/utils/aws/mime.sh +26 -0
  50. bplusplus/yolov5detect/utils/aws/resume.py +41 -0
  51. bplusplus/yolov5detect/utils/aws/userdata.sh +27 -0
  52. bplusplus/yolov5detect/utils/callbacks.py +72 -0
  53. bplusplus/yolov5detect/utils/dataloaders.py +1385 -0
  54. bplusplus/yolov5detect/utils/docker/Dockerfile +73 -0
  55. bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +40 -0
  56. bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +42 -0
  57. bplusplus/yolov5detect/utils/downloads.py +136 -0
  58. bplusplus/yolov5detect/utils/flask_rest_api/README.md +70 -0
  59. bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +17 -0
  60. bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +49 -0
  61. bplusplus/yolov5detect/utils/general.py +1294 -0
  62. bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +25 -0
  63. bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +6 -0
  64. bplusplus/yolov5detect/utils/google_app_engine/app.yaml +16 -0
  65. bplusplus/yolov5detect/utils/loggers/__init__.py +476 -0
  66. bplusplus/yolov5detect/utils/loggers/clearml/README.md +222 -0
  67. bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
  68. bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +230 -0
  69. bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +90 -0
  70. bplusplus/yolov5detect/utils/loggers/comet/README.md +250 -0
  71. bplusplus/yolov5detect/utils/loggers/comet/__init__.py +551 -0
  72. bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +151 -0
  73. bplusplus/yolov5detect/utils/loggers/comet/hpo.py +126 -0
  74. bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +135 -0
  75. bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
  76. bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +210 -0
  77. bplusplus/yolov5detect/utils/loss.py +259 -0
  78. bplusplus/yolov5detect/utils/metrics.py +381 -0
  79. bplusplus/yolov5detect/utils/plots.py +517 -0
  80. bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
  81. bplusplus/yolov5detect/utils/segment/augmentations.py +100 -0
  82. bplusplus/yolov5detect/utils/segment/dataloaders.py +366 -0
  83. bplusplus/yolov5detect/utils/segment/general.py +160 -0
  84. bplusplus/yolov5detect/utils/segment/loss.py +198 -0
  85. bplusplus/yolov5detect/utils/segment/metrics.py +225 -0
  86. bplusplus/yolov5detect/utils/segment/plots.py +152 -0
  87. bplusplus/yolov5detect/utils/torch_utils.py +482 -0
  88. bplusplus/yolov5detect/utils/triton.py +90 -0
  89. bplusplus-1.1.0.dist-info/METADATA +179 -0
  90. bplusplus-1.1.0.dist-info/RECORD +92 -0
  91. bplusplus/build_model.py +0 -38
  92. bplusplus-0.1.0.dist-info/METADATA +0 -91
  93. bplusplus-0.1.0.dist-info/RECORD +0 -8
  94. {bplusplus-0.1.0.dist-info → bplusplus-1.1.0.dist-info}/LICENSE +0 -0
  95. {bplusplus-0.1.0.dist-info → bplusplus-1.1.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1385 @@
1
+ # Ultralytics YOLOv5 🚀, AGPL-3.0 license
2
+ """Dataloaders and dataset utils."""
3
+
4
+ import contextlib
5
+ import glob
6
+ import hashlib
7
+ import json
8
+ import math
9
+ import os
10
+ import random
11
+ import shutil
12
+ import time
13
+ from itertools import repeat
14
+ from multiprocessing.pool import Pool, ThreadPool
15
+ from pathlib import Path
16
+ from threading import Thread
17
+ from urllib.parse import urlparse
18
+
19
+ import numpy as np
20
+ import psutil
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchvision
24
+ import yaml
25
+ from PIL import ExifTags, Image, ImageOps
26
+ from torch.utils.data import DataLoader, Dataset, dataloader, distributed
27
+ from tqdm import tqdm
28
+
29
+ from utils.augmentations import (
30
+ Albumentations,
31
+ augment_hsv,
32
+ classify_albumentations,
33
+ classify_transforms,
34
+ copy_paste,
35
+ letterbox,
36
+ mixup,
37
+ random_perspective,
38
+ )
39
+ from utils.general import (
40
+ DATASETS_DIR,
41
+ LOGGER,
42
+ NUM_THREADS,
43
+ TQDM_BAR_FORMAT,
44
+ check_dataset,
45
+ check_requirements,
46
+ check_yaml,
47
+ clean_str,
48
+ cv2,
49
+ is_colab,
50
+ is_kaggle,
51
+ segments2boxes,
52
+ unzip_file,
53
+ xyn2xy,
54
+ xywh2xyxy,
55
+ xywhn2xyxy,
56
+ xyxy2xywhn,
57
+ )
58
+ from utils.torch_utils import torch_distributed_zero_first
59
+
60
+ # Parameters
61
+ HELP_URL = "See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data"
62
+ IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
63
+ VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
64
+ LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
65
+ RANK = int(os.getenv("RANK", -1))
66
+ WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
67
+ PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
68
+
69
+ # Get orientation exif tag
70
+ for orientation in ExifTags.TAGS.keys():
71
+ if ExifTags.TAGS[orientation] == "Orientation":
72
+ break
73
+
74
+
75
+ def get_hash(paths):
76
+ """Generates a single SHA256 hash for a list of file or directory paths by combining their sizes and paths."""
77
+ size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
78
+ h = hashlib.sha256(str(size).encode()) # hash sizes
79
+ h.update("".join(paths).encode()) # hash paths
80
+ return h.hexdigest() # return hash
81
+
82
+
83
+ def exif_size(img):
84
+ """Returns corrected PIL image size (width, height) considering EXIF orientation."""
85
+ s = img.size # (width, height)
86
+ with contextlib.suppress(Exception):
87
+ rotation = dict(img._getexif().items())[orientation]
88
+ if rotation in [6, 8]: # rotation 270 or 90
89
+ s = (s[1], s[0])
90
+ return s
91
+
92
+
93
+ def exif_transpose(image):
94
+ """
95
+ Transpose a PIL image accordingly if it has an EXIF Orientation tag.
96
+ Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose().
97
+
98
+ :param image: The image to transpose.
99
+ :return: An image.
100
+ """
101
+ exif = image.getexif()
102
+ orientation = exif.get(0x0112, 1) # default 1
103
+ if orientation > 1:
104
+ method = {
105
+ 2: Image.FLIP_LEFT_RIGHT,
106
+ 3: Image.ROTATE_180,
107
+ 4: Image.FLIP_TOP_BOTTOM,
108
+ 5: Image.TRANSPOSE,
109
+ 6: Image.ROTATE_270,
110
+ 7: Image.TRANSVERSE,
111
+ 8: Image.ROTATE_90,
112
+ }.get(orientation)
113
+ if method is not None:
114
+ image = image.transpose(method)
115
+ del exif[0x0112]
116
+ image.info["exif"] = exif.tobytes()
117
+ return image
118
+
119
+
120
+ def seed_worker(worker_id):
121
+ """
122
+ Sets the seed for a dataloader worker to ensure reproducibility, based on PyTorch's randomness notes.
123
+
124
+ See https://pytorch.org/docs/stable/notes/randomness.html#dataloader.
125
+ """
126
+ worker_seed = torch.initial_seed() % 2**32
127
+ np.random.seed(worker_seed)
128
+ random.seed(worker_seed)
129
+
130
+
131
+ # Inherit from DistributedSampler and override iterator
132
+ # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
133
+ class SmartDistributedSampler(distributed.DistributedSampler):
134
+ """A distributed sampler ensuring deterministic shuffling and balanced data distribution across GPUs."""
135
+
136
+ def __iter__(self):
137
+ """Yields indices for distributed data sampling, shuffled deterministically based on epoch and seed."""
138
+ g = torch.Generator()
139
+ g.manual_seed(self.seed + self.epoch)
140
+
141
+ # determine the eventual size (n) of self.indices (DDP indices)
142
+ n = int((len(self.dataset) - self.rank - 1) / self.num_replicas) + 1 # num_replicas == WORLD_SIZE
143
+ idx = torch.randperm(n, generator=g)
144
+ if not self.shuffle:
145
+ idx = idx.sort()[0]
146
+
147
+ idx = idx.tolist()
148
+ if self.drop_last:
149
+ idx = idx[: self.num_samples]
150
+ else:
151
+ padding_size = self.num_samples - len(idx)
152
+ if padding_size <= len(idx):
153
+ idx += idx[:padding_size]
154
+ else:
155
+ idx += (idx * math.ceil(padding_size / len(idx)))[:padding_size]
156
+
157
+ return iter(idx)
158
+
159
+
160
+ def create_dataloader(
161
+ path,
162
+ imgsz,
163
+ batch_size,
164
+ stride,
165
+ single_cls=False,
166
+ hyp=None,
167
+ augment=False,
168
+ cache=False,
169
+ pad=0.0,
170
+ rect=False,
171
+ rank=-1,
172
+ workers=8,
173
+ image_weights=False,
174
+ quad=False,
175
+ prefix="",
176
+ shuffle=False,
177
+ seed=0,
178
+ ):
179
+ """Creates and returns a configured DataLoader instance for loading and processing image datasets."""
180
+ if rect and shuffle:
181
+ LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
182
+ shuffle = False
183
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
184
+ dataset = LoadImagesAndLabels(
185
+ path,
186
+ imgsz,
187
+ batch_size,
188
+ augment=augment, # augmentation
189
+ hyp=hyp, # hyperparameters
190
+ rect=rect, # rectangular batches
191
+ cache_images=cache,
192
+ single_cls=single_cls,
193
+ stride=int(stride),
194
+ pad=pad,
195
+ image_weights=image_weights,
196
+ prefix=prefix,
197
+ rank=rank,
198
+ )
199
+
200
+ batch_size = min(batch_size, len(dataset))
201
+ nd = torch.cuda.device_count() # number of CUDA devices
202
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
203
+ sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
204
+ loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
205
+ generator = torch.Generator()
206
+ generator.manual_seed(6148914691236517205 + seed + RANK)
207
+ return loader(
208
+ dataset,
209
+ batch_size=batch_size,
210
+ shuffle=shuffle and sampler is None,
211
+ num_workers=nw,
212
+ sampler=sampler,
213
+ pin_memory=PIN_MEMORY,
214
+ collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
215
+ worker_init_fn=seed_worker,
216
+ generator=generator,
217
+ ), dataset
218
+
219
+
220
+ class InfiniteDataLoader(dataloader.DataLoader):
221
+ """
222
+ Dataloader that reuses workers.
223
+
224
+ Uses same syntax as vanilla DataLoader
225
+ """
226
+
227
+ def __init__(self, *args, **kwargs):
228
+ """Initializes an InfiniteDataLoader that reuses workers with standard DataLoader syntax, augmenting with a
229
+ repeating sampler.
230
+ """
231
+ super().__init__(*args, **kwargs)
232
+ object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
233
+ self.iterator = super().__iter__()
234
+
235
+ def __len__(self):
236
+ """Returns the length of the batch sampler's sampler in the InfiniteDataLoader."""
237
+ return len(self.batch_sampler.sampler)
238
+
239
+ def __iter__(self):
240
+ """Yields batches of data indefinitely in a loop by resetting the sampler when exhausted."""
241
+ for _ in range(len(self)):
242
+ yield next(self.iterator)
243
+
244
+
245
+ class _RepeatSampler:
246
+ """
247
+ Sampler that repeats forever.
248
+
249
+ Args:
250
+ sampler (Sampler)
251
+ """
252
+
253
+ def __init__(self, sampler):
254
+ """Initializes a perpetual sampler wrapping a provided `Sampler` instance for endless data iteration."""
255
+ self.sampler = sampler
256
+
257
+ def __iter__(self):
258
+ """Returns an infinite iterator over the dataset by repeatedly yielding from the given sampler."""
259
+ while True:
260
+ yield from iter(self.sampler)
261
+
262
+
263
+ class LoadScreenshots:
264
+ """Loads and processes screenshots for YOLOv5 detection from specified screen regions using mss."""
265
+
266
+ def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
267
+ """
268
+ Initializes a screenshot dataloader for YOLOv5 with specified source region, image size, stride, auto, and
269
+ transforms.
270
+
271
+ Source = [screen_number left top width height] (pixels)
272
+ """
273
+ check_requirements("mss")
274
+ import mss
275
+
276
+ source, *params = source.split()
277
+ self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
278
+ if len(params) == 1:
279
+ self.screen = int(params[0])
280
+ elif len(params) == 4:
281
+ left, top, width, height = (int(x) for x in params)
282
+ elif len(params) == 5:
283
+ self.screen, left, top, width, height = (int(x) for x in params)
284
+ self.img_size = img_size
285
+ self.stride = stride
286
+ self.transforms = transforms
287
+ self.auto = auto
288
+ self.mode = "stream"
289
+ self.frame = 0
290
+ self.sct = mss.mss()
291
+
292
+ # Parse monitor shape
293
+ monitor = self.sct.monitors[self.screen]
294
+ self.top = monitor["top"] if top is None else (monitor["top"] + top)
295
+ self.left = monitor["left"] if left is None else (monitor["left"] + left)
296
+ self.width = width or monitor["width"]
297
+ self.height = height or monitor["height"]
298
+ self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
299
+
300
+ def __iter__(self):
301
+ """Iterates over itself, enabling use in loops and iterable contexts."""
302
+ return self
303
+
304
+ def __next__(self):
305
+ """Captures and returns the next screen frame as a BGR numpy array, cropping to only the first three channels
306
+ from BGRA.
307
+ """
308
+ im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
309
+ s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
310
+
311
+ if self.transforms:
312
+ im = self.transforms(im0) # transforms
313
+ else:
314
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
315
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
316
+ im = np.ascontiguousarray(im) # contiguous
317
+ self.frame += 1
318
+ return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
319
+
320
+
321
+ class LoadImages:
322
+ """YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`."""
323
+
324
+ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
325
+ """Initializes YOLOv5 loader for images/videos, supporting glob patterns, directories, and lists of paths."""
326
+ if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
327
+ path = Path(path).read_text().rsplit()
328
+ files = []
329
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
330
+ p = str(Path(p).resolve())
331
+ if "*" in p:
332
+ files.extend(sorted(glob.glob(p, recursive=True))) # glob
333
+ elif os.path.isdir(p):
334
+ files.extend(sorted(glob.glob(os.path.join(p, "*.*")))) # dir
335
+ elif os.path.isfile(p):
336
+ files.append(p) # files
337
+ else:
338
+ raise FileNotFoundError(f"{p} does not exist")
339
+
340
+ images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
341
+ videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
342
+ ni, nv = len(images), len(videos)
343
+
344
+ self.img_size = img_size
345
+ self.stride = stride
346
+ self.files = images + videos
347
+ self.nf = ni + nv # number of files
348
+ self.video_flag = [False] * ni + [True] * nv
349
+ self.mode = "image"
350
+ self.auto = auto
351
+ self.transforms = transforms # optional
352
+ self.vid_stride = vid_stride # video frame-rate stride
353
+ if any(videos):
354
+ self._new_video(videos[0]) # new video
355
+ else:
356
+ self.cap = None
357
+ assert self.nf > 0, (
358
+ f"No images or videos found in {p}. "
359
+ f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
360
+ )
361
+
362
+ def __iter__(self):
363
+ """Initializes iterator by resetting count and returns the iterator object itself."""
364
+ self.count = 0
365
+ return self
366
+
367
+ def __next__(self):
368
+ """Advances to the next file in the dataset, raising StopIteration if at the end."""
369
+ if self.count == self.nf:
370
+ raise StopIteration
371
+ path = self.files[self.count]
372
+
373
+ if self.video_flag[self.count]:
374
+ # Read video
375
+ self.mode = "video"
376
+ for _ in range(self.vid_stride):
377
+ self.cap.grab()
378
+ ret_val, im0 = self.cap.retrieve()
379
+ while not ret_val:
380
+ self.count += 1
381
+ self.cap.release()
382
+ if self.count == self.nf: # last video
383
+ raise StopIteration
384
+ path = self.files[self.count]
385
+ self._new_video(path)
386
+ ret_val, im0 = self.cap.read()
387
+
388
+ self.frame += 1
389
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
390
+ s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
391
+
392
+ else:
393
+ # Read image
394
+ self.count += 1
395
+ try:
396
+ im0 = cv2.imread(path) # BGR
397
+ if im0 is None:
398
+ raise FileNotFoundError(f"Image Not Found: {path}")
399
+ except Exception as e:
400
+ print(f"Error loading image {path}: {e}")
401
+ return self.__next__() # Skip to the next image
402
+
403
+ s = f"image {self.count}/{self.nf} {path}: "
404
+ if self.transforms:
405
+ im = self.transforms(im0) # transforms
406
+ else:
407
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
408
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
409
+ im = np.ascontiguousarray(im) # contiguous
410
+
411
+ return path, im, im0, self.cap, s
412
+
413
+ def _new_video(self, path):
414
+ """Initializes a new video capture object with path, frame count adjusted by stride, and orientation
415
+ metadata.
416
+ """
417
+ self.frame = 0
418
+ self.cap = cv2.VideoCapture(path)
419
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
420
+ self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
421
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
422
+
423
+ def _cv2_rotate(self, im):
424
+ """Rotates a cv2 image based on its orientation; supports 0, 90, and 180 degrees rotations."""
425
+ if self.orientation == 0:
426
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
427
+ elif self.orientation == 180:
428
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
429
+ elif self.orientation == 90:
430
+ return cv2.rotate(im, cv2.ROTATE_180)
431
+ return im
432
+
433
+ def __len__(self):
434
+ """Returns the number of files in the dataset."""
435
+ return self.nf # number of files
436
+
437
+
438
+ class LoadStreams:
439
+ """Loads and processes video streams for YOLOv5, supporting various sources including YouTube and IP cameras."""
440
+
441
+ def __init__(self, sources="file.streams", img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
442
+ """Initializes a stream loader for processing video streams with YOLOv5, supporting various sources including
443
+ YouTube.
444
+ """
445
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
446
+ self.mode = "stream"
447
+ self.img_size = img_size
448
+ self.stride = stride
449
+ self.vid_stride = vid_stride # video frame-rate stride
450
+ sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
451
+ n = len(sources)
452
+ self.sources = [clean_str(x) for x in sources] # clean source names for later
453
+ self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
454
+ for i, s in enumerate(sources): # index, source
455
+ # Start thread to read frames from video stream
456
+ st = f"{i + 1}/{n}: {s}... "
457
+ if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
458
+ # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
459
+ check_requirements(("pafy", "youtube_dl==2020.12.2"))
460
+ import pafy
461
+
462
+ s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
463
+ s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
464
+ if s == 0:
465
+ assert not is_colab(), "--source 0 webcam unsupported on Colab. Rerun command in a local environment."
466
+ assert not is_kaggle(), "--source 0 webcam unsupported on Kaggle. Rerun command in a local environment."
467
+ cap = cv2.VideoCapture(s)
468
+ assert cap.isOpened(), f"{st}Failed to open {s}"
469
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
470
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
471
+ fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
472
+ self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float("inf") # infinite stream fallback
473
+ self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
474
+
475
+ _, self.imgs[i] = cap.read() # guarantee first frame
476
+ self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
477
+ LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
478
+ self.threads[i].start()
479
+ LOGGER.info("") # newline
480
+
481
+ # check for common shapes
482
+ s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
483
+ self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
484
+ self.auto = auto and self.rect
485
+ self.transforms = transforms # optional
486
+ if not self.rect:
487
+ LOGGER.warning("WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.")
488
+
489
+ def update(self, i, cap, stream):
490
+ """Reads frames from stream `i`, updating imgs array; handles stream reopening on signal loss."""
491
+ n, f = 0, self.frames[i] # frame number, frame array
492
+ while cap.isOpened() and n < f:
493
+ n += 1
494
+ cap.grab() # .read() = .grab() followed by .retrieve()
495
+ if n % self.vid_stride == 0:
496
+ success, im = cap.retrieve()
497
+ if success:
498
+ self.imgs[i] = im
499
+ else:
500
+ LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
501
+ self.imgs[i] = np.zeros_like(self.imgs[i])
502
+ cap.open(stream) # re-open stream if signal was lost
503
+ time.sleep(0.0) # wait time
504
+
505
+ def __iter__(self):
506
+ """Resets and returns the iterator for iterating over video frames or images in a dataset."""
507
+ self.count = -1
508
+ return self
509
+
510
+ def __next__(self):
511
+ """Iterates over video frames or images, halting on thread stop or 'q' key press, raising `StopIteration` when
512
+ done.
513
+ """
514
+ self.count += 1
515
+ if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord("q"): # q to quit
516
+ cv2.destroyAllWindows()
517
+ raise StopIteration
518
+
519
+ im0 = self.imgs.copy()
520
+ if self.transforms:
521
+ im = np.stack([self.transforms(x) for x in im0]) # transforms
522
+ else:
523
+ im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
524
+ im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
525
+ im = np.ascontiguousarray(im) # contiguous
526
+
527
+ return self.sources, im, im0, None, ""
528
+
529
+ def __len__(self):
530
+ """Returns the number of sources in the dataset, supporting up to 32 streams at 30 FPS over 30 years."""
531
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
532
+
533
+
534
+ def img2label_paths(img_paths):
535
+ """Generates label file paths from corresponding image file paths by replacing `/images/` with `/labels/` and
536
+ extension with `.txt`.
537
+ """
538
+ sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
539
+ return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
540
+
541
+
542
+ class LoadImagesAndLabels(Dataset):
543
+ """Loads images and their corresponding labels for training and validation in YOLOv5."""
544
+
545
+ cache_version = 0.6 # dataset labels *.cache version
546
+ rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
547
+
548
+ def __init__(
549
+ self,
550
+ path,
551
+ img_size=640,
552
+ batch_size=16,
553
+ augment=False,
554
+ hyp=None,
555
+ rect=False,
556
+ image_weights=False,
557
+ cache_images=False,
558
+ single_cls=False,
559
+ stride=32,
560
+ pad=0.0,
561
+ min_items=0,
562
+ prefix="",
563
+ rank=-1,
564
+ seed=0,
565
+ ):
566
+ """Initializes the YOLOv5 dataset loader, handling images and their labels, caching, and preprocessing."""
567
+ self.img_size = img_size
568
+ self.augment = augment
569
+ self.hyp = hyp
570
+ self.image_weights = image_weights
571
+ self.rect = False if image_weights else rect
572
+ self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
573
+ self.mosaic_border = [-img_size // 2, -img_size // 2]
574
+ self.stride = stride
575
+ self.path = path
576
+ self.albumentations = Albumentations(size=img_size) if augment else None
577
+
578
+ try:
579
+ f = [] # image files
580
+ for p in path if isinstance(path, list) else [path]:
581
+ p = Path(p) # os-agnostic
582
+ if p.is_dir(): # dir
583
+ f += glob.glob(str(p / "**" / "*.*"), recursive=True)
584
+ # f = list(p.rglob('*.*')) # pathlib
585
+ elif p.is_file(): # file
586
+ with open(p) as t:
587
+ t = t.read().strip().splitlines()
588
+ parent = str(p.parent) + os.sep
589
+ f += [x.replace("./", parent, 1) if x.startswith("./") else x for x in t] # to global path
590
+ # f += [p.parent / x.lstrip(os.sep) for x in t] # to global path (pathlib)
591
+ else:
592
+ raise FileNotFoundError(f"{prefix}{p} does not exist")
593
+ self.im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
594
+ # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
595
+ assert self.im_files, f"{prefix}No images found"
596
+ except Exception as e:
597
+ raise Exception(f"{prefix}Error loading data from {path}: {e}\n{HELP_URL}") from e
598
+
599
+ # Check cache
600
+ self.label_files = img2label_paths(self.im_files) # labels
601
+ cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix(".cache")
602
+ try:
603
+ cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
604
+ assert cache["version"] == self.cache_version # matches current version
605
+ assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
606
+ except Exception:
607
+ cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
608
+
609
+ # Display cache
610
+ nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
611
+ if exists and LOCAL_RANK in {-1, 0}:
612
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
613
+ tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
614
+ if cache["msgs"]:
615
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
616
+ assert nf > 0 or not augment, f"{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
617
+
618
+ # Read cache
619
+ [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
620
+ labels, shapes, self.segments = zip(*cache.values())
621
+ nl = len(np.concatenate(labels, 0)) # number of labels
622
+ assert nl > 0 or not augment, f"{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
623
+ self.labels = list(labels)
624
+ self.shapes = np.array(shapes)
625
+ self.im_files = list(cache.keys()) # update
626
+ self.label_files = img2label_paths(cache.keys()) # update
627
+
628
+ # Filter images
629
+ if min_items:
630
+ include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
631
+ LOGGER.info(f"{prefix}{n - len(include)}/{n} images filtered from dataset")
632
+ self.im_files = [self.im_files[i] for i in include]
633
+ self.label_files = [self.label_files[i] for i in include]
634
+ self.labels = [self.labels[i] for i in include]
635
+ self.segments = [self.segments[i] for i in include]
636
+ self.shapes = self.shapes[include] # wh
637
+
638
+ # Create indices
639
+ n = len(self.shapes) # number of images
640
+ bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
641
+ nb = bi[-1] + 1 # number of batches
642
+ self.batch = bi # batch index of image
643
+ self.n = n
644
+ self.indices = np.arange(n)
645
+ if rank > -1: # DDP indices (see: SmartDistributedSampler)
646
+ # force each rank (i.e. GPU process) to sample the same subset of data on every epoch
647
+ self.indices = self.indices[np.random.RandomState(seed=seed).permutation(n) % WORLD_SIZE == RANK]
648
+
649
+ # Update labels
650
+ include_class = [] # filter labels to include only these classes (optional)
651
+ self.segments = list(self.segments)
652
+ include_class_array = np.array(include_class).reshape(1, -1)
653
+ for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
654
+ if include_class:
655
+ j = (label[:, 0:1] == include_class_array).any(1)
656
+ self.labels[i] = label[j]
657
+ if segment:
658
+ self.segments[i] = [segment[idx] for idx, elem in enumerate(j) if elem]
659
+ if single_cls: # single-class training, merge all classes into 0
660
+ self.labels[i][:, 0] = 0
661
+
662
+ # Rectangular Training
663
+ if self.rect:
664
+ # Sort by aspect ratio
665
+ s = self.shapes # wh
666
+ ar = s[:, 1] / s[:, 0] # aspect ratio
667
+ irect = ar.argsort()
668
+ self.im_files = [self.im_files[i] for i in irect]
669
+ self.label_files = [self.label_files[i] for i in irect]
670
+ self.labels = [self.labels[i] for i in irect]
671
+ self.segments = [self.segments[i] for i in irect]
672
+ self.shapes = s[irect] # wh
673
+ ar = ar[irect]
674
+
675
+ # Set training image shapes
676
+ shapes = [[1, 1]] * nb
677
+ for i in range(nb):
678
+ ari = ar[bi == i]
679
+ mini, maxi = ari.min(), ari.max()
680
+ if maxi < 1:
681
+ shapes[i] = [maxi, 1]
682
+ elif mini > 1:
683
+ shapes[i] = [1, 1 / mini]
684
+
685
+ self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
686
+
687
+ # Cache images into RAM/disk for faster training
688
+ if cache_images == "ram" and not self.check_cache_ram(prefix=prefix):
689
+ cache_images = False
690
+ self.ims = [None] * n
691
+ self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
692
+ if cache_images:
693
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
694
+ self.im_hw0, self.im_hw = [None] * n, [None] * n
695
+ fcn = self.cache_images_to_disk if cache_images == "disk" else self.load_image
696
+ results = ThreadPool(NUM_THREADS).imap(lambda i: (i, fcn(i)), self.indices)
697
+ pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
698
+ for i, x in pbar:
699
+ if cache_images == "disk":
700
+ b += self.npy_files[i].stat().st_size
701
+ else: # 'ram'
702
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
703
+ b += self.ims[i].nbytes * WORLD_SIZE
704
+ pbar.desc = f"{prefix}Caching images ({b / gb:.1f}GB {cache_images})"
705
+ pbar.close()
706
+
707
+ def check_cache_ram(self, safety_margin=0.1, prefix=""):
708
+ """Checks if available RAM is sufficient for caching images, adjusting for a safety margin."""
709
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
710
+ n = min(self.n, 30) # extrapolate from 30 random images
711
+ for _ in range(n):
712
+ im = cv2.imread(random.choice(self.im_files)) # sample image
713
+ ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
714
+ b += im.nbytes * ratio**2
715
+ mem_required = b * self.n / n # GB required to cache dataset into RAM
716
+ mem = psutil.virtual_memory()
717
+ cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
718
+ if not cache:
719
+ LOGGER.info(
720
+ f'{prefix}{mem_required / gb:.1f}GB RAM required, '
721
+ f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
722
+ f"{'caching images ✅' if cache else 'not caching images ⚠️'}"
723
+ )
724
+ return cache
725
+
726
+ def cache_labels(self, path=Path("./labels.cache"), prefix=""):
727
+ """Caches dataset labels, verifies images, reads shapes, and tracks dataset integrity."""
728
+ x = {} # dict
729
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
730
+ desc = f"{prefix}Scanning {path.parent / path.stem}..."
731
+ with Pool(NUM_THREADS) as pool:
732
+ pbar = tqdm(
733
+ pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
734
+ desc=desc,
735
+ total=len(self.im_files),
736
+ bar_format=TQDM_BAR_FORMAT,
737
+ )
738
+ for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
739
+ nm += nm_f
740
+ nf += nf_f
741
+ ne += ne_f
742
+ nc += nc_f
743
+ if im_file:
744
+ x[im_file] = [lb, shape, segments]
745
+ if msg:
746
+ msgs.append(msg)
747
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
748
+
749
+ pbar.close()
750
+ if msgs:
751
+ LOGGER.info("\n".join(msgs))
752
+ if nf == 0:
753
+ LOGGER.warning(f"{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
754
+ x["hash"] = get_hash(self.label_files + self.im_files)
755
+ x["results"] = nf, nm, ne, nc, len(self.im_files)
756
+ x["msgs"] = msgs # warnings
757
+ x["version"] = self.cache_version # cache version
758
+ try:
759
+ np.save(path, x) # save cache for next time
760
+ path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
761
+ LOGGER.info(f"{prefix}New cache created: {path}")
762
+ except Exception as e:
763
+ LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}") # not writeable
764
+ return x
765
+
766
+ def __len__(self):
767
+ """Returns the number of images in the dataset."""
768
+ return len(self.im_files)
769
+
770
+ # def __iter__(self):
771
+ # self.count = -1
772
+ # print('ran dataset iter')
773
+ # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
774
+ # return self
775
+
776
+ def __getitem__(self, index):
777
+ """Fetches the dataset item at the given index, considering linear, shuffled, or weighted sampling."""
778
+ index = self.indices[index] # linear, shuffled, or image_weights
779
+
780
+ hyp = self.hyp
781
+ mosaic = self.mosaic and random.random() < hyp["mosaic"]
782
+ if mosaic:
783
+ # Load mosaic
784
+ img, labels = self.load_mosaic(index)
785
+ shapes = None
786
+
787
+ # MixUp augmentation
788
+ if random.random() < hyp["mixup"]:
789
+ img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))
790
+
791
+ else:
792
+ # Load image
793
+ img, (h0, w0), (h, w) = self.load_image(index)
794
+
795
+ # Letterbox
796
+ shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
797
+ img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
798
+ shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
799
+
800
+ labels = self.labels[index].copy()
801
+ if labels.size: # normalized xywh to pixel xyxy format
802
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
803
+
804
+ if self.augment:
805
+ img, labels = random_perspective(
806
+ img,
807
+ labels,
808
+ degrees=hyp["degrees"],
809
+ translate=hyp["translate"],
810
+ scale=hyp["scale"],
811
+ shear=hyp["shear"],
812
+ perspective=hyp["perspective"],
813
+ )
814
+
815
+ nl = len(labels) # number of labels
816
+ if nl:
817
+ labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)
818
+
819
+ if self.augment:
820
+ # Albumentations
821
+ img, labels = self.albumentations(img, labels)
822
+ nl = len(labels) # update after albumentations
823
+
824
+ # HSV color-space
825
+ augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
826
+
827
+ # Flip up-down
828
+ if random.random() < hyp["flipud"]:
829
+ img = np.flipud(img)
830
+ if nl:
831
+ labels[:, 2] = 1 - labels[:, 2]
832
+
833
+ # Flip left-right
834
+ if random.random() < hyp["fliplr"]:
835
+ img = np.fliplr(img)
836
+ if nl:
837
+ labels[:, 1] = 1 - labels[:, 1]
838
+
839
+ # Cutouts
840
+ # labels = cutout(img, labels, p=0.5)
841
+ # nl = len(labels) # update after cutout
842
+
843
+ labels_out = torch.zeros((nl, 6))
844
+ if nl:
845
+ labels_out[:, 1:] = torch.from_numpy(labels)
846
+
847
+ # Convert
848
+ img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
849
+ img = np.ascontiguousarray(img)
850
+
851
+ return torch.from_numpy(img), labels_out, self.im_files[index], shapes
852
+
853
+ def load_image(self, i):
854
+ """
855
+ Loads an image by index, returning the image, its original dimensions, and resized dimensions.
856
+
857
+ Returns (im, original hw, resized hw)
858
+ """
859
+ im, f, fn = (
860
+ self.ims[i],
861
+ self.im_files[i],
862
+ self.npy_files[i],
863
+ )
864
+ if im is None: # not cached in RAM
865
+ if fn.exists(): # load npy
866
+ im = np.load(fn)
867
+ else: # read image
868
+ im = cv2.imread(f) # BGR
869
+ assert im is not None, f"Image Not Found {f}"
870
+ h0, w0 = im.shape[:2] # orig hw
871
+ r = self.img_size / max(h0, w0) # ratio
872
+ if r != 1: # if sizes are not equal
873
+ interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
874
+ im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
875
+ return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
876
+ return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
877
+
878
+ def cache_images_to_disk(self, i):
879
+ """Saves an image to disk as an *.npy file for quicker loading, identified by index `i`."""
880
+ f = self.npy_files[i]
881
+ if not f.exists():
882
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
883
+
884
+ def load_mosaic(self, index):
885
+ """Loads a 4-image mosaic for YOLOv5, combining 1 selected and 3 random images, with labels and segments."""
886
+ labels4, segments4 = [], []
887
+ s = self.img_size
888
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
889
+ indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
890
+ random.shuffle(indices)
891
+ for i, index in enumerate(indices):
892
+ # Load image
893
+ img, _, (h, w) = self.load_image(index)
894
+
895
+ # place img in img4
896
+ if i == 0: # top left
897
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
898
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
899
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
900
+ elif i == 1: # top right
901
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
902
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
903
+ elif i == 2: # bottom left
904
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
905
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
906
+ elif i == 3: # bottom right
907
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
908
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
909
+
910
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
911
+ padw = x1a - x1b
912
+ padh = y1a - y1b
913
+
914
+ # Labels
915
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
916
+ if labels.size:
917
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
918
+ segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
919
+ labels4.append(labels)
920
+ segments4.extend(segments)
921
+
922
+ # Concat/clip labels
923
+ labels4 = np.concatenate(labels4, 0)
924
+ for x in (labels4[:, 1:], *segments4):
925
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
926
+ # img4, labels4 = replicate(img4, labels4) # replicate
927
+
928
+ # Augment
929
+ img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp["copy_paste"])
930
+ img4, labels4 = random_perspective(
931
+ img4,
932
+ labels4,
933
+ segments4,
934
+ degrees=self.hyp["degrees"],
935
+ translate=self.hyp["translate"],
936
+ scale=self.hyp["scale"],
937
+ shear=self.hyp["shear"],
938
+ perspective=self.hyp["perspective"],
939
+ border=self.mosaic_border,
940
+ ) # border to remove
941
+
942
+ return img4, labels4
943
+
944
+ def load_mosaic9(self, index):
945
+ """Loads 1 image + 8 random images into a 9-image mosaic for augmented YOLOv5 training, returning labels and
946
+ segments.
947
+ """
948
+ labels9, segments9 = [], []
949
+ s = self.img_size
950
+ indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
951
+ random.shuffle(indices)
952
+ hp, wp = -1, -1 # height, width previous
953
+ for i, index in enumerate(indices):
954
+ # Load image
955
+ img, _, (h, w) = self.load_image(index)
956
+
957
+ # place img in img9
958
+ if i == 0: # center
959
+ img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
960
+ h0, w0 = h, w
961
+ c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
962
+ elif i == 1: # top
963
+ c = s, s - h, s + w, s
964
+ elif i == 2: # top right
965
+ c = s + wp, s - h, s + wp + w, s
966
+ elif i == 3: # right
967
+ c = s + w0, s, s + w0 + w, s + h
968
+ elif i == 4: # bottom right
969
+ c = s + w0, s + hp, s + w0 + w, s + hp + h
970
+ elif i == 5: # bottom
971
+ c = s + w0 - w, s + h0, s + w0, s + h0 + h
972
+ elif i == 6: # bottom left
973
+ c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
974
+ elif i == 7: # left
975
+ c = s - w, s + h0 - h, s, s + h0
976
+ elif i == 8: # top left
977
+ c = s - w, s + h0 - hp - h, s, s + h0 - hp
978
+
979
+ padx, pady = c[:2]
980
+ x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
981
+
982
+ # Labels
983
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
984
+ if labels.size:
985
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
986
+ segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
987
+ labels9.append(labels)
988
+ segments9.extend(segments)
989
+
990
+ # Image
991
+ img9[y1:y2, x1:x2] = img[y1 - pady :, x1 - padx :] # img9[ymin:ymax, xmin:xmax]
992
+ hp, wp = h, w # height, width previous
993
+
994
+ # Offset
995
+ yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
996
+ img9 = img9[yc : yc + 2 * s, xc : xc + 2 * s]
997
+
998
+ # Concat/clip labels
999
+ labels9 = np.concatenate(labels9, 0)
1000
+ labels9[:, [1, 3]] -= xc
1001
+ labels9[:, [2, 4]] -= yc
1002
+ c = np.array([xc, yc]) # centers
1003
+ segments9 = [x - c for x in segments9]
1004
+
1005
+ for x in (labels9[:, 1:], *segments9):
1006
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
1007
+ # img9, labels9 = replicate(img9, labels9) # replicate
1008
+
1009
+ # Augment
1010
+ img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp["copy_paste"])
1011
+ img9, labels9 = random_perspective(
1012
+ img9,
1013
+ labels9,
1014
+ segments9,
1015
+ degrees=self.hyp["degrees"],
1016
+ translate=self.hyp["translate"],
1017
+ scale=self.hyp["scale"],
1018
+ shear=self.hyp["shear"],
1019
+ perspective=self.hyp["perspective"],
1020
+ border=self.mosaic_border,
1021
+ ) # border to remove
1022
+
1023
+ return img9, labels9
1024
+
1025
+ @staticmethod
1026
+ def collate_fn(batch):
1027
+ """Batches images, labels, paths, and shapes, assigning unique indices to targets in merged label tensor."""
1028
+ im, label, path, shapes = zip(*batch) # transposed
1029
+ for i, lb in enumerate(label):
1030
+ lb[:, 0] = i # add target image index for build_targets()
1031
+ return torch.stack(im, 0), torch.cat(label, 0), path, shapes
1032
+
1033
+ @staticmethod
1034
+ def collate_fn4(batch):
1035
+ """Bundles a batch's data by quartering the number of shapes and paths, preparing it for model input."""
1036
+ im, label, path, shapes = zip(*batch) # transposed
1037
+ n = len(shapes) // 4
1038
+ im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
1039
+
1040
+ ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
1041
+ wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
1042
+ s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
1043
+ for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
1044
+ i *= 4
1045
+ if random.random() < 0.5:
1046
+ im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode="bilinear", align_corners=False)[
1047
+ 0
1048
+ ].type(im[i].type())
1049
+ lb = label[i]
1050
+ else:
1051
+ im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
1052
+ lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
1053
+ im4.append(im1)
1054
+ label4.append(lb)
1055
+
1056
+ for i, lb in enumerate(label4):
1057
+ lb[:, 0] = i # add target image index for build_targets()
1058
+
1059
+ return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
1060
+
1061
+
1062
+ # Ancillary functions --------------------------------------------------------------------------------------------------
1063
+ def flatten_recursive(path=DATASETS_DIR / "coco128"):
1064
+ """Flattens a directory by copying all files from subdirectories to a new top-level directory, preserving
1065
+ filenames.
1066
+ """
1067
+ new_path = Path(f"{str(path)}_flat")
1068
+ if os.path.exists(new_path):
1069
+ shutil.rmtree(new_path) # delete output folder
1070
+ os.makedirs(new_path) # make new output folder
1071
+ for file in tqdm(glob.glob(f"{str(Path(path))}/**/*.*", recursive=True)):
1072
+ shutil.copyfile(file, new_path / Path(file).name)
1073
+
1074
+
1075
+ def extract_boxes(path=DATASETS_DIR / "coco128"):
1076
+ """
1077
+ Converts a detection dataset to a classification dataset, creating a directory for each class and extracting
1078
+ bounding boxes.
1079
+
1080
+ Example: from utils.dataloaders import *; extract_boxes()
1081
+ """
1082
+ path = Path(path) # images dir
1083
+ shutil.rmtree(path / "classification") if (path / "classification").is_dir() else None # remove existing
1084
+ files = list(path.rglob("*.*"))
1085
+ n = len(files) # number of files
1086
+ for im_file in tqdm(files, total=n):
1087
+ if im_file.suffix[1:] in IMG_FORMATS:
1088
+ # image
1089
+ im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
1090
+ h, w = im.shape[:2]
1091
+
1092
+ # labels
1093
+ lb_file = Path(img2label_paths([str(im_file)])[0])
1094
+ if Path(lb_file).exists():
1095
+ with open(lb_file) as f:
1096
+ lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
1097
+
1098
+ for j, x in enumerate(lb):
1099
+ c = int(x[0]) # class
1100
+ f = (path / "classification") / f"{c}" / f"{path.stem}_{im_file.stem}_{j}.jpg" # new filename
1101
+ if not f.parent.is_dir():
1102
+ f.parent.mkdir(parents=True)
1103
+
1104
+ b = x[1:] * [w, h, w, h] # box
1105
+ # b[2:] = b[2:].max() # rectangle to square
1106
+ b[2:] = b[2:] * 1.2 + 3 # pad
1107
+ b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)
1108
+
1109
+ b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
1110
+ b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
1111
+ assert cv2.imwrite(str(f), im[b[1] : b[3], b[0] : b[2]]), f"box failure in {f}"
1112
+
1113
+
1114
+ def autosplit(path=DATASETS_DIR / "coco128/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
1115
+ """Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
1116
+ Usage: from utils.dataloaders import *; autosplit().
1117
+
1118
+ Arguments:
1119
+ path: Path to images directory
1120
+ weights: Train, val, test weights (list, tuple)
1121
+ annotated_only: Only use images with an annotated txt file
1122
+ """
1123
+ path = Path(path) # images dir
1124
+ files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
1125
+ n = len(files) # number of files
1126
+ random.seed(0) # for reproducibility
1127
+ indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
1128
+
1129
+ txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
1130
+ for x in txt:
1131
+ if (path.parent / x).exists():
1132
+ (path.parent / x).unlink() # remove existing
1133
+
1134
+ print(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
1135
+ for i, img in tqdm(zip(indices, files), total=n):
1136
+ if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
1137
+ with open(path.parent / txt[i], "a") as f:
1138
+ f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
1139
+
1140
+
1141
+ def verify_image_label(args):
1142
+ """Verifies a single image-label pair, ensuring image format, size, and legal label values."""
1143
+ im_file, lb_file, prefix = args
1144
+ nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, "", [] # number (missing, found, empty, corrupt), message, segments
1145
+ try:
1146
+ # verify images
1147
+ im = Image.open(im_file)
1148
+ im.verify() # PIL verify
1149
+ shape = exif_size(im) # image size
1150
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
1151
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
1152
+ if im.format.lower() in ("jpg", "jpeg"):
1153
+ with open(im_file, "rb") as f:
1154
+ f.seek(-2, 2)
1155
+ if f.read() != b"\xff\xd9": # corrupt JPEG
1156
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
1157
+ msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
1158
+
1159
+ # verify labels
1160
+ if os.path.isfile(lb_file):
1161
+ nf = 1 # label found
1162
+ with open(lb_file) as f:
1163
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
1164
+ if any(len(x) > 6 for x in lb): # is segment
1165
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
1166
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
1167
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
1168
+ lb = np.array(lb, dtype=np.float32)
1169
+ nl = len(lb)
1170
+ if nl:
1171
+ assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
1172
+ assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
1173
+ assert (lb[:, 1:] <= 1).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
1174
+ _, i = np.unique(lb, axis=0, return_index=True)
1175
+ if len(i) < nl: # duplicate row check
1176
+ lb = lb[i] # remove duplicates
1177
+ if segments:
1178
+ segments = [segments[x] for x in i]
1179
+ msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
1180
+ else:
1181
+ ne = 1 # label empty
1182
+ lb = np.zeros((0, 5), dtype=np.float32)
1183
+ else:
1184
+ nm = 1 # label missing
1185
+ lb = np.zeros((0, 5), dtype=np.float32)
1186
+ return im_file, lb, shape, segments, nm, nf, ne, nc, msg
1187
+ except Exception as e:
1188
+ nc = 1
1189
+ msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
1190
+ return [None, None, None, None, nm, nf, ne, nc, msg]
1191
+
1192
+
1193
+ class HUBDatasetStats:
1194
+ """
1195
+ Class for generating HUB dataset JSON and `-hub` dataset directory.
1196
+
1197
+ Arguments:
1198
+ path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
1199
+ autodownload: Attempt to download dataset if not found locally
1200
+
1201
+ Usage
1202
+ from utils.dataloaders import HUBDatasetStats
1203
+ stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
1204
+ stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
1205
+ stats.get_json(save=False)
1206
+ stats.process_images()
1207
+ """
1208
+
1209
+ def __init__(self, path="coco128.yaml", autodownload=False):
1210
+ """Initializes HUBDatasetStats with optional auto-download for datasets, given a path to dataset YAML or ZIP
1211
+ file.
1212
+ """
1213
+ zipped, data_dir, yaml_path = self._unzip(Path(path))
1214
+ try:
1215
+ with open(check_yaml(yaml_path), errors="ignore") as f:
1216
+ data = yaml.safe_load(f) # data dict
1217
+ if zipped:
1218
+ data["path"] = data_dir
1219
+ except Exception as e:
1220
+ raise Exception("error/HUB/dataset_stats/yaml_load") from e
1221
+
1222
+ check_dataset(data, autodownload) # download dataset if missing
1223
+ self.hub_dir = Path(data["path"] + "-hub")
1224
+ self.im_dir = self.hub_dir / "images"
1225
+ self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
1226
+ self.stats = {"nc": data["nc"], "names": list(data["names"].values())} # statistics dictionary
1227
+ self.data = data
1228
+
1229
+ @staticmethod
1230
+ def _find_yaml(dir):
1231
+ """Finds and returns the path to a single '.yaml' file in the specified directory, preferring files that match
1232
+ the directory name.
1233
+ """
1234
+ files = list(dir.glob("*.yaml")) or list(dir.rglob("*.yaml")) # try root level first and then recursive
1235
+ assert files, f"No *.yaml file found in {dir}"
1236
+ if len(files) > 1:
1237
+ files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
1238
+ assert files, f"Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed"
1239
+ assert len(files) == 1, f"Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}"
1240
+ return files[0]
1241
+
1242
+ def _unzip(self, path):
1243
+ """Unzips a .zip file at 'path', returning success status, unzipped directory, and path to YAML file within."""
1244
+ if not str(path).endswith(".zip"): # path is data.yaml
1245
+ return False, None, path
1246
+ assert Path(path).is_file(), f"Error unzipping {path}, file not found"
1247
+ unzip_file(path, path=path.parent)
1248
+ dir = path.with_suffix("") # dataset directory == zip name
1249
+ assert dir.is_dir(), f"Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/"
1250
+ return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
1251
+
1252
+ def _hub_ops(self, f, max_dim=1920):
1253
+ """Resizes and saves an image at reduced quality for web/app viewing, supporting both PIL and OpenCV."""
1254
+ f_new = self.im_dir / Path(f).name # dataset-hub image filename
1255
+ try: # use PIL
1256
+ im = Image.open(f)
1257
+ r = max_dim / max(im.height, im.width) # ratio
1258
+ if r < 1.0: # image too large
1259
+ im = im.resize((int(im.width * r), int(im.height * r)))
1260
+ im.save(f_new, "JPEG", quality=50, optimize=True) # save
1261
+ except Exception as e: # use OpenCV
1262
+ LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
1263
+ im = cv2.imread(f)
1264
+ im_height, im_width = im.shape[:2]
1265
+ r = max_dim / max(im_height, im_width) # ratio
1266
+ if r < 1.0: # image too large
1267
+ im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
1268
+ cv2.imwrite(str(f_new), im)
1269
+
1270
+ def get_json(self, save=False, verbose=False):
1271
+ """Generates dataset JSON for Ultralytics HUB, optionally saves or prints it; save=bool, verbose=bool."""
1272
+
1273
+ def _round(labels):
1274
+ """Rounds class labels to integers and coordinates to 4 decimal places for improved label accuracy."""
1275
+ return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
1276
+
1277
+ for split in "train", "val", "test":
1278
+ if self.data.get(split) is None:
1279
+ self.stats[split] = None # i.e. no test set
1280
+ continue
1281
+ dataset = LoadImagesAndLabels(self.data[split]) # load dataset
1282
+ x = np.array(
1283
+ [
1284
+ np.bincount(label[:, 0].astype(int), minlength=self.data["nc"])
1285
+ for label in tqdm(dataset.labels, total=dataset.n, desc="Statistics")
1286
+ ]
1287
+ ) # shape(128x80)
1288
+ self.stats[split] = {
1289
+ "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
1290
+ "image_stats": {
1291
+ "total": dataset.n,
1292
+ "unlabelled": int(np.all(x == 0, 1).sum()),
1293
+ "per_class": (x > 0).sum(0).tolist(),
1294
+ },
1295
+ "labels": [{str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)],
1296
+ }
1297
+
1298
+ # Save, print and return
1299
+ if save:
1300
+ stats_path = self.hub_dir / "stats.json"
1301
+ print(f"Saving {stats_path.resolve()}...")
1302
+ with open(stats_path, "w") as f:
1303
+ json.dump(self.stats, f) # save stats.json
1304
+ if verbose:
1305
+ print(json.dumps(self.stats, indent=2, sort_keys=False))
1306
+ return self.stats
1307
+
1308
+ def process_images(self):
1309
+ """Compresses images for Ultralytics HUB across 'train', 'val', 'test' splits and saves to specified
1310
+ directory.
1311
+ """
1312
+ for split in "train", "val", "test":
1313
+ if self.data.get(split) is None:
1314
+ continue
1315
+ dataset = LoadImagesAndLabels(self.data[split]) # load dataset
1316
+ desc = f"{split} images"
1317
+ for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
1318
+ pass
1319
+ print(f"Done. All images saved to {self.im_dir}")
1320
+ return self.im_dir
1321
+
1322
+
1323
+ # Classification dataloaders -------------------------------------------------------------------------------------------
1324
+ class ClassificationDataset(torchvision.datasets.ImageFolder):
1325
+ """
1326
+ YOLOv5 Classification Dataset.
1327
+
1328
+ Arguments:
1329
+ root: Dataset path
1330
+ transform: torchvision transforms, used by default
1331
+ album_transform: Albumentations transforms, used if installed
1332
+ """
1333
+
1334
+ def __init__(self, root, augment, imgsz, cache=False):
1335
+ """Initializes YOLOv5 Classification Dataset with optional caching, augmentations, and transforms for image
1336
+ classification.
1337
+ """
1338
+ super().__init__(root=root)
1339
+ self.torch_transforms = classify_transforms(imgsz)
1340
+ self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
1341
+ self.cache_ram = cache is True or cache == "ram"
1342
+ self.cache_disk = cache == "disk"
1343
+ self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
1344
+
1345
+ def __getitem__(self, i):
1346
+ """Fetches and transforms an image sample by index, supporting RAM/disk caching and Augmentations."""
1347
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
1348
+ if self.cache_ram and im is None:
1349
+ im = self.samples[i][3] = cv2.imread(f)
1350
+ elif self.cache_disk:
1351
+ if not fn.exists(): # load npy
1352
+ np.save(fn.as_posix(), cv2.imread(f))
1353
+ im = np.load(fn)
1354
+ else: # read image
1355
+ im = cv2.imread(f) # BGR
1356
+ if self.album_transforms:
1357
+ sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
1358
+ else:
1359
+ sample = self.torch_transforms(im)
1360
+ return sample, j
1361
+
1362
+
1363
+ def create_classification_dataloader(
1364
+ path, imgsz=224, batch_size=16, augment=True, cache=False, rank=-1, workers=8, shuffle=True
1365
+ ):
1366
+ # Returns Dataloader object to be used with YOLOv5 Classifier
1367
+ """Creates a DataLoader for image classification, supporting caching, augmentation, and distributed training."""
1368
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
1369
+ dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
1370
+ batch_size = min(batch_size, len(dataset))
1371
+ nd = torch.cuda.device_count()
1372
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
1373
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
1374
+ generator = torch.Generator()
1375
+ generator.manual_seed(6148914691236517205 + RANK)
1376
+ return InfiniteDataLoader(
1377
+ dataset,
1378
+ batch_size=batch_size,
1379
+ shuffle=shuffle and sampler is None,
1380
+ num_workers=nw,
1381
+ sampler=sampler,
1382
+ pin_memory=PIN_MEMORY,
1383
+ worker_init_fn=seed_worker,
1384
+ generator=generator,
1385
+ ) # or DataLoader(persistent_workers=True)