ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/data/dataset.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
-
# Ultralytics
|
2
|
-
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import json
|
4
|
+
from collections import defaultdict
|
3
5
|
from itertools import repeat
|
4
6
|
from multiprocessing.pool import ThreadPool
|
5
7
|
from pathlib import Path
|
@@ -7,14 +9,34 @@ from pathlib import Path
|
|
7
9
|
import cv2
|
8
10
|
import numpy as np
|
9
11
|
import torch
|
10
|
-
import torchvision
|
11
12
|
from PIL import Image
|
13
|
+
from torch.utils.data import ConcatDataset
|
12
14
|
|
13
|
-
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
15
|
+
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
14
16
|
from ultralytics.utils.ops import resample_segments
|
15
|
-
from .
|
17
|
+
from ultralytics.utils.torch_utils import TORCHVISION_0_18
|
18
|
+
|
19
|
+
from .augment import (
|
20
|
+
Compose,
|
21
|
+
Format,
|
22
|
+
Instances,
|
23
|
+
LetterBox,
|
24
|
+
RandomLoadText,
|
25
|
+
classify_augmentations,
|
26
|
+
classify_transforms,
|
27
|
+
v8_transforms,
|
28
|
+
)
|
16
29
|
from .base import BaseDataset
|
17
|
-
from .utils import
|
30
|
+
from .utils import (
|
31
|
+
HELP_URL,
|
32
|
+
LOGGER,
|
33
|
+
get_hash,
|
34
|
+
img2label_paths,
|
35
|
+
load_dataset_cache_file,
|
36
|
+
save_dataset_cache_file,
|
37
|
+
verify_image,
|
38
|
+
verify_image_label,
|
39
|
+
)
|
18
40
|
|
19
41
|
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
20
42
|
DATASET_CACHE_VERSION = "1.0.3"
|
@@ -46,7 +68,7 @@ class YOLODataset(BaseDataset):
|
|
46
68
|
Cache dataset labels, check images and read shapes.
|
47
69
|
|
48
70
|
Args:
|
49
|
-
path (Path): Path where to save the cache file. Default is Path(
|
71
|
+
path (Path): Path where to save the cache file. Default is Path("./labels.cache").
|
50
72
|
|
51
73
|
Returns:
|
52
74
|
(dict): labels.
|
@@ -56,7 +78,7 @@ class YOLODataset(BaseDataset):
|
|
56
78
|
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
57
79
|
total = len(self.im_files)
|
58
80
|
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
59
|
-
if self.use_keypoints and (nkpt <= 0 or ndim not in
|
81
|
+
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
|
60
82
|
raise ValueError(
|
61
83
|
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
62
84
|
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
@@ -82,16 +104,16 @@ class YOLODataset(BaseDataset):
|
|
82
104
|
nc += nc_f
|
83
105
|
if im_file:
|
84
106
|
x["labels"].append(
|
85
|
-
|
86
|
-
im_file
|
87
|
-
shape
|
88
|
-
cls
|
89
|
-
bboxes
|
90
|
-
segments
|
91
|
-
keypoints
|
92
|
-
normalized
|
93
|
-
bbox_format
|
94
|
-
|
107
|
+
{
|
108
|
+
"im_file": im_file,
|
109
|
+
"shape": shape,
|
110
|
+
"cls": lb[:, 0:1], # n, 1
|
111
|
+
"bboxes": lb[:, 1:], # n, 4
|
112
|
+
"segments": segments,
|
113
|
+
"keypoints": keypoint,
|
114
|
+
"normalized": True,
|
115
|
+
"bbox_format": "xywh",
|
116
|
+
}
|
95
117
|
)
|
96
118
|
if msg:
|
97
119
|
msgs.append(msg)
|
@@ -105,7 +127,7 @@ class YOLODataset(BaseDataset):
|
|
105
127
|
x["hash"] = get_hash(self.label_files + self.im_files)
|
106
128
|
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
107
129
|
x["msgs"] = msgs # warnings
|
108
|
-
save_dataset_cache_file(self.prefix, path, x)
|
130
|
+
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
109
131
|
return x
|
110
132
|
|
111
133
|
def get_labels(self):
|
@@ -121,7 +143,7 @@ class YOLODataset(BaseDataset):
|
|
121
143
|
|
122
144
|
# Display cache
|
123
145
|
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
124
|
-
if exists and LOCAL_RANK in
|
146
|
+
if exists and LOCAL_RANK in {-1, 0}:
|
125
147
|
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
126
148
|
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
127
149
|
if cache["msgs"]:
|
@@ -167,6 +189,7 @@ class YOLODataset(BaseDataset):
|
|
167
189
|
batch_idx=True,
|
168
190
|
mask_ratio=hyp.mask_ratio,
|
169
191
|
mask_overlap=hyp.overlap_mask,
|
192
|
+
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
|
170
193
|
)
|
171
194
|
)
|
172
195
|
return transforms
|
@@ -195,8 +218,10 @@ class YOLODataset(BaseDataset):
|
|
195
218
|
# NOTE: do NOT resample oriented boxes
|
196
219
|
segment_resamples = 100 if self.use_obb else 1000
|
197
220
|
if len(segments) > 0:
|
198
|
-
#
|
199
|
-
|
221
|
+
# make sure segments interpolate correctly if original length is greater than segment_resamples
|
222
|
+
max_len = max(len(s) for s in segments)
|
223
|
+
segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples
|
224
|
+
# list[np.array(segment_resamples, 2)] * num_samples
|
200
225
|
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
201
226
|
else:
|
202
227
|
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
@@ -213,7 +238,7 @@ class YOLODataset(BaseDataset):
|
|
213
238
|
value = values[i]
|
214
239
|
if k == "img":
|
215
240
|
value = torch.stack(value, 0)
|
216
|
-
if k in
|
241
|
+
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
|
217
242
|
value = torch.cat(value, 0)
|
218
243
|
new_batch[k] = value
|
219
244
|
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
@@ -223,8 +248,145 @@ class YOLODataset(BaseDataset):
|
|
223
248
|
return new_batch
|
224
249
|
|
225
250
|
|
226
|
-
|
227
|
-
|
251
|
+
class YOLOMultiModalDataset(YOLODataset):
|
252
|
+
"""
|
253
|
+
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
254
|
+
|
255
|
+
Args:
|
256
|
+
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
257
|
+
task (str): An explicit arg to point current task, Defaults to 'detect'.
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
261
|
+
"""
|
262
|
+
|
263
|
+
def __init__(self, *args, data=None, task="detect", **kwargs):
|
264
|
+
"""Initializes a dataset object for object detection tasks with optional specifications."""
|
265
|
+
super().__init__(*args, data=data, task=task, **kwargs)
|
266
|
+
|
267
|
+
def update_labels_info(self, label):
|
268
|
+
"""Add texts information for multi-modal model training."""
|
269
|
+
labels = super().update_labels_info(label)
|
270
|
+
# NOTE: some categories are concatenated with its synonyms by `/`.
|
271
|
+
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
272
|
+
return labels
|
273
|
+
|
274
|
+
def build_transforms(self, hyp=None):
|
275
|
+
"""Enhances data transformations with optional text augmentation for multi-modal training."""
|
276
|
+
transforms = super().build_transforms(hyp)
|
277
|
+
if self.augment:
|
278
|
+
# NOTE: hard-coded the args for now.
|
279
|
+
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
|
280
|
+
return transforms
|
281
|
+
|
282
|
+
|
283
|
+
class GroundingDataset(YOLODataset):
|
284
|
+
"""Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""
|
285
|
+
|
286
|
+
def __init__(self, *args, task="detect", json_file, **kwargs):
|
287
|
+
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
|
288
|
+
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
|
289
|
+
self.json_file = json_file
|
290
|
+
super().__init__(*args, task=task, data={}, **kwargs)
|
291
|
+
|
292
|
+
def get_img_files(self, img_path):
|
293
|
+
"""The image files would be read in `get_labels` function, return empty list here."""
|
294
|
+
return []
|
295
|
+
|
296
|
+
def get_labels(self):
|
297
|
+
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
|
298
|
+
labels = []
|
299
|
+
LOGGER.info("Loading annotation file...")
|
300
|
+
with open(self.json_file) as f:
|
301
|
+
annotations = json.load(f)
|
302
|
+
images = {f"{x['id']:d}": x for x in annotations["images"]}
|
303
|
+
img_to_anns = defaultdict(list)
|
304
|
+
for ann in annotations["annotations"]:
|
305
|
+
img_to_anns[ann["image_id"]].append(ann)
|
306
|
+
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
|
307
|
+
img = images[f"{img_id:d}"]
|
308
|
+
h, w, f = img["height"], img["width"], img["file_name"]
|
309
|
+
im_file = Path(self.img_path) / f
|
310
|
+
if not im_file.exists():
|
311
|
+
continue
|
312
|
+
self.im_files.append(str(im_file))
|
313
|
+
bboxes = []
|
314
|
+
cat2id = {}
|
315
|
+
texts = []
|
316
|
+
for ann in anns:
|
317
|
+
if ann["iscrowd"]:
|
318
|
+
continue
|
319
|
+
box = np.array(ann["bbox"], dtype=np.float32)
|
320
|
+
box[:2] += box[2:] / 2
|
321
|
+
box[[0, 2]] /= float(w)
|
322
|
+
box[[1, 3]] /= float(h)
|
323
|
+
if box[2] <= 0 or box[3] <= 0:
|
324
|
+
continue
|
325
|
+
|
326
|
+
caption = img["caption"]
|
327
|
+
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]])
|
328
|
+
if cat_name not in cat2id:
|
329
|
+
cat2id[cat_name] = len(cat2id)
|
330
|
+
texts.append([cat_name])
|
331
|
+
cls = cat2id[cat_name] # class
|
332
|
+
box = [cls] + box.tolist()
|
333
|
+
if box not in bboxes:
|
334
|
+
bboxes.append(box)
|
335
|
+
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
336
|
+
labels.append(
|
337
|
+
{
|
338
|
+
"im_file": im_file,
|
339
|
+
"shape": (h, w),
|
340
|
+
"cls": lb[:, 0:1], # n, 1
|
341
|
+
"bboxes": lb[:, 1:], # n, 4
|
342
|
+
"normalized": True,
|
343
|
+
"bbox_format": "xywh",
|
344
|
+
"texts": texts,
|
345
|
+
}
|
346
|
+
)
|
347
|
+
return labels
|
348
|
+
|
349
|
+
def build_transforms(self, hyp=None):
|
350
|
+
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
|
351
|
+
transforms = super().build_transforms(hyp)
|
352
|
+
if self.augment:
|
353
|
+
# NOTE: hard-coded the args for now.
|
354
|
+
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
|
355
|
+
return transforms
|
356
|
+
|
357
|
+
|
358
|
+
class YOLOConcatDataset(ConcatDataset):
|
359
|
+
"""
|
360
|
+
Dataset as a concatenation of multiple datasets.
|
361
|
+
|
362
|
+
This class is useful to assemble different existing datasets.
|
363
|
+
"""
|
364
|
+
|
365
|
+
@staticmethod
|
366
|
+
def collate_fn(batch):
|
367
|
+
"""Collates data samples into batches."""
|
368
|
+
return YOLODataset.collate_fn(batch)
|
369
|
+
|
370
|
+
|
371
|
+
# TODO: support semantic segmentation
|
372
|
+
class SemanticDataset(BaseDataset):
|
373
|
+
"""
|
374
|
+
Semantic Segmentation Dataset.
|
375
|
+
|
376
|
+
This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
|
377
|
+
from the BaseDataset class.
|
378
|
+
|
379
|
+
Note:
|
380
|
+
This class is currently a placeholder and needs to be populated with methods and attributes for supporting
|
381
|
+
semantic segmentation tasks.
|
382
|
+
"""
|
383
|
+
|
384
|
+
def __init__(self):
|
385
|
+
"""Initialize a SemanticDataset object."""
|
386
|
+
super().__init__()
|
387
|
+
|
388
|
+
|
389
|
+
class ClassificationDataset:
|
228
390
|
"""
|
229
391
|
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
230
392
|
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
@@ -256,12 +418,28 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
256
418
|
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
257
419
|
debugging. Default is an empty string.
|
258
420
|
"""
|
259
|
-
|
421
|
+
import torchvision # scope for faster 'import ultralytics'
|
422
|
+
|
423
|
+
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
424
|
+
if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
|
425
|
+
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
|
426
|
+
else:
|
427
|
+
self.base = torchvision.datasets.ImageFolder(root=root)
|
428
|
+
self.samples = self.base.samples
|
429
|
+
self.root = self.base.root
|
430
|
+
|
431
|
+
# Initialize attributes
|
260
432
|
if augment and args.fraction < 1.0: # reduce training fraction
|
261
433
|
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
262
434
|
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
263
|
-
self.cache_ram = args.cache is True or args.cache == "ram" # cache images into RAM
|
264
|
-
self.
|
435
|
+
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
436
|
+
if self.cache_ram:
|
437
|
+
LOGGER.warning(
|
438
|
+
"WARNING ⚠️ Classification `cache_ram` training has known memory leak in "
|
439
|
+
"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`."
|
440
|
+
)
|
441
|
+
self.cache_ram = False
|
442
|
+
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
265
443
|
self.samples = self.verify_images() # filter out bad images
|
266
444
|
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
267
445
|
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
@@ -284,8 +462,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
284
462
|
def __getitem__(self, i):
|
285
463
|
"""Returns subset of data and targets corresponding to given indices."""
|
286
464
|
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
287
|
-
if self.cache_ram
|
288
|
-
im
|
465
|
+
if self.cache_ram:
|
466
|
+
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
467
|
+
im = self.samples[i][3] = cv2.imread(f)
|
289
468
|
elif self.cache_disk:
|
290
469
|
if not fn.exists(): # load npy
|
291
470
|
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
@@ -306,77 +485,37 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
306
485
|
desc = f"{self.prefix}Scanning {self.root}..."
|
307
486
|
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
308
487
|
|
309
|
-
|
488
|
+
try:
|
310
489
|
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
311
490
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
312
491
|
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
313
492
|
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
314
|
-
if LOCAL_RANK in
|
493
|
+
if LOCAL_RANK in {-1, 0}:
|
315
494
|
d = f"{desc} {nf} images, {nc} corrupt"
|
316
495
|
TQDM(None, desc=d, total=n, initial=n)
|
317
496
|
if cache["msgs"]:
|
318
497
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
319
498
|
return samples
|
320
499
|
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
def load_dataset_cache_file(path):
|
345
|
-
"""Load an Ultralytics *.cache dictionary from path."""
|
346
|
-
import gc
|
347
|
-
|
348
|
-
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
349
|
-
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
350
|
-
gc.enable()
|
351
|
-
return cache
|
352
|
-
|
353
|
-
|
354
|
-
def save_dataset_cache_file(prefix, path, x):
|
355
|
-
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
356
|
-
x["version"] = DATASET_CACHE_VERSION # add cache version
|
357
|
-
if is_dir_writeable(path.parent):
|
358
|
-
if path.exists():
|
359
|
-
path.unlink() # remove *.cache file if exists
|
360
|
-
np.save(str(path), x) # save cache for next time
|
361
|
-
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
362
|
-
LOGGER.info(f"{prefix}New cache created: {path}")
|
363
|
-
else:
|
364
|
-
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
365
|
-
|
366
|
-
|
367
|
-
# TODO: support semantic segmentation
|
368
|
-
class SemanticDataset(BaseDataset):
|
369
|
-
"""
|
370
|
-
Semantic Segmentation Dataset.
|
371
|
-
|
372
|
-
This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
|
373
|
-
from the BaseDataset class.
|
374
|
-
|
375
|
-
Note:
|
376
|
-
This class is currently a placeholder and needs to be populated with methods and attributes for supporting
|
377
|
-
semantic segmentation tasks.
|
378
|
-
"""
|
379
|
-
|
380
|
-
def __init__(self):
|
381
|
-
"""Initialize a SemanticDataset object."""
|
382
|
-
super().__init__()
|
500
|
+
except (FileNotFoundError, AssertionError, AttributeError):
|
501
|
+
# Run scan if *.cache retrieval failed
|
502
|
+
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
503
|
+
with ThreadPool(NUM_THREADS) as pool:
|
504
|
+
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
505
|
+
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
506
|
+
for sample, nf_f, nc_f, msg in pbar:
|
507
|
+
if nf_f:
|
508
|
+
samples.append(sample)
|
509
|
+
if msg:
|
510
|
+
msgs.append(msg)
|
511
|
+
nf += nf_f
|
512
|
+
nc += nc_f
|
513
|
+
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
514
|
+
pbar.close()
|
515
|
+
if msgs:
|
516
|
+
LOGGER.info("\n".join(msgs))
|
517
|
+
x["hash"] = get_hash([x[0] for x in self.samples])
|
518
|
+
x["results"] = nf, nc, len(samples), samples
|
519
|
+
x["msgs"] = msgs # warnings
|
520
|
+
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
521
|
+
return samples
|