dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -14,57 +14,79 @@ Examples:
14
14
  >>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
15
15
  """
16
16
 
17
+ from __future__ import annotations
18
+
19
+ import gc
17
20
  import random
18
21
  import shutil
19
22
  import subprocess
20
23
  import time
24
+ from datetime import datetime
21
25
 
22
26
  import numpy as np
23
27
  import torch
24
28
 
25
29
  from ultralytics.cfg import get_cfg, get_save_dir
26
30
  from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr
31
+ from ultralytics.utils.checks import check_requirements
32
+ from ultralytics.utils.patches import torch_load
27
33
  from ultralytics.utils.plotting import plot_tune_results
28
34
 
29
35
 
30
36
  class Tuner:
31
- """
32
- A class for hyperparameter tuning of YOLO models.
37
+ """A class for hyperparameter tuning of YOLO models.
33
38
 
34
39
  The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
35
- search space and retraining the model to evaluate their performance.
40
+ search space and retraining the model to evaluate their performance. Supports both local CSV storage and distributed
41
+ MongoDB Atlas coordination for multi-machine hyperparameter optimization.
36
42
 
37
43
  Attributes:
38
- space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
44
+ space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
39
45
  tune_dir (Path): Directory where evolution logs and results will be saved.
40
46
  tune_csv (Path): Path to the CSV file where evolution logs are saved.
41
47
  args (dict): Configuration arguments for the tuning process.
42
48
  callbacks (list): Callback functions to be executed during tuning.
43
49
  prefix (str): Prefix string for logging messages.
50
+ mongodb (MongoClient): Optional MongoDB client for distributed tuning.
51
+ collection (Collection): MongoDB collection for storing tuning results.
44
52
 
45
53
  Methods:
46
- _mutate: Mutates the given hyperparameters within the specified bounds.
47
- __call__: Executes the hyperparameter evolution across multiple iterations.
54
+ _mutate: Mutate hyperparameters based on bounds and scaling factors.
55
+ __call__: Execute the hyperparameter evolution across multiple iterations.
48
56
 
49
57
  Examples:
50
58
  Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
51
59
  >>> from ultralytics import YOLO
52
60
  >>> model = YOLO("yolo11n.pt")
53
61
  >>> model.tune(
54
- ... data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False
55
- ... )
62
+ >>> data="coco8.yaml",
63
+ >>> epochs=10,
64
+ >>> iterations=300,
65
+ >>> plots=False,
66
+ >>> save=False,
67
+ >>> val=False
68
+ >>> )
69
+
70
+ Tune with distributed MongoDB Atlas coordination across multiple machines:
71
+ >>> model.tune(
72
+ >>> data="coco8.yaml",
73
+ >>> epochs=10,
74
+ >>> iterations=300,
75
+ >>> mongodb_uri="mongodb+srv://user:pass@cluster.mongodb.net/",
76
+ >>> mongodb_db="ultralytics",
77
+ >>> mongodb_collection="tune_results"
78
+ >>> )
56
79
 
57
- Tune with custom search space.
58
- >>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
80
+ Tune with custom search space:
81
+ >>> model.tune(space={"lr0": (1e-5, 1e-1), "momentum": (0.6, 0.98)})
59
82
  """
60
83
 
61
- def __init__(self, args=DEFAULT_CFG, _callbacks=None):
62
- """
63
- Initialize the Tuner with configurations.
84
+ def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
85
+ """Initialize the Tuner with configurations.
64
86
 
65
87
  Args:
66
88
  args (dict): Configuration for hyperparameter evolution.
