ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +527 -67
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +44 -37
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +84 -56
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.28.dist-info/METADATA +0 -373
  244. ultralytics-8.1.28.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
2
- import contextlib
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import json
4
+ from collections import defaultdict
3
5
  from itertools import repeat
4
6
  from multiprocessing.pool import ThreadPool
5
7
  from pathlib import Path
@@ -7,14 +9,34 @@ from pathlib import Path
7
9
  import cv2
8
10
  import numpy as np
9
11
  import torch
10
- import torchvision
11
12
  from PIL import Image
13
+ from torch.utils.data import ConcatDataset
12
14
 
13
- from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
15
+ from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
14
16
  from ultralytics.utils.ops import resample_segments
15
- from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
17
+ from ultralytics.utils.torch_utils import TORCHVISION_0_18
18
+
19
+ from .augment import (
20
+ Compose,
21
+ Format,
22
+ Instances,
23
+ LetterBox,
24
+ RandomLoadText,
25
+ classify_augmentations,
26
+ classify_transforms,
27
+ v8_transforms,
28
+ )
16
29
  from .base import BaseDataset
17
- from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
30
+ from .utils import (
31
+ HELP_URL,
32
+ LOGGER,
33
+ get_hash,
34
+ img2label_paths,
35
+ load_dataset_cache_file,
36
+ save_dataset_cache_file,
37
+ verify_image,
38
+ verify_image_label,
39
+ )
18
40
 
19
41
  # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
20
42
  DATASET_CACHE_VERSION = "1.0.3"
@@ -46,7 +68,7 @@ class YOLODataset(BaseDataset):
46
68
  Cache dataset labels, check images and read shapes.
47
69
 
48
70
  Args:
49
- path (Path): Path where to save the cache file. Default is Path('./labels.cache').
71
+ path (Path): Path where to save the cache file. Default is Path("./labels.cache").
50
72
 
51
73
  Returns:
52
74
  (dict): labels.
@@ -56,7 +78,7 @@ class YOLODataset(BaseDataset):
56
78
  desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
57
79
  total = len(self.im_files)
58
80
  nkpt, ndim = self.data.get("kpt_shape", (0, 0))
59
- if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
81
+ if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
60
82
  raise ValueError(
61
83
  "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
62
84
  "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
@@ -82,16 +104,16 @@ class YOLODataset(BaseDataset):
82
104
  nc += nc_f
83
105
  if im_file:
84
106
  x["labels"].append(
85
- dict(
86
- im_file=im_file,
87
- shape=shape,
88
- cls=lb[:, 0:1], # n, 1
89
- bboxes=lb[:, 1:], # n, 4
90
- segments=segments,
91
- keypoints=keypoint,
92
- normalized=True,
93
- bbox_format="xywh",
94
- )
107
+ {
108
+ "im_file": im_file,
109
+ "shape": shape,
110
+ "cls": lb[:, 0:1], # n, 1
111
+ "bboxes": lb[:, 1:], # n, 4
112
+ "segments": segments,
113
+ "keypoints": keypoint,
114
+ "normalized": True,
115
+ "bbox_format": "xywh",
116
+ }
95
117
  )
96
118
  if msg:
97
119
  msgs.append(msg)
@@ -105,7 +127,7 @@ class YOLODataset(BaseDataset):
105
127
  x["hash"] = get_hash(self.label_files + self.im_files)
106
128
  x["results"] = nf, nm, ne, nc, len(self.im_files)
107
129
  x["msgs"] = msgs # warnings
108
- save_dataset_cache_file(self.prefix, path, x)
130
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
109
131
  return x
110
132
 
111
133
  def get_labels(self):
@@ -121,7 +143,7 @@ class YOLODataset(BaseDataset):
121
143
 
122
144
  # Display cache
123
145
  nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
124
- if exists and LOCAL_RANK in (-1, 0):
146
+ if exists and LOCAL_RANK in {-1, 0}:
125
147
  d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
126
148
  TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
127
149
  if cache["msgs"]:
@@ -167,6 +189,7 @@ class YOLODataset(BaseDataset):
167
189
  batch_idx=True,
168
190
  mask_ratio=hyp.mask_ratio,
169
191
  mask_overlap=hyp.overlap_mask,
192
+ bgr=hyp.bgr if self.augment else 0.0, # only affect training.
170
193
  )
