dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,243 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""
|
3
|
+
Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance
|
4
|
+
segmentation, image classification, pose estimation, and multi-object tracking.
|
5
|
+
|
6
|
+
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
|
7
|
+
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
|
8
|
+
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
|
9
|
+
|
10
|
+
Examples:
|
11
|
+
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
12
|
+
>>> from ultralytics import YOLO
|
13
|
+
>>> model = YOLO("yolo11n.pt")
|
14
|
+
>>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
15
|
+
"""
|
16
|
+
|
17
|
+
import random
|
18
|
+
import shutil
|
19
|
+
import subprocess
|
20
|
+
import time
|
21
|
+
|
22
|
+
import numpy as np
|
23
|
+
import torch
|
24
|
+
|
25
|
+
from ultralytics.cfg import get_cfg, get_save_dir
|
26
|
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr
|
27
|
+
from ultralytics.utils.plotting import plot_tune_results
|
28
|
+
|
29
|
+
|
30
|
+
class Tuner:
|
31
|
+
"""
|
32
|
+
A class for hyperparameter tuning of YOLO models.
|
33
|
+
|
34
|
+
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
|
35
|
+
search space and retraining the model to evaluate their performance.
|
36
|
+
|
37
|
+
Attributes:
|
38
|
+
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
|
39
|
+
tune_dir (Path): Directory where evolution logs and results will be saved.
|
40
|
+
tune_csv (Path): Path to the CSV file where evolution logs are saved.
|
41
|
+
args (dict): Configuration arguments for the tuning process.
|
42
|
+
callbacks (list): Callback functions to be executed during tuning.
|
43
|
+
prefix (str): Prefix string for logging messages.
|
44
|
+
|
45
|
+
Methods:
|
46
|
+
_mutate: Mutates the given hyperparameters within the specified bounds.
|
47
|
+
__call__: Executes the hyperparameter evolution across multiple iterations.
|
48
|
+
|
49
|
+
Examples:
|
50
|
+
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
51
|
+
>>> from ultralytics import YOLO
|
52
|
+
>>> model = YOLO("yolo11n.pt")
|
53
|
+
>>> model.tune(
|
54
|
+
... data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False
|
55
|
+
... )
|
56
|
+
|
57
|
+
Tune with custom search space.
|
58
|
+
>>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
|
59
|
+
"""
|
60
|
+
|
61
|
+
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
|
62
|
+
"""
|
63
|
+
Initialize the Tuner with configurations.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
args (dict): Configuration for hyperparameter evolution.
|
67
|
+
_callbacks (list, optional): Callback functions to be executed during tuning.
|
68
|
+
"""
|
69
|
+
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
70
|
+
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
71
|
+
"lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
72
|
+
"lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
|
73
|
+
"momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
|
74
|
+
"weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
|
75
|
+
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
|
76
|
+
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
|
77
|
+
"box": (1.0, 20.0), # box loss gain
|
78
|
+
"cls": (0.2, 4.0), # cls loss gain (scale with pixels)
|
79
|
+
"dfl": (0.4, 6.0), # dfl loss gain
|
80
|
+
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
81
|
+
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
82
|
+
"hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
|
83
|
+
"degrees": (0.0, 45.0), # image rotation (+/- deg)
|
84
|
+
"translate": (0.0, 0.9), # image translation (+/- fraction)
|
85
|
+
"scale": (0.0, 0.95), # image scale (+/- gain)
|
86
|
+
"shear": (0.0, 10.0), # image shear (+/- deg)
|
87
|
+
"perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
88
|
+
"flipud": (0.0, 1.0), # image flip up-down (probability)
|
89
|
+
"fliplr": (0.0, 1.0), # image flip left-right (probability)
|
90
|
+
"bgr": (0.0, 1.0), # image channel bgr (probability)
|
91
|
+
"mosaic": (0.0, 1.0), # image mosaic (probability)
|
92
|
+
"mixup": (0.0, 1.0), # image mixup (probability)
|
93
|
+
"cutmix": (0.0, 1.0), # image cutmix (probability)
|
94
|
+
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
|
95
|
+
}
|
96
|
+
self.args = get_cfg(overrides=args)
|
97
|
+
self.args.exist_ok = self.args.resume # resume w/ same tune_dir
|
98
|
+
self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune")
|
99
|
+
self.args.name, self.args.exist_ok, self.args.resume = (None, False, False) # reset to not affect training
|
100
|
+
self.tune_csv = self.tune_dir / "tune_results.csv"
|
101
|
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
102
|
+
self.prefix = colorstr("Tuner: ")
|
103
|
+
callbacks.add_integration_callbacks(self)
|
104
|
+
LOGGER.info(
|
105
|
+
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
|
106
|
+
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
|
107
|
+
)
|
108
|
+
|
109
|
+
def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
|
110
|
+
"""
|
111
|
+
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
parent (str): Parent selection method: 'single' or 'weighted'.
|
115
|
+
n (int): Number of parents to consider.
|
116
|
+
mutation (float): Probability of a parameter mutation in any given iteration.
|
117
|
+
sigma (float): Standard deviation for Gaussian random number generator.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
(dict): A dictionary containing mutated hyperparameters.
|
121
|
+
"""
|
122
|
+
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
|
123
|
+
# Select parent(s)
|
124
|
+
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
125
|
+
fitness = x[:, 0] # first column
|
126
|
+
n = min(n, len(x)) # number of previous results to consider
|
127
|
+
x = x[np.argsort(-fitness)][:n] # top n mutations
|
128
|
+
w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
|
129
|
+
if parent == "single" or len(x) == 1:
|
130
|
+
# x = x[random.randint(0, n - 1)] # random selection
|
131
|
+
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
|
132
|
+
elif parent == "weighted":
|
133
|
+
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
|
134
|
+
|
135
|
+
# Mutate
|
136
|
+
r = np.random # method
|
137
|
+
r.seed(int(time.time()))
|
138
|
+
g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
|
139
|
+
ng = len(self.space)
|
140
|
+
v = np.ones(ng)
|
141
|
+
while all(v == 1): # mutate until a change occurs (prevent duplicates)
|
142
|
+
v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
|
143
|
+
hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
|
144
|
+
else:
|
145
|
+
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
|
146
|
+
|
147
|
+
# Constrain to limits
|
148
|
+
for k, v in self.space.items():
|
149
|
+
hyp[k] = max(hyp[k], v[0]) # lower limit
|
150
|
+
hyp[k] = min(hyp[k], v[1]) # upper limit
|
151
|
+
hyp[k] = round(hyp[k], 5) # significant digits
|
152
|
+
|
153
|
+
return hyp
|
154
|
+
|
155
|
+
def __call__(self, model=None, iterations=10, cleanup=True):
|
156
|
+
"""
|
157
|
+
Execute the hyperparameter evolution process when the Tuner instance is called.
|
158
|
+
|
159
|
+
This method iterates through the number of iterations, performing the following steps in each iteration:
|
160
|
+
|
161
|
+
1. Load the existing hyperparameters or initialize new ones.
|
162
|
+
2. Mutate the hyperparameters using the `mutate` method.
|
163
|
+
3. Train a YOLO model with the mutated hyperparameters.
|
164
|
+
4. Log the fitness score and mutated hyperparameters to a CSV file.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
model (Model): A pre-initialized YOLO model to be used for training.
|
168
|
+
iterations (int): The number of generations to run the evolution for.
|
169
|
+
cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
|
170
|
+
|
171
|
+
Note:
|
172
|
+
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
|
173
|
+
Ensure this path is set correctly in the Tuner instance.
|
174
|
+
"""
|
175
|
+
t0 = time.time()
|
176
|
+
best_save_dir, best_metrics = None, None
|
177
|
+
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
178
|
+
start = 0
|
179
|
+
if self.tune_csv.exists():
|
180
|
+
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
181
|
+
start = x.shape[0]
|
182
|
+
LOGGER.info(f"{self.prefix}Resuming tuning run {self.tune_dir} from iteration {start + 1}...")
|
183
|
+
for i in range(start, iterations):
|
184
|
+
# Mutate hyperparameters
|
185
|
+
mutated_hyp = self._mutate()
|
186
|
+
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
|
187
|
+
|
188
|
+
metrics = {}
|
189
|
+
train_args = {**vars(self.args), **mutated_hyp}
|
190
|
+
save_dir = get_save_dir(get_cfg(train_args))
|
191
|
+
weights_dir = save_dir / "weights"
|
192
|
+
try:
|
193
|
+
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
|
194
|
+
launch = [__import__("sys").executable, "-m", "ultralytics.cfg.__init__"] # workaround yolo not found
|
195
|
+
cmd = [*launch, "train", *(f"{k}={v}" for k, v in train_args.items())]
|
196
|
+
return_code = subprocess.run(cmd, check=True).returncode
|
197
|
+
ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
|
198
|
+
metrics = torch.load(ckpt_file)["train_metrics"]
|
199
|
+
assert return_code == 0, "training failed"
|
200
|
+
|
201
|
+
except Exception as e:
|
202
|
+
LOGGER.error(f"training failure for hyperparameter tuning iteration {i + 1}\n{e}")
|
203
|
+
|
204
|
+
# Save results and mutated_hyp to CSV
|
205
|
+
fitness = metrics.get("fitness", 0.0)
|
206
|
+
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
207
|
+
headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
|
208
|
+
with open(self.tune_csv, "a", encoding="utf-8") as f:
|
209
|
+
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
210
|
+
|
211
|
+
# Get best results
|
212
|
+
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
213
|
+
fitness = x[:, 0] # first column
|
214
|
+
best_idx = fitness.argmax()
|
215
|
+
best_is_current = best_idx == i
|
216
|
+
if best_is_current:
|
217
|
+
best_save_dir = save_dir
|
218
|
+
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
|
219
|
+
for ckpt in weights_dir.glob("*.pt"):
|
220
|
+
shutil.copy2(ckpt, self.tune_dir / "weights")
|
221
|
+
elif cleanup:
|
222
|
+
shutil.rmtree(weights_dir, ignore_errors=True) # remove iteration weights/ dir to reduce storage space
|
223
|
+
|
224
|
+
# Plot tune results
|
225
|
+
plot_tune_results(self.tune_csv)
|
226
|
+
|
227
|
+
# Save and print tune results
|
228
|
+
header = (
|
229
|
+
f"{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n"
|
230
|
+
f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n"
|
231
|
+
f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n"
|
232
|
+
f"{self.prefix}Best fitness metrics are {best_metrics}\n"
|
233
|
+
f"{self.prefix}Best fitness model is {best_save_dir}\n"
|
234
|
+
f"{self.prefix}Best fitness hyperparameters are printed below.\n"
|
235
|
+
)
|
236
|
+
LOGGER.info("\n" + header)
|
237
|
+
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|
238
|
+
YAML.save(
|
239
|
+
self.tune_dir / "best_hyperparameters.yaml",
|
240
|
+
data=data,
|
241
|
+
header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
|
242
|
+
)
|
243
|
+
YAML.print(self.tune_dir / "best_hyperparameters.yaml")
|
@@ -0,0 +1,377 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""
|
3
|
+
Check a model's accuracy on a test or val split of a dataset.
|
4
|
+
|
5
|
+
Usage:
|
6
|
+
$ yolo mode=val model=yolo11n.pt data=coco8.yaml imgsz=640
|
7
|
+
|
8
|
+
Usage - formats:
|
9
|
+
$ yolo mode=val model=yolo11n.pt # PyTorch
|
10
|
+
yolo11n.torchscript # TorchScript
|
11
|
+
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
12
|
+
yolo11n_openvino_model # OpenVINO
|
13
|
+
yolo11n.engine # TensorRT
|
14
|
+
yolo11n.mlpackage # CoreML (macOS-only)
|
15
|
+
yolo11n_saved_model # TensorFlow SavedModel
|
16
|
+
yolo11n.pb # TensorFlow GraphDef
|
17
|
+
yolo11n.tflite # TensorFlow Lite
|
18
|
+
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
|
19
|
+
yolo11n_paddle_model # PaddlePaddle
|
20
|
+
yolo11n.mnn # MNN
|
21
|
+
yolo11n_ncnn_model # NCNN
|
22
|
+
yolo11n_imx_model # Sony IMX
|
23
|
+
yolo11n_rknn_model # Rockchip RKNN
|
24
|
+
"""
|
25
|
+
|
26
|
+
import json
|
27
|
+
import time
|
28
|
+
from pathlib import Path
|
29
|
+
|
30
|
+
import numpy as np
|
31
|
+
import torch
|
32
|
+
|
33
|
+
from ultralytics.cfg import get_cfg, get_save_dir
|
34
|
+
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
35
|
+
from ultralytics.nn.autobackend import AutoBackend
|
36
|
+
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
37
|
+
from ultralytics.utils.checks import check_imgsz
|
38
|
+
from ultralytics.utils.ops import Profile
|
39
|
+
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
40
|
+
|
41
|
+
|
42
|
+
class BaseValidator:
|
43
|
+
"""
|
44
|
+
A base class for creating validators.
|
45
|
+
|
46
|
+
This class provides the foundation for validation processes, including model evaluation, metric computation, and
|
47
|
+
result visualization.
|
48
|
+
|
49
|
+
Attributes:
|
50
|
+
args (SimpleNamespace): Configuration for the validator.
|
51
|
+
dataloader (DataLoader): Dataloader to use for validation.
|
52
|
+
pbar (tqdm): Progress bar to update during validation.
|
53
|
+
model (nn.Module): Model to validate.
|
54
|
+
data (dict): Data dictionary containing dataset information.
|
55
|
+
device (torch.device): Device to use for validation.
|
56
|
+
batch_i (int): Current batch index.
|
57
|
+
training (bool): Whether the model is in training mode.
|
58
|
+
names (dict): Class names mapping.
|
59
|
+
seen (int): Number of images seen so far during validation.
|
60
|
+
stats (dict): Statistics collected during validation.
|
61
|
+
confusion_matrix: Confusion matrix for classification evaluation.
|
62
|
+
nc (int): Number of classes.
|
63
|
+
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
64
|
+
jdict (list): List to store JSON validation results.
|
65
|
+
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
66
|
+
batch processing times in milliseconds.
|
67
|
+
save_dir (Path): Directory to save results.
|
68
|
+
plots (dict): Dictionary to store plots for visualization.
|
69
|
+
callbacks (dict): Dictionary to store various callback functions.
|
70
|
+
|
71
|
+
Methods:
|
72
|
+
__call__: Execute validation process, running inference on dataloader and computing performance metrics.
|
73
|
+
match_predictions: Match predictions to ground truth objects using IoU.
|
74
|
+
add_callback: Append the given callback to the specified event.
|
75
|
+
run_callbacks: Run all callbacks associated with a specified event.
|
76
|
+
get_dataloader: Get data loader from dataset path and batch size.
|
77
|
+
build_dataset: Build dataset from image path.
|
78
|
+
preprocess: Preprocess an input batch.
|
79
|
+
postprocess: Postprocess the predictions.
|
80
|
+
init_metrics: Initialize performance metrics for the YOLO model.
|
81
|
+
update_metrics: Update metrics based on predictions and batch.
|
82
|
+
finalize_metrics: Finalize and return all metrics.
|
83
|
+
get_stats: Return statistics about the model's performance.
|
84
|
+
check_stats: Check statistics.
|
85
|
+
print_results: Print the results of the model's predictions.
|
86
|
+
get_desc: Get description of the YOLO model.
|
87
|
+
on_plot: Register plots (e.g. to be consumed in callbacks).
|
88
|
+
plot_val_samples: Plot validation samples during training.
|
89
|
+
plot_predictions: Plot YOLO model predictions on batch images.
|
90
|
+
pred_to_json: Convert predictions to JSON format.
|
91
|
+
eval_json: Evaluate and return JSON format of prediction statistics.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
95
|
+
"""
|
96
|
+
Initialize a BaseValidator instance.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
100
|
+
save_dir (Path, optional): Directory to save results.
|
101
|
+
pbar (tqdm.tqdm, optional): Progress bar for displaying progress.
|
102
|
+
args (SimpleNamespace, optional): Configuration for the validator.
|
103
|
+
_callbacks (dict, optional): Dictionary to store various callback functions.
|
104
|
+
"""
|
105
|
+
self.args = get_cfg(overrides=args)
|
106
|
+
self.dataloader = dataloader
|
107
|
+
self.pbar = pbar
|
108
|
+
self.stride = None
|
109
|
+
self.data = None
|
110
|
+
self.device = None
|
111
|
+
self.batch_i = None
|
112
|
+
self.training = True
|
113
|
+
self.names = None
|
114
|
+
self.seen = None
|
115
|
+
self.stats = None
|
116
|
+
self.confusion_matrix = None
|
117
|
+
self.nc = None
|
118
|
+
self.iouv = None
|
119
|
+
self.jdict = None
|
120
|
+
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
121
|
+
|
122
|
+
self.save_dir = save_dir or get_save_dir(self.args)
|
123
|
+
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
124
|
+
if self.args.conf is None:
|
125
|
+
self.args.conf = 0.001 # default conf=0.001
|
126
|
+
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
127
|
+
|
128
|
+
self.plots = {}
|
129
|
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
130
|
+
|
131
|
+
@smart_inference_mode()
|
132
|
+
def __call__(self, trainer=None, model=None):
|
133
|
+
"""
|
134
|
+
Execute validation process, running inference on dataloader and computing performance metrics.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
trainer (object, optional): Trainer object that contains the model to validate.
|
138
|
+
model (nn.Module, optional): Model to validate if not using a trainer.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
stats (dict): Dictionary containing validation statistics.
|
142
|
+
"""
|
143
|
+
self.training = trainer is not None
|
144
|
+
augment = self.args.augment and (not self.training)
|
145
|
+
if self.training:
|
146
|
+
self.device = trainer.device
|
147
|
+
self.data = trainer.data
|
148
|
+
# Force FP16 val during training
|
149
|
+
self.args.half = self.device.type != "cpu" and trainer.amp
|
150
|
+
model = trainer.ema.ema or trainer.model
|
151
|
+
model = model.half() if self.args.half else model.float()
|
152
|
+
# self.model = model
|
153
|
+
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
154
|
+
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
155
|
+
model.eval()
|
156
|
+
else:
|
157
|
+
if str(self.args.model).endswith(".yaml") and model is None:
|
158
|
+
LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
|
159
|
+
callbacks.add_integration_callbacks(self)
|
160
|
+
model = AutoBackend(
|
161
|
+
weights=model or self.args.model,
|
162
|
+
device=select_device(self.args.device, self.args.batch),
|
163
|
+
dnn=self.args.dnn,
|
164
|
+
data=self.args.data,
|
165
|
+
fp16=self.args.half,
|
166
|
+
)
|
167
|
+
# self.model = model
|
168
|
+
self.device = model.device # update device
|
169
|
+
self.args.half = model.fp16 # update half
|
170
|
+
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
171
|
+
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
172
|
+
if engine:
|
173
|
+
self.args.batch = model.batch_size
|
174
|
+
elif not (pt or jit or getattr(model, "dynamic", False)):
|
175
|
+
self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
|
176
|
+
LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
|
177
|
+
|
178
|
+
if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
|
179
|
+
self.data = check_det_dataset(self.args.data)
|
180
|
+
elif self.args.task == "classify":
|
181
|
+
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
182
|
+
else:
|
183
|
+
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
184
|
+
|
185
|
+
if self.device.type in {"cpu", "mps"}:
|
186
|
+
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
187
|
+
if not (pt or getattr(model, "dynamic", False)):
|
188
|
+
self.args.rect = False
|
189
|
+
self.stride = model.stride # used in get_dataloader() for padding
|
190
|
+
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
191
|
+
|
192
|
+
model.eval()
|
193
|
+
model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
|
194
|
+
|
195
|
+
self.run_callbacks("on_val_start")
|
196
|
+
dt = (
|
197
|
+
Profile(device=self.device),
|
198
|
+
Profile(device=self.device),
|
199
|
+
Profile(device=self.device),
|
200
|
+
Profile(device=self.device),
|
201
|
+
)
|
202
|
+
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
203
|
+
self.init_metrics(de_parallel(model))
|
204
|
+
self.jdict = [] # empty before each val
|
205
|
+
for batch_i, batch in enumerate(bar):
|
206
|
+
self.run_callbacks("on_val_batch_start")
|
207
|
+
self.batch_i = batch_i
|
208
|
+
# Preprocess
|
209
|
+
with dt[0]:
|
210
|
+
batch = self.preprocess(batch)
|
211
|
+
|
212
|
+
# Inference
|
213
|
+
with dt[1]:
|
214
|
+
preds = model(batch["img"], augment=augment)
|
215
|
+
|
216
|
+
# Loss
|
217
|
+
with dt[2]:
|
218
|
+
if self.training:
|
219
|
+
self.loss += model.loss(batch, preds)[1]
|
220
|
+
|
221
|
+
# Postprocess
|
222
|
+
with dt[3]:
|
223
|
+
preds = self.postprocess(preds)
|
224
|
+
|
225
|
+
self.update_metrics(preds, batch)
|
226
|
+
if self.args.plots and batch_i < 3:
|
227
|
+
self.plot_val_samples(batch, batch_i)
|
228
|
+
self.plot_predictions(batch, preds, batch_i)
|
229
|
+
|
230
|
+
self.run_callbacks("on_val_batch_end")
|
231
|
+
stats = self.get_stats()
|
232
|
+
self.check_stats(stats)
|
233
|
+
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
234
|
+
self.finalize_metrics()
|
235
|
+
self.print_results()
|
236
|
+
self.run_callbacks("on_val_end")
|
237
|
+
if self.training:
|
238
|
+
model.float()
|
239
|
+
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
240
|
+
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
241
|
+
else:
|
242
|
+
LOGGER.info(
|
243
|
+
"Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
|
244
|
+
*tuple(self.speed.values())
|
245
|
+
)
|
246
|
+
)
|
247
|
+
if self.args.save_json and self.jdict:
|
248
|
+
with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
|
249
|
+
LOGGER.info(f"Saving {f.name}...")
|
250
|
+
json.dump(self.jdict, f) # flatten and save
|
251
|
+
stats = self.eval_json(stats) # update stats
|
252
|
+
if self.args.plots or self.args.save_json:
|
253
|
+
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
254
|
+
return stats
|
255
|
+
|
256
|
+
def match_predictions(
|
257
|
+
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
|
258
|
+
) -> torch.Tensor:
|
259
|
+
"""
|
260
|
+
Match predictions to ground truth objects using IoU.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
264
|
+
true_classes (torch.Tensor): Target class indices of shape (M,).
|
265
|
+
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
|
266
|
+
use_scipy (bool): Whether to use scipy for matching (more precise).
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
(torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
|
270
|
+
"""
|
271
|
+
# Dx10 matrix, where D - detections, 10 - IoU thresholds
|
272
|
+
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
273
|
+
# LxD matrix where L - labels (rows), D - detections (columns)
|
274
|
+
correct_class = true_classes[:, None] == pred_classes
|
275
|
+
iou = iou * correct_class # zero out the wrong classes
|
276
|
+
iou = iou.cpu().numpy()
|
277
|
+
for i, threshold in enumerate(self.iouv.cpu().tolist()):
|
278
|
+
if use_scipy:
|
279
|
+
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
280
|
+
import scipy # scope import to avoid importing for all commands
|
281
|
+
|
282
|
+
cost_matrix = iou * (iou >= threshold)
|
283
|
+
if cost_matrix.any():
|
284
|
+
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
|
285
|
+
valid = cost_matrix[labels_idx, detections_idx] > 0
|
286
|
+
if valid.any():
|
287
|
+
correct[detections_idx[valid], i] = True
|
288
|
+
else:
|
289
|
+
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
|
290
|
+
matches = np.array(matches).T
|
291
|
+
if matches.shape[0]:
|
292
|
+
if matches.shape[0] > 1:
|
293
|
+
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
|
294
|
+
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
295
|
+
# matches = matches[matches[:, 2].argsort()[::-1]]
|
296
|
+
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
297
|
+
correct[matches[:, 1].astype(int), i] = True
|
298
|
+
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
299
|
+
|
300
|
+
def add_callback(self, event: str, callback):
|
301
|
+
"""Append the given callback to the specified event."""
|
302
|
+
self.callbacks[event].append(callback)
|
303
|
+
|
304
|
+
def run_callbacks(self, event: str):
|
305
|
+
"""Run all callbacks associated with a specified event."""
|
306
|
+
for callback in self.callbacks.get(event, []):
|
307
|
+
callback(self)
|
308
|
+
|
309
|
+
def get_dataloader(self, dataset_path, batch_size):
|
310
|
+
"""Get data loader from dataset path and batch size."""
|
311
|
+
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
312
|
+
|
313
|
+
def build_dataset(self, img_path):
|
314
|
+
"""Build dataset from image path."""
|
315
|
+
raise NotImplementedError("build_dataset function not implemented in validator")
|
316
|
+
|
317
|
+
def preprocess(self, batch):
|
318
|
+
"""Preprocess an input batch."""
|
319
|
+
return batch
|
320
|
+
|
321
|
+
def postprocess(self, preds):
|
322
|
+
"""Postprocess the predictions."""
|
323
|
+
return preds
|
324
|
+
|
325
|
+
def init_metrics(self, model):
|
326
|
+
"""Initialize performance metrics for the YOLO model."""
|
327
|
+
pass
|
328
|
+
|
329
|
+
def update_metrics(self, preds, batch):
|
330
|
+
"""Update metrics based on predictions and batch."""
|
331
|
+
pass
|
332
|
+
|
333
|
+
def finalize_metrics(self, *args, **kwargs):
|
334
|
+
"""Finalize and return all metrics."""
|
335
|
+
pass
|
336
|
+
|
337
|
+
def get_stats(self):
|
338
|
+
"""Return statistics about the model's performance."""
|
339
|
+
return {}
|
340
|
+
|
341
|
+
def check_stats(self, stats):
|
342
|
+
"""Check statistics."""
|
343
|
+
pass
|
344
|
+
|
345
|
+
def print_results(self):
|
346
|
+
"""Print the results of the model's predictions."""
|
347
|
+
pass
|
348
|
+
|
349
|
+
def get_desc(self):
|
350
|
+
"""Get description of the YOLO model."""
|
351
|
+
pass
|
352
|
+
|
353
|
+
@property
|
354
|
+
def metric_keys(self):
|
355
|
+
"""Return the metric keys used in YOLO training/validation."""
|
356
|
+
return []
|
357
|
+
|
358
|
+
def on_plot(self, name, data=None):
|
359
|
+
"""Register plots (e.g. to be consumed in callbacks)."""
|
360
|
+
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
361
|
+
|
362
|
+
# TODO: may need to put these following functions into callback
|
363
|
+
def plot_val_samples(self, batch, ni):
|
364
|
+
"""Plot validation samples during training."""
|
365
|
+
pass
|
366
|
+
|
367
|
+
def plot_predictions(self, batch, preds, ni):
|
368
|
+
"""Plot YOLO model predictions on batch images."""
|
369
|
+
pass
|
370
|
+
|
371
|
+
def pred_to_json(self, preds, batch):
|
372
|
+
"""Convert predictions to JSON format."""
|
373
|
+
pass
|
374
|
+
|
375
|
+
def eval_json(self, stats):
|
376
|
+
"""Evaluate and return JSON format of prediction statistics."""
|
377
|
+
pass
|