67
- _callbacks (list, optional): Callback functions to be executed during tuning.
89
+ _callbacks (list | None, optional): Callback functions to be executed during tuning.
68
90
  """
69
91
  self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
70
92
  # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
@@ -75,7 +97,7 @@ class Tuner:
75
97
  "warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
76
98
  "warmup_momentum": (0.0, 0.95), # warmup initial momentum
77
99
  "box": (1.0, 20.0), # box loss gain
78
- "cls": (0.2, 4.0), # cls loss gain (scale with pixels)
100
+ "cls": (0.1, 4.0), # cls loss gain (scale with pixels)
79
101
  "dfl": (0.4, 6.0), # dfl loss gain
80
102
  "hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
81
103
  "hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
@@ -92,7 +114,12 @@ class Tuner:
92
114
  "mixup": (0.0, 1.0), # image mixup (probability)
93
115
  "cutmix": (0.0, 1.0), # image cutmix (probability)
94
116
  "copy_paste": (0.0, 1.0), # segment copy-paste (probability)
117
+ "close_mosaic": (0.0, 10.0), # close dataloader mosaic (epochs)
95
118
  }
119
+ mongodb_uri = args.pop("mongodb_uri", None)
120
+ mongodb_db = args.pop("mongodb_db", "ultralytics")
121
+ mongodb_collection = args.pop("mongodb_collection", "tuner_results")
122
+
96
123
  self.args = get_cfg(overrides=args)
97
124
  self.args.exist_ok = self.args.resume # resume w/ same tune_dir
98
125
  self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune")
@@ -101,88 +128,252 @@ class Tuner:
101
128
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
102
129
  self.prefix = colorstr("Tuner: ")
103
130
  callbacks.add_integration_callbacks(self)
131
+
132
+ # MongoDB Atlas support (optional)
133
+ self.mongodb = None
134
+ if mongodb_uri:
135
+ self._init_mongodb(mongodb_uri, mongodb_db, mongodb_collection)
136
+
104
137
  LOGGER.info(
105
138
  f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
106
139
  f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
107
140
  )
108
141
 
109
- def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
142
+ def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
143
+ """Create MongoDB client with exponential backoff retry on connection failures.
144
+
145
+ Args:
146
+ uri (str): MongoDB connection string with credentials and cluster information.
147
+ max_retries (int): Maximum number of connection attempts before giving up.
148
+
149
+ Returns:
150
+ (MongoClient): Connected MongoDB client instance.
151
+ """
152
+ check_requirements("pymongo")
153
+
154
+ from pymongo import MongoClient
155
+ from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
156
+
157
+ for attempt in range(max_retries):
158
+ try:
159
+ client = MongoClient(
160
+ uri,
161
+ serverSelectionTimeoutMS=30000,
162
+ connectTimeoutMS=20000,
163
+ socketTimeoutMS=40000,
164
+ retryWrites=True,
165
+ retryReads=True,
166
+ maxPoolSize=30,
167
+ minPoolSize=3,
168
+ maxIdleTimeMS=60000,
169
+ )
170
+ client.admin.command("ping") # Test connection
171
+ LOGGER.info(f"{self.prefix}Connected to MongoDB Atlas (attempt {attempt + 1})")
172
+ return client
173
+ except (ConnectionFailure, ServerSelectionTimeoutError):
174
+ if attempt == max_retries - 1:
175
+ raise
176
+ wait_time = 2**attempt
177
+ LOGGER.warning(
178
+ f"{self.prefix}MongoDB connection failed (attempt {attempt + 1}), retrying in {wait_time}s..."
179
+ )
180
+ time.sleep(wait_time)
181
+
182
+ def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
183
+ """Initialize MongoDB connection for distributed tuning.
184
+
185
+ Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines. Each worker
186
+ saves results to a shared collection and reads the latest best hyperparameters from all workers for evolution.
187
+
188
+ Args:
189
+ mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
190
+ mongodb_db (str, optional): Database name.
191
+ mongodb_collection (str, optional): Collection name.
192
+
193
+ Notes:
194
+ - Creates a fitness index for fast queries of top results
195
+ - Falls back to CSV-only mode if connection fails
196
+ - Uses connection pooling and retry logic for production reliability
197
+ """
198
+ self.mongodb = self._connect(mongodb_uri)
199
+ self.collection = self.mongodb[mongodb_db][mongodb_collection]
200
+ self.collection.create_index([("fitness", -1)], background=True)
201
+ LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
202
+
203
+ def _get_mongodb_results(self, n: int = 5) -> list:
204
+ """Get top N results from MongoDB sorted by fitness.
205
+
206
+ Args:
207
+ n (int): Number of top results to retrieve.
208
+
209
+ Returns:
210
+ (list[dict]): List of result documents with fitness scores and hyperparameters.
211
+ """
212
+ try:
213
+ return list(self.collection.find().sort("fitness", -1).limit(n))
214
+ except Exception:
215
+ return []
216
+
217
+ def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
218
+ """Save results to MongoDB with proper type conversion.
219
+
220
+ Args:
221
+ fitness (float): Fitness score achieved with these hyperparameters.
222
+ hyperparameters (dict[str, float]): Dictionary of hyperparameter values.
223
+ metrics (dict): Complete training metrics dictionary (mAP, precision, recall, losses, etc.).
224
+ iteration (int): Current iteration number.
225
+ """
226
+ try:
227
+ self.collection.insert_one(
228
+ {
229
+ "fitness": float(fitness),
230
+ "hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
231
+ "metrics": metrics,
232
+ "timestamp": datetime.now(),
233
+ "iteration": iteration,
234
+ }
235
+ )
236
+ except Exception as e:
237
+ LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
238
+
239
+ def _sync_mongodb_to_csv(self):
240
+ """Sync MongoDB results to CSV for plotting compatibility.
241
+
242
+ Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
243
+ the existing plotting functions to work seamlessly with distributed MongoDB data.
110
244
  """
111
- Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
245
+ try:
246
+ # Get all results from MongoDB
247
+ all_results = list(self.collection.find().sort("iteration", 1))
248
+ if not all_results:
249
+ return
250
+
251
+ # Write to CSV
252
+ headers = ",".join(["fitness", *list(self.space.keys())]) + "\n"
253
+ with open(self.tune_csv, "w", encoding="utf-8") as f:
254
+ f.write(headers)
255
+ for result in all_results:
256
+ fitness = result["fitness"]
257
+ hyp_values = [result["hyperparameters"][k] for k in self.space.keys()]
258
+ log_row = [round(fitness, 5), *hyp_values]
259
+ f.write(",".join(map(str, log_row)) + "\n")
260
+
261
+ except Exception as e:
262
+ LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
263
+
264
+ def _crossover(self, x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
265
+ """BLX-α crossover from up to top-k parents (x[:,0]=fitness, rest=genes)."""
266
+ k = min(k, len(x))
267
+ # fitness weights (shifted to >0); fallback to uniform if degenerate
268
+ weights = x[:, 0] - x[:, 0].min() + 1e-6
269
+ if not np.isfinite(weights).all() or weights.sum() == 0:
270
+ weights = np.ones_like(weights)
271
+ idxs = random.choices(range(len(x)), weights=weights, k=k)
272
+ parents_mat = np.stack([x[i][1:] for i in idxs], 0) # (k, ng) strip fitness
273
+ lo, hi = parents_mat.min(0), parents_mat.max(0)
274
+ span = hi - lo
275
+ return np.random.uniform(lo - alpha * span, hi + alpha * span)
276
+
277
+ def _mutate(
278
+ self,
279
+ n: int = 9,
280
+ mutation: float = 0.5,
281
+ sigma: float = 0.2,
282
+ ) -> dict[str, float]:
283
+ """Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
112
284
 