171
194
  )
172
195
  return transforms
@@ -195,8 +218,10 @@ class YOLODataset(BaseDataset):
195
218
  # NOTE: do NOT resample oriented boxes
196
219
  segment_resamples = 100 if self.use_obb else 1000
197
220
  if len(segments) > 0:
198
- # list[np.array(1000, 2)] * num_samples
199
- # (N, 1000, 2)
221
+ # make sure segments interpolate correctly if original length is greater than segment_resamples
222
+ max_len = max(len(s) for s in segments)
223
+ segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples
224
+ # list[np.array(segment_resamples, 2)] * num_samples
200
225
  segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
201
226
  else:
202
227
  segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
@@ -213,7 +238,7 @@ class YOLODataset(BaseDataset):
213
238
  value = values[i]
214
239
  if k == "img":
215
240
  value = torch.stack(value, 0)
216
- if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
241
+ if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
217
242
  value = torch.cat(value, 0)
218
243
  new_batch[k] = value
219
244
  new_batch["batch_idx"] = list(new_batch["batch_idx"])
@@ -223,8 +248,145 @@ class YOLODataset(BaseDataset):
223
248
  return new_batch
224
249
 
225
250
 
226
- # Classification dataloaders -------------------------------------------------------------------------------------------
227
- class ClassificationDataset(torchvision.datasets.ImageFolder):
251
+ class YOLOMultiModalDataset(YOLODataset):
252
+ """
253
+ Dataset class for loading object detection and/or segmentation labels in YOLO format.
254
+
255
+ Args:
256
+ data (dict, optional): A dataset YAML dictionary. Defaults to None.
257
+ task (str): An explicit arg to point current task, Defaults to 'detect'.
258
+
259
+ Returns:
260
+ (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
261
+ """
262
+
263
+ def __init__(self, *args, data=None, task="detect", **kwargs):
264
+ """Initializes a dataset object for object detection tasks with optional specifications."""
265
+ super().__init__(*args, data=data, task=task, **kwargs)
266
+
267
+ def update_labels_info(self, label):
268
+ """Add texts information for multi-modal model training."""
269
+ labels = super().update_labels_info(label)
270
+ # NOTE: some categories are concatenated with its synonyms by `/`.
271
+ labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
272
+ return labels
273
+
274
+ def build_transforms(self, hyp=None):
275
+ """Enhances data transformations with optional text augmentation for multi-modal training."""
276
+ transforms = super().build_transforms(hyp)
277
+ if self.augment:
278
+ # NOTE: hard-coded the args for now.
279
+ transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
280
+ return transforms
281
+
282
+
283
+ class GroundingDataset(YOLODataset):
284
+ """Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""
285
+
286
+ def __init__(self, *args, task="detect", json_file, **kwargs):
287
+ """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
288
+ assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
289
+ self.json_file = json_file
290
+ super().__init__(*args, task=task, data={}, **kwargs)
291
+
292
+ def get_img_files(self, img_path):
293
+ """The image files would be read in `get_labels` function, return empty list here."""
294
+ return []
295
+
296
+ def get_labels(self):
297
+ """Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
298
+ labels = []
299
+ LOGGER.info("Loading annotation file...")
300
+ with open(self.json_file) as f:
301
+ annotations = json.load(f)
302
+ images = {f"{x['id']:d}": x for x in annotations["images"]}
303
+ img_to_anns = defaultdict(list)
304
+ for ann in annotations["annotations"]:
305
+ img_to_anns[ann["image_id"]].append(ann)
306
+ for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
307
+ img = images[f"{img_id:d}"]
308
+ h, w, f = img["height"], img["width"], img["file_name"]
309
+ im_file = Path(self.img_path) / f
310
+ if not im_file.exists():
311
+ continue
312
+ self.im_files.append(str(im_file))
313
+ bboxes = []
314
+ cat2id = {}
315
+ texts = []
316
+ for ann in anns:
317
+ if ann["iscrowd"]:
318
+ continue
319
+ box = np.array(ann["bbox"], dtype=np.float32)
320
+ box[:2] += box[2:] / 2
321
+ box[[0, 2]] /= float(w)
322
+ box[[1, 3]] /= float(h)
323
+ if box[2] <= 0 or box[3] <= 0:
324
+ continue
325
+
326
+ caption = img["caption"]
327
+ cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]])
328
+ if cat_name not in cat2id:
329
+ cat2id[cat_name] = len(cat2id)
330
+ texts.append([cat_name])
331
+ cls = cat2id[cat_name] # class
332
+ box = [cls] + box.tolist()
333
+ if box not in bboxes:
334
+ bboxes.append(box)
335
+ lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
336
+ labels.append(
337
+ {
338
+ "im_file": im_file,
339
+ "shape": (h, w),
340
+ "cls": lb[:, 0:1], # n, 1
341
+ "bboxes": lb[:, 1:], # n, 4
342
+ "normalized": True,
343
+ "bbox_format": "xywh",
344
+ "texts": texts,
345
+ }
346
+ )
347
+ return labels
348
+
349
+ def build_transforms(self, hyp=None):
350
+ """Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
351
+ transforms = super().build_transforms(hyp)
352
+ if self.augment:
353
+ # NOTE: hard-coded the args for now.
354
+ transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
355
+ return transforms
356
+
357
+
358
+ class YOLOConcatDataset(ConcatDataset):
359
+ """
360
+ Dataset as a concatenation of multiple datasets.
361
+
362
+ This class is useful to assemble different existing datasets.
363
+ """
364
+
365
+ @staticmethod
366
+ def collate_fn(batch):
367
+ """Collates data samples into batches."""
368
+ return YOLODataset.collate_fn(batch)
369
+
370
+
371
+ # TODO: support semantic segmentation
372
+ class SemanticDataset(BaseDataset):
373
+ """
374
+ Semantic Segmentation Dataset.
375
+
376
+ This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
377
+ from the BaseDataset class.
378
+
379
+ Note:
380
+ This class is currently a placeholder and needs to be populated with methods and attributes for supporting
381
+ semantic segmentation tasks.
382
+ """
383
+
384
+ def __init__(self):
385
+ """Initialize a SemanticDataset object."""
386
+ super().__init__()
387
+
388
+
389
+ class ClassificationDataset:
228
390
  """
229
391
  Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
230
392
  augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
@@ -256,12 +418,28 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
256
418
  prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
257
419
  debugging. Default is an empty string.
258
420
  """
