dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,387 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """
3
+ Check a model's accuracy on a test or val split of a dataset.
4
+
5
+ Usage:
6
+ $ yolo mode=val model=yolo11n.pt data=coco8.yaml imgsz=640
7
+
8
+ Usage - formats:
9
+ $ yolo mode=val model=yolo11n.pt # PyTorch
10
+ yolo11n.torchscript # TorchScript
11
+ yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
12
+ yolo11n_openvino_model # OpenVINO
13
+ yolo11n.engine # TensorRT
14
+ yolo11n.mlpackage # CoreML (macOS-only)
15
+ yolo11n_saved_model # TensorFlow SavedModel
16
+ yolo11n.pb # TensorFlow GraphDef
17
+ yolo11n.tflite # TensorFlow Lite
18
+ yolo11n_edgetpu.tflite # TensorFlow Edge TPU
19
+ yolo11n_paddle_model # PaddlePaddle
20
+ yolo11n.mnn # MNN
21
+ yolo11n_ncnn_model # NCNN
22
+ yolo11n_imx_model # Sony IMX
23
+ yolo11n_rknn_model # Rockchip RKNN
24
+ """
25
+
26
+ import json
27
+ import time
28
+ from pathlib import Path
29
+
30
+ import numpy as np
31
+ import torch
32
+ import torch.distributed as dist
33
+
34
+ from ultralytics.cfg import get_cfg, get_save_dir
35
+ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
36
+ from ultralytics.nn.autobackend import AutoBackend
37
+ from ultralytics.utils import LOGGER, RANK, TQDM, callbacks, colorstr, emojis
38
+ from ultralytics.utils.checks import check_imgsz
39
+ from ultralytics.utils.ops import Profile
40
+ from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
41
+
42
+
43
+ class BaseValidator:
44
+ """A base class for creating validators.
45
+
46
+ This class provides the foundation for validation processes, including model evaluation, metric computation, and
47
+ result visualization.
48
+
49
+ Attributes:
50
+ args (SimpleNamespace): Configuration for the validator.
51
+ dataloader (DataLoader): DataLoader to use for validation.
52
+ model (nn.Module): Model to validate.
53
+ data (dict): Data dictionary containing dataset information.
54
+ device (torch.device): Device to use for validation.
55
+ batch_i (int): Current batch index.
56
+ training (bool): Whether the model is in training mode.
57
+ names (dict): Class names mapping.
58
+ seen (int): Number of images seen so far during validation.
59
+ stats (dict): Statistics collected during validation.
60
+ confusion_matrix: Confusion matrix for classification evaluation.
61
+ nc (int): Number of classes.
62
+ iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
63
+ jdict (list): List to store JSON validation results.
64
+ speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
65
+ processing times in milliseconds.
66
+ save_dir (Path): Directory to save results.
67
+ plots (dict): Dictionary to store plots for visualization.
68
+ callbacks (dict): Dictionary to store various callback functions.
69
+ stride (int): Model stride for padding calculations.
70
+ loss (torch.Tensor): Accumulated loss during training validation.
71
+
72
+ Methods:
73
+ __call__: Execute validation process, running inference on dataloader and computing performance metrics.
74
+ match_predictions: Match predictions to ground truth objects using IoU.
75
+ add_callback: Append the given callback to the specified event.
76
+ run_callbacks: Run all callbacks associated with a specified event.
77
+ get_dataloader: Get data loader from dataset path and batch size.
78
+ build_dataset: Build dataset from image path.
79
+ preprocess: Preprocess an input batch.
80
+ postprocess: Postprocess the predictions.
81
+ init_metrics: Initialize performance metrics for the YOLO model.
82
+ update_metrics: Update metrics based on predictions and batch.
83
+ finalize_metrics: Finalize and return all metrics.
84
+ get_stats: Return statistics about the model's performance.
85
+ print_results: Print the results of the model's predictions.
86
+ get_desc: Get description of the YOLO model.
87
+ on_plot: Register plots for visualization.
88
+ plot_val_samples: Plot validation samples during training.
89
+ plot_predictions: Plot YOLO model predictions on batch images.
90
+ pred_to_json: Convert predictions to JSON format.
91
+ eval_json: Evaluate and return JSON format of prediction statistics.
92
+ """
93
+
94
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
95
+ """Initialize a BaseValidator instance.
96
+
97
+ Args:
98
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
99
+ save_dir (Path, optional): Directory to save results.
100
+ args (SimpleNamespace, optional): Configuration for the validator.
101
+ _callbacks (dict, optional): Dictionary to store various callback functions.
102
+ """
103
+ import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
104
+
105
+ self.args = get_cfg(overrides=args)
106
+ self.dataloader = dataloader
107
+ self.stride = None
108
+ self.data = None
109
+ self.device = None
110
+ self.batch_i = None
111
+ self.training = True
112
+ self.names = None
113
+ self.seen = None
114
+ self.stats = None
115
+ self.confusion_matrix = None
116
+ self.nc = None
117
+ self.iouv = None
118
+ self.jdict = None
119
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
120
+
121
+ self.save_dir = save_dir or get_save_dir(self.args)
122
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
123
+ if self.args.conf is None:
124
+ self.args.conf = 0.01 if self.args.task == "obb" else 0.001 # reduce OBB val memory usage
125
+ self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
126
+
127
+ self.plots = {}
128
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
129
+
130
+ @smart_inference_mode()
131
+ def __call__(self, trainer=None, model=None):
132
+ """Execute validation process, running inference on dataloader and computing performance metrics.
133
+
134
+ Args:
135
+ trainer (object, optional): Trainer object that contains the model to validate.
136
+ model (nn.Module, optional): Model to validate if not using a trainer.
137
+
138
+ Returns:
139
+ (dict): Dictionary containing validation statistics.
140
+ """
141
+ self.training = trainer is not None
142
+ augment = self.args.augment and (not self.training)
143
+ if self.training:
144
+ self.device = trainer.device
145
+ self.data = trainer.data
146
+ # Force FP16 val during training
147
+ self.args.half = self.device.type != "cpu" and trainer.amp
148
+ model = trainer.ema.ema or trainer.model
149
+ if trainer.args.compile and hasattr(model, "_orig_mod"):
150
+ model = model._orig_mod # validate non-compiled original model to avoid issues
151
+ model = model.half() if self.args.half else model.float()
152
+ self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
153
+ self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
154
+ model.eval()
155
+ else:
156
+ if str(self.args.model).endswith(".yaml") and model is None:
157
+ LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
158
+ callbacks.add_integration_callbacks(self)
159
+ model = AutoBackend(
160
+ model=model or self.args.model,
161
+ device=select_device(self.args.device) if RANK == -1 else torch.device("cuda", RANK),
162
+ dnn=self.args.dnn,
163
+ data=self.args.data,
164
+ fp16=self.args.half,
165
+ )
166
+ self.device = model.device # update device
167
+ self.args.half = model.fp16 # update half
168
+ stride, pt, jit = model.stride, model.pt, model.jit
169
+ imgsz = check_imgsz(self.args.imgsz, stride=stride)
170
+ if not (pt or jit or getattr(model, "dynamic", False)):
171
+ self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
172
+ LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
173
+
174
+ if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
175
+ self.data = check_det_dataset(self.args.data)
176
+ elif self.args.task == "classify":
177
+ self.data = check_cls_dataset(self.args.data, split=self.args.split)
178
+ else:
179
+ raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
180
+
181
+ if self.device.type in {"cpu", "mps"}:
182
+ self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
183
+ if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
184
+ self.args.rect = False
185
+ self.stride = model.stride # used in get_dataloader() for padding
186
+ self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
187
+
188
+ model.eval()
189
+ if self.args.compile:
190
+ model = attempt_compile(model, device=self.device)
191
+ model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
192
+
193
+ self.run_callbacks("on_val_start")
194
+ dt = (
195
+ Profile(device=self.device),
196
+ Profile(device=self.device),
197
+ Profile(device=self.device),
198
+ Profile(device=self.device),
199
+ )
200
+ bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
201
+ self.init_metrics(unwrap_model(model))
202
+ self.jdict = [] # empty before each val
203
+ for batch_i, batch in enumerate(bar):
204
+ self.run_callbacks("on_val_batch_start")
205
+ self.batch_i = batch_i
206
+ # Preprocess
207
+ with dt[0]:
208
+ batch = self.preprocess(batch)
209
+
210
+ # Inference
211
+ with dt[1]:
212
+ preds = model(batch["img"], augment=augment)
213
+
214
+ # Loss
215
+ with dt[2]:
216
+ if self.training:
217
+ self.loss += model.loss(batch, preds)[1]
218
+
219
+ # Postprocess
220
+ with dt[3]:
221
+ preds = self.postprocess(preds)
222
+
223
+ self.update_metrics(preds, batch)
224
+ if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
225
+ self.plot_val_samples(batch, batch_i)
226
+ self.plot_predictions(batch, preds, batch_i)
227
+
228
+ self.run_callbacks("on_val_batch_end")
229
+
230
+ stats = {}
231
+ self.gather_stats()
232
+ if RANK in {-1, 0}:
233
+ stats = self.get_stats()
234
+ self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
235
+ self.finalize_metrics()
236
+ self.print_results()
237
+ self.run_callbacks("on_val_end")
238
+
239
+ if self.training:
240
+ model.float()
241
+ # Reduce loss across all GPUs
242
+ loss = self.loss.clone().detach()
243
+ if trainer.world_size > 1:
244
+ dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
245
+ if RANK > 0:
246
+ return
247
+ results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
248
+ return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
249
+ else:
250
+ if RANK > 0:
251
+ return stats
252
+ LOGGER.info(
253
+ "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
254
+ *tuple(self.speed.values())
255
+ )
256
+ )
257
+ if self.args.save_json and self.jdict:
258
+ with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
259
+ LOGGER.info(f"Saving {f.name}...")
260
+ json.dump(self.jdict, f) # flatten and save
261
+ stats = self.eval_json(stats) # update stats
262
+ if self.args.plots or self.args.save_json:
263
+ LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
264
+ return stats
265
+
266
+ def match_predictions(
267
+ self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
268
+ ) -> torch.Tensor:
269
+ """Match predictions to ground truth objects using IoU.
270
+
271
+ Args:
272
+ pred_classes (torch.Tensor): Predicted class indices of shape (N,).
273
+ true_classes (torch.Tensor): Target class indices of shape (M,).
274
+ iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
275
+ use_scipy (bool, optional): Whether to use scipy for matching (more precise).
276
+
277
+ Returns:
278
+ (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
279
+ """
280
+ # Dx10 matrix, where D - detections, 10 - IoU thresholds
281
+ correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
282
+ # LxD matrix where L - labels (rows), D - detections (columns)
283
+ correct_class = true_classes[:, None] == pred_classes
284
+ iou = iou * correct_class # zero out the wrong classes
285
+ iou = iou.cpu().numpy()
286
+ for i, threshold in enumerate(self.iouv.cpu().tolist()):
287
+ if use_scipy:
288
+ # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
289
+ import scipy # scope import to avoid importing for all commands
290
+
291
+ cost_matrix = iou * (iou >= threshold)
292
+ if cost_matrix.any():
293
+ labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
294
+ valid = cost_matrix[labels_idx, detections_idx] > 0
295
+ if valid.any():
296
+ correct[detections_idx[valid], i] = True
297
+ else:
298
+ matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
299
+ matches = np.array(matches).T
300
+ if matches.shape[0]:
301
+ if matches.shape[0] > 1:
302
+ matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
303
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
304
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
305
+ correct[matches[:, 1].astype(int), i] = True
306
+ return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
307
+
308
+ def add_callback(self, event: str, callback):
309
+ """Append the given callback to the specified event."""
310
+ self.callbacks[event].append(callback)
311
+
312
+ def run_callbacks(self, event: str):
313
+ """Run all callbacks associated with a specified event."""
314
+ for callback in self.callbacks.get(event, []):
315
+ callback(self)
316
+
317
+ def get_dataloader(self, dataset_path, batch_size):
318
+ """Get data loader from dataset path and batch size."""
319
+ raise NotImplementedError("get_dataloader function not implemented for this validator")
320
+
321
+ def build_dataset(self, img_path):
322
+ """Build dataset from image path."""
323
+ raise NotImplementedError("build_dataset function not implemented in validator")
324
+
325
+ def preprocess(self, batch):
326
+ """Preprocess an input batch."""
327
+ return batch
328
+
329
+ def postprocess(self, preds):
330
+ """Postprocess the predictions."""
331
+ return preds
332
+
333
+ def init_metrics(self, model):
334
+ """Initialize performance metrics for the YOLO model."""
335
+ pass
336
+
337
+ def update_metrics(self, preds, batch):
338
+ """Update metrics based on predictions and batch."""
339
+ pass
340
+
341
+ def finalize_metrics(self):
342
+ """Finalize and return all metrics."""
343
+ pass
344
+
345
+ def get_stats(self):
346
+ """Return statistics about the model's performance."""
347
+ return {}
348
+
349
+ def gather_stats(self):
350
+ """Gather statistics from all the GPUs during DDP training to GPU 0."""
351
+ pass
352
+
353
+ def print_results(self):
354
+ """Print the results of the model's predictions."""
355
+ pass
356
+
357
+ def get_desc(self):
358
+ """Get description of the YOLO model."""
359
+ pass
360
+
361
+ @property
362
+ def metric_keys(self):
363
+ """Return the metric keys used in YOLO training/validation."""
364
+ return []
365
+
366
+ def on_plot(self, name, data=None):
367
+ """Register plots for visualization, deduplicating by type."""
368
+ plot_type = data.get("type") if data else None
369
+ if plot_type and any((v.get("data") or {}).get("type") == plot_type for v in self.plots.values()):
370
+ return # Skip duplicate plot types
371
+ self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
372
+
373
+ def plot_val_samples(self, batch, ni):
374
+ """Plot validation samples during training."""
375
+ pass
376
+
377
+ def plot_predictions(self, batch, preds, ni):
378
+ """Plot YOLO model predictions on batch images."""
379
+ pass
380
+
381
+ def pred_to_json(self, preds, batch):
382
+ """Convert predictions to JSON format."""
383
+ pass
384
+
385
+ def eval_json(self, stats):
386
+ """Evaluate and return JSON format of prediction statistics."""
387
+ pass
@@ -0,0 +1,166 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from ultralytics.data.utils import HUBDatasetStats
6
+ from ultralytics.hub.auth import Auth
7
+ from ultralytics.hub.session import HUBTrainingSession
8
+ from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
9
+ from ultralytics.utils import LOGGER, SETTINGS, checks
10
+
11
+ __all__ = (
12
+ "HUB_WEB_ROOT",
13
+ "PREFIX",
14
+ "HUBTrainingSession",
15
+ "check_dataset",
16
+ "export_fmts_hub",
17
+ "export_model",
18
+ "get_export",
19
+ "login",
20
+ "logout",
21
+ "reset_model",
22
+ )
23
+
24
+
25
+ def login(api_key: str | None = None, save: bool = True) -> bool:
26
+ """Log in to the Ultralytics HUB API using the provided API key.
27
+
28
+ The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
29
+ environment variable if successfully authenticated.
30
+
31
+ Args:
32
+ api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS
33
+ or HUB_API_KEY environment variable.
34
+ save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
35
+
36
+ Returns:
37
+ (bool): True if authentication is successful, False otherwise.
38
+ """
39
+ checks.check_requirements("hub-sdk>=0.0.12")
40
+ from hub_sdk import HUBClient
41
+
42
+ api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
43
+ saved_key = SETTINGS.get("api_key")
44
+ active_key = api_key or saved_key
45
+ credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
46
+
47
+ client = HUBClient(credentials) # initialize HUBClient
48
+
49
+ if client.authenticated:
50
+ # Successfully authenticated with HUB
51
+
52
+ if save and client.api_key != saved_key:
53
+ SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
54
+
55
+ # Set message based on whether key was provided or retrieved from settings
56
+ log_message = (
57
+ "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
58
+ )
59
+ LOGGER.info(f"{PREFIX}{log_message}")
60
+
61
+ return True
62
+ else:
63
+ # Failed to authenticate with HUB
64
+ LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo login API_KEY'")
65
+ return False
66
+
67
+
68
+ def logout():
69
+ """Log out of Ultralytics HUB by removing the API key from the settings file."""
70
+ SETTINGS["api_key"] = ""
71
+ LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
72
+
73
+
74
+ def reset_model(model_id: str = ""):
75
+ """Reset a trained model to an untrained state."""
76
+ import requests # scoped as slow import
77
+
78
+ r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
79
+ if r.status_code == 200:
80
+ LOGGER.info(f"{PREFIX}Model reset successfully")
81
+ return
82
+ LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
83
+
84
+
85
+ def export_fmts_hub():
86
+ """Return a list of HUB-supported export formats."""
87
+ from ultralytics.engine.exporter import export_formats
88
+
89
+ return [*list(export_formats()["Argument"][1:]), "ultralytics_tflite", "ultralytics_coreml"]
90
+
91
+
92
+ def export_model(model_id: str = "", format: str = "torchscript"):
93
+ """Export a model to a specified format for deployment via the Ultralytics HUB API.
94
+
95
+ Args:
96
+ model_id (str): The ID of the model to export. An empty string will use the default model.
97
+ format (str): The format to export the model to. Must be one of the supported formats returned by
98
+ export_fmts_hub().
99
+
100
+ Raises:
101
+ AssertionError: If the specified format is not supported or if the export request fails.
102
+
103
+ Examples:
104
+ >>> from ultralytics import hub
105
+ >>> hub.export_model(model_id="your_model_id", format="torchscript")
106
+ """
107
+ import requests # scoped as slow import
108
+
109
+ assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
110
+ r = requests.post(
111
+ f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
112
+ )
113
+ assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
114
+ LOGGER.info(f"{PREFIX}{format} export started ✅")
115
+
116
+
117
+ def get_export(model_id: str = "", format: str = "torchscript"):
118
+ """Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
119
+
120
+ Args:
121
+ model_id (str): The ID of the model to retrieve from Ultralytics HUB.
122
+ format (str): The export format to retrieve. Must be one of the supported formats returned by export_fmts_hub().
123
+
124
+ Returns:
125
+ (dict): JSON response containing the exported model information.
126
+
127
+ Raises:
128
+ AssertionError: If the specified format is not supported or if the API request fails.
129
+
130
+ Examples:
131
+ >>> from ultralytics import hub
132
+ >>> result = hub.get_export(model_id="your_model_id", format="torchscript")
133
+ """
134
+ import requests # scoped as slow import
135
+
136
+ assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
137
+ r = requests.post(
138
+ f"{HUB_API_ROOT}/get-export",
139
+ json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
140
+ headers={"x-api-key": Auth().api_key},
141
+ )
142
+ assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
143
+ return r.json()
144
+
145
+
146
+ def check_dataset(path: str, task: str) -> None:
147
+ """Check HUB dataset Zip file for errors before upload.
148
+
149
+ Args:
150
+ path (str): Path to data.zip (with data.yaml inside data.zip).
151
+ task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.
152
+
153
+ Examples:
154
+ >>> from ultralytics.hub import check_dataset
155
+ >>> check_dataset("path/to/coco8.zip", task="detect") # detect dataset
156
+ >>> check_dataset("path/to/coco8-seg.zip", task="segment") # segment dataset
157
+ >>> check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset
158
+ >>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
159
+ >>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
160
+
161
+ Notes:
162
+ Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
163
+ i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
164
+ """
165
+ HUBDatasetStats(path=path, task=task).get_json()
166
+ LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")