113
285
  Args:
114
- parent (str): Parent selection method: 'single' or 'weighted'.
115
- n (int): Number of parents to consider.
286
+ parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
287
+ n (int): Number of top parents to consider.
116
288
  mutation (float): Probability of a parameter mutation in any given iteration.
117
289
  sigma (float): Standard deviation for Gaussian random number generator.
118
290
 
119
291
  Returns:
120
- (dict): A dictionary containing mutated hyperparameters.
292
+ (dict[str, float]): A dictionary containing mutated hyperparameters.
121
293
  """
122
- if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
123
- # Select parent(s)
124
- x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
125
- fitness = x[:, 0] # first column
126
- n = min(n, len(x)) # number of previous results to consider
127
- x = x[np.argsort(-fitness)][:n] # top n mutations
128
- w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
129
- if parent == "single" or len(x) == 1:
130
- # x = x[random.randint(0, n - 1)] # random selection
131
- x = x[random.choices(range(n), weights=w)[0]] # weighted selection
132
- elif parent == "weighted":
133
- x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
134
-
135
- # Mutate
136
- r = np.random # method
137
- r.seed(int(time.time()))
138
- g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
294
+ x = None
295
+
296
+ # Try MongoDB first if available
297
+ if self.mongodb:
298
+ results = self._get_mongodb_results(n)
299
+ if results:
300
+ # MongoDB already sorted by fitness DESC, so results[0] is best
301
+ x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
302
+ elif self.collection.name in self.collection.database.list_collection_names(): # Tuner started elsewhere
303
+ x = np.array([[0.0] + [getattr(self.args, k) for k in self.space.keys()]])
304
+
305
+ # Fall back to CSV if MongoDB unavailable or empty
306
+ if x is None and self.tune_csv.exists():
307
+ csv_data = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
308
+ if len(csv_data) > 0:
309
+ fitness = csv_data[:, 0] # first column
310
+ order = np.argsort(-fitness)
311
+ x = csv_data[order][:n] # top-n sorted by fitness DESC
312
+
313
+ # Mutate if we have data, otherwise use defaults
314
+ if x is not None:
315
+ np.random.seed(int(time.time()))
139
316
  ng = len(self.space)
140
- v = np.ones(ng)
141
- while all(v == 1): # mutate until a change occurs (prevent duplicates)
142
- v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
143
- hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
317
+
318
+ # Crossover
319
+ genes = self._crossover(x)
320
+
321
+ # Mutation
322
+ gains = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
323
+ factors = np.ones(ng)
324
+ while np.all(factors == 1): # mutate until a change occurs (prevent duplicates)
325
+ mask = np.random.random(ng) < mutation
326
+ step = np.random.randn(ng) * (sigma * gains)
327
+ factors = np.where(mask, np.exp(step), 1.0).clip(0.25, 4.0)
328
+ hyp = {k: float(genes[i] * factors[i]) for i, k in enumerate(self.space.keys())}
144
329
  else:
145
330
  hyp = {k: getattr(self.args, k) for k in self.space.keys()}
146
331
 
147
332
  # Constrain to limits
148
- for k, v in self.space.items():
149
- hyp[k] = max(hyp[k], v[0]) # lower limit
150
- hyp[k] = min(hyp[k], v[1]) # upper limit
151
- hyp[k] = round(hyp[k], 5) # significant digits
333
+ for k, bounds in self.space.items():
334
+ hyp[k] = round(min(max(hyp[k], bounds[0]), bounds[1]), 5)
152
335
 
153
- return hyp
336
+ # Update types
337
+ if "close_mosaic" in hyp:
338
+ hyp["close_mosaic"] = round(hyp["close_mosaic"])
154
339
 
155
- def __call__(self, model=None, iterations=10, cleanup=True):
156
- """
157
- Execute the hyperparameter evolution process when the Tuner instance is called.
340
+ return hyp
158
341
 