259
- super().__init__(root=root)
421
+ import torchvision # scope for faster 'import ultralytics'
422
+
423
+ # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
424
+ if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
425
+ self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
426
+ else:
427
+ self.base = torchvision.datasets.ImageFolder(root=root)
428
+ self.samples = self.base.samples
429
+ self.root = self.base.root
430
+
431
+ # Initialize attributes
260
432
  if augment and args.fraction < 1.0: # reduce training fraction
261
433
  self.samples = self.samples[: round(len(self.samples) * args.fraction)]
262
434
  self.prefix = colorstr(f"{prefix}: ") if prefix else ""
263
- self.cache_ram = args.cache is True or args.cache == "ram" # cache images into RAM
264
- self.cache_disk = args.cache == "disk" # cache images on hard drive as uncompressed *.npy files
435
+ self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
436
+ if self.cache_ram:
437
+ LOGGER.warning(
438
+ "WARNING ⚠️ Classification `cache_ram` training has known memory leak in "
439
+ "https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`."
440
+ )
441
+ self.cache_ram = False
442
+ self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
265
443
  self.samples = self.verify_images() # filter out bad images
266
444
  self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
267
445
  scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
@@ -284,8 +462,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
284
462
  def __getitem__(self, i):
285
463
  """Returns subset of data and targets corresponding to given indices."""
286
464
  f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
287
- if self.cache_ram and im is None:
288
- im = self.samples[i][3] = cv2.imread(f)
465
+ if self.cache_ram:
466
+ if im is None: # Warning: two separate if statements required here, do not combine this with previous line
467
+ im = self.samples[i][3] = cv2.imread(f)
289
468
  elif self.cache_disk:
290
469
  if not fn.exists(): # load npy
