dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,777 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
import random
|
6
|
+
import subprocess
|
7
|
+
import time
|
8
|
+
import zipfile
|
9
|
+
from multiprocessing.pool import ThreadPool
|
10
|
+
from pathlib import Path
|
11
|
+
from tarfile import is_tarfile
|
12
|
+
|
13
|
+
import cv2
|
14
|
+
import numpy as np
|
15
|
+
from PIL import Image, ImageOps
|
16
|
+
|
17
|
+
from ultralytics.nn.autobackend import check_class_names
|
18
|
+
from ultralytics.utils import (
|
19
|
+
DATASETS_DIR,
|
20
|
+
LOGGER,
|
21
|
+
MACOS,
|
22
|
+
NUM_THREADS,
|
23
|
+
ROOT,
|
24
|
+
SETTINGS_FILE,
|
25
|
+
TQDM,
|
26
|
+
YAML,
|
27
|
+
clean_url,
|
28
|
+
colorstr,
|
29
|
+
emojis,
|
30
|
+
is_dir_writeable,
|
31
|
+
)
|
32
|
+
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
33
|
+
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
34
|
+
from ultralytics.utils.ops import segments2boxes
|
35
|
+
|
36
|
+
HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
|
37
|
+
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes
|
38
|
+
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
|
39
|
+
PIN_MEMORY = str(os.getenv("PIN_MEMORY", not MACOS)).lower() == "true" # global pin_memory for dataloaders
|
40
|
+
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
41
|
+
|
42
|
+
|
43
|
+
def img2label_paths(img_paths):
|
44
|
+
"""Define label paths as a function of image paths."""
|
45
|
+
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
46
|
+
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
47
|
+
|
48
|
+
|
49
|
+
def check_file_speeds(files, threshold_ms=10, threshold_mb=50, max_files=5, prefix=""):
|
50
|
+
"""
|
51
|
+
Check dataset file access speed and provide performance feedback.
|
52
|
+
|
53
|
+
This function tests the access speed of dataset files by measuring ping (stat call) time and read speed.
|
54
|
+
It samples up to 5 files from the provided list and warns if access times exceed the threshold.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
files (list): List of file paths to check for access speed.
|
58
|
+
threshold_ms (float, optional): Threshold in milliseconds for ping time warnings.
|
59
|
+
threshold_mb (float, optional): Threshold in megabytes per second for read speed warnings.
|
60
|
+
max_files (int, optional): The maximum number of files to check.
|
61
|
+
prefix (str, optional): Prefix string to add to log messages.
|
62
|
+
|
63
|
+
Examples:
|
64
|
+
>>> from pathlib import Path
|
65
|
+
>>> image_files = list(Path("dataset/images").glob("*.jpg"))
|
66
|
+
>>> check_file_speeds(image_files, threshold_ms=15)
|
67
|
+
"""
|
68
|
+
if not files or len(files) == 0:
|
69
|
+
LOGGER.warning(f"{prefix}Image speed checks: No files to check")
|
70
|
+
return
|
71
|
+
|
72
|
+
# Sample files (max 5)
|
73
|
+
files = random.sample(files, min(max_files, len(files)))
|
74
|
+
|
75
|
+
# Test ping (stat time)
|
76
|
+
ping_times = []
|
77
|
+
file_sizes = []
|
78
|
+
read_speeds = []
|
79
|
+
|
80
|
+
for f in files:
|
81
|
+
try:
|
82
|
+
# Measure ping (stat call)
|
83
|
+
start = time.perf_counter()
|
84
|
+
file_size = os.stat(f).st_size
|
85
|
+
ping_times.append((time.perf_counter() - start) * 1000) # ms
|
86
|
+
file_sizes.append(file_size)
|
87
|
+
|
88
|
+
# Measure read speed
|
89
|
+
start = time.perf_counter()
|
90
|
+
with open(f, "rb") as file_obj:
|
91
|
+
_ = file_obj.read()
|
92
|
+
read_time = time.perf_counter() - start
|
93
|
+
if read_time > 0: # Avoid division by zero
|
94
|
+
read_speeds.append(file_size / (1 << 20) / read_time) # MB/s
|
95
|
+
except Exception:
|
96
|
+
pass
|
97
|
+
|
98
|
+
if not ping_times:
|
99
|
+
LOGGER.warning(f"{prefix}Image speed checks: failed to access files")
|
100
|
+
return
|
101
|
+
|
102
|
+
# Calculate stats with uncertainties
|
103
|
+
avg_ping = np.mean(ping_times)
|
104
|
+
std_ping = np.std(ping_times, ddof=1) if len(ping_times) > 1 else 0
|
105
|
+
size_msg = f", size: {np.mean(file_sizes) / (1 << 10):.1f} KB"
|
106
|
+
ping_msg = f"ping: {avg_ping:.1f}±{std_ping:.1f} ms"
|
107
|
+
|
108
|
+
if read_speeds:
|
109
|
+
avg_speed = np.mean(read_speeds)
|
110
|
+
std_speed = np.std(read_speeds, ddof=1) if len(read_speeds) > 1 else 0
|
111
|
+
speed_msg = f", read: {avg_speed:.1f}±{std_speed:.1f} MB/s"
|
112
|
+
else:
|
113
|
+
speed_msg = ""
|
114
|
+
|
115
|
+
if avg_ping < threshold_ms or avg_speed < threshold_mb:
|
116
|
+
LOGGER.info(f"{prefix}Fast image access ✅ ({ping_msg}{speed_msg}{size_msg})")
|
117
|
+
else:
|
118
|
+
LOGGER.warning(
|
119
|
+
f"{prefix}Slow image access detected ({ping_msg}{speed_msg}{size_msg}). "
|
120
|
+
f"Use local storage instead of remote/mounted storage for better performance. "
|
121
|
+
f"See https://docs.ultralytics.com/guides/model-training-tips/"
|
122
|
+
)
|
123
|
+
|
124
|
+
|
125
|
+
def get_hash(paths):
|
126
|
+
"""Returns a single hash value of a list of paths (files or dirs)."""
|
127
|
+
size = 0
|
128
|
+
for p in paths:
|
129
|
+
try:
|
130
|
+
size += os.stat(p).st_size
|
131
|
+
except OSError:
|
132
|
+
continue
|
133
|
+
h = __import__("hashlib").sha256(str(size).encode()) # hash sizes
|
134
|
+
h.update("".join(paths).encode()) # hash paths
|
135
|
+
return h.hexdigest() # return hash
|
136
|
+
|
137
|
+
|
138
|
+
def exif_size(img: Image.Image):
|
139
|
+
"""Returns exif-corrected PIL size."""
|
140
|
+
s = img.size # (width, height)
|
141
|
+
if img.format == "JPEG": # only support JPEG images
|
142
|
+
try:
|
143
|
+
if exif := img.getexif():
|
144
|
+
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
|
145
|
+
if rotation in {6, 8}: # rotation 270 or 90
|
146
|
+
s = s[1], s[0]
|
147
|
+
except Exception:
|
148
|
+
pass
|
149
|
+
return s
|
150
|
+
|
151
|
+
|
152
|
+
def verify_image(args):
|
153
|
+
"""Verify one image."""
|
154
|
+
(im_file, cls), prefix = args
|
155
|
+
# Number (found, corrupt), message
|
156
|
+
nf, nc, msg = 0, 0, ""
|
157
|
+
try:
|
158
|
+
im = Image.open(im_file)
|
159
|
+
im.verify() # PIL verify
|
160
|
+
shape = exif_size(im) # image size
|
161
|
+
shape = (shape[1], shape[0]) # hw
|
162
|
+
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
163
|
+
assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
164
|
+
if im.format.lower() in {"jpg", "jpeg"}:
|
165
|
+
with open(im_file, "rb") as f:
|
166
|
+
f.seek(-2, 2)
|
167
|
+
if f.read() != b"\xff\xd9": # corrupt JPEG
|
168
|
+
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
169
|
+
msg = f"{prefix}{im_file}: corrupt JPEG restored and saved"
|
170
|
+
nf = 1
|
171
|
+
except Exception as e:
|
172
|
+
nc = 1
|
173
|
+
msg = f"{prefix}{im_file}: ignoring corrupt image/label: {e}"
|
174
|
+
return (im_file, cls), nf, nc, msg
|
175
|
+
|
176
|
+
|
177
|
+
def verify_image_label(args):
|
178
|
+
"""Verify one image-label pair."""
|
179
|
+
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim, single_cls = args
|
180
|
+
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
181
|
+
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
182
|
+
try:
|
183
|
+
# Verify images
|
184
|
+
im = Image.open(im_file)
|
185
|
+
im.verify() # PIL verify
|
186
|
+
shape = exif_size(im) # image size
|
187
|
+
shape = (shape[1], shape[0]) # hw
|
188
|
+
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
189
|
+
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
190
|
+
if im.format.lower() in {"jpg", "jpeg"}:
|
191
|
+
with open(im_file, "rb") as f:
|
192
|
+
f.seek(-2, 2)
|
193
|
+
if f.read() != b"\xff\xd9": # corrupt JPEG
|
194
|
+
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
195
|
+
msg = f"{prefix}{im_file}: corrupt JPEG restored and saved"
|
196
|
+
|
197
|
+
# Verify labels
|
198
|
+
if os.path.isfile(lb_file):
|
199
|
+
nf = 1 # label found
|
200
|
+
with open(lb_file, encoding="utf-8") as f:
|
201
|
+
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
202
|
+
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
|
203
|
+
classes = np.array([x[0] for x in lb], dtype=np.float32)
|
204
|
+
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
|
205
|
+
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
206
|
+
lb = np.array(lb, dtype=np.float32)
|
207
|
+
if nl := len(lb):
|
208
|
+
if keypoint:
|
209
|
+
assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
|
210
|
+
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
|
211
|
+
else:
|
212
|
+
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
213
|
+
points = lb[:, 1:]
|
214
|
+
assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
|
215
|
+
assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
|
216
|
+
|
217
|
+
# All labels
|
218
|
+
if single_cls:
|
219
|
+
lb[:, 0] = 0
|
220
|
+
max_cls = lb[:, 0].max() # max label count
|
221
|
+
assert max_cls < num_cls, (
|
222
|
+
f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
|
223
|
+
f"Possible class labels are 0-{num_cls - 1}"
|
224
|
+
)
|
225
|
+
_, i = np.unique(lb, axis=0, return_index=True)
|
226
|
+
if len(i) < nl: # duplicate row check
|
227
|
+
lb = lb[i] # remove duplicates
|
228
|
+
if segments:
|
229
|
+
segments = [segments[x] for x in i]
|
230
|
+
msg = f"{prefix}{im_file}: {nl - len(i)} duplicate labels removed"
|
231
|
+
else:
|
232
|
+
ne = 1 # label empty
|
233
|
+
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
|
234
|
+
else:
|
235
|
+
nm = 1 # label missing
|
236
|
+
lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)
|
237
|
+
if keypoint:
|
238
|
+
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
|
239
|
+
if ndim == 2:
|
240
|
+
kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
|
241
|
+
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
|
242
|
+
lb = lb[:, :5]
|
243
|
+
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
244
|
+
except Exception as e:
|
245
|
+
nc = 1
|
246
|
+
msg = f"{prefix}{im_file}: ignoring corrupt image/label: {e}"
|
247
|
+
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
248
|
+
|
249
|
+
|
250
|
+
def visualize_image_annotations(image_path, txt_path, label_map):
|
251
|
+
"""
|
252
|
+
Visualizes YOLO annotations (bounding boxes and class labels) on an image.
|
253
|
+
|
254
|
+
This function reads an image and its corresponding annotation file in YOLO format, then
|
255
|
+
draws bounding boxes around detected objects and labels them with their respective class names.
|
256
|
+
The bounding box colors are assigned based on the class ID, and the text color is dynamically
|
257
|
+
adjusted for readability, depending on the background color's luminance.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL.
|
261
|
+
txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object.
|
262
|
+
label_map (dict): A dictionary that maps class IDs (integers) to class labels (strings).
|
263
|
+
|
264
|
+
Examples:
|
265
|
+
>>> label_map = {0: "cat", 1: "dog", 2: "bird"} # It should include all annotated classes details
|
266
|
+
>>> visualize_image_annotations("path/to/image.jpg", "path/to/annotations.txt", label_map)
|
267
|
+
"""
|
268
|
+
import matplotlib.pyplot as plt
|
269
|
+
|
270
|
+
from ultralytics.utils.plotting import colors
|
271
|
+
|
272
|
+
img = np.array(Image.open(image_path))
|
273
|
+
img_height, img_width = img.shape[:2]
|
274
|
+
annotations = []
|
275
|
+
with open(txt_path, encoding="utf-8") as file:
|
276
|
+
for line in file:
|
277
|
+
class_id, x_center, y_center, width, height = map(float, line.split())
|
278
|
+
x = (x_center - width / 2) * img_width
|
279
|
+
y = (y_center - height / 2) * img_height
|
280
|
+
w = width * img_width
|
281
|
+
h = height * img_height
|
282
|
+
annotations.append((x, y, w, h, int(class_id)))
|
283
|
+
fig, ax = plt.subplots(1) # Plot the image and annotations
|
284
|
+
for x, y, w, h, label in annotations:
|
285
|
+
color = tuple(c / 255 for c in colors(label, True)) # Get and normalize the RGB color
|
286
|
+
rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor="none") # Create a rectangle
|
287
|
+
ax.add_patch(rect)
|
288
|
+
luminance = 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] # Formula for luminance
|
289
|
+
ax.text(x, y - 5, label_map[label], color="white" if luminance < 0.5 else "black", backgroundcolor=color)
|
290
|
+
ax.imshow(img)
|
291
|
+
plt.show()
|
292
|
+
|
293
|
+
|
294
|
+
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
|
295
|
+
"""
|
296
|
+
Convert a list of polygons to a binary mask of the specified image size.
|
297
|
+
|
298
|
+
Args:
|
299
|
+
imgsz (tuple): The size of the image as (height, width).
|
300
|
+
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
|
301
|
+
N is the number of polygons, and M is the number of points such that M % 2 = 0.
|
302
|
+
color (int, optional): The color value to fill in the polygons on the mask.
|
303
|
+
downsample_ratio (int, optional): Factor by which to downsample the mask.
|
304
|
+
|
305
|
+
Returns:
|
306
|
+
(np.ndarray): A binary mask of the specified image size with the polygons filled in.
|
307
|
+
"""
|
308
|
+
mask = np.zeros(imgsz, dtype=np.uint8)
|
309
|
+
polygons = np.asarray(polygons, dtype=np.int32)
|
310
|
+
polygons = polygons.reshape((polygons.shape[0], -1, 2))
|
311
|
+
cv2.fillPoly(mask, polygons, color=color)
|
312
|
+
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
|
313
|
+
# Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1
|
314
|
+
return cv2.resize(mask, (nw, nh))
|
315
|
+
|
316
|
+
|
317
|
+
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
|
318
|
+
"""
|
319
|
+
Convert a list of polygons to a set of binary masks of the specified image size.
|
320
|
+
|
321
|
+
Args:
|
322
|
+
imgsz (tuple): The size of the image as (height, width).
|
323
|
+
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
|
324
|
+
N is the number of polygons, and M is the number of points such that M % 2 = 0.
|
325
|
+
color (int): The color value to fill in the polygons on the masks.
|
326
|
+
downsample_ratio (int, optional): Factor by which to downsample each mask.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
(np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
|
330
|
+
"""
|
331
|
+
return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
|
332
|
+
|
333
|
+
|
334
|
+
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
335
|
+
"""Return a (640, 640) overlap mask."""
|
336
|
+
masks = np.zeros(
|
337
|
+
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
338
|
+
dtype=np.int32 if len(segments) > 255 else np.uint8,
|
339
|
+
)
|
340
|
+
areas = []
|
341
|
+
ms = []
|
342
|
+
for si in range(len(segments)):
|
343
|
+
mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
|
344
|
+
ms.append(mask.astype(masks.dtype))
|
345
|
+
areas.append(mask.sum())
|
346
|
+
areas = np.asarray(areas)
|
347
|
+
index = np.argsort(-areas)
|
348
|
+
ms = np.array(ms)[index]
|
349
|
+
for i in range(len(segments)):
|
350
|
+
mask = ms[i] * (i + 1)
|
351
|
+
masks = masks + mask
|
352
|
+
masks = np.clip(masks, a_min=0, a_max=i + 1)
|
353
|
+
return masks, index
|
354
|
+
|
355
|
+
|
356
|
+
def find_dataset_yaml(path: Path) -> Path:
|
357
|
+
"""
|
358
|
+
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
|
359
|
+
|
360
|
+
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
|
361
|
+
performs a recursive search. It prefers YAML files that have the same stem as the provided path.
|
362
|
+
|
363
|
+
Args:
|
364
|
+
path (Path): The directory path to search for the YAML file.
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
(Path): The path of the found YAML file.
|
368
|
+
"""
|
369
|
+
files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
|
370
|
+
assert files, f"No YAML file found in '{path.resolve()}'"
|
371
|
+
if len(files) > 1:
|
372
|
+
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
|
373
|
+
assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
|
374
|
+
return files[0]
|
375
|
+
|
376
|
+
|
377
|
+
def check_det_dataset(dataset, autodownload=True):
|
378
|
+
"""
|
379
|
+
Download, verify, and/or unzip a dataset if not found locally.
|
380
|
+
|
381
|
+
This function checks the availability of a specified dataset, and if not found, it has the option to download and
|
382
|
+
unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
|
383
|
+
resolves paths related to the dataset.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
|
387
|
+
autodownload (bool, optional): Whether to automatically download the dataset if not found.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
(dict): Parsed dataset information and paths.
|
391
|
+
"""
|
392
|
+
file = check_file(dataset)
|
393
|
+
|
394
|
+
# Download (optional)
|
395
|
+
extract_dir = ""
|
396
|
+
if zipfile.is_zipfile(file) or is_tarfile(file):
|
397
|
+
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
398
|
+
file = find_dataset_yaml(DATASETS_DIR / new_dir)
|
399
|
+
extract_dir, autodownload = file.parent, False
|
400
|
+
|
401
|
+
# Read YAML
|
402
|
+
data = YAML.load(file, append_filename=True) # dictionary
|
403
|
+
|
404
|
+
# Checks
|
405
|
+
for k in "train", "val":
|
406
|
+
if k not in data:
|
407
|
+
if k != "val" or "validation" not in data:
|
408
|
+
raise SyntaxError(
|
409
|
+
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
|
410
|
+
)
|
411
|
+
LOGGER.warning("renaming data YAML 'validation' key to 'val' to match YOLO format.")
|
412
|
+
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
|
413
|
+
if "names" not in data and "nc" not in data:
|
414
|
+
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
415
|
+
if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
|
416
|
+
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
|
417
|
+
if "names" not in data:
|
418
|
+
data["names"] = [f"class_{i}" for i in range(data["nc"])]
|
419
|
+
else:
|
420
|
+
data["nc"] = len(data["names"])
|
421
|
+
|
422
|
+
data["names"] = check_class_names(data["names"])
|
423
|
+
data["channels"] = data.get("channels", 3) # get image channels, default to 3
|
424
|
+
|
425
|
+
# Resolve paths
|
426
|
+
path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
|
427
|
+
if not path.exists() and not path.is_absolute():
|
428
|
+
path = (DATASETS_DIR / path).resolve() # path relative to DATASETS_DIR
|
429
|
+
|
430
|
+
# Set paths
|
431
|
+
data["path"] = path # download scripts
|
432
|
+
for k in "train", "val", "test", "minival":
|
433
|
+
if data.get(k): # prepend path
|
434
|
+
if isinstance(data[k], str):
|
435
|
+
x = (path / data[k]).resolve()
|
436
|
+
if not x.exists() and data[k].startswith("../"):
|
437
|
+
x = (path / data[k][3:]).resolve()
|
438
|
+
data[k] = str(x)
|
439
|
+
else:
|
440
|
+
data[k] = [str((path / x).resolve()) for x in data[k]]
|
441
|
+
|
442
|
+
# Parse YAML
|
443
|
+
val, s = (data.get(x) for x in ("val", "download"))
|
444
|
+
if val:
|
445
|
+
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
446
|
+
if not all(x.exists() for x in val):
|
447
|
+
name = clean_url(dataset) # dataset name with URL auth stripped
|
448
|
+
LOGGER.info("")
|
449
|
+
m = f"Dataset '{name}' images not found, missing path '{[x for x in val if not x.exists()][0]}'"
|
450
|
+
if s and autodownload:
|
451
|
+
LOGGER.warning(m)
|
452
|
+
else:
|
453
|
+
m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'"
|
454
|
+
raise FileNotFoundError(m)
|
455
|
+
t = time.time()
|
456
|
+
r = None # success
|
457
|
+
if s.startswith("http") and s.endswith(".zip"): # URL
|
458
|
+
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
459
|
+
elif s.startswith("bash "): # bash script
|
460
|
+
LOGGER.info(f"Running {s} ...")
|
461
|
+
r = os.system(s)
|
462
|
+
else: # python script
|
463
|
+
exec(s, {"yaml": data})
|
464
|
+
dt = f"({round(time.time() - t, 1)}s)"
|
465
|
+
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
|
466
|
+
LOGGER.info(f"Dataset download {s}\n")
|
467
|
+
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
468
|
+
|
469
|
+
return data # dictionary
|
470
|
+
|
471
|
+
|
472
|
+
def check_cls_dataset(dataset, split=""):
|
473
|
+
"""
|
474
|
+
Checks a classification dataset such as Imagenet.
|
475
|
+
|
476
|
+
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
|
477
|
+
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
|
478
|
+
|
479
|
+
Args:
|
480
|
+
dataset (str | Path): The name of the dataset.
|
481
|
+
split (str, optional): The split of the dataset. Either 'val', 'test', or ''.
|
482
|
+
|
483
|
+
Returns:
|
484
|
+
(dict): A dictionary containing the following keys:
|
485
|
+
|
486
|
+
- 'train' (Path): The directory path containing the training set of the dataset.
|
487
|
+
- 'val' (Path): The directory path containing the validation set of the dataset.
|
488
|
+
- 'test' (Path): The directory path containing the test set of the dataset.
|
489
|
+
- 'nc' (int): The number of classes in the dataset.
|
490
|
+
- 'names' (dict): A dictionary of class names in the dataset.
|
491
|
+
"""
|
492
|
+
# Download (optional if dataset=https://file.zip is passed directly)
|
493
|
+
if str(dataset).startswith(("http:/", "https:/")):
|
494
|
+
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
495
|
+
elif str(dataset).endswith((".zip", ".tar", ".gz")):
|
496
|
+
file = check_file(dataset)
|
497
|
+
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
498
|
+
|
499
|
+
dataset = Path(dataset)
|
500
|
+
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
501
|
+
if not data_dir.is_dir():
|
502
|
+
LOGGER.info("")
|
503
|
+
LOGGER.warning(f"Dataset not found, missing path {data_dir}, attempting download...")
|
504
|
+
t = time.time()
|
505
|
+
if str(dataset) == "imagenet":
|
506
|
+
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
507
|
+
else:
|
508
|
+
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip"
|
509
|
+
download(url, dir=data_dir.parent)
|
510
|
+
LOGGER.info(f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n")
|
511
|
+
train_set = data_dir / "train"
|
512
|
+
if not train_set.is_dir():
|
513
|
+
LOGGER.warning(f"Dataset 'split=train' not found at {train_set}")
|
514
|
+
image_files = list(data_dir.rglob("*.jpg")) + list(data_dir.rglob("*.png"))
|
515
|
+
if image_files:
|
516
|
+
from ultralytics.data.split import split_classify_dataset
|
517
|
+
|
518
|
+
LOGGER.info(f"Found {len(image_files)} images in subdirectories. Attempting to split...")
|
519
|
+
data_dir = split_classify_dataset(data_dir, train_ratio=0.8)
|
520
|
+
train_set = data_dir / "train"
|
521
|
+
else:
|
522
|
+
LOGGER.error(f"No images found in {data_dir} or its subdirectories.")
|
523
|
+
val_set = (
|
524
|
+
data_dir / "val"
|
525
|
+
if (data_dir / "val").exists()
|
526
|
+
else data_dir / "validation"
|
527
|
+
if (data_dir / "validation").exists()
|
528
|
+
else None
|
529
|
+
) # data/test or data/val
|
530
|
+
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
|
531
|
+
if split == "val" and not val_set:
|
532
|
+
LOGGER.warning("Dataset 'split=val' not found, using 'split=test' instead.")
|
533
|
+
val_set = test_set
|
534
|
+
elif split == "test" and not test_set:
|
535
|
+
LOGGER.warning("Dataset 'split=test' not found, using 'split=val' instead.")
|
536
|
+
test_set = val_set
|
537
|
+
|
538
|
+
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
|
539
|
+
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
|
540
|
+
names = dict(enumerate(sorted(names)))
|
541
|
+
|
542
|
+
# Print to console
|
543
|
+
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
|
544
|
+
prefix = f"{colorstr(f'{k}:')} {v}..."
|
545
|
+
if v is None:
|
546
|
+
LOGGER.info(prefix)
|
547
|
+
else:
|
548
|
+
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
|
549
|
+
nf = len(files) # number of files
|
550
|
+
nd = len({file.parent for file in files}) # number of directories
|
551
|
+
if nf == 0:
|
552
|
+
if k == "train":
|
553
|
+
raise FileNotFoundError(f"{dataset} '{k}:' no training images found")
|
554
|
+
else:
|
555
|
+
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes (no images found)")
|
556
|
+
elif nd != nc:
|
557
|
+
LOGGER.error(f"{prefix} found {nf} images in {nd} classes (requires {nc} classes, not {nd})")
|
558
|
+
else:
|
559
|
+
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
|
560
|
+
|
561
|
+
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names, "channels": 3}
|
562
|
+
|
563
|
+
|
564
|
+
class HUBDatasetStats:
|
565
|
+
"""
|
566
|
+
A class for generating HUB dataset JSON and `-hub` dataset directory.
|
567
|
+
|
568
|
+
Args:
|
569
|
+
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
|
570
|
+
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
|
571
|
+
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
|
572
|
+
|
573
|
+
Note:
|
574
|
+
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
575
|
+
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
576
|
+
|
577
|
+
Examples:
|
578
|
+
>>> from ultralytics.data.utils import HUBDatasetStats
|
579
|
+
>>> stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset
|
580
|
+
>>> stats = HUBDatasetStats("path/to/coco8-seg.zip", task="segment") # segment dataset
|
581
|
+
>>> stats = HUBDatasetStats("path/to/coco8-pose.zip", task="pose") # pose dataset
|
582
|
+
>>> stats = HUBDatasetStats("path/to/dota8.zip", task="obb") # OBB dataset
|
583
|
+
>>> stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset
|
584
|
+
>>> stats.get_json(save=True)
|
585
|
+
>>> stats.process_images()
|
586
|
+
"""
|
587
|
+
|
588
|
+
def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
|
589
|
+
"""Initialize class."""
|
590
|
+
path = Path(path).resolve()
|
591
|
+
LOGGER.info(f"Starting HUB dataset checks for {path}....")
|
592
|
+
|
593
|
+
self.task = task # detect, segment, pose, classify, obb
|
594
|
+
if self.task == "classify":
|
595
|
+
unzip_dir = unzip_file(path)
|
596
|
+
data = check_cls_dataset(unzip_dir)
|
597
|
+
data["path"] = unzip_dir
|
598
|
+
else: # detect, segment, pose, obb
|
599
|
+
_, data_dir, yaml_path = self._unzip(Path(path))
|
600
|
+
try:
|
601
|
+
# Load YAML with checks
|
602
|
+
data = YAML.load(yaml_path)
|
603
|
+
data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
|
604
|
+
YAML.save(yaml_path, data)
|
605
|
+
data = check_det_dataset(yaml_path, autodownload) # dict
|
606
|
+
data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
|
607
|
+
except Exception as e:
|
608
|
+
raise Exception("error/HUB/dataset_stats/init") from e
|
609
|
+
|
610
|
+
self.hub_dir = Path(f"{data['path']}-hub")
|
611
|
+
self.im_dir = self.hub_dir / "images"
|
612
|
+
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
|
613
|
+
self.data = data
|
614
|
+
|
615
|
+
@staticmethod
|
616
|
+
def _unzip(path):
|
617
|
+
"""Unzip data.zip."""
|
618
|
+
if not str(path).endswith(".zip"): # path is data.yaml
|
619
|
+
return False, None, path
|
620
|
+
unzip_dir = unzip_file(path, path=path.parent)
|
621
|
+
assert unzip_dir.is_dir(), (
|
622
|
+
f"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/"
|
623
|
+
)
|
624
|
+
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
625
|
+
|
626
|
+
def _hub_ops(self, f):
|
627
|
+
"""Saves a compressed image for HUB previews."""
|
628
|
+
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
|
629
|
+
|
630
|
+
def get_json(self, save=False, verbose=False):
|
631
|
+
"""Return dataset JSON for Ultralytics HUB."""
|
632
|
+
|
633
|
+
def _round(labels):
|
634
|
+
"""Update labels to integer class and 4 decimal place floats."""
|
635
|
+
if self.task == "detect":
|
636
|
+
coordinates = labels["bboxes"]
|
637
|
+
elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy
|
638
|
+
coordinates = [x.flatten() for x in labels["segments"]]
|
639
|
+
elif self.task == "pose":
|
640
|
+
n, nk, nd = labels["keypoints"].shape
|
641
|
+
coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1)
|
642
|
+
else:
|
643
|
+
raise ValueError(f"Undefined dataset task={self.task}.")
|
644
|
+
zipped = zip(labels["cls"], coordinates)
|
645
|
+
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
646
|
+
|
647
|
+
for split in "train", "val", "test":
|
648
|
+
self.stats[split] = None # predefine
|
649
|
+
path = self.data.get(split)
|
650
|
+
|
651
|
+
# Check split
|
652
|
+
if path is None: # no split
|
653
|
+
continue
|
654
|
+
files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
|
655
|
+
if not files: # no images
|
656
|
+
continue
|
657
|
+
|
658
|
+
# Get dataset statistics
|
659
|
+
if self.task == "classify":
|
660
|
+
from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics'
|
661
|
+
|
662
|
+
dataset = ImageFolder(self.data[split])
|
663
|
+
|
664
|
+
x = np.zeros(len(dataset.classes)).astype(int)
|
665
|
+
for im in dataset.imgs:
|
666
|
+
x[im[1]] += 1
|
667
|
+
|
668
|
+
self.stats[split] = {
|
669
|
+
"instance_stats": {"total": len(dataset), "per_class": x.tolist()},
|
670
|
+
"image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
|
671
|
+
"labels": [{Path(k).name: v} for k, v in dataset.imgs],
|
672
|
+
}
|
673
|
+
else:
|
674
|
+
from ultralytics.data import YOLODataset
|
675
|
+
|
676
|
+
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
|
677
|
+
x = np.array(
|
678
|
+
[
|
679
|
+
np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
|
680
|
+
for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
|
681
|
+
]
|
682
|
+
) # shape(128x80)
|
683
|
+
self.stats[split] = {
|
684
|
+
"instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
|
685
|
+
"image_stats": {
|
686
|
+
"total": len(dataset),
|
687
|
+
"unlabelled": int(np.all(x == 0, 1).sum()),
|
688
|
+
"per_class": (x > 0).sum(0).tolist(),
|
689
|
+
},
|
690
|
+
"labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
|
691
|
+
}
|
692
|
+
|
693
|
+
# Save, print and return
|
694
|
+
if save:
|
695
|
+
self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/
|
696
|
+
stats_path = self.hub_dir / "stats.json"
|
697
|
+
LOGGER.info(f"Saving {stats_path.resolve()}...")
|
698
|
+
with open(stats_path, "w", encoding="utf-8") as f:
|
699
|
+
json.dump(self.stats, f) # save stats.json
|
700
|
+
if verbose:
|
701
|
+
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
702
|
+
return self.stats
|
703
|
+
|
704
|
+
def process_images(self):
|
705
|
+
"""Compress images for Ultralytics HUB."""
|
706
|
+
from ultralytics.data import YOLODataset # ClassificationDataset
|
707
|
+
|
708
|
+
self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/
|
709
|
+
for split in "train", "val", "test":
|
710
|
+
if self.data.get(split) is None:
|
711
|
+
continue
|
712
|
+
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
713
|
+
with ThreadPool(NUM_THREADS) as pool:
|
714
|
+
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
|
715
|
+
pass
|
716
|
+
LOGGER.info(f"Done. All images saved to {self.im_dir}")
|
717
|
+
return self.im_dir
|
718
|
+
|
719
|
+
|
720
|
+
def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
721
|
+
"""
|
722
|
+
Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
|
723
|
+
Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
|
724
|
+
resized.
|
725
|
+
|
726
|
+
Args:
|
727
|
+
f (str): The path to the input image file.
|
728
|
+
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
|
729
|
+
max_dim (int, optional): The maximum dimension (width or height) of the output image.
|
730
|
+
quality (int, optional): The image compression quality as a percentage.
|
731
|
+
|
732
|
+
Examples:
|
733
|
+
>>> from pathlib import Path
|
734
|
+
>>> from ultralytics.data.utils import compress_one_image
|
735
|
+
>>> for f in Path("path/to/dataset").rglob("*.jpg"):
|
736
|
+
>>> compress_one_image(f)
|
737
|
+
"""
|
738
|
+
try: # use PIL
|
739
|
+
Image.MAX_IMAGE_PIXELS = None # Fix DecompressionBombError, allow optimization of image > ~178.9 million pixels
|
740
|
+
im = Image.open(f)
|
741
|
+
if im.mode in {"RGBA", "LA"}: # Convert to RGB if needed (for JPEG)
|
742
|
+
im = im.convert("RGB")
|
743
|
+
r = max_dim / max(im.height, im.width) # ratio
|
744
|
+
if r < 1.0: # image too large
|
745
|
+
im = im.resize((int(im.width * r), int(im.height * r)))
|
746
|
+
im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
|
747
|
+
except Exception as e: # use OpenCV
|
748
|
+
LOGGER.warning(f"HUB ops PIL failure {f}: {e}")
|
749
|
+
im = cv2.imread(f)
|
750
|
+
im_height, im_width = im.shape[:2]
|
751
|
+
r = max_dim / max(im_height, im_width) # ratio
|
752
|
+
if r < 1.0: # image too large
|
753
|
+
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
754
|
+
cv2.imwrite(str(f_new or f), im)
|
755
|
+
|
756
|
+
|
757
|
+
def load_dataset_cache_file(path):
|
758
|
+
"""Load an Ultralytics *.cache dictionary from path."""
|
759
|
+
import gc
|
760
|
+
|
761
|
+
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
762
|
+
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
763
|
+
gc.enable()
|
764
|
+
return cache
|
765
|
+
|
766
|
+
|
767
|
+
def save_dataset_cache_file(prefix, path, x, version):
|
768
|
+
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
769
|
+
x["version"] = version # add cache version
|
770
|
+
if is_dir_writeable(path.parent):
|
771
|
+
if path.exists():
|
772
|
+
path.unlink() # remove *.cache file if exists
|
773
|
+
with open(str(path), "wb") as file: # context manager here fixes windows async np.save bug
|
774
|
+
np.save(file, x)
|
775
|
+
LOGGER.info(f"{prefix}New cache created: {path}")
|
776
|
+
else:
|
777
|
+
LOGGER.warning(f"{prefix}Cache directory {path.parent} is not writeable, cache not saved.")
|