159
- This method iterates through the number of iterations, performing the following steps in each iteration:
342
+ def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
343
+ """Execute the hyperparameter evolution process when the Tuner instance is called.
160
344
 
161
- 1. Load the existing hyperparameters or initialize new ones.
162
- 2. Mutate the hyperparameters using the `mutate` method.
163
- 3. Train a YOLO model with the mutated hyperparameters.
164
- 4. Log the fitness score and mutated hyperparameters to a CSV file.
345
+ This method iterates through the specified number of iterations, performing the following steps:
346
+ 1. Sync MongoDB results to CSV (if using distributed mode)
347
+ 2. Mutate hyperparameters using the best previous results or defaults
348
+ 3. Train a YOLO model with the mutated hyperparameters
349
+ 4. Log fitness scores and hyperparameters to MongoDB and/or CSV
350
+ 5. Track the best performing configuration across all iterations
165
351
 
166
352
  Args:
167
- model (Model): A pre-initialized YOLO model to be used for training.
353
+ model (Model | None, optional): A pre-initialized YOLO model to be used for training.
168
354
  iterations (int): The number of generations to run the evolution for.
169
- cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
170
-
171
- Note:
172
- The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
173
- Ensure this path is set correctly in the Tuner instance.
355
+ cleanup (bool): Whether to delete iteration weights to reduce storage space during tuning.
174
356
  """
175
357
  t0 = time.time()
176
358
  best_save_dir, best_metrics = None, None
