ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/data/dataset.py
CHANGED
|
@@ -18,7 +18,7 @@ from .base import BaseDataset
|
|
|
18
18
|
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
|
19
19
|
|
|
20
20
|
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
|
21
|
-
DATASET_CACHE_VERSION =
|
|
21
|
+
DATASET_CACHE_VERSION = "1.0.3"
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class YOLODataset(BaseDataset):
|
|
@@ -33,16 +33,16 @@ class YOLODataset(BaseDataset):
|
|
|
33
33
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
|
-
def __init__(self, *args, data=None, task=
|
|
36
|
+
def __init__(self, *args, data=None, task="detect", **kwargs):
|
|
37
37
|
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
|
|
38
|
-
self.use_segments = task ==
|
|
39
|
-
self.use_keypoints = task ==
|
|
40
|
-
self.use_obb = task ==
|
|
38
|
+
self.use_segments = task == "segment"
|
|
39
|
+
self.use_keypoints = task == "pose"
|
|
40
|
+
self.use_obb = task == "obb"
|
|
41
41
|
self.data = data
|
|
42
|
-
assert not (self.use_segments and self.use_keypoints),
|
|
42
|
+
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
|
43
43
|
super().__init__(*args, **kwargs)
|
|
44
44
|
|
|
45
|
-
def cache_labels(self, path=Path(
|
|
45
|
+
def cache_labels(self, path=Path("./labels.cache")):
|
|
46
46
|
"""
|
|
47
47
|
Cache dataset labels, check images and read shapes.
|
|
48
48
|
|
|
@@ -51,19 +51,29 @@ class YOLODataset(BaseDataset):
|
|
|
51
51
|
Returns:
|
|
52
52
|
(dict): labels.
|
|
53
53
|
"""
|
|
54
|
-
x = {
|
|
54
|
+
x = {"labels": []}
|
|
55
55
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
|
56
|
-
desc = f
|
|
56
|
+
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
|
57
57
|
total = len(self.im_files)
|
|
58
|
-
nkpt, ndim = self.data.get(
|
|
58
|
+
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
|
59
59
|
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
|
60
|
-
raise ValueError(
|
|
61
|
-
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
|
62
|
+
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
|
63
|
+
)
|
|
62
64
|
with ThreadPool(NUM_THREADS) as pool:
|
|
63
|
-
results = pool.imap(
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
results = pool.imap(
|
|
66
|
+
func=verify_image_label,
|
|
67
|
+
iterable=zip(
|
|
68
|
+
self.im_files,
|
|
69
|
+
self.label_files,
|
|
70
|
+
repeat(self.prefix),
|
|
71
|
+
repeat(self.use_keypoints),
|
|
72
|
+
repeat(len(self.data["names"])),
|
|
73
|
+
repeat(nkpt),
|
|
74
|
+
repeat(ndim),
|
|
75
|
+
),
|
|
76
|
+
)
|
|
67
77
|
pbar = TQDM(results, desc=desc, total=total)
|
|
68
78
|
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
|
69
79
|
nm += nm_f
|
|
@@ -71,7 +81,7 @@ class YOLODataset(BaseDataset):
|
|
|
71
81
|
ne += ne_f
|
|
72
82
|
nc += nc_f
|
|
73
83
|
if im_file:
|
|
74
|
-
x[
|
|
84
|
+
x["labels"].append(
|
|
75
85
|
dict(
|
|
76
86
|
im_file=im_file,
|
|
77
87
|
shape=shape,
|
|
@@ -80,60 +90,63 @@ class YOLODataset(BaseDataset):
|
|
|
80
90
|
segments=segments,
|
|
81
91
|
keypoints=keypoint,
|
|
82
92
|
normalized=True,
|
|
83
|
-
bbox_format=
|
|
93
|
+
bbox_format="xywh",
|
|
94
|
+
)
|
|
95
|
+
)
|
|
84
96
|
if msg:
|
|
85
97
|
msgs.append(msg)
|
|
86
|
-
pbar.desc = f
|
|
98
|
+
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
|
87
99
|
pbar.close()
|
|
88
100
|
|
|
89
101
|
if msgs:
|
|
90
|
-
LOGGER.info(
|
|
102
|
+
LOGGER.info("\n".join(msgs))
|
|
91
103
|
if nf == 0:
|
|
92
|
-
LOGGER.warning(f
|
|
93
|
-
x[
|
|
94
|
-
x[
|
|
95
|
-
x[
|
|
104
|
+
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
|
105
|
+
x["hash"] = get_hash(self.label_files + self.im_files)
|
|
106
|
+
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
|
107
|
+
x["msgs"] = msgs # warnings
|
|
96
108
|
save_dataset_cache_file(self.prefix, path, x)
|
|
97
109
|
return x
|
|
98
110
|
|
|
99
111
|
def get_labels(self):
|
|
100
112
|
"""Returns dictionary of labels for YOLO training."""
|
|
101
113
|
self.label_files = img2label_paths(self.im_files)
|
|
102
|
-
cache_path = Path(self.label_files[0]).parent.with_suffix(
|
|
114
|
+
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
|
103
115
|
try:
|
|
104
116
|
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
|
105
|
-
assert cache[
|
|
106
|
-
assert cache[
|
|
117
|
+
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
118
|
+
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
|
107
119
|
except (FileNotFoundError, AssertionError, AttributeError):
|
|
108
120
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
|
109
121
|
|
|
110
122
|
# Display cache
|
|
111
|
-
nf, nm, ne, nc, n = cache.pop(
|
|
123
|
+
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
|
112
124
|
if exists and LOCAL_RANK in (-1, 0):
|
|
113
|
-
d = f
|
|
125
|
+
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
|
114
126
|
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
|
115
|
-
if cache[
|
|
116
|
-
LOGGER.info(
|
|
127
|
+
if cache["msgs"]:
|
|
128
|
+
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
|
117
129
|
|
|
118
130
|
# Read cache
|
|
119
|
-
[cache.pop(k) for k in (
|
|
120
|
-
labels = cache[
|
|
131
|
+
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
|
132
|
+
labels = cache["labels"]
|
|
121
133
|
if not labels:
|
|
122
|
-
LOGGER.warning(f
|
|
123
|
-
self.im_files = [lb[
|
|
134
|
+
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
|
|
135
|
+
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
|
124
136
|
|
|
125
137
|
# Check if the dataset is all boxes or all segments
|
|
126
|
-
lengths = ((len(lb[
|
|
138
|
+
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
|
|
127
139
|
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
|
128
140
|
if len_segments and len_boxes != len_segments:
|
|
129
141
|
LOGGER.warning(
|
|
130
|
-
f
|
|
131
|
-
f
|
|
132
|
-
|
|
142
|
+
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
|
143
|
+
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
|
144
|
+
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
|
|
145
|
+
)
|
|
133
146
|
for lb in labels:
|
|
134
|
-
lb[
|
|
147
|
+
lb["segments"] = []
|
|
135
148
|
if len_cls == 0:
|
|
136
|
-
LOGGER.warning(f
|
|
149
|
+
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
|
|
137
150
|
return labels
|
|
138
151
|
|
|
139
152
|
def build_transforms(self, hyp=None):
|
|
@@ -145,14 +158,17 @@ class YOLODataset(BaseDataset):
|
|
|
145
158
|
else:
|
|
146
159
|
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
|
147
160
|
transforms.append(
|
|
148
|
-
Format(
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
161
|
+
Format(
|
|
162
|
+
bbox_format="xywh",
|
|
163
|
+
normalize=True,
|
|
164
|
+
return_mask=self.use_segments,
|
|
165
|
+
return_keypoint=self.use_keypoints,
|
|
166
|
+
return_obb=self.use_obb,
|
|
167
|
+
batch_idx=True,
|
|
168
|
+
mask_ratio=hyp.mask_ratio,
|
|
169
|
+
mask_overlap=hyp.overlap_mask,
|
|
170
|
+
)
|
|
171
|
+
)
|
|
156
172
|
return transforms
|
|
157
173
|
|
|
158
174
|
def close_mosaic(self, hyp):
|
|
@@ -166,11 +182,11 @@ class YOLODataset(BaseDataset):
|
|
|
166
182
|
"""Custom your label format here."""
|
|
167
183
|
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
|
168
184
|
# We can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
|
169
|
-
bboxes = label.pop(
|
|
170
|
-
segments = label.pop(
|
|
171
|
-
keypoints = label.pop(
|
|
172
|
-
bbox_format = label.pop(
|
|
173
|
-
normalized = label.pop(
|
|
185
|
+
bboxes = label.pop("bboxes")
|
|
186
|
+
segments = label.pop("segments", [])
|
|
187
|
+
keypoints = label.pop("keypoints", None)
|
|
188
|
+
bbox_format = label.pop("bbox_format")
|
|
189
|
+
normalized = label.pop("normalized")
|
|
174
190
|
|
|
175
191
|
# NOTE: do NOT resample oriented boxes
|
|
176
192
|
segment_resamples = 100 if self.use_obb else 1000
|
|
@@ -180,7 +196,7 @@ class YOLODataset(BaseDataset):
|
|
|
180
196
|
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
|
181
197
|
else:
|
|
182
198
|
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
|
183
|
-
label[
|
|
199
|
+
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
|
184
200
|
return label
|
|
185
201
|
|
|
186
202
|
@staticmethod
|
|
@@ -191,15 +207,15 @@ class YOLODataset(BaseDataset):
|
|
|
191
207
|
values = list(zip(*[list(b.values()) for b in batch]))
|
|
192
208
|
for i, k in enumerate(keys):
|
|
193
209
|
value = values[i]
|
|
194
|
-
if k ==
|
|
210
|
+
if k == "img":
|
|
195
211
|
value = torch.stack(value, 0)
|
|
196
|
-
if k in [
|
|
212
|
+
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
|
|
197
213
|
value = torch.cat(value, 0)
|
|
198
214
|
new_batch[k] = value
|
|
199
|
-
new_batch[
|
|
200
|
-
for i in range(len(new_batch[
|
|
201
|
-
new_batch[
|
|
202
|
-
new_batch[
|
|
215
|
+
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
|
216
|
+
for i in range(len(new_batch["batch_idx"])):
|
|
217
|
+
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
|
218
|
+
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
|
203
219
|
return new_batch
|
|
204
220
|
|
|
205
221
|
|
|
@@ -219,7 +235,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
219
235
|
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
|
220
236
|
"""
|
|
221
237
|
|
|
222
|
-
def __init__(self, root, args, augment=False, cache=False, prefix=
|
|
238
|
+
def __init__(self, root, args, augment=False, cache=False, prefix=""):
|
|
223
239
|
"""
|
|
224
240
|
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
|
225
241
|
|
|
@@ -231,23 +247,28 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
231
247
|
"""
|
|
232
248
|
super().__init__(root=root)
|
|
233
249
|
if augment and args.fraction < 1.0: # reduce training fraction
|
|
234
|
-
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
|
235
|
-
self.prefix = colorstr(f
|
|
236
|
-
self.cache_ram = cache is True or cache ==
|
|
237
|
-
self.cache_disk = cache ==
|
|
250
|
+
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
|
251
|
+
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
|
252
|
+
self.cache_ram = cache is True or cache == "ram"
|
|
253
|
+
self.cache_disk = cache == "disk"
|
|
238
254
|
self.samples = self.verify_images() # filter out bad images
|
|
239
|
-
self.samples = [list(x) + [Path(x[0]).with_suffix(
|
|
255
|
+
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
|
240
256
|
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
|
241
|
-
self.torch_transforms =
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
257
|
+
self.torch_transforms = (
|
|
258
|
+
classify_augmentations(
|
|
259
|
+
size=args.imgsz,
|
|
260
|
+
scale=scale,
|
|
261
|
+
hflip=args.fliplr,
|
|
262
|
+
vflip=args.flipud,
|
|
263
|
+
erasing=args.erasing,
|
|
264
|
+
auto_augment=args.auto_augment,
|
|
265
|
+
hsv_h=args.hsv_h,
|
|
266
|
+
hsv_s=args.hsv_s,
|
|
267
|
+
hsv_v=args.hsv_v,
|
|
268
|
+
)
|
|
269
|
+
if augment
|
|
270
|
+
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
|
271
|
+
)
|
|
251
272
|
|
|
252
273
|
def __getitem__(self, i):
|
|
253
274
|
"""Returns subset of data and targets corresponding to given indices."""
|
|
@@ -263,7 +284,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
263
284
|
# Convert NumPy array to PIL image
|
|
264
285
|
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
|
265
286
|
sample = self.torch_transforms(im)
|
|
266
|
-
return {
|
|
287
|
+
return {"img": sample, "cls": j}
|
|
267
288
|
|
|
268
289
|
def __len__(self) -> int:
|
|
269
290
|
"""Return the total number of samples in the dataset."""
|
|
@@ -271,19 +292,19 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
271
292
|
|
|
272
293
|
def verify_images(self):
|
|
273
294
|
"""Verify all images in dataset."""
|
|
274
|
-
desc = f
|
|
275
|
-
path = Path(self.root).with_suffix(
|
|
295
|
+
desc = f"{self.prefix}Scanning {self.root}..."
|
|
296
|
+
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
|
276
297
|
|
|
277
298
|
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
|
278
299
|
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
|
279
|
-
assert cache[
|
|
280
|
-
assert cache[
|
|
281
|
-
nf, nc, n, samples = cache.pop(
|
|
300
|
+
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
301
|
+
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
|
302
|
+
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
|
282
303
|
if LOCAL_RANK in (-1, 0):
|
|
283
|
-
d = f
|
|
304
|
+
d = f"{desc} {nf} images, {nc} corrupt"
|
|
284
305
|
TQDM(None, desc=d, total=n, initial=n)
|
|
285
|
-
if cache[
|
|
286
|
-
LOGGER.info(
|
|
306
|
+
if cache["msgs"]:
|
|
307
|
+
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
|
287
308
|
return samples
|
|
288
309
|
|
|
289
310
|
# Run scan if *.cache retrieval failed
|
|
@@ -298,13 +319,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
298
319
|
msgs.append(msg)
|
|
299
320
|
nf += nf_f
|
|
300
321
|
nc += nc_f
|
|
301
|
-
pbar.desc = f
|
|
322
|
+
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
|
302
323
|
pbar.close()
|
|
303
324
|
if msgs:
|
|
304
|
-
LOGGER.info(
|
|
305
|
-
x[
|
|
306
|
-
x[
|
|
307
|
-
x[
|
|
325
|
+
LOGGER.info("\n".join(msgs))
|
|
326
|
+
x["hash"] = get_hash([x[0] for x in self.samples])
|
|
327
|
+
x["results"] = nf, nc, len(samples), samples
|
|
328
|
+
x["msgs"] = msgs # warnings
|
|
308
329
|
save_dataset_cache_file(self.prefix, path, x)
|
|
309
330
|
return samples
|
|
310
331
|
|
|
@@ -312,6 +333,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
312
333
|
def load_dataset_cache_file(path):
|
|
313
334
|
"""Load an Ultralytics *.cache dictionary from path."""
|
|
314
335
|
import gc
|
|
336
|
+
|
|
315
337
|
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
|
316
338
|
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
|
317
339
|
gc.enable()
|
|
@@ -320,15 +342,15 @@ def load_dataset_cache_file(path):
|
|
|
320
342
|
|
|
321
343
|
def save_dataset_cache_file(prefix, path, x):
|
|
322
344
|
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
|
323
|
-
x[
|
|
345
|
+
x["version"] = DATASET_CACHE_VERSION # add cache version
|
|
324
346
|
if is_dir_writeable(path.parent):
|
|
325
347
|
if path.exists():
|
|
326
348
|
path.unlink() # remove *.cache file if exists
|
|
327
349
|
np.save(str(path), x) # save cache for next time
|
|
328
|
-
path.with_suffix(
|
|
329
|
-
LOGGER.info(f
|
|
350
|
+
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
|
351
|
+
LOGGER.info(f"{prefix}New cache created: {path}")
|
|
330
352
|
else:
|
|
331
|
-
LOGGER.warning(f
|
|
353
|
+
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
|
332
354
|
|
|
333
355
|
|
|
334
356
|
# TODO: support semantic segmentation
|