ultralytics-opencv-headless 8.3.246__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1578 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +313 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +1006 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +501 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1563 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.246.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.246.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.246.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.246.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.246.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.246.dist-info/top_level.txt +1 -0
@@ -0,0 +1,302 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from ultralytics.models.yolo.detect import DetectionValidator
12
+ from ultralytics.utils import LOGGER, ops
13
+ from ultralytics.utils.metrics import OBBMetrics, batch_probiou
14
+ from ultralytics.utils.nms import TorchNMS
15
+ from ultralytics.utils.plotting import plot_images
16
+
17
+
18
+ class OBBValidator(DetectionValidator):
19
+ """A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
20
+
21
+ This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
22
+ satellite imagery where objects can appear at various orientations.
23
+
24
+ Attributes:
25
+ args (dict): Configuration arguments for the validator.
26
+ metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
27
+ is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
28
+
29
+ Methods:
30
+ init_metrics: Initialize evaluation metrics for YOLO.
31
+ _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
32
+ _prepare_batch: Prepare batch data for OBB validation.
33
+ _prepare_pred: Prepare predictions with scaled and padded bounding boxes.
34
+ plot_predictions: Plot predicted bounding boxes on input images.
35
+ pred_to_json: Serialize YOLO predictions to COCO json format.
36
+ save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
37
+ eval_json: Evaluate YOLO output in JSON format and return performance statistics.
38
+
39
+ Examples:
40
+ >>> from ultralytics.models.yolo.obb import OBBValidator
41
+ >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
42
+ >>> validator = OBBValidator(args=args)
43
+ >>> validator(model=args["model"])
44
+ """
45
+
46
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
47
+ """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
48
+
49
+ This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
50
+ extends the DetectionValidator class and configures it specifically for the OBB task.
51
+
52
+ Args:
53
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
54
+ save_dir (str | Path, optional): Directory to save results.
55
+ args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
56
+ _callbacks (list, optional): List of callback functions to be called during validation.
57
+ """
58
+ super().__init__(dataloader, save_dir, args, _callbacks)
59
+ self.args.task = "obb"
60
+ self.metrics = OBBMetrics()
61
+
62
+ def init_metrics(self, model: torch.nn.Module) -> None:
63
+ """Initialize evaluation metrics for YOLO obb validation.
64
+
65
+ Args:
66
+ model (torch.nn.Module): Model to validate.
67
+ """
68
+ super().init_metrics(model)
69
+ val = self.data.get(self.args.split, "") # validation path
70
+ self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
71
+ self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
72
+
73
+ def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
74
+ """Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
75
+
76
+ Args:
77
+ preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
78
+ class labels and bounding boxes.
79
+ batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
80
+ labels and bounding boxes.
81
+
82
+ Returns:
83
+ (dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
84
+ with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
85
+ predictions compared to the ground truth.
86
+
87
+ Examples:
88
+ >>> detections = torch.rand(100, 7) # 100 sample detections
89
+ >>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
90
+ >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
91
+ >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
92
+ """
93
+ if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
94
+ return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
95
+ iou = batch_probiou(batch["bboxes"], preds["bboxes"])
96
+ return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
97
+
98
+ def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
99
+ """Postprocess OBB predictions.
100
+
101
+ Args:
102
+ preds (torch.Tensor): Raw predictions from the model.
103
+
104
+ Returns:
105
+ (list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
106
+ """
107
+ preds = super().postprocess(preds)
108
+ for pred in preds:
109
+ pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1) # concatenate angle
110
+ return preds
111
+
112
+ def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
113
+ """Prepare batch data for OBB validation with proper scaling and formatting.
114
+
115
+ Args:
116
+ si (int): Batch index to process.
117
+ batch (dict[str, Any]): Dictionary containing batch data with keys:
118
+ - batch_idx: Tensor of batch indices
119
+ - cls: Tensor of class labels
120
+ - bboxes: Tensor of bounding boxes
121
+ - ori_shape: Original image shapes
122
+ - img: Batch of images
123
+ - ratio_pad: Ratio and padding information
124
+
125
+ Returns:
126
+ (dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
127
+ """
128
+ idx = batch["batch_idx"] == si
129
+ cls = batch["cls"][idx].squeeze(-1)
130
+ bbox = batch["bboxes"][idx]
131
+ ori_shape = batch["ori_shape"][si]
132
+ imgsz = batch["img"].shape[2:]
133
+ ratio_pad = batch["ratio_pad"][si]
134
+ if cls.shape[0]:
135
+ bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
136
+ return {
137
+ "cls": cls,
138
+ "bboxes": bbox,
139
+ "ori_shape": ori_shape,
140
+ "imgsz": imgsz,
141
+ "ratio_pad": ratio_pad,
142
+ "im_file": batch["im_file"][si],
143
+ }
144
+
145
+ def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
146
+ """Plot predicted bounding boxes on input images and save the result.
147
+
148
+ Args:
149
+ batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
150
+ preds (list[dict[str, torch.Tensor]]): List of prediction dictionaries for each image in the batch.
151
+ ni (int): Batch index used for naming the output file.
152
+
153
+ Examples:
154
+ >>> validator = OBBValidator()
155
+ >>> batch = {"img": images, "im_file": paths}
156
+ >>> preds = [{"bboxes": torch.rand(10, 5), "cls": torch.zeros(10), "conf": torch.rand(10)}]
157
+ >>> validator.plot_predictions(batch, preds, 0)
158
+ """
159
+ if not preds:
160
+ return
161
+ for i, pred in enumerate(preds):
162
+ pred["batch_idx"] = torch.ones_like(pred["conf"]) * i
163
+ keys = preds[0].keys()
164
+ batched_preds = {k: torch.cat([x[k] for x in preds], dim=0) for k in keys}
165
+ plot_images(
166
+ images=batch["img"],
167
+ labels=batched_preds,
168
+ paths=batch["im_file"],
169
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
170
+ names=self.names,
171
+ on_plot=self.on_plot,
172
+ )
173
+
174
+ def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
175
+ """Convert YOLO predictions to COCO JSON format with rotated bounding box information.
176
+
177
+ Args:
178
+ predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
179
+ bounding box coordinates, confidence scores, and class predictions.
180
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
181
+
182
+ Notes:
183
+ This method processes rotated bounding box predictions and converts them to both rbox format
184
+ (x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
185
+ to the JSON dictionary.
186
+ """
187
+ path = Path(pbatch["im_file"])
188
+ stem = path.stem
189
+ image_id = int(stem) if stem.isnumeric() else stem
190
+ rbox = predn["bboxes"]
191
+ poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
192
+ for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
193
+ self.jdict.append(
194
+ {
195
+ "image_id": image_id,
196
+ "file_name": path.name,
197
+ "category_id": self.class_map[int(c)],
198
+ "score": round(s, 5),
199
+ "rbox": [round(x, 3) for x in r],
200
+ "poly": [round(x, 3) for x in b],
201
+ }
202
+ )
203
+
204
+ def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
205
+ """Save YOLO OBB detections to a text file in normalized coordinates.
206
+
207
+ Args:
208
+ predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
209
+ class predictions, and angles in format (x, y, w, h, conf, cls, angle).
210
+ save_conf (bool): Whether to save confidence scores in the text file.
211
+ shape (tuple[int, int]): Original image shape in format (height, width).
212
+ file (Path): Output file path to save detections.
213
+
214
+ Examples:
215
+ >>> validator = OBBValidator()
216
+ >>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
217
+ >>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
218
+ """
219
+ import numpy as np
220
+
221
+ from ultralytics.engine.results import Results
222
+
223
+ Results(
224
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
225
+ path=None,
226
+ names=self.names,
227
+ obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
228
+ ).save_txt(file, save_conf=save_conf)
229
+
230
+ def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
231
+ """Scales predictions to the original image size."""
232
+ return {
233
+ **predn,
234
+ "bboxes": ops.scale_boxes(
235
+ pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
236
+ ),
237
+ }
238
+
239
+ def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
240
+ """Evaluate YOLO output in JSON format and save predictions in DOTA format.
241
+
242
+ Args:
243
+ stats (dict[str, Any]): Performance statistics dictionary.
244
+
245
+ Returns:
246
+ (dict[str, Any]): Updated performance statistics.
247
+ """
248
+ if self.args.save_json and self.is_dota and len(self.jdict):
249
+ import json
250
+ import re
251
+ from collections import defaultdict
252
+
253
+ pred_json = self.save_dir / "predictions.json" # predictions
254
+ pred_txt = self.save_dir / "predictions_txt" # predictions
255
+ pred_txt.mkdir(parents=True, exist_ok=True)
256
+ data = json.load(open(pred_json))
257
+ # Save split results
258
+ LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
259
+ for d in data:
260
+ image_id = d["image_id"]
261
+ score = d["score"]
262
+ classname = self.names[d["category_id"] - 1].replace(" ", "-")
263
+ p = d["poly"]
264
+
265
+ with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
266
+ f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
267
+ # Save merged results, this could result slightly lower map than using official merging script,
268
+ # because of the probiou calculation.
269
+ pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
270
+ pred_merged_txt.mkdir(parents=True, exist_ok=True)
271
+ merged_results = defaultdict(list)
272
+ LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
273
+ for d in data:
274
+ image_id = d["image_id"].split("__", 1)[0]
275
+ pattern = re.compile(r"\d+___\d+")
276
+ x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
277
+ bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
278
+ bbox[0] += x
279
+ bbox[1] += y
280
+ bbox.extend([score, cls])
281
+ merged_results[image_id].append(bbox)
282
+ for image_id, bbox in merged_results.items():
283
+ bbox = torch.tensor(bbox)
284
+ max_wh = torch.max(bbox[:, :2]).item() * 2
285
+ c = bbox[:, 6:7] * max_wh # classes
286
+ scores = bbox[:, 5] # scores
287
+ b = bbox[:, :5].clone()
288
+ b[:, :2] += c
289
+ # 0.3 could get results close to the ones from official merging script, even slightly better.
290
+ i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
291
+ bbox = bbox[i]
292
+
293
+ b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
294
+ for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
295
+ classname = self.names[int(x[-1])].replace(" ", "-")
296
+ p = [round(i, 3) for i in x[:-2]] # poly
297
+ score = round(x[-2], 3)
298
+
299
+ with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
300
+ f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
301
+
302
+ return stats
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .predict import PosePredictor
4
+ from .train import PoseTrainer
5
+ from .val import PoseValidator
6
+
7
+ __all__ = "PosePredictor", "PoseTrainer", "PoseValidator"
@@ -0,0 +1,65 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.models.yolo.detect.predict import DetectionPredictor
4
+ from ultralytics.utils import DEFAULT_CFG, ops
5
+
6
+
7
+ class PosePredictor(DetectionPredictor):
8
+ """A class extending the DetectionPredictor class for prediction based on a pose model.
9
+
10
+ This class specializes in pose estimation, handling keypoints detection alongside standard object detection
11
+ capabilities inherited from DetectionPredictor.
12
+
13
+ Attributes:
14
+ args (namespace): Configuration arguments for the predictor.
15
+ model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
16
+
17
+ Methods:
18
+ construct_result: Construct the result object from the prediction, including keypoints.
19
+
20
+ Examples:
21
+ >>> from ultralytics.utils import ASSETS
22
+ >>> from ultralytics.models.yolo.pose import PosePredictor
23
+ >>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
24
+ >>> predictor = PosePredictor(overrides=args)
25
+ >>> predictor.predict_cli()
26
+ """
27
+
28
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
29
+ """Initialize PosePredictor for pose estimation tasks.
30
+
31
+ Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
32
+ for Apple MPS.
33
+
34
+ Args:
35
+ cfg (Any): Configuration for the predictor.
36
+ overrides (dict, optional): Configuration overrides that take precedence over cfg.
37
+ _callbacks (list, optional): List of callback functions to be invoked during prediction.
38
+ """
39
+ super().__init__(cfg, overrides, _callbacks)
40
+ self.args.task = "pose"
41
+
42
+ def construct_result(self, pred, img, orig_img, img_path):
43
+ """Construct the result object from the prediction, including keypoints.
44
+
45
+ Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
46
+ result object.
47
+
48
+ Args:
49
+ pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
50
+ the number of detections, K is the number of keypoints, and D is the keypoint dimension.
51
+ img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
52
+ orig_img (np.ndarray): The original unprocessed image as a numpy array.
53
+ img_path (str): The path to the original image file.
54
+
55
+ Returns:
56
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and
57
+ keypoints.
58
+ """
59
+ result = super().construct_result(pred, img, orig_img, img_path)
60
+ # Extract keypoints from prediction and reshape according to model's keypoint shape
61
+ pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
62
+ # Scale keypoints coordinates to match the original image dimensions
63
+ pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
64
+ result.update(keypoints=pred_kpts)
65
+ return result
@@ -0,0 +1,110 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import copy
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from ultralytics.models import yolo
10
+ from ultralytics.nn.tasks import PoseModel
11
+ from ultralytics.utils import DEFAULT_CFG
12
+
13
+
14
+ class PoseTrainer(yolo.detect.DetectionTrainer):
15
+ """A class extending the DetectionTrainer class for training YOLO pose estimation models.
16
+
17
+ This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
18
+ of pose keypoints alongside bounding boxes.
19
+
20
+ Attributes:
21
+ args (dict): Configuration arguments for training.
22
+ model (PoseModel): The pose estimation model being trained.
23
+ data (dict): Dataset configuration including keypoint shape information.
24
+ loss_names (tuple): Names of the loss components used in training.
25
+
26
+ Methods:
27
+ get_model: Retrieve a pose estimation model with specified configuration.
28
+ set_model_attributes: Set keypoints shape attribute on the model.
29
+ get_validator: Create a validator instance for model evaluation.
30
+ plot_training_samples: Visualize training samples with keypoints.
31
+ get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
32
+
33
+ Examples:
34
+ >>> from ultralytics.models.yolo.pose import PoseTrainer
35
+ >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
36
+ >>> trainer = PoseTrainer(overrides=args)
37
+ >>> trainer.train()
38
+ """
39
+
40
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
41
+ """Initialize a PoseTrainer object for training YOLO pose estimation models.
42
+
43
+ Args:
44
+ cfg (dict, optional): Default configuration dictionary containing training parameters.
45
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
46
+ _callbacks (list, optional): List of callback functions to be executed during training.
47
+
48
+ Notes:
49
+ This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
50
+ A warning is issued when using Apple MPS device due to known bugs with pose models.
51
+ """
52
+ if overrides is None:
53
+ overrides = {}
54
+ overrides["task"] = "pose"
55
+ super().__init__(cfg, overrides, _callbacks)
56
+
57
+ def get_model(
58
+ self,
59
+ cfg: str | Path | dict[str, Any] | None = None,
60
+ weights: str | Path | None = None,
61
+ verbose: bool = True,
62
+ ) -> PoseModel:
63
+ """Get pose estimation model with specified configuration and weights.
64
+
65
+ Args:
66
+ cfg (str | Path | dict, optional): Model configuration file path or dictionary.
67
+ weights (str | Path, optional): Path to the model weights file.
68
+ verbose (bool): Whether to display model information.
69
+
70
+ Returns:
71
+ (PoseModel): Initialized pose estimation model.
72
+ """
73
+ model = PoseModel(
74
+ cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
75
+ )
76
+ if weights:
77
+ model.load(weights)
78
+
79
+ return model
80
+
81
+ def set_model_attributes(self):
82
+ """Set keypoints shape attribute of PoseModel."""
83
+ super().set_model_attributes()
84
+ self.model.kpt_shape = self.data["kpt_shape"]
85
+ kpt_names = self.data.get("kpt_names")
86
+ if not kpt_names:
87
+ names = list(map(str, range(self.model.kpt_shape[0])))
88
+ kpt_names = {i: names for i in range(self.model.nc)}
89
+ self.model.kpt_names = kpt_names
90
+
91
+ def get_validator(self):
92
+ """Return an instance of the PoseValidator class for validation."""
93
+ self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
94
+ return yolo.pose.PoseValidator(
95
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
96
+ )
97
+
98
+ def get_dataset(self) -> dict[str, Any]:
99
+ """Retrieve the dataset and ensure it contains the required `kpt_shape` key.
100
+
101
+ Returns:
102
+ (dict): A dictionary containing the training/validation/test dataset and category names.
103
+
104
+ Raises:
105
+ KeyError: If the `kpt_shape` key is not present in the dataset.
106
+ """
107
+ data = super().get_dataset()
108
+ if "kpt_shape" not in data:
109
+ raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
110
+ return data