177
359
  (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
360
+
361
+ # Sync MongoDB to CSV at startup for proper resume logic
362
+ if self.mongodb:
363
+ self._sync_mongodb_to_csv()
364
+
178
365
  start = 0
179
366
  if self.tune_csv.exists():
180
367
  x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
181
368
  start = x.shape[0]
182
369
  LOGGER.info(f"{self.prefix}Resuming tuning run {self.tune_dir} from iteration {start + 1}...")
183
370
  for i in range(start, iterations):
371
+ # Linearly decay sigma from 0.2 → 0.1 over first 300 iterations
372
+ frac = min(i / 300.0, 1.0)
373
+ sigma_i = 0.2 - 0.1 * frac
374
+
184
375
  # Mutate hyperparameters
185
- mutated_hyp = self._mutate()
376
+ mutated_hyp = self._mutate(sigma=sigma_i)
186
377
  LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
187
378
 
188
379
  metrics = {}
@@ -195,18 +386,34 @@ class Tuner:
195
386
  cmd = [*launch, "train", *(f"{k}={v}" for k, v in train_args.items())]
196
387
  return_code = subprocess.run(cmd, check=True).returncode
197
388
  ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
198
- metrics = torch.load(ckpt_file)["train_metrics"]
389
+ metrics = torch_load(ckpt_file)["train_metrics"]
199
390
  assert return_code == 0, "training failed"
200
391
 
392
+ # Cleanup
393
+ time.sleep(1)
394
+ gc.collect()
395
+ torch.cuda.empty_cache()
396
+
201
397
  except Exception as e:
202
398
  LOGGER.error(f"training failure for hyperparameter tuning iteration {i + 1}\n{e}")
203
399
 
204
- # Save results and mutated_hyp to CSV
400
+ # Save results - MongoDB takes precedence
205
401
  fitness = metrics.get("fitness", 0.0)
206
- log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
207
- headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
208
- with open(self.tune_csv, "a", encoding="utf-8") as f:
209
- f.write(headers + ",".join(map(str, log_row)) + "\n")
402
+ if self.mongodb:
403
+ self._save_to_mongodb(fitness, mutated_hyp, metrics, i + 1)
404
+ self._sync_mongodb_to_csv()
405
+ total_mongo_iterations = self.collection.count_documents({})
406
+ if total_mongo_iterations >= iterations:
407
+ LOGGER.info(
408
+ f"{self.prefix}Target iterations ({iterations}) reached in MongoDB ({total_mongo_iterations}). Stopping."
409
+ )
410
+ break
411
+ else:
412
+ # Save to CSV only if no MongoDB
413
+ log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
414
+ headers = "" if self.tune_csv.exists() else (",".join(["fitness", *list(self.space.keys())]) + "\n")
415
+ with open(self.tune_csv, "a", encoding="utf-8") as f:
416
+ f.write(headers + ",".join(map(str, log_row)) + "\n")
210
417
 
211
418
  # Get best results
212
419
  x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
@@ -214,15 +421,15 @@ class Tuner:
214
421
  best_idx = fitness.argmax()
215
422
  best_is_current = best_idx == i
216
423
  if best_is_current:
217
- best_save_dir = save_dir
424
+ best_save_dir = str(save_dir)
218
425
  best_metrics = {k: round(v, 5) for k, v in metrics.items()}
219
426
  for ckpt in weights_dir.glob("*.pt"):
220
427
  shutil.copy2(ckpt, self.tune_dir / "weights")
221
- elif cleanup:
222
- shutil.rmtree(weights_dir, ignore_errors=True) # remove iteration weights/ dir to reduce storage space
428
+ elif cleanup and best_save_dir:
429
+ shutil.rmtree(best_save_dir, ignore_errors=True) # remove iteration dirs to reduce storage space
223
430
 
224
431
  # Plot tune results
225
- plot_tune_results(self.tune_csv)
432
+ plot_tune_results(str(self.tune_csv))
226
433
 
227
434
  # Save and print tune results
228
435
  header = (
@@ -230,8 +437,7 @@ class Tuner:
230
437
  f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n"
231
438
  f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n"
232
439
  f"{self.prefix}Best fitness metrics are {best_metrics}\n"
233
- f"{self.prefix}Best fitness model is {best_save_dir}\n"
234
- f"{self.prefix}Best fitness hyperparameters are printed below.\n"
440
+ f"{self.prefix}Best fitness model is {best_save_dir}"
235
441
  )
236
442
  LOGGER.info("\n" + header)
237
443
  data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}