dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,283 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ from ultralytics.models.yolo.detect import DetectionValidator
8
+ from ultralytics.utils import LOGGER, ops
9
+ from ultralytics.utils.metrics import OBBMetrics, batch_probiou
10
+ from ultralytics.utils.plotting import output_to_rotated_target, plot_images
11
+
12
+
13
+ class OBBValidator(DetectionValidator):
14
+ """
15
+ A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
16
+
17
+ This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
18
+ satellite imagery where objects can appear at various orientations.
19
+
20
+ Attributes:
21
+ args (dict): Configuration arguments for the validator.
22
+ metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
23
+ is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
24
+
25
+ Methods:
26
+ init_metrics: Initialize evaluation metrics for YOLO.
27
+ _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
28
+ _prepare_batch: Prepare batch data for OBB validation.
29
+ _prepare_pred: Prepare predictions with scaled and padded bounding boxes.
30
+ plot_predictions: Plot predicted bounding boxes on input images.
31
+ pred_to_json: Serialize YOLO predictions to COCO json format.
32
+ save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
33
+ eval_json: Evaluate YOLO output in JSON format and return performance statistics.
34
+
35
+ Examples:
36
+ >>> from ultralytics.models.yolo.obb import OBBValidator
37
+ >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
38
+ >>> validator = OBBValidator(args=args)
39
+ >>> validator(model=args["model"])
40
+ """
41
+
42
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
43
+ """
44
+ Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
45
+
46
+ This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
47
+ It extends the DetectionValidator class and configures it specifically for the OBB task.
48
+
49
+ Args:
50
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
51
+ save_dir (str | Path, optional): Directory to save results.
52
+ pbar (bool, optional): Display progress bar during validation.
53
+ args (dict, optional): Arguments containing validation parameters.
54
+ _callbacks (list, optional): List of callback functions to be called during validation.
55
+ """
56
+ super().__init__(dataloader, save_dir, pbar, args, _callbacks)
57
+ self.args.task = "obb"
58
+ self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
59
+
60
+ def init_metrics(self, model):
61
+ """Initialize evaluation metrics for YOLO."""
62
+ super().init_metrics(model)
63
+ val = self.data.get(self.args.split, "") # validation path
64
+ self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
65
+
66
+ def _process_batch(self, detections, gt_bboxes, gt_cls):
67
+ """
68
+ Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
69
+
70
+ Args:
71
+ detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
72
+ data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
73
+ gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
74
+ represented as (x1, y1, x2, y2, angle).
75
+ gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
76
+
77
+ Returns:
78
+ (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
79
+ Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
80
+
81
+ Examples:
82
+ >>> detections = torch.rand(100, 7) # 100 sample detections
83
+ >>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
84
+ >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
85
+ >>> correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
86
+
87
+ Note:
88
+ This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
89
+ """
90
+ iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
91
+ return self.match_predictions(detections[:, 5], gt_cls, iou)
92
+
93
+ def _prepare_batch(self, si, batch):
94
+ """
95
+ Prepare batch data for OBB validation with proper scaling and formatting.
96
+
97
+ Args:
98
+ si (int): Batch index to process.
99
+ batch (dict): Dictionary containing batch data with keys:
100
+ - batch_idx: Tensor of batch indices
101
+ - cls: Tensor of class labels
102
+ - bboxes: Tensor of bounding boxes
103
+ - ori_shape: Original image shapes
104
+ - img: Batch of images
105
+ - ratio_pad: Ratio and padding information
106
+
107
+ This method filters the batch data for a specific batch index, extracts class labels and bounding boxes,
108
+ and scales the bounding boxes to the original image dimensions.
109
+ """
110
+ idx = batch["batch_idx"] == si
111
+ cls = batch["cls"][idx].squeeze(-1)
112
+ bbox = batch["bboxes"][idx]
113
+ ori_shape = batch["ori_shape"][si]
114
+ imgsz = batch["img"].shape[2:]
115
+ ratio_pad = batch["ratio_pad"][si]
116
+ if len(cls):
117
+ bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
118
+ ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
119
+ return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
120
+
121
+ def _prepare_pred(self, pred, pbatch):
122
+ """
123
+ Prepare predictions by scaling bounding boxes to original image dimensions.
124
+
125
+ This method takes prediction tensors containing bounding box coordinates and scales them from the model's
126
+ input dimensions to the original image dimensions using the provided batch information.
127
+
128
+ Args:
129
+ pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
130
+ pbatch (dict): Dictionary containing batch information with keys:
131
+ - imgsz (tuple): Model input image size.
132
+ - ori_shape (tuple): Original image shape.
133
+ - ratio_pad (tuple): Ratio and padding information for scaling.
134
+
135
+ Returns:
136
+ (torch.Tensor): Scaled prediction tensor with bounding boxes in original image dimensions.
137
+ """
138
+ predn = pred.clone()
139
+ ops.scale_boxes(
140
+ pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
141
+ ) # native-space pred
142
+ return predn
143
+
144
+ def plot_predictions(self, batch, preds, ni):
145
+ """
146
+ Plot predicted bounding boxes on input images and save the result.
147
+
148
+ Args:
149
+ batch (dict): Batch data containing images, file paths, and other metadata.
150
+ preds (list): List of prediction tensors for each image in the batch.
151
+ ni (int): Batch index used for naming the output file.
152
+
153
+ Examples:
154
+ >>> validator = OBBValidator()
155
+ >>> batch = {"img": images, "im_file": paths}
156
+ >>> preds = [torch.rand(10, 7)] # Example predictions for one image
157
+ >>> validator.plot_predictions(batch, preds, 0)
158
+ """
159
+ plot_images(
160
+ batch["img"],
161
+ *output_to_rotated_target(preds, max_det=self.args.max_det),
162
+ paths=batch["im_file"],
163
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
164
+ names=self.names,
165
+ on_plot=self.on_plot,
166
+ ) # pred
167
+
168
+ def pred_to_json(self, predn, filename):
169
+ """
170
+ Convert YOLO predictions to COCO JSON format with rotated bounding box information.
171
+
172
+ Args:
173
+ predn (torch.Tensor): Prediction tensor containing bounding box coordinates, confidence scores,
174
+ class predictions, and rotation angles with shape (N, 6+) where the last column is the angle.
175
+ filename (str | Path): Path to the image file for which predictions are being processed.
176
+
177
+ Notes:
178
+ This method processes rotated bounding box predictions and converts them to both rbox format
179
+ (x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
180
+ to the JSON dictionary.
181
+ """
182
+ stem = Path(filename).stem
183
+ image_id = int(stem) if stem.isnumeric() else stem
184
+ rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
185
+ poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
186
+ for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
187
+ self.jdict.append(
188
+ {
189
+ "image_id": image_id,
190
+ "category_id": self.class_map[int(predn[i, 5].item())],
191
+ "score": round(predn[i, 4].item(), 5),
192
+ "rbox": [round(x, 3) for x in r],
193
+ "poly": [round(x, 3) for x in b],
194
+ }
195
+ )
196
+
197
+ def save_one_txt(self, predn, save_conf, shape, file):
198
+ """
199
+ Save YOLO OBB (Oriented Bounding Box) detections to a text file in normalized coordinates.
200
+
201
+ Args:
202
+ predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
203
+ class predictions, and angles in format (x, y, w, h, conf, cls, angle).
204
+ save_conf (bool): Whether to save confidence scores in the text file.
205
+ shape (tuple): Original image shape in format (height, width).
206
+ file (Path | str): Output file path to save detections.
207
+
208
+ Examples:
209
+ >>> validator = OBBValidator()
210
+ >>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
211
+ >>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
212
+ """
213
+ import numpy as np
214
+
215
+ from ultralytics.engine.results import Results
216
+
217
+ rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
218
+ # xywh, r, conf, cls
219
+ obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
220
+ Results(
221
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
222
+ path=None,
223
+ names=self.names,
224
+ obb=obb,
225
+ ).save_txt(file, save_conf=save_conf)
226
+
227
+ def eval_json(self, stats):
228
+ """Evaluate YOLO output in JSON format and save predictions in DOTA format."""
229
+ if self.args.save_json and self.is_dota and len(self.jdict):
230
+ import json
231
+ import re
232
+ from collections import defaultdict
233
+
234
+ pred_json = self.save_dir / "predictions.json" # predictions
235
+ pred_txt = self.save_dir / "predictions_txt" # predictions
236
+ pred_txt.mkdir(parents=True, exist_ok=True)
237
+ data = json.load(open(pred_json))
238
+ # Save split results
239
+ LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
240
+ for d in data:
241
+ image_id = d["image_id"]
242
+ score = d["score"]
243
+ classname = self.names[d["category_id"] - 1].replace(" ", "-")
244
+ p = d["poly"]
245
+
246
+ with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
247
+ f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
248
+ # Save merged results, this could result slightly lower map than using official merging script,
249
+ # because of the probiou calculation.
250
+ pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
251
+ pred_merged_txt.mkdir(parents=True, exist_ok=True)
252
+ merged_results = defaultdict(list)
253
+ LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
254
+ for d in data:
255
+ image_id = d["image_id"].split("__")[0]
256
+ pattern = re.compile(r"\d+___\d+")
257
+ x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
258
+ bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
259
+ bbox[0] += x
260
+ bbox[1] += y
261
+ bbox.extend([score, cls])
262
+ merged_results[image_id].append(bbox)
263
+ for image_id, bbox in merged_results.items():
264
+ bbox = torch.tensor(bbox)
265
+ max_wh = torch.max(bbox[:, :2]).item() * 2
266
+ c = bbox[:, 6:7] * max_wh # classes
267
+ scores = bbox[:, 5] # scores
268
+ b = bbox[:, :5].clone()
269
+ b[:, :2] += c
270
+ # 0.3 could get results close to the ones from official merging script, even slightly better.
271
+ i = ops.nms_rotated(b, scores, 0.3)
272
+ bbox = bbox[i]
273
+
274
+ b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
275
+ for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
276
+ classname = self.names[int(x[-1])].replace(" ", "-")
277
+ p = [round(i, 3) for i in x[:-2]] # poly
278
+ score = round(x[-2], 3)
279
+
280
+ with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
281
+ f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
282
+
283
+ return stats
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .predict import PosePredictor
4
+ from .train import PoseTrainer
5
+ from .val import PoseValidator
6
+
7
+ __all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
@@ -0,0 +1,79 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.models.yolo.detect.predict import DetectionPredictor
4
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
5
+
6
+
7
+ class PosePredictor(DetectionPredictor):
8
+ """
9
+ A class extending the DetectionPredictor class for prediction based on a pose model.
10
+
11
+ This class specializes in pose estimation, handling keypoints detection alongside standard object detection
12
+ capabilities inherited from DetectionPredictor.
13
+
14
+ Attributes:
15
+ args (namespace): Configuration arguments for the predictor.
16
+ model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
17
+
18
+ Methods:
19
+ construct_result: Constructs the result object from the prediction, including keypoints.
20
+
21
+ Examples:
22
+ >>> from ultralytics.utils import ASSETS
23
+ >>> from ultralytics.models.yolo.pose import PosePredictor
24
+ >>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
25
+ >>> predictor = PosePredictor(overrides=args)
26
+ >>> predictor.predict_cli()
27
+ """
28
+
29
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
30
+ """
31
+ Initialize PosePredictor, a specialized predictor for pose estimation tasks.
32
+
33
+ This initializer sets up a PosePredictor instance, configuring it for pose detection tasks and handling
34
+ device-specific warnings for Apple MPS.
35
+
36
+ Args:
37
+ cfg (Any): Configuration for the predictor. Default is DEFAULT_CFG.
38
+ overrides (dict, optional): Configuration overrides that take precedence over cfg.
39
+ _callbacks (list, optional): List of callback functions to be invoked during prediction.
40
+
41
+ Examples:
42
+ >>> from ultralytics.utils import ASSETS
43
+ >>> from ultralytics.models.yolo.pose import PosePredictor
44
+ >>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
45
+ >>> predictor = PosePredictor(overrides=args)
46
+ >>> predictor.predict_cli()
47
+ """
48
+ super().__init__(cfg, overrides, _callbacks)
49
+ self.args.task = "pose"
50
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
51
+ LOGGER.warning(
52
+ "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
53
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
54
+ )
55
+
56
+ def construct_result(self, pred, img, orig_img, img_path):
57
+ """
58
+ Construct the result object from the prediction, including keypoints.
59
+
60
+ This method extends the parent class implementation by extracting keypoint data from predictions
61
+ and adding them to the result object.
62
+
63
+ Args:
64
+ pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
65
+ the number of detections, K is the number of keypoints, and D is the keypoint dimension.
66
+ img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
67
+ orig_img (np.ndarray): The original unprocessed image as a numpy array.
68
+ img_path (str): The path to the original image file.
69
+
70
+ Returns:
71
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
72
+ """
73
+ result = super().construct_result(pred, img, orig_img, img_path)
74
+ # Extract keypoints from prediction and reshape according to model's keypoint shape
75
+ pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
76
+ # Scale keypoints coordinates to match the original image dimensions
77
+ pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
78
+ result.update(keypoints=pred_kpts)
79
+ return result
@@ -0,0 +1,154 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from copy import copy
4
+
5
+ from ultralytics.models import yolo
6
+ from ultralytics.nn.tasks import PoseModel
7
+ from ultralytics.utils import DEFAULT_CFG, LOGGER
8
+ from ultralytics.utils.plotting import plot_images, plot_results
9
+
10
+
11
+ class PoseTrainer(yolo.detect.DetectionTrainer):
12
+ """
13
+ A class extending the DetectionTrainer class for training YOLO pose estimation models.
14
+
15
+ This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
16
+ of pose keypoints alongside bounding boxes.
17
+
18
+ Attributes:
19
+ args (dict): Configuration arguments for training.
20
+ model (PoseModel): The pose estimation model being trained.
21
+ data (dict): Dataset configuration including keypoint shape information.
22
+ loss_names (Tuple[str]): Names of the loss components used in training.
23
+
24
+ Methods:
25
+ get_model: Retrieves a pose estimation model with specified configuration.
26
+ set_model_attributes: Sets keypoints shape attribute on the model.
27
+ get_validator: Creates a validator instance for model evaluation.
28
+ plot_training_samples: Visualizes training samples with keypoints.
29
+ plot_metrics: Generates and saves training/validation metric plots.
30
+
31
+ Examples:
32
+ >>> from ultralytics.models.yolo.pose import PoseTrainer
33
+ >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
34
+ >>> trainer = PoseTrainer(overrides=args)
35
+ >>> trainer.train()
36
+ """
37
+
38
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
+ """
40
+ Initialize a PoseTrainer object for training YOLO pose estimation models.
41
+
42
+ This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
43
+ handling specific configurations needed for keypoint detection models.
44
+
45
+ Args:
46
+ cfg (dict, optional): Default configuration dictionary containing training parameters.
47
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
48
+ _callbacks (list, optional): List of callback functions to be executed during training.
49
+
50
+ Notes:
51
+ This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
52
+ A warning is issued when using Apple MPS device due to known bugs with pose models.
53
+
54
+ Examples:
55
+ >>> from ultralytics.models.yolo.pose import PoseTrainer
56
+ >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
57
+ >>> trainer = PoseTrainer(overrides=args)
58
+ >>> trainer.train()
59
+ """
60
+ if overrides is None:
61
+ overrides = {}
62
+ overrides["task"] = "pose"
63
+ super().__init__(cfg, overrides, _callbacks)
64
+
65
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
66
+ LOGGER.warning(
67
+ "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
68
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
69
+ )
70
+
71
+ def get_model(self, cfg=None, weights=None, verbose=True):
72
+ """
73
+ Get pose estimation model with specified configuration and weights.
74
+
75
+ Args:
76
+ cfg (str | Path | dict | None): Model configuration file path or dictionary.
77
+ weights (str | Path | None): Path to the model weights file.
78
+ verbose (bool): Whether to display model information.
79
+
80
+ Returns:
81
+ (PoseModel): Initialized pose estimation model.
82
+ """
83
+ model = PoseModel(
84
+ cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
85
+ )
86
+ if weights:
87
+ model.load(weights)
88
+
89
+ return model
90
+
91
+ def set_model_attributes(self):
92
+ """Sets keypoints shape attribute of PoseModel."""
93
+ super().set_model_attributes()
94
+ self.model.kpt_shape = self.data["kpt_shape"]
95
+
96
+ def get_validator(self):
97
+ """Returns an instance of the PoseValidator class for validation."""
98
+ self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
99
+ return yolo.pose.PoseValidator(
100
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
101
+ )
102
+
103
+ def plot_training_samples(self, batch, ni):
104
+ """
105
+ Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
106
+
107
+ Args:
108
+ batch (dict): Dictionary containing batch data with the following keys:
109
+ - img (torch.Tensor): Batch of images
110
+ - keypoints (torch.Tensor): Keypoints coordinates for pose estimation
111
+ - cls (torch.Tensor): Class labels
112
+ - bboxes (torch.Tensor): Bounding box coordinates
113
+ - im_file (list): List of image file paths
114
+ - batch_idx (torch.Tensor): Batch indices for each instance
115
+ ni (int): Current training iteration number used for filename
116
+
117
+ The function saves the plotted batch as an image in the trainer's save directory with the filename
118
+ 'train_batch{ni}.jpg', where ni is the iteration number.
119
+ """
120
+ images = batch["img"]
121
+ kpts = batch["keypoints"]
122
+ cls = batch["cls"].squeeze(-1)
123
+ bboxes = batch["bboxes"]
124
+ paths = batch["im_file"]
125
+ batch_idx = batch["batch_idx"]
126
+ plot_images(
127
+ images,
128
+ batch_idx,
129
+ cls,
130
+ bboxes,
131
+ kpts=kpts,
132
+ paths=paths,
133
+ fname=self.save_dir / f"train_batch{ni}.jpg",
134
+ on_plot=self.on_plot,
135
+ )
136
+
137
+ def plot_metrics(self):
138
+ """Plots training/val metrics."""
139
+ plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
140
+
141
+ def get_dataset(self):
142
+ """
143
+ Retrieves the dataset and ensures it contains the required `kpt_shape` key.
144
+
145
+ Returns:
146
+ (dict): A dictionary containing the training/validation/test dataset and category names.
147
+
148
+ Raises:
149
+ KeyError: If the `kpt_shape` key is not present in the dataset.
150
+ """
151
+ data = super().get_dataset()
152
+ if "kpt_shape" not in data:
153
+ raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
154
+ return data