dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,248 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from ultralytics.models.yolo.detect import DetectionValidator
12
+ from ultralytics.utils import ops
13
+ from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
14
+
15
+
16
+ class PoseValidator(DetectionValidator):
17
+ """A class extending the DetectionValidator class for validation based on a pose model.
18
+
19
+ This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
20
+ metrics for pose evaluation.
21
+
22
+ Attributes:
23
+ sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
24
+ kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
25
+ args (dict): Arguments for the validator including task set to "pose".
26
+ metrics (PoseMetrics): Metrics object for pose evaluation.
27
+
28
+ Methods:
29
+ preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
30
+ get_desc: Return description of evaluation metrics in string format.
31
+ init_metrics: Initialize pose estimation metrics for YOLO model.
32
+ _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
33
+ dimensions.
34
+ _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
35
+ _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
36
+ and ground truth.
37
+ plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
38
+ plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
39
+ save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
40
+ pred_to_json: Convert YOLO predictions to COCO JSON format.
41
+ eval_json: Evaluate object detection model using COCO JSON format.
42
+
43
+ Examples:
44
+ >>> from ultralytics.models.yolo.pose import PoseValidator
45
+ >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
46
+ >>> validator = PoseValidator(args=args)
47
+ >>> validator()
48
+
49
+ Notes:
50
+ This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
51
+ for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
52
+ due to a known bug with pose models.
53
+ """
54
+
55
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
56
+ """Initialize a PoseValidator object for pose estimation validation.
57
+
58
+ This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
59
+ specialized metrics for pose evaluation.
60
+
61
+ Args:
62
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
63
+ save_dir (Path | str, optional): Directory to save results.
64
+ args (dict, optional): Arguments for the validator including task set to "pose".
65
+ _callbacks (list, optional): List of callback functions to be executed during validation.
66
+ """
67
+ super().__init__(dataloader, save_dir, args, _callbacks)
68
+ self.sigma = None
69
+ self.kpt_shape = None
70
+ self.args.task = "pose"
71
+ self.metrics = PoseMetrics()
72
+
73
+ def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
74
+ """Preprocess batch by converting keypoints data to float and moving it to the device."""
75
+ batch = super().preprocess(batch)
76
+ batch["keypoints"] = batch["keypoints"].float()
77
+ return batch
78
+
79
+ def get_desc(self) -> str:
80
+ """Return description of evaluation metrics in string format."""
81
+ return ("%22s" + "%11s" * 10) % (
82
+ "Class",
83
+ "Images",
84
+ "Instances",
85
+ "Box(P",
86
+ "R",
87
+ "mAP50",
88
+ "mAP50-95)",
89
+ "Pose(P",
90
+ "R",
91
+ "mAP50",
92
+ "mAP50-95)",
93
+ )
94
+
95
+ def init_metrics(self, model: torch.nn.Module) -> None:
96
+ """Initialize evaluation metrics for YOLO pose validation.
97
+
98
+ Args:
99
+ model (torch.nn.Module): Model to validate.
100
+ """
101
+ super().init_metrics(model)
102
+ self.kpt_shape = self.data["kpt_shape"]
103
+ is_pose = self.kpt_shape == [17, 3]
104
+ nkpt = self.kpt_shape[0]
105
+ self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
106
+
107
+ def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
108
+ """Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
109
+
110
+ This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
111
+ predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
112
+ flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
113
+
114
+ Args:
115
+ preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
116
+ scores, class predictions, and keypoint data.
117
+
118
+ Returns:
119
+ (dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
120
+ - 'bboxes': Bounding box coordinates
121
+ - 'conf': Confidence scores
122
+ - 'cls': Class predictions
123
+ - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
124
+
125
+ Notes:
126
+ If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
127
+ to the next one. The keypoints are extracted from the 'extra' field which contains additional
128
+ task-specific data beyond basic detection.
129
+ """
130
+ preds = super().postprocess(preds)
131
+ for pred in preds:
132
+ pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
133
+ return preds
134
+
135
+ def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
136
+ """Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
137
+
138
+ Args:
139
+ si (int): Batch index.
140
+ batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
141
+
142
+ Returns:
143
+ (dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
144
+
145
+ Notes:
146
+ This method extends the parent class's _prepare_batch method by adding keypoint processing.
147
+ Keypoints are scaled from normalized coordinates to original image dimensions.
148
+ """
149
+ pbatch = super()._prepare_batch(si, batch)
150
+ kpts = batch["keypoints"][batch["batch_idx"] == si]
151
+ h, w = pbatch["imgsz"]
152
+ kpts = kpts.clone()
153
+ kpts[..., 0] *= w
154
+ kpts[..., 1] *= h
155
+ pbatch["keypoints"] = kpts
156
+ return pbatch
157
+
158
+ def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
159
+ """Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
160
+ truth.
161
+
162
+ Args:
163
+ preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
164
+ and 'keypoints' for keypoint predictions.
165
+ batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
166
+ for bounding boxes, and 'keypoints' for keypoint annotations.
167
+
168
+ Returns:
169
+ (dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
170
+ positives across 10 IoU levels.
171
+
172
+ Notes:
173
+ `0.53` scale factor used in area computation is referenced from
174
+ https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
175
+ """
176
+ tp = super()._process_batch(preds, batch)
177
+ gt_cls = batch["cls"]
178
+ if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
179
+ tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
180
+ else:
181
+ # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
182
+ area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
183
+ iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
184
+ tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
185
+ tp.update({"tp_p": tp_p}) # update tp with kpts IoU
186
+ return tp
187
+
188
+ def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
189
+ """Save YOLO pose detections to a text file in normalized coordinates.
190
+
191
+ Args:
192
+ predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
193
+ save_conf (bool): Whether to save confidence scores.
194
+ shape (tuple[int, int]): Shape of the original image (height, width).
195
+ file (Path): Output file path to save detections.
196
+
197
+ Notes:
198
+ The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
199
+ normalized (x, y, visibility) values for each point.
200
+ """
201
+ from ultralytics.engine.results import Results
202
+
203
+ Results(
204
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
205
+ path=None,
206
+ names=self.names,
207
+ boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
208
+ keypoints=predn["keypoints"],
209
+ ).save_txt(file, save_conf=save_conf)
210
+
211
+ def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
212
+ """Convert YOLO predictions to COCO JSON format.
213
+
214
+ This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
215
+ format, and appends the results to the internal JSON dictionary (self.jdict).
216
+
217
+ Args:
218
+ predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
219
+ tensors.
220
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
221
+
222
+ Notes:
223
+ The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
224
+ converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
225
+ before saving to the JSON dictionary.
226
+ """
227
+ super().pred_to_json(predn, pbatch)
228
+ kpts = predn["kpts"]
229
+ for i, k in enumerate(kpts.flatten(1, 2).tolist()):
230
+ self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
231
+
232
+ def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
233
+ """Scales predictions to the original image size."""
234
+ return {
235
+ **super().scale_preds(predn, pbatch),
236
+ "kpts": ops.scale_coords(
237
+ pbatch["imgsz"],
238
+ predn["keypoints"].clone(),
239
+ pbatch["ori_shape"],
240
+ ratio_pad=pbatch["ratio_pad"],
241
+ ),
242
+ }
243
+
244
+ def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
245
+ """Evaluate object detection model using COCO JSON format."""
246
+ anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
247
+ pred_json = self.save_dir / "predictions.json" # predictions
248
+ return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .predict import SegmentationPredictor
4
+ from .train import SegmentationTrainer
5
+ from .val import SegmentationValidator
6
+
7
+ __all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
@@ -0,0 +1,109 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.engine.results import Results
4
+ from ultralytics.models.yolo.detect.predict import DetectionPredictor
5
+ from ultralytics.utils import DEFAULT_CFG, ops
6
+
7
+
8
+ class SegmentationPredictor(DetectionPredictor):
9
+ """A class extending the DetectionPredictor class for prediction based on a segmentation model.
10
+
11
+ This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
12
+ prediction results.
13
+
14
+ Attributes:
15
+ args (dict): Configuration arguments for the predictor.
16
+ model (torch.nn.Module): The loaded YOLO segmentation model.
17
+ batch (list): Current batch of images being processed.
18
+
19
+ Methods:
20
+ postprocess: Apply non-max suppression and process segmentation detections.
21
+ construct_results: Construct a list of result objects from predictions.
22
+ construct_result: Construct a single result object from a prediction.
23
+
24
+ Examples:
25
+ >>> from ultralytics.utils import ASSETS
26
+ >>> from ultralytics.models.yolo.segment import SegmentationPredictor
27
+ >>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
28
+ >>> predictor = SegmentationPredictor(overrides=args)
29
+ >>> predictor.predict_cli()
30
+ """
31
+
32
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
33
+ """Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
34
+
35
+ This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
36
+ prediction results.
37
+
38
+ Args:
39
+ cfg (dict): Configuration for the predictor.
40
+ overrides (dict, optional): Configuration overrides that take precedence over cfg.
41
+ _callbacks (list, optional): List of callback functions to be invoked during prediction.
42
+ """
43
+ super().__init__(cfg, overrides, _callbacks)
44
+ self.args.task = "segment"
45
+
46
+ def postprocess(self, preds, img, orig_imgs):
47
+ """Apply non-max suppression and process segmentation detections for each image in the input batch.
48
+
49
+ Args:
50
+ preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
51
+ img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
52
+ orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
53
+
54
+ Returns:
55
+ (list): List of Results objects containing the segmentation predictions for each image in the batch. Each
56
+ Results object includes both bounding boxes and segmentation masks.
57
+
58
+ Examples:
59
+ >>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
60
+ >>> results = predictor.postprocess(preds, img, orig_img)
61
+ """
62
+ # Extract protos - tuple if PyTorch model or array if exported
63
+ protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
64
+ return super().postprocess(preds[0], img, orig_imgs, protos=protos)
65
+
66
+ def construct_results(self, preds, img, orig_imgs, protos):
67
+ """Construct a list of result objects from the predictions.
68
+
69
+ Args:
70
+ preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
71
+ img (torch.Tensor): The image after preprocessing.
72
+ orig_imgs (list[np.ndarray]): List of original images before preprocessing.
73
+ protos (list[torch.Tensor]): List of prototype masks.
74
+
75
+ Returns:
76
+ (list[Results]): List of result objects containing the original images, image paths, class names, bounding
77
+ boxes, and masks.
78
+ """
79
+ return [
80
+ self.construct_result(pred, img, orig_img, img_path, proto)
81
+ for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
82
+ ]
83
+
84
+ def construct_result(self, pred, img, orig_img, img_path, proto):
85
+ """Construct a single result object from the prediction.
86
+
87
+ Args:
88
+ pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
89
+ img (torch.Tensor): The image after preprocessing.
90
+ orig_img (np.ndarray): The original image before preprocessing.
91
+ img_path (str): The path to the original image.
92
+ proto (torch.Tensor): The prototype masks.
93
+
94
+ Returns:
95
+ (Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
96
+ """
97
+ if pred.shape[0] == 0: # save empty boxes
98
+ masks = None
99
+ elif self.args.retina_masks:
100
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
101
+ masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # NHW
102
+ else:
103
+ masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # NHW
104
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
105
+ if masks is not None:
106
+ keep = masks.amax((-2, -1)) > 0 # only keep predictions with masks
107
+ if not all(keep): # most predictions have masks
108
+ pred, masks = pred[keep], masks[keep] # indexing is slow
109
+ return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
@@ -0,0 +1,69 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import copy
6
+ from pathlib import Path
7
+
8
+ from ultralytics.models import yolo
9
+ from ultralytics.nn.tasks import SegmentationModel
10
+ from ultralytics.utils import DEFAULT_CFG, RANK
11
+
12
+
13
+ class SegmentationTrainer(yolo.detect.DetectionTrainer):
14
+ """A class extending the DetectionTrainer class for training based on a segmentation model.
15
+
16
+ This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
17
+ functionality including model initialization, validation, and visualization.
18
+
19
+ Attributes:
20
+ loss_names (tuple[str]): Names of the loss components used during training.
21
+
22
+ Examples:
23
+ >>> from ultralytics.models.yolo.segment import SegmentationTrainer
24
+ >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
25
+ >>> trainer = SegmentationTrainer(overrides=args)
26
+ >>> trainer.train()
27
+ """
28
+
29
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
30
+ """Initialize a SegmentationTrainer object.
31
+
32
+ Args:
33
+ cfg (dict): Configuration dictionary with default training settings.
34
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
35
+ _callbacks (list, optional): List of callback functions to be executed during training.
36
+ """
37
+ if overrides is None:
38
+ overrides = {}
39
+ overrides["task"] = "segment"
40
+ super().__init__(cfg, overrides, _callbacks)
41
+
42
+ def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
43
+ """Initialize and return a SegmentationModel with specified configuration and weights.
44
+
45
+ Args:
46
+ cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
47
+ weights (str | Path, optional): Path to pretrained weights file.
48
+ verbose (bool): Whether to display model information during initialization.
49
+
50
+ Returns:
51
+ (SegmentationModel): Initialized segmentation model with loaded weights if specified.
52
+
53
+ Examples:
54
+ >>> trainer = SegmentationTrainer()
55
+ >>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
56
+ >>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
57
+ """
58
+ model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
59
+ if weights:
60
+ model.load(weights)
61
+
62
+ return model
63
+
64
+ def get_validator(self):
65
+ """Return an instance of SegmentationValidator for validation of YOLO model."""
66
+ self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
67
+ return yolo.segment.SegmentationValidator(
68
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
69
+ )