291
470
  np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
@@ -306,77 +485,37 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
306
485
  desc = f"{self.prefix}Scanning {self.root}..."
307
486
  path = Path(self.root).with_suffix(".cache") # *.cache file path
308
487
 
309
- with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
488
+ try:
310
489
  cache = load_dataset_cache_file(path) # attempt to load a *.cache file
311
490
  assert cache["version"] == DATASET_CACHE_VERSION # matches current version
312
491
  assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
313
492
  nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
314
- if LOCAL_RANK in (-1, 0):
493
+ if LOCAL_RANK in {-1, 0}:
315
494
  d = f"{desc} {nf} images, {nc} corrupt"
316
495
  TQDM(None, desc=d, total=n, initial=n)
317
496
  if cache["msgs"]:
318
497
  LOGGER.info("\n".join(cache["msgs"])) # display warnings
319
498
  return samples
320
499
 
321
- # Run scan if *.cache retrieval failed
322
- nf, nc, msgs, samples, x = 0, 0, [], [], {}
323
- with ThreadPool(NUM_THREADS) as pool:
324
- results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
325
- pbar = TQDM(results, desc=desc, total=len(self.samples))
326
- for sample, nf_f, nc_f, msg in pbar:
327
- if nf_f:
328
- samples.append(sample)
329
- if msg:
330
- msgs.append(msg)
331
- nf += nf_f
332
- nc += nc_f
333
- pbar.desc = f"{desc} {nf} images, {nc} corrupt"
334
- pbar.close()
335
- if msgs:
336
- LOGGER.info("\n".join(msgs))
337
- x["hash"] = get_hash([x[0] for x in self.samples])
338
- x["results"] = nf, nc, len(samples), samples
339
- x["msgs"] = msgs # warnings
340
- save_dataset_cache_file(self.prefix, path, x)
341
- return samples
342
-
343
-
344
- def load_dataset_cache_file(path):
345
- """Load an Ultralytics *.cache dictionary from path."""
346
- import gc
347
-
348
- gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
349
- cache = np.load(str(path), allow_pickle=True).item() # load dict
350
- gc.enable()
351
- return cache
352
-
353
-
354
- def save_dataset_cache_file(prefix, path, x):
355
- """Save an Ultralytics dataset *.cache dictionary x to path."""
356
- x["version"] = DATASET_CACHE_VERSION # add cache version
357
- if is_dir_writeable(path.parent):
358
- if path.exists():
359
- path.unlink() # remove *.cache file if exists
360
- np.save(str(path), x) # save cache for next time
361
- path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
362
- LOGGER.info(f"{prefix}New cache created: {path}")
363
- else:
364
- LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
365
-
366
-
367
- # TODO: support semantic segmentation
368
- class SemanticDataset(BaseDataset):
369
- """
370
- Semantic Segmentation Dataset.
371
-
372
- This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
373
- from the BaseDataset class.
374
-
375
- Note:
376
- This class is currently a placeholder and needs to be populated with methods and attributes for supporting
377
- semantic segmentation tasks.
378
- """
379
-
380
- def __init__(self):
381
- """Initialize a SemanticDataset object."""
382
- super().__init__()
500
+ except (FileNotFoundError, AssertionError, AttributeError):
501
+ # Run scan if *.cache retrieval failed
502
+ nf, nc, msgs, samples, x = 0, 0, [], [], {}
503
+ with ThreadPool(NUM_THREADS) as pool:
504
+ results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
505
+ pbar = TQDM(results, desc=desc, total=len(self.samples))
506
+ for sample, nf_f, nc_f, msg in pbar:
507
+ if nf_f:
508
+ samples.append(sample)
509
+ if msg:
510
+ msgs.append(msg)
511
+ nf += nf_f
512
+ nc += nc_f
513
+ pbar.desc = f"{desc} {nf} images, {nc} corrupt"
514
+ pbar.close()
515
+ if msgs:
516
+ LOGGER.info("\n".join(msgs))
517
+ x["hash"] = get_hash([x[0] for x in self.samples])
518
+ x["results"] = nf, nc, len(samples), samples
519
+ x["msgs"] = msgs # warnings
520
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
521
+ return samples