dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from pathlib import Path
6
+ from typing import Any
4
7
 
8
+ import numpy as np
5
9
  import torch
6
10
 
7
11
  from ultralytics.models.yolo.detect import DetectionValidator
8
12
  from ultralytics.utils import LOGGER, ops
9
13
  from ultralytics.utils.metrics import OBBMetrics, batch_probiou
10
- from ultralytics.utils.plotting import output_to_rotated_target, plot_images
14
+ from ultralytics.utils.nms import TorchNMS
11
15
 
12
16
 
13
17
  class OBBValidator(DetectionValidator):
14
- """
15
- A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
18
+ """A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
16
19
 
17
20
  This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
18
21
  satellite imagery where objects can appear at various orientations.
@@ -39,64 +42,78 @@ class OBBValidator(DetectionValidator):
39
42
  >>> validator(model=args["model"])
40
43
  """
41
44
 
42
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
43
- """
44
- Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
45
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
46
+ """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
45
47
 
46
- This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
47
- It extends the DetectionValidator class and configures it specifically for the OBB task.
48
+ This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
49
+ extends the DetectionValidator class and configures it specifically for the OBB task.
48
50
 
49
51
  Args:
50
52
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
51
53
  save_dir (str | Path, optional): Directory to save results.
52
- pbar (bool, optional): Display progress bar during validation.
53
- args (dict, optional): Arguments containing validation parameters.
54
+ args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
54
55
  _callbacks (list, optional): List of callback functions to be called during validation.
55
56
  """
56
- super().__init__(dataloader, save_dir, pbar, args, _callbacks)
57
+ super().__init__(dataloader, save_dir, args, _callbacks)
57
58
  self.args.task = "obb"
58
- self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
59
+ self.metrics = OBBMetrics()
59
60
 
60
- def init_metrics(self, model):
61
- """Initialize evaluation metrics for YOLO."""
61
+ def init_metrics(self, model: torch.nn.Module) -> None:
62
+ """Initialize evaluation metrics for YOLO obb validation.
63
+
64
+ Args:
65
+ model (torch.nn.Module): Model to validate.
66
+ """
62
67
  super().init_metrics(model)
63
68
  val = self.data.get(self.args.split, "") # validation path
64
69
  self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
70
+ self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
65
71
 
66
- def _process_batch(self, detections, gt_bboxes, gt_cls):
67
- """
68
- Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
72
+ def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
73
+ """Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
69
74
 
70
75
  Args:
71
- detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
72
- data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
73
- gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
74
- represented as (x1, y1, x2, y2, angle).
75
- gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
76
+ preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
77
+ class labels and bounding boxes.
78
+ batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
79
+ labels and bounding boxes.
76
80
 
77
81
  Returns:
78
- (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
79
- Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
82
+ (dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
83
+ with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
84
+ predictions compared to the ground truth.
80
85
 
81
86
  Examples:
82
87
  >>> detections = torch.rand(100, 7) # 100 sample detections
83
88
  >>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
84
89
  >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
85
- >>> correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
86
-
87
- Note:
88
- This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
90
+ >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
89
91
  """
90
- iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
91
- return self.match_predictions(detections[:, 5], gt_cls, iou)
92
+ if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
93
+ return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
94
+ iou = batch_probiou(batch["bboxes"], preds["bboxes"])
95
+ return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
96
+
97
+ def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
98
+ """Postprocess OBB predictions.
92
99
 
93
- def _prepare_batch(self, si, batch):
100
+ Args:
101
+ preds (torch.Tensor): Raw predictions from the model.
102
+
103
+ Returns:
104
+ (list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
94
105
  """
95
- Prepare batch data for OBB validation with proper scaling and formatting.
106
+ preds = super().postprocess(preds)
107
+ for pred in preds:
108
+ pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1) # concatenate angle
109
+ return preds
110
+
111
+ def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
112
+ """Prepare batch data for OBB validation with proper scaling and formatting.
96
113
 
97
114
  Args:
98
115
  si (int): Batch index to process.
99
- batch (dict): Dictionary containing batch data with keys:
116
+ batch (dict[str, Any]): Dictionary containing batch data with keys:
100
117
  - batch_idx: Tensor of batch indices
101
118
  - cls: Tensor of class labels
102
119
  - bboxes: Tensor of bounding boxes
@@ -104,8 +121,8 @@ class OBBValidator(DetectionValidator):
104
121
  - img: Batch of images
105
122
  - ratio_pad: Ratio and padding information
106
123
 
107
- This method filters the batch data for a specific batch index, extracts class labels and bounding boxes,
108
- and scales the bounding boxes to the original image dimensions.
124
+ Returns:
125
+ (dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
109
126
  """
110
127
  idx = batch["batch_idx"] == si
111
128
  cls = batch["cls"][idx].squeeze(-1)
@@ -113,41 +130,23 @@ class OBBValidator(DetectionValidator):
113
130
  ori_shape = batch["ori_shape"][si]
114
131
  imgsz = batch["img"].shape[2:]
115
132
  ratio_pad = batch["ratio_pad"][si]
116
- if len(cls):
133
+ if cls.shape[0]:
117
134
  bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
118
- ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
119
- return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
120
-
121
- def _prepare_pred(self, pred, pbatch):
122
- """
123
- Prepare predictions by scaling bounding boxes to original image dimensions.
124
-
125
- This method takes prediction tensors containing bounding box coordinates and scales them from the model's
126
- input dimensions to the original image dimensions using the provided batch information.
135
+ return {
136
+ "cls": cls,
137
+ "bboxes": bbox,
138
+ "ori_shape": ori_shape,
139
+ "imgsz": imgsz,
140
+ "ratio_pad": ratio_pad,
141
+ "im_file": batch["im_file"][si],
142
+ }
143
+
144
+ def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
145
+ """Plot predicted bounding boxes on input images and save the result.
127
146
 
128
147
  Args:
129
- pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
130
- pbatch (dict): Dictionary containing batch information with keys:
131
- - imgsz (tuple): Model input image size.
132
- - ori_shape (tuple): Original image shape.
133
- - ratio_pad (tuple): Ratio and padding information for scaling.
134
-
135
- Returns:
136
- (torch.Tensor): Scaled prediction tensor with bounding boxes in original image dimensions.
137
- """
138
- predn = pred.clone()
139
- ops.scale_boxes(
140
- pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
141
- ) # native-space pred
142
- return predn
143
-
144
- def plot_predictions(self, batch, preds, ni):
145
- """
146
- Plot predicted bounding boxes on input images and save the result.
147
-
148
- Args:
149
- batch (dict): Batch data containing images, file paths, and other metadata.
150
- preds (list): List of prediction tensors for each image in the batch.
148
+ batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
149
+ preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
151
150
  ni (int): Batch index used for naming the output file.
152
151
 
153
152
  Examples:
@@ -156,54 +155,50 @@ class OBBValidator(DetectionValidator):
156
155
  >>> preds = [torch.rand(10, 7)] # Example predictions for one image
157
156
  >>> validator.plot_predictions(batch, preds, 0)
158
157
  """
159
- plot_images(
160
- batch["img"],
161
- *output_to_rotated_target(preds, max_det=self.args.max_det),
162
- paths=batch["im_file"],
163
- fname=self.save_dir / f"val_batch{ni}_pred.jpg",
164
- names=self.names,
165
- on_plot=self.on_plot,
166
- ) # pred
158
+ for p in preds:
159
+ # TODO: fix this duplicated `xywh2xyxy`
160
+ p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
161
+ super().plot_predictions(batch, preds, ni) # plot bboxes
167
162
 
168
- def pred_to_json(self, predn, filename):
169
- """
170
- Convert YOLO predictions to COCO JSON format with rotated bounding box information.
163
+ def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
164
+ """Convert YOLO predictions to COCO JSON format with rotated bounding box information.
171
165
 
172
166
  Args:
173
- predn (torch.Tensor): Prediction tensor containing bounding box coordinates, confidence scores,
174
- class predictions, and rotation angles with shape (N, 6+) where the last column is the angle.
175
- filename (str | Path): Path to the image file for which predictions are being processed.
167
+ predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
168
+ bounding box coordinates, confidence scores, and class predictions.
169
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
176
170
 
177
171
  Notes:
178
172
  This method processes rotated bounding box predictions and converts them to both rbox format
179
173
  (x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
180
174
  to the JSON dictionary.
181
175
  """
182
- stem = Path(filename).stem
176
+ path = Path(pbatch["im_file"])
177
+ stem = path.stem
183
178
  image_id = int(stem) if stem.isnumeric() else stem
184
- rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
179
+ rbox = predn["bboxes"]
185
180
  poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
186
- for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
181
+ for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
187
182
  self.jdict.append(
188
183
  {
189
184
  "image_id": image_id,
190
- "category_id": self.class_map[int(predn[i, 5].item())],
191
- "score": round(predn[i, 4].item(), 5),
185
+ "file_name": path.name,
186
+ "category_id": self.class_map[int(c)],
187
+ "score": round(s, 5),
192
188
  "rbox": [round(x, 3) for x in r],
193
189
  "poly": [round(x, 3) for x in b],
194
190
  }
195
191
  )
196
192
 
197
- def save_one_txt(self, predn, save_conf, shape, file):
198
- """
199
- Save YOLO OBB (Oriented Bounding Box) detections to a text file in normalized coordinates.
193
+ def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
194
+ """Save YOLO OBB detections to a text file in normalized coordinates.
200
195
 
201
196
  Args:
202
197
  predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
203
198
  class predictions, and angles in format (x, y, w, h, conf, cls, angle).
204
199
  save_conf (bool): Whether to save confidence scores in the text file.
205
- shape (tuple): Original image shape in format (height, width).
206
- file (Path | str): Output file path to save detections.
200
+ shape (tuple[int, int]): Original image shape in format (height, width).
201
+ file (Path): Output file path to save detections.
207
202
 
208
203
  Examples:
209
204
  >>> validator = OBBValidator()
@@ -214,18 +209,31 @@ class OBBValidator(DetectionValidator):
214
209
 
215
210
  from ultralytics.engine.results import Results
216
211
 
217
- rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
218
- # xywh, r, conf, cls
219
- obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
220
212
  Results(
221
213
  np.zeros((shape[0], shape[1]), dtype=np.uint8),
222
214
  path=None,
223
215
  names=self.names,
224
- obb=obb,
216
+ obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
225
217
  ).save_txt(file, save_conf=save_conf)
226
218
 
227
- def eval_json(self, stats):
228
- """Evaluate YOLO output in JSON format and save predictions in DOTA format."""
219
+ def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
220
+ """Scales predictions to the original image size."""
221
+ return {
222
+ **predn,
223
+ "bboxes": ops.scale_boxes(
224
+ pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
225
+ ),
226
+ }
227
+
228
+ def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
229
+ """Evaluate YOLO output in JSON format and save predictions in DOTA format.
230
+
231
+ Args:
232
+ stats (dict[str, Any]): Performance statistics dictionary.
233
+
234
+ Returns:
235
+ (dict[str, Any]): Updated performance statistics.
236
+ """
229
237
  if self.args.save_json and self.is_dota and len(self.jdict):
230
238
  import json
231
239
  import re
@@ -252,7 +260,7 @@ class OBBValidator(DetectionValidator):
252
260
  merged_results = defaultdict(list)
253
261
  LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
254
262
  for d in data:
255
- image_id = d["image_id"].split("__")[0]
263
+ image_id = d["image_id"].split("__", 1)[0]
256
264
  pattern = re.compile(r"\d+___\d+")
257
265
  x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
258
266
  bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
@@ -268,7 +276,7 @@ class OBBValidator(DetectionValidator):
268
276
  b = bbox[:, :5].clone()
269
277
  b[:, :2] += c
270
278
  # 0.3 could get results close to the ones from official merging script, even slightly better.
271
- i = ops.nms_rotated(b, scores, 0.3)
279
+ i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
272
280
  bbox = bbox[i]
273
281
 
274
282
  b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
@@ -4,4 +4,4 @@ from .predict import PosePredictor
4
4
  from .train import PoseTrainer
5
5
  from .val import PoseValidator
6
6
 
7
- __all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
7
+ __all__ = "PosePredictor", "PoseTrainer", "PoseValidator"
@@ -5,8 +5,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
5
5
 
6
6
 
7
7
  class PosePredictor(DetectionPredictor):
8
- """
9
- A class extending the DetectionPredictor class for prediction based on a pose model.
8
+ """A class extending the DetectionPredictor class for prediction based on a pose model.
10
9
 
11
10
  This class specializes in pose estimation, handling keypoints detection alongside standard object detection
12
11
  capabilities inherited from DetectionPredictor.
@@ -16,7 +15,7 @@ class PosePredictor(DetectionPredictor):
16
15
  model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
17
16
 
18
17
  Methods:
19
- construct_result: Constructs the result object from the prediction, including keypoints.
18
+ construct_result: Construct the result object from the prediction, including keypoints.
20
19
 
21
20
  Examples:
22
21
  >>> from ultralytics.utils import ASSETS
@@ -27,14 +26,13 @@ class PosePredictor(DetectionPredictor):
27
26
  """
28
27
 
29
28
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
30
- """
31
- Initialize PosePredictor, a specialized predictor for pose estimation tasks.
29
+ """Initialize PosePredictor for pose estimation tasks.
32
30
 
33
- This initializer sets up a PosePredictor instance, configuring it for pose detection tasks and handling
34
- device-specific warnings for Apple MPS.
31
+ Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
32
+ for Apple MPS.
35
33
 
36
34
  Args:
37
- cfg (Any): Configuration for the predictor. Default is DEFAULT_CFG.
35
+ cfg (Any): Configuration for the predictor.
38
36
  overrides (dict, optional): Configuration overrides that take precedence over cfg.
39
37
  _callbacks (list, optional): List of callback functions to be invoked during prediction.
40
38
 
@@ -54,11 +52,10 @@ class PosePredictor(DetectionPredictor):
54
52
  )
55
53
 
56
54
  def construct_result(self, pred, img, orig_img, img_path):
57
- """
58
- Construct the result object from the prediction, including keypoints.
55
+ """Construct the result object from the prediction, including keypoints.
59
56
 
60
- This method extends the parent class implementation by extracting keypoint data from predictions
61
- and adding them to the result object.
57
+ Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
58
+ result object.
62
59
 
63
60
  Args:
64
61
  pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
@@ -68,11 +65,12 @@ class PosePredictor(DetectionPredictor):
68
65
  img_path (str): The path to the original image file.
69
66
 
70
67
  Returns:
71
- (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
68
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and
69
+ keypoints.
72
70
  """
73
71
  result = super().construct_result(pred, img, orig_img, img_path)
74
72
  # Extract keypoints from prediction and reshape according to model's keypoint shape
75
- pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
73
+ pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
76
74
  # Scale keypoints coordinates to match the original image dimensions
77
75
  pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
78
76
  result.update(keypoints=pred_kpts)
@@ -1,16 +1,18 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from copy import copy
6
+ from pathlib import Path
7
+ from typing import Any
4
8
 
5
9
  from ultralytics.models import yolo
6
10
  from ultralytics.nn.tasks import PoseModel
7
11
  from ultralytics.utils import DEFAULT_CFG, LOGGER
8
- from ultralytics.utils.plotting import plot_images, plot_results
9
12
 
10
13
 
11
14
  class PoseTrainer(yolo.detect.DetectionTrainer):
12
- """
13
- A class extending the DetectionTrainer class for training YOLO pose estimation models.
15
+ """A class extending the DetectionTrainer class for training YOLO pose estimation models.
14
16
 
15
17
  This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
16
18
  of pose keypoints alongside bounding boxes.
@@ -19,14 +21,14 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
19
21
  args (dict): Configuration arguments for training.
20
22
  model (PoseModel): The pose estimation model being trained.
21
23
  data (dict): Dataset configuration including keypoint shape information.
22
- loss_names (Tuple[str]): Names of the loss components used in training.
24
+ loss_names (tuple): Names of the loss components used in training.
23
25
 
24
26
  Methods:
25
- get_model: Retrieves a pose estimation model with specified configuration.
26
- set_model_attributes: Sets keypoints shape attribute on the model.
27
- get_validator: Creates a validator instance for model evaluation.
28
- plot_training_samples: Visualizes training samples with keypoints.
29
- plot_metrics: Generates and saves training/validation metric plots.
27
+ get_model: Retrieve a pose estimation model with specified configuration.
28
+ set_model_attributes: Set keypoints shape attribute on the model.
29
+ get_validator: Create a validator instance for model evaluation.
30
+ plot_training_samples: Visualize training samples with keypoints.
31
+ get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
30
32
 
31
33
  Examples:
32
34
  >>> from ultralytics.models.yolo.pose import PoseTrainer
@@ -35,12 +37,8 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
35
37
  >>> trainer.train()
36
38
  """
37
39
 
38
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
- """
40
- Initialize a PoseTrainer object for training YOLO pose estimation models.
41
-
42
- This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
43
- handling specific configurations needed for keypoint detection models.
40
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
41
+ """Initialize a PoseTrainer object for training YOLO pose estimation models.
44
42
 
45
43
  Args:
46
44
  cfg (dict, optional): Default configuration dictionary containing training parameters.
@@ -50,12 +48,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
50
48
  Notes:
51
49
  This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
52
50
  A warning is issued when using Apple MPS device due to known bugs with pose models.
53
-
54
- Examples:
55
- >>> from ultralytics.models.yolo.pose import PoseTrainer
56
- >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
57
- >>> trainer = PoseTrainer(overrides=args)
58
- >>> trainer.train()
59
51
  """
60
52
  if overrides is None:
61
53
  overrides = {}
@@ -68,13 +60,17 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
68
60
  "See https://github.com/ultralytics/ultralytics/issues/4031."
69
61
  )
70
62
 
71
- def get_model(self, cfg=None, weights=None, verbose=True):
72
- """
73
- Get pose estimation model with specified configuration and weights.
63
+ def get_model(
64
+ self,
65
+ cfg: str | Path | dict[str, Any] | None = None,
66
+ weights: str | Path | None = None,
67
+ verbose: bool = True,
68
+ ) -> PoseModel:
69
+ """Get pose estimation model with specified configuration and weights.
74
70
 
75
71
  Args:
76
- cfg (str | Path | dict | None): Model configuration file path or dictionary.
77
- weights (str | Path | None): Path to the model weights file.
72
+ cfg (str | Path | dict, optional): Model configuration file path or dictionary.
73
+ weights (str | Path, optional): Path to the model weights file.
78
74
  verbose (bool): Whether to display model information.
79
75
 
80
76
  Returns:
@@ -89,58 +85,24 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
89
85
  return model
90
86
 
91
87
  def set_model_attributes(self):
92
- """Sets keypoints shape attribute of PoseModel."""
88
+ """Set keypoints shape attribute of PoseModel."""
93
89
  super().set_model_attributes()
94
90
  self.model.kpt_shape = self.data["kpt_shape"]
91
+ kpt_names = self.data.get("kpt_names")
92
+ if not kpt_names:
93
+ names = list(map(str, range(self.model.kpt_shape[0])))
94
+ kpt_names = {i: names for i in range(self.model.nc)}
95
+ self.model.kpt_names = kpt_names
95
96
 
96
97
  def get_validator(self):
97
- """Returns an instance of the PoseValidator class for validation."""
98
+ """Return an instance of the PoseValidator class for validation."""
98
99
  self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
99
100
  return yolo.pose.PoseValidator(
100
101
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
101
102
  )
102
103
 
103
- def plot_training_samples(self, batch, ni):
104
- """
105
- Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
106
-
107
- Args:
108
- batch (dict): Dictionary containing batch data with the following keys:
109
- - img (torch.Tensor): Batch of images
110
- - keypoints (torch.Tensor): Keypoints coordinates for pose estimation
111
- - cls (torch.Tensor): Class labels
112
- - bboxes (torch.Tensor): Bounding box coordinates
113
- - im_file (list): List of image file paths
114
- - batch_idx (torch.Tensor): Batch indices for each instance
115
- ni (int): Current training iteration number used for filename
116
-
117
- The function saves the plotted batch as an image in the trainer's save directory with the filename
118
- 'train_batch{ni}.jpg', where ni is the iteration number.
119
- """
120
- images = batch["img"]
121
- kpts = batch["keypoints"]
122
- cls = batch["cls"].squeeze(-1)
123
- bboxes = batch["bboxes"]
124
- paths = batch["im_file"]
125
- batch_idx = batch["batch_idx"]
126
- plot_images(
127
- images,
128
- batch_idx,
129
- cls,
130
- bboxes,
131
- kpts=kpts,
132
- paths=paths,
133
- fname=self.save_dir / f"train_batch{ni}.jpg",
134
- on_plot=self.on_plot,
135
- )
136
-
137
- def plot_metrics(self):
138
- """Plots training/val metrics."""
139
- plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
140
-
141
- def get_dataset(self):
142
- """
143
- Retrieves the dataset and ensures it contains the required `kpt_shape` key.
104
+ def get_dataset(self) -> dict[str, Any]:
105
+ """Retrieve the dataset and ensure it contains the required `kpt_shape` key.
144
106
 
145
107
  Returns:
146
108
  (dict): A dictionary containing the training/validation/test dataset and category names.