dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/engine/tuner.py
CHANGED
|
@@ -14,57 +14,79 @@ Examples:
|
|
|
14
14
|
>>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import gc
|
|
17
20
|
import random
|
|
18
21
|
import shutil
|
|
19
22
|
import subprocess
|
|
20
23
|
import time
|
|
24
|
+
from datetime import datetime
|
|
21
25
|
|
|
22
26
|
import numpy as np
|
|
23
27
|
import torch
|
|
24
28
|
|
|
25
29
|
from ultralytics.cfg import get_cfg, get_save_dir
|
|
26
30
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr
|
|
31
|
+
from ultralytics.utils.checks import check_requirements
|
|
32
|
+
from ultralytics.utils.patches import torch_load
|
|
27
33
|
from ultralytics.utils.plotting import plot_tune_results
|
|
28
34
|
|
|
29
35
|
|
|
30
36
|
class Tuner:
|
|
31
|
-
"""
|
|
32
|
-
A class for hyperparameter tuning of YOLO models.
|
|
37
|
+
"""A class for hyperparameter tuning of YOLO models.
|
|
33
38
|
|
|
34
39
|
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
|
|
35
|
-
search space and retraining the model to evaluate their performance.
|
|
40
|
+
search space and retraining the model to evaluate their performance. Supports both local CSV storage and distributed
|
|
41
|
+
MongoDB Atlas coordination for multi-machine hyperparameter optimization.
|
|
36
42
|
|
|
37
43
|
Attributes:
|
|
38
|
-
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
|
|
44
|
+
space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
|
|
39
45
|
tune_dir (Path): Directory where evolution logs and results will be saved.
|
|
40
46
|
tune_csv (Path): Path to the CSV file where evolution logs are saved.
|
|
41
47
|
args (dict): Configuration arguments for the tuning process.
|
|
42
48
|
callbacks (list): Callback functions to be executed during tuning.
|
|
43
49
|
prefix (str): Prefix string for logging messages.
|
|
50
|
+
mongodb (MongoClient): Optional MongoDB client for distributed tuning.
|
|
51
|
+
collection (Collection): MongoDB collection for storing tuning results.
|
|
44
52
|
|
|
45
53
|
Methods:
|
|
46
|
-
_mutate:
|
|
47
|
-
__call__:
|
|
54
|
+
_mutate: Mutate hyperparameters based on bounds and scaling factors.
|
|
55
|
+
__call__: Execute the hyperparameter evolution across multiple iterations.
|
|
48
56
|
|
|
49
57
|
Examples:
|
|
50
58
|
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
|
51
59
|
>>> from ultralytics import YOLO
|
|
52
60
|
>>> model = YOLO("yolo11n.pt")
|
|
53
61
|
>>> model.tune(
|
|
54
|
-
|
|
55
|
-
|
|
62
|
+
>>> data="coco8.yaml",
|
|
63
|
+
>>> epochs=10,
|
|
64
|
+
>>> iterations=300,
|
|
65
|
+
>>> plots=False,
|
|
66
|
+
>>> save=False,
|
|
67
|
+
>>> val=False
|
|
68
|
+
>>> )
|
|
69
|
+
|
|
70
|
+
Tune with distributed MongoDB Atlas coordination across multiple machines:
|
|
71
|
+
>>> model.tune(
|
|
72
|
+
>>> data="coco8.yaml",
|
|
73
|
+
>>> epochs=10,
|
|
74
|
+
>>> iterations=300,
|
|
75
|
+
>>> mongodb_uri="mongodb+srv://user:pass@cluster.mongodb.net/",
|
|
76
|
+
>>> mongodb_db="ultralytics",
|
|
77
|
+
>>> mongodb_collection="tune_results"
|
|
78
|
+
>>> )
|
|
56
79
|
|
|
57
|
-
Tune with custom search space
|
|
58
|
-
>>> model.tune(space={
|
|
80
|
+
Tune with custom search space:
|
|
81
|
+
>>> model.tune(space={"lr0": (1e-5, 1e-1), "momentum": (0.6, 0.98)})
|
|
59
82
|
"""
|
|
60
83
|
|
|
61
|
-
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
|
|
62
|
-
"""
|
|
63
|
-
Initialize the Tuner with configurations.
|
|
84
|
+
def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
|
|
85
|
+
"""Initialize the Tuner with configurations.
|
|
64
86
|
|
|
65
87
|
Args:
|
|
66
88
|
args (dict): Configuration for hyperparameter evolution.
|
|
67
|
-
_callbacks (list, optional): Callback functions to be executed during tuning.
|
|
89
|
+
_callbacks (list | None, optional): Callback functions to be executed during tuning.
|
|
68
90
|
"""
|
|
69
91
|
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
|
70
92
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
|
@@ -75,7 +97,7 @@ class Tuner:
|
|
|
75
97
|
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
|
|
76
98
|
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
|
|
77
99
|
"box": (1.0, 20.0), # box loss gain
|
|
78
|
-
"cls": (0.
|
|
100
|
+
"cls": (0.1, 4.0), # cls loss gain (scale with pixels)
|
|
79
101
|
"dfl": (0.4, 6.0), # dfl loss gain
|
|
80
102
|
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
|
81
103
|
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
|
@@ -92,7 +114,12 @@ class Tuner:
|
|
|
92
114
|
"mixup": (0.0, 1.0), # image mixup (probability)
|
|
93
115
|
"cutmix": (0.0, 1.0), # image cutmix (probability)
|
|
94
116
|
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
|
|
117
|
+
"close_mosaic": (0.0, 10.0), # close dataloader mosaic (epochs)
|
|
95
118
|
}
|
|
119
|
+
mongodb_uri = args.pop("mongodb_uri", None)
|
|
120
|
+
mongodb_db = args.pop("mongodb_db", "ultralytics")
|
|
121
|
+
mongodb_collection = args.pop("mongodb_collection", "tuner_results")
|
|
122
|
+
|
|
96
123
|
self.args = get_cfg(overrides=args)
|
|
97
124
|
self.args.exist_ok = self.args.resume # resume w/ same tune_dir
|
|
98
125
|
self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune")
|
|
@@ -101,88 +128,252 @@ class Tuner:
|
|
|
101
128
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
|
102
129
|
self.prefix = colorstr("Tuner: ")
|
|
103
130
|
callbacks.add_integration_callbacks(self)
|
|
131
|
+
|
|
132
|
+
# MongoDB Atlas support (optional)
|
|
133
|
+
self.mongodb = None
|
|
134
|
+
if mongodb_uri:
|
|
135
|
+
self._init_mongodb(mongodb_uri, mongodb_db, mongodb_collection)
|
|
136
|
+
|
|
104
137
|
LOGGER.info(
|
|
105
138
|
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
|
|
106
139
|
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
|
|
107
140
|
)
|
|
108
141
|
|
|
109
|
-
def
|
|
142
|
+
def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
|
|
143
|
+
"""Create MongoDB client with exponential backoff retry on connection failures.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
uri (str): MongoDB connection string with credentials and cluster information.
|
|
147
|
+
max_retries (int): Maximum number of connection attempts before giving up.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
(MongoClient): Connected MongoDB client instance.
|
|
151
|
+
"""
|
|
152
|
+
check_requirements("pymongo")
|
|
153
|
+
|
|
154
|
+
from pymongo import MongoClient
|
|
155
|
+
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
|
156
|
+
|
|
157
|
+
for attempt in range(max_retries):
|
|
158
|
+
try:
|
|
159
|
+
client = MongoClient(
|
|
160
|
+
uri,
|
|
161
|
+
serverSelectionTimeoutMS=30000,
|
|
162
|
+
connectTimeoutMS=20000,
|
|
163
|
+
socketTimeoutMS=40000,
|
|
164
|
+
retryWrites=True,
|
|
165
|
+
retryReads=True,
|
|
166
|
+
maxPoolSize=30,
|
|
167
|
+
minPoolSize=3,
|
|
168
|
+
maxIdleTimeMS=60000,
|
|
169
|
+
)
|
|
170
|
+
client.admin.command("ping") # Test connection
|
|
171
|
+
LOGGER.info(f"{self.prefix}Connected to MongoDB Atlas (attempt {attempt + 1})")
|
|
172
|
+
return client
|
|
173
|
+
except (ConnectionFailure, ServerSelectionTimeoutError):
|
|
174
|
+
if attempt == max_retries - 1:
|
|
175
|
+
raise
|
|
176
|
+
wait_time = 2**attempt
|
|
177
|
+
LOGGER.warning(
|
|
178
|
+
f"{self.prefix}MongoDB connection failed (attempt {attempt + 1}), retrying in {wait_time}s..."
|
|
179
|
+
)
|
|
180
|
+
time.sleep(wait_time)
|
|
181
|
+
|
|
182
|
+
def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
|
|
183
|
+
"""Initialize MongoDB connection for distributed tuning.
|
|
184
|
+
|
|
185
|
+
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines. Each worker
|
|
186
|
+
saves results to a shared collection and reads the latest best hyperparameters from all workers for evolution.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
|
|
190
|
+
mongodb_db (str, optional): Database name.
|
|
191
|
+
mongodb_collection (str, optional): Collection name.
|
|
192
|
+
|
|
193
|
+
Notes:
|
|
194
|
+
- Creates a fitness index for fast queries of top results
|
|
195
|
+
- Falls back to CSV-only mode if connection fails
|
|
196
|
+
- Uses connection pooling and retry logic for production reliability
|
|
197
|
+
"""
|
|
198
|
+
self.mongodb = self._connect(mongodb_uri)
|
|
199
|
+
self.collection = self.mongodb[mongodb_db][mongodb_collection]
|
|
200
|
+
self.collection.create_index([("fitness", -1)], background=True)
|
|
201
|
+
LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
|
|
202
|
+
|
|
203
|
+
def _get_mongodb_results(self, n: int = 5) -> list:
|
|
204
|
+
"""Get top N results from MongoDB sorted by fitness.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
n (int): Number of top results to retrieve.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
(list[dict]): List of result documents with fitness scores and hyperparameters.
|
|
211
|
+
"""
|
|
212
|
+
try:
|
|
213
|
+
return list(self.collection.find().sort("fitness", -1).limit(n))
|
|
214
|
+
except Exception:
|
|
215
|
+
return []
|
|
216
|
+
|
|
217
|
+
def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
|
|
218
|
+
"""Save results to MongoDB with proper type conversion.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
fitness (float): Fitness score achieved with these hyperparameters.
|
|
222
|
+
hyperparameters (dict[str, float]): Dictionary of hyperparameter values.
|
|
223
|
+
metrics (dict): Complete training metrics dictionary (mAP, precision, recall, losses, etc.).
|
|
224
|
+
iteration (int): Current iteration number.
|
|
225
|
+
"""
|
|
226
|
+
try:
|
|
227
|
+
self.collection.insert_one(
|
|
228
|
+
{
|
|
229
|
+
"fitness": float(fitness),
|
|
230
|
+
"hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
|
|
231
|
+
"metrics": metrics,
|
|
232
|
+
"timestamp": datetime.now(),
|
|
233
|
+
"iteration": iteration,
|
|
234
|
+
}
|
|
235
|
+
)
|
|
236
|
+
except Exception as e:
|
|
237
|
+
LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
|
|
238
|
+
|
|
239
|
+
def _sync_mongodb_to_csv(self):
|
|
240
|
+
"""Sync MongoDB results to CSV for plotting compatibility.
|
|
241
|
+
|
|
242
|
+
Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
|
|
243
|
+
the existing plotting functions to work seamlessly with distributed MongoDB data.
|
|
110
244
|
"""
|
|
111
|
-
|
|
245
|
+
try:
|
|
246
|
+
# Get all results from MongoDB
|
|
247
|
+
all_results = list(self.collection.find().sort("iteration", 1))
|
|
248
|
+
if not all_results:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
# Write to CSV
|
|
252
|
+
headers = ",".join(["fitness", *list(self.space.keys())]) + "\n"
|
|
253
|
+
with open(self.tune_csv, "w", encoding="utf-8") as f:
|
|
254
|
+
f.write(headers)
|
|
255
|
+
for result in all_results:
|
|
256
|
+
fitness = result["fitness"]
|
|
257
|
+
hyp_values = [result["hyperparameters"][k] for k in self.space.keys()]
|
|
258
|
+
log_row = [round(fitness, 5), *hyp_values]
|
|
259
|
+
f.write(",".join(map(str, log_row)) + "\n")
|
|
260
|
+
|
|
261
|
+
except Exception as e:
|
|
262
|
+
LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
|
|
263
|
+
|
|
264
|
+
def _crossover(self, x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
|
|
265
|
+
"""BLX-α crossover from up to top-k parents (x[:,0]=fitness, rest=genes)."""
|
|
266
|
+
k = min(k, len(x))
|
|
267
|
+
# fitness weights (shifted to >0); fallback to uniform if degenerate
|
|
268
|
+
weights = x[:, 0] - x[:, 0].min() + 1e-6
|
|
269
|
+
if not np.isfinite(weights).all() or weights.sum() == 0:
|
|
270
|
+
weights = np.ones_like(weights)
|
|
271
|
+
idxs = random.choices(range(len(x)), weights=weights, k=k)
|
|
272
|
+
parents_mat = np.stack([x[i][1:] for i in idxs], 0) # (k, ng) strip fitness
|
|
273
|
+
lo, hi = parents_mat.min(0), parents_mat.max(0)
|
|
274
|
+
span = hi - lo
|
|
275
|
+
return np.random.uniform(lo - alpha * span, hi + alpha * span)
|
|
276
|
+
|
|
277
|
+
def _mutate(
|
|
278
|
+
self,
|
|
279
|
+
n: int = 9,
|
|
280
|
+
mutation: float = 0.5,
|
|
281
|
+
sigma: float = 0.2,
|
|
282
|
+
) -> dict[str, float]:
|
|
283
|
+
"""Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
|
112
284
|
|
|
113
285
|
Args:
|
|
114
|
-
parent (str): Parent selection method
|
|
115
|
-
n (int): Number of parents to consider.
|
|
286
|
+
parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
|
|
287
|
+
n (int): Number of top parents to consider.
|
|
116
288
|
mutation (float): Probability of a parameter mutation in any given iteration.
|
|
117
289
|
sigma (float): Standard deviation for Gaussian random number generator.
|
|
118
290
|
|
|
119
291
|
Returns:
|
|
120
|
-
(dict): A dictionary containing mutated hyperparameters.
|
|
292
|
+
(dict[str, float]): A dictionary containing mutated hyperparameters.
|
|
121
293
|
"""
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
x =
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
294
|
+
x = None
|
|
295
|
+
|
|
296
|
+
# Try MongoDB first if available
|
|
297
|
+
if self.mongodb:
|
|
298
|
+
results = self._get_mongodb_results(n)
|
|
299
|
+
if results:
|
|
300
|
+
# MongoDB already sorted by fitness DESC, so results[0] is best
|
|
301
|
+
x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
|
|
302
|
+
elif self.collection.name in self.collection.database.list_collection_names(): # Tuner started elsewhere
|
|
303
|
+
x = np.array([[0.0] + [getattr(self.args, k) for k in self.space.keys()]])
|
|
304
|
+
|
|
305
|
+
# Fall back to CSV if MongoDB unavailable or empty
|
|
306
|
+
if x is None and self.tune_csv.exists():
|
|
307
|
+
csv_data = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
|
308
|
+
if len(csv_data) > 0:
|
|
309
|
+
fitness = csv_data[:, 0] # first column
|
|
310
|
+
order = np.argsort(-fitness)
|
|
311
|
+
x = csv_data[order][:n] # top-n sorted by fitness DESC
|
|
312
|
+
|
|
313
|
+
# Mutate if we have data, otherwise use defaults
|
|
314
|
+
if x is not None:
|
|
315
|
+
np.random.seed(int(time.time()))
|
|
139
316
|
ng = len(self.space)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
317
|
+
|
|
318
|
+
# Crossover
|
|
319
|
+
genes = self._crossover(x)
|
|
320
|
+
|
|
321
|
+
# Mutation
|
|
322
|
+
gains = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
|
|
323
|
+
factors = np.ones(ng)
|
|
324
|
+
while np.all(factors == 1): # mutate until a change occurs (prevent duplicates)
|
|
325
|
+
mask = np.random.random(ng) < mutation
|
|
326
|
+
step = np.random.randn(ng) * (sigma * gains)
|
|
327
|
+
factors = np.where(mask, np.exp(step), 1.0).clip(0.25, 4.0)
|
|
328
|
+
hyp = {k: float(genes[i] * factors[i]) for i, k in enumerate(self.space.keys())}
|
|
144
329
|
else:
|
|
145
330
|
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
|
|
146
331
|
|
|
147
332
|
# Constrain to limits
|
|
148
|
-
for k,
|
|
149
|
-
hyp[k] = max(hyp[k],
|
|
150
|
-
hyp[k] = min(hyp[k], v[1]) # upper limit
|
|
151
|
-
hyp[k] = round(hyp[k], 5) # significant digits
|
|
333
|
+
for k, bounds in self.space.items():
|
|
334
|
+
hyp[k] = round(min(max(hyp[k], bounds[0]), bounds[1]), 5)
|
|
152
335
|
|
|
153
|
-
|
|
336
|
+
# Update types
|
|
337
|
+
if "close_mosaic" in hyp:
|
|
338
|
+
hyp["close_mosaic"] = round(hyp["close_mosaic"])
|
|
154
339
|
|
|
155
|
-
|
|
156
|
-
"""
|
|
157
|
-
Execute the hyperparameter evolution process when the Tuner instance is called.
|
|
340
|
+
return hyp
|
|
158
341
|
|
|
159
|
-
|
|
342
|
+
def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
|
|
343
|
+
"""Execute the hyperparameter evolution process when the Tuner instance is called.
|
|
160
344
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
345
|
+
This method iterates through the specified number of iterations, performing the following steps:
|
|
346
|
+
1. Sync MongoDB results to CSV (if using distributed mode)
|
|
347
|
+
2. Mutate hyperparameters using the best previous results or defaults
|
|
348
|
+
3. Train a YOLO model with the mutated hyperparameters
|
|
349
|
+
4. Log fitness scores and hyperparameters to MongoDB and/or CSV
|
|
350
|
+
5. Track the best performing configuration across all iterations
|
|
165
351
|
|
|
166
352
|
Args:
|
|
167
|
-
model (Model): A pre-initialized YOLO model to be used for training.
|
|
353
|
+
model (Model | None, optional): A pre-initialized YOLO model to be used for training.
|
|
168
354
|
iterations (int): The number of generations to run the evolution for.
|
|
169
|
-
cleanup (bool): Whether to delete iteration weights to reduce storage space
|
|
170
|
-
|
|
171
|
-
Note:
|
|
172
|
-
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
|
|
173
|
-
Ensure this path is set correctly in the Tuner instance.
|
|
355
|
+
cleanup (bool): Whether to delete iteration weights to reduce storage space during tuning.
|
|
174
356
|
"""
|
|
175
357
|
t0 = time.time()
|
|
176
358
|
best_save_dir, best_metrics = None, None
|
|
177
359
|
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
|
360
|
+
|
|
361
|
+
# Sync MongoDB to CSV at startup for proper resume logic
|
|
362
|
+
if self.mongodb:
|
|
363
|
+
self._sync_mongodb_to_csv()
|
|
364
|
+
|
|
178
365
|
start = 0
|
|
179
366
|
if self.tune_csv.exists():
|
|
180
367
|
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
|
181
368
|
start = x.shape[0]
|
|
182
369
|
LOGGER.info(f"{self.prefix}Resuming tuning run {self.tune_dir} from iteration {start + 1}...")
|
|
183
370
|
for i in range(start, iterations):
|
|
371
|
+
# Linearly decay sigma from 0.2 → 0.1 over first 300 iterations
|
|
372
|
+
frac = min(i / 300.0, 1.0)
|
|
373
|
+
sigma_i = 0.2 - 0.1 * frac
|
|
374
|
+
|
|
184
375
|
# Mutate hyperparameters
|
|
185
|
-
mutated_hyp = self._mutate()
|
|
376
|
+
mutated_hyp = self._mutate(sigma=sigma_i)
|
|
186
377
|
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
|
|
187
378
|
|
|
188
379
|
metrics = {}
|
|
@@ -195,18 +386,34 @@ class Tuner:
|
|
|
195
386
|
cmd = [*launch, "train", *(f"{k}={v}" for k, v in train_args.items())]
|
|
196
387
|
return_code = subprocess.run(cmd, check=True).returncode
|
|
197
388
|
ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
|
|
198
|
-
metrics =
|
|
389
|
+
metrics = torch_load(ckpt_file)["train_metrics"]
|
|
199
390
|
assert return_code == 0, "training failed"
|
|
200
391
|
|
|
392
|
+
# Cleanup
|
|
393
|
+
time.sleep(1)
|
|
394
|
+
gc.collect()
|
|
395
|
+
torch.cuda.empty_cache()
|
|
396
|
+
|
|
201
397
|
except Exception as e:
|
|
202
398
|
LOGGER.error(f"training failure for hyperparameter tuning iteration {i + 1}\n{e}")
|
|
203
399
|
|
|
204
|
-
# Save results
|
|
400
|
+
# Save results - MongoDB takes precedence
|
|
205
401
|
fitness = metrics.get("fitness", 0.0)
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
402
|
+
if self.mongodb:
|
|
403
|
+
self._save_to_mongodb(fitness, mutated_hyp, metrics, i + 1)
|
|
404
|
+
self._sync_mongodb_to_csv()
|
|
405
|
+
total_mongo_iterations = self.collection.count_documents({})
|
|
406
|
+
if total_mongo_iterations >= iterations:
|
|
407
|
+
LOGGER.info(
|
|
408
|
+
f"{self.prefix}Target iterations ({iterations}) reached in MongoDB ({total_mongo_iterations}). Stopping."
|
|
409
|
+
)
|
|
410
|
+
break
|
|
411
|
+
else:
|
|
412
|
+
# Save to CSV only if no MongoDB
|
|
413
|
+
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
|
414
|
+
headers = "" if self.tune_csv.exists() else (",".join(["fitness", *list(self.space.keys())]) + "\n")
|
|
415
|
+
with open(self.tune_csv, "a", encoding="utf-8") as f:
|
|
416
|
+
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
|
210
417
|
|
|
211
418
|
# Get best results
|
|
212
419
|
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
|
@@ -214,15 +421,15 @@ class Tuner:
|
|
|
214
421
|
best_idx = fitness.argmax()
|
|
215
422
|
best_is_current = best_idx == i
|
|
216
423
|
if best_is_current:
|
|
217
|
-
best_save_dir = save_dir
|
|
424
|
+
best_save_dir = str(save_dir)
|
|
218
425
|
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
|
|
219
426
|
for ckpt in weights_dir.glob("*.pt"):
|
|
220
427
|
shutil.copy2(ckpt, self.tune_dir / "weights")
|
|
221
|
-
elif cleanup:
|
|
222
|
-
shutil.rmtree(
|
|
428
|
+
elif cleanup and best_save_dir:
|
|
429
|
+
shutil.rmtree(best_save_dir, ignore_errors=True) # remove iteration dirs to reduce storage space
|
|
223
430
|
|
|
224
431
|
# Plot tune results
|
|
225
|
-
plot_tune_results(self.tune_csv)
|
|
432
|
+
plot_tune_results(str(self.tune_csv))
|
|
226
433
|
|
|
227
434
|
# Save and print tune results
|
|
228
435
|
header = (
|
|
@@ -230,8 +437,7 @@ class Tuner:
|
|
|
230
437
|
f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n"
|
|
231
438
|
f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n"
|
|
232
439
|
f"{self.prefix}Best fitness metrics are {best_metrics}\n"
|
|
233
|
-
f"{self.prefix}Best fitness model is {best_save_dir}
|
|
234
|
-
f"{self.prefix}Best fitness hyperparameters are printed below.\n"
|
|
440
|
+
f"{self.prefix}Best fitness model is {best_save_dir}"
|
|
235
441
|
)
|
|
236
442
|
LOGGER.info("\n" + header)
|
|
237
443
|
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|