dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,43 +1,44 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from pathlib import Path
6
+ from typing import Any
4
7
 
5
8
  import numpy as np
6
9
  import torch
7
10
 
8
11
  from ultralytics.models.yolo.detect import DetectionValidator
9
12
  from ultralytics.utils import LOGGER, ops
10
- from ultralytics.utils.checks import check_requirements
11
- from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
12
- from ultralytics.utils.plotting import output_to_target, plot_images
13
+ from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
13
14
 
14
15
 
15
16
  class PoseValidator(DetectionValidator):
16
- """
17
- A class extending the DetectionValidator class for validation based on a pose model.
17
+ """A class extending the DetectionValidator class for validation based on a pose model.
18
18
 
19
- This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
20
- specialized metrics for pose evaluation.
19
+ This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
20
+ metrics for pose evaluation.
21
21
 
22
22
  Attributes:
23
23
  sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
24
- kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
24
+ kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
25
25
  args (dict): Arguments for the validator including task set to "pose".
26
26
  metrics (PoseMetrics): Metrics object for pose evaluation.
27
27
 
28
28
  Methods:
29
- preprocess: Preprocesses batch data for pose validation.
30
- get_desc: Returns description of evaluation metrics.
31
- init_metrics: Initializes pose metrics for the model.
32
- _prepare_batch: Prepares a batch for processing.
33
- _prepare_pred: Prepares and scales predictions for evaluation.
34
- update_metrics: Updates metrics with new predictions.
35
- _process_batch: Processes batch to compute IoU between detections and ground truth.
36
- plot_val_samples: Plots validation samples with ground truth annotations.
37
- plot_predictions: Plots model predictions.
38
- save_one_txt: Saves detections to a text file.
39
- pred_to_json: Converts predictions to COCO JSON format.
40
- eval_json: Evaluates model using COCO JSON format.
29
+ preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
30
+ get_desc: Return description of evaluation metrics in string format.
31
+ init_metrics: Initialize pose estimation metrics for YOLO model.
32
+ _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
33
+ dimensions.
34
+ _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
35
+ _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
36
+ and ground truth.
37
+ plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
38
+ plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
39
+ save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
40
+ pred_to_json: Convert YOLO predictions to COCO JSON format.
41
+ eval_json: Evaluate object detection model using COCO JSON format.
41
42
 
42
43
  Examples:
43
44
  >>> from ultralytics.models.yolo.pose import PoseValidator
@@ -46,9 +47,8 @@ class PoseValidator(DetectionValidator):
46
47
  >>> validator()
47
48
  """
48
49
 
49
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
50
- """
51
- Initialize a PoseValidator object for pose estimation validation.
50
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
51
+ """Initialize a PoseValidator object for pose estimation validation.
52
52
 
53
53
  This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
54
54
  specialized metrics for pose evaluation.
@@ -56,7 +56,6 @@ class PoseValidator(DetectionValidator):
56
56
  Args:
57
57
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
58
58
  save_dir (Path | str, optional): Directory to save results.
59
- pbar (Any, optional): Progress bar for displaying progress.
60
59
  args (dict, optional): Arguments for the validator including task set to "pose".
61
60
  _callbacks (list, optional): List of callback functions to be executed during validation.
62
61
 
@@ -71,24 +70,24 @@ class PoseValidator(DetectionValidator):
71
70
  for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
72
71
  due to a known bug with pose models.
73
72
  """
74
- super().__init__(dataloader, save_dir, pbar, args, _callbacks)
73
+ super().__init__(dataloader, save_dir, args, _callbacks)
75
74
  self.sigma = None
76
75
  self.kpt_shape = None
77
76
  self.args.task = "pose"
78
- self.metrics = PoseMetrics(save_dir=self.save_dir)
77
+ self.metrics = PoseMetrics()
79
78
  if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
80
79
  LOGGER.warning(
81
80
  "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
82
81
  "See https://github.com/ultralytics/ultralytics/issues/4031."
83
82
  )
84
83
 
85
- def preprocess(self, batch):
84
+ def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
86
85
  """Preprocess batch by converting keypoints data to float and moving it to the device."""
87
86
  batch = super().preprocess(batch)
88
- batch["keypoints"] = batch["keypoints"].to(self.device).float()
87
+ batch["keypoints"] = batch["keypoints"].float()
89
88
  return batch
90
89
 
91
- def get_desc(self):
90
+ def get_desc(self) -> str:
92
91
  """Return description of evaluation metrics in string format."""
93
92
  return ("%22s" + "%11s" * 10) % (
94
93
  "Class",
@@ -104,25 +103,55 @@ class PoseValidator(DetectionValidator):
104
103
  "mAP50-95)",
105
104
  )
106
105
 
107
- def init_metrics(self, model):
108
- """Initialize pose estimation metrics for YOLO model."""
106
+ def init_metrics(self, model: torch.nn.Module) -> None:
107
+ """Initialize evaluation metrics for YOLO pose validation.
108
+
109
+ Args:
110
+ model (torch.nn.Module): Model to validate.
111
+ """
109
112
  super().init_metrics(model)
110
113
  self.kpt_shape = self.data["kpt_shape"]
111
114
  is_pose = self.kpt_shape == [17, 3]
112
115
  nkpt = self.kpt_shape[0]
113
116
  self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
114
- self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
115
117
 
116
- def _prepare_batch(self, si, batch):
118
+ def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
119
+ """Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
120
+
121
+ This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
122
+ predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
123
+ flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
124
+
125
+ Args:
126
+ preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
127
+ scores, class predictions, and keypoint data.
128
+
129
+ Returns:
130
+ (dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
131
+ - 'bboxes': Bounding box coordinates
132
+ - 'conf': Confidence scores
133
+ - 'cls': Class predictions
134
+ - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
135
+
136
+ Notes:
137
+ If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
138
+ to the next one. The keypoints are extracted from the 'extra' field which contains additional
139
+ task-specific data beyond basic detection.
117
140
  """
118
- Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
141
+ preds = super().postprocess(preds)
142
+ for pred in preds:
143
+ pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
144
+ return preds
145
+
146
+ def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
147
+ """Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
119
148
 
120
149
  Args:
121
150
  si (int): Batch index.
122
- batch (dict): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
151
+ batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
123
152
 
124
153
  Returns:
125
- pbatch (dict): Prepared batch with keypoints scaled to original image dimensions.
154
+ (dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
126
155
 
127
156
  Notes:
128
157
  This method extends the parent class's _prepare_batch method by adding keypoint processing.
@@ -134,187 +163,46 @@ class PoseValidator(DetectionValidator):
134
163
  kpts = kpts.clone()
135
164
  kpts[..., 0] *= w
136
165
  kpts[..., 1] *= h
137
- kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
138
- pbatch["kpts"] = kpts
166
+ pbatch["keypoints"] = kpts
139
167
  return pbatch
140
168
 
141
- def _prepare_pred(self, pred, pbatch):
142
- """
143
- Prepare and scale keypoints in predictions for pose processing.
144
-
145
- This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls
146
- the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates
147
- to match the original image dimensions.
148
-
149
- Args:
150
- pred (torch.Tensor): Raw prediction tensor from the model.
151
- pbatch (dict): Processed batch dictionary containing image information including:
152
- - imgsz: Image size used for inference
153
- - ori_shape: Original image shape
154
- - ratio_pad: Ratio and padding information for coordinate scaling
155
-
156
- Returns:
157
- predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
158
- """
159
- predn = super()._prepare_pred(pred, pbatch)
160
- nk = pbatch["kpts"].shape[1]
161
- pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
162
- ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
163
- return predn, pred_kpts
164
-
165
- def update_metrics(self, preds, batch):
166
- """
167
- Update metrics with new predictions and ground truth data.
168
-
169
- This method processes each prediction, compares it with ground truth, and updates various statistics
170
- for performance evaluation.
169
+ def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
170
+ """Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
171
+ truth.
171
172
 
172
173
  Args:
173
- preds (List[torch.Tensor]): List of prediction tensors from the model.
174
- batch (dict): Batch data containing images and ground truth annotations.
175
- """
176
- for si, pred in enumerate(preds):
177
- self.seen += 1
178
- npr = len(pred)
179
- stat = dict(
180
- conf=torch.zeros(0, device=self.device),
181
- pred_cls=torch.zeros(0, device=self.device),
182
- tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
183
- tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
184
- )
185
- pbatch = self._prepare_batch(si, batch)
186
- cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
187
- nl = len(cls)
188
- stat["target_cls"] = cls
189
- stat["target_img"] = cls.unique()
190
- if npr == 0:
191
- if nl:
192
- for k in self.stats.keys():
193
- self.stats[k].append(stat[k])
194
- if self.args.plots:
195
- self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
196
- continue
197
-
198
- # Predictions
199
- if self.args.single_cls:
200
- pred[:, 5] = 0
201
- predn, pred_kpts = self._prepare_pred(pred, pbatch)
202
- stat["conf"] = predn[:, 4]
203
- stat["pred_cls"] = predn[:, 5]
204
-
205
- # Evaluate
206
- if nl:
207
- stat["tp"] = self._process_batch(predn, bbox, cls)
208
- stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
209
- if self.args.plots:
210
- self.confusion_matrix.process_batch(predn, bbox, cls)
211
-
212
- for k in self.stats.keys():
213
- self.stats[k].append(stat[k])
214
-
215
- # Save
216
- if self.args.save_json:
217
- self.pred_to_json(predn, batch["im_file"][si])
218
- if self.args.save_txt:
219
- self.save_one_txt(
220
- predn,
221
- pred_kpts,
222
- self.args.save_conf,
223
- pbatch["ori_shape"],
224
- self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
225
- )
226
-
227
- def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
228
- """
229
- Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
230
-
231
- Args:
232
- detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
233
- detection is of the format (x1, y1, x2, y2, conf, class).
234
- gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
235
- box is of the format (x1, y1, x2, y2).
236
- gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
237
- pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
238
- 51 corresponds to 17 keypoints each having 3 values.
239
- gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
174
+ preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
175
+ and 'keypoints' for keypoint predictions.
176
+ batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
177
+ for bounding boxes, and 'keypoints' for keypoint annotations.
240
178
 
241
179
  Returns:
242
- (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
243
- where N is the number of detections.
180
+ (dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
181
+ positives across 10 IoU levels.
244
182
 
245
183
  Notes:
246
184
  `0.53` scale factor used in area computation is referenced from
247
185
  https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
248
186
  """
249
- if pred_kpts is not None and gt_kpts is not None:
187
+ tp = super()._process_batch(preds, batch)
188
+ gt_cls = batch["cls"]
189
+ if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
190
+ tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
191
+ else:
250
192
  # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
251
- area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
252
- iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
253
- else: # boxes
254
- iou = box_iou(gt_bboxes, detections[:, :4])
255
-
256
- return self.match_predictions(detections[:, 5], gt_cls, iou)
257
-
258
- def plot_val_samples(self, batch, ni):
259
- """
260
- Plot and save validation set samples with ground truth bounding boxes and keypoints.
261
-
262
- Args:
263
- batch (dict): Dictionary containing batch data with keys:
264
- - img (torch.Tensor): Batch of images
265
- - batch_idx (torch.Tensor): Batch indices for each image
266
- - cls (torch.Tensor): Class labels
267
- - bboxes (torch.Tensor): Bounding box coordinates
268
- - keypoints (torch.Tensor): Keypoint coordinates
269
- - im_file (list): List of image file paths
270
- ni (int): Batch index used for naming the output file
271
- """
272
- plot_images(
273
- batch["img"],
274
- batch["batch_idx"],
275
- batch["cls"].squeeze(-1),
276
- batch["bboxes"],
277
- kpts=batch["keypoints"],
278
- paths=batch["im_file"],
279
- fname=self.save_dir / f"val_batch{ni}_labels.jpg",
280
- names=self.names,
281
- on_plot=self.on_plot,
282
- )
283
-
284
- def plot_predictions(self, batch, preds, ni):
285
- """
286
- Plot and save model predictions with bounding boxes and keypoints.
287
-
288
- Args:
289
- batch (dict): Dictionary containing batch data including images, file paths, and other metadata.
290
- preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
291
- confidence scores, class predictions, and keypoints.
292
- ni (int): Batch index used for naming the output file.
293
-
294
- The function extracts keypoints from predictions, converts predictions to target format, and plots them
295
- on the input images. The resulting visualization is saved to the specified save directory.
296
- """
297
- pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
298
- plot_images(
299
- batch["img"],
300
- *output_to_target(preds, max_det=self.args.max_det),
301
- kpts=pred_kpts,
302
- paths=batch["im_file"],
303
- fname=self.save_dir / f"val_batch{ni}_pred.jpg",
304
- names=self.names,
305
- on_plot=self.on_plot,
306
- ) # pred
193
+ area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
194
+ iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
195
+ tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
196
+ tp.update({"tp_p": tp_p}) # update tp with kpts IoU
197
+ return tp
307
198
 
308
- def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
309
- """
310
- Save YOLO pose detections to a text file in normalized coordinates.
199
+ def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
200
+ """Save YOLO pose detections to a text file in normalized coordinates.
311
201
 
312
202
  Args:
313
- predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
314
- pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
315
- and D is the dimension (typically 3 for x, y, visibility).
203
+ predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
316
204
  save_conf (bool): Whether to save confidence scores.
317
- shape (tuple): Original image shape (height, width).
205
+ shape (tuple[int, int]): Shape of the original image (height, width).
318
206
  file (Path): Output file path to save detections.
319
207
 
320
208
  Notes:
@@ -327,68 +215,45 @@ class PoseValidator(DetectionValidator):
327
215
  np.zeros((shape[0], shape[1]), dtype=np.uint8),
328
216
  path=None,
329
217
  names=self.names,
330
- boxes=predn[:, :6],
331
- keypoints=pred_kpts,
218
+ boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
219
+ keypoints=predn["keypoints"],
332
220
  ).save_txt(file, save_conf=save_conf)
333
221
 
334
- def pred_to_json(self, predn, filename):
335
- """
336
- Convert YOLO predictions to COCO JSON format.
222
+ def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
223
+ """Convert YOLO predictions to COCO JSON format.
337
224
 
338
- This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
339
- to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
225
+ This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
226
+ format, and appends the results to the internal JSON dictionary (self.jdict).
340
227
 
341
228
  Args:
342
- predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
343
- and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
344
- keypoints dimension.
345
- filename (str | Path): Path to the image file for which predictions are being processed.
229
+ predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
230
+ tensors.
231
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
346
232
 
347
233
  Notes:
348
234
  The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
349
235
  converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
350
236
  before saving to the JSON dictionary.
351
237
  """
352
- stem = Path(filename).stem
353
- image_id = int(stem) if stem.isnumeric() else stem
354
- box = ops.xyxy2xywh(predn[:, :4]) # xywh
355
- box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
356
- for p, b in zip(predn.tolist(), box.tolist()):
357
- self.jdict.append(
358
- {
359
- "image_id": image_id,
360
- "category_id": self.class_map[int(p[5])],
361
- "bbox": [round(x, 3) for x in b],
362
- "keypoints": p[6:],
363
- "score": round(p[4], 5),
364
- }
365
- )
366
-
367
- def eval_json(self, stats):
238
+ super().pred_to_json(predn, pbatch)
239
+ kpts = predn["kpts"]
240
+ for i, k in enumerate(kpts.flatten(1, 2).tolist()):
241
+ self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
242
+
243
+ def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
244
+ """Scales predictions to the original image size."""
245
+ return {
246
+ **super().scale_preds(predn, pbatch),
247
+ "kpts": ops.scale_coords(
248
+ pbatch["imgsz"],
249
+ predn["keypoints"].clone(),
250
+ pbatch["ori_shape"],
251
+ ratio_pad=pbatch["ratio_pad"],
252
+ ),
253
+ }
254
+
255
+ def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
368
256
  """Evaluate object detection model using COCO JSON format."""
369
- if self.args.save_json and self.is_coco and len(self.jdict):
370
- anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
371
- pred_json = self.save_dir / "predictions.json" # predictions
372
- LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
373
- try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
374
- check_requirements("pycocotools>=2.0.6")
375
- from pycocotools.coco import COCO # noqa
376
- from pycocotools.cocoeval import COCOeval # noqa
377
-
378
- for x in anno_json, pred_json:
379
- assert x.is_file(), f"{x} file not found"
380
- anno = COCO(str(anno_json)) # init annotations api
381
- pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
382
- for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
383
- if self.is_coco:
384
- eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
385
- eval.evaluate()
386
- eval.accumulate()
387
- eval.summarize()
388
- idx = i * 4 + 2
389
- stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
390
- :2
391
- ] # update mAP50-95 and mAP50
392
- except Exception as e:
393
- LOGGER.warning(f"pycocotools unable to run: {e}")
394
- return stats
257
+ anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
258
+ pred_json = self.save_dir / "predictions.json" # predictions
259
+ return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])
@@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
6
6
 
7
7
 
8
8
  class SegmentationPredictor(DetectionPredictor):
9
- """
10
- A class extending the DetectionPredictor class for prediction based on a segmentation model.
9
+ """A class extending the DetectionPredictor class for prediction based on a segmentation model.
11
10
 
12
11
  This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
13
12
  prediction results.
@@ -18,9 +17,9 @@ class SegmentationPredictor(DetectionPredictor):
18
17
  batch (list): Current batch of images being processed.
19
18
 
20
19
  Methods:
21
- postprocess: Applies non-max suppression and processes detections.
22
- construct_results: Constructs a list of result objects from predictions.
23
- construct_result: Constructs a single result object from a prediction.
20
+ postprocess: Apply non-max suppression and process segmentation detections.
21
+ construct_results: Construct a list of result objects from predictions.
22
+ construct_result: Construct a single result object from a prediction.
24
23
 
25
24
  Examples:
26
25
  >>> from ultralytics.utils import ASSETS
@@ -31,14 +30,13 @@ class SegmentationPredictor(DetectionPredictor):
31
30
  """
32
31
 
33
32
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
34
- """
35
- Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
33
+ """Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
36
34
 
37
35
  This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
38
36
  prediction results.
39
37
 
40
38
  Args:
41
- cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
39
+ cfg (dict): Configuration for the predictor.
42
40
  overrides (dict, optional): Configuration overrides that take precedence over cfg.
43
41
  _callbacks (list, optional): List of callback functions to be invoked during prediction.
44
42
  """
@@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
46
44
  self.args.task = "segment"
47
45
 
48
46
  def postprocess(self, preds, img, orig_imgs):
49
- """
50
- Apply non-max suppression and process segmentation detections for each image in the input batch.
47
+ """Apply non-max suppression and process segmentation detections for each image in the input batch.
51
48
 
52
49
  Args:
53
50
  preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
@@ -55,8 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
55
52
  orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
56
53
 
57
54
  Returns:
58
- (list): List of Results objects containing the segmentation predictions for each image in the batch.
59
- Each Results object includes both bounding boxes and segmentation masks.
55
+ (list): List of Results objects containing the segmentation predictions for each image in the batch. Each
56
+ Results object includes both bounding boxes and segmentation masks.
60
57
 
61
58
  Examples:
62
59
  >>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
@@ -67,18 +64,17 @@ class SegmentationPredictor(DetectionPredictor):
67
64
  return super().postprocess(preds[0], img, orig_imgs, protos=protos)
68
65
 
69
66
  def construct_results(self, preds, img, orig_imgs, protos):
70
- """
71
- Construct a list of result objects from the predictions.
67
+ """Construct a list of result objects from the predictions.
72
68
 
73
69
  Args:
74
- preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
70
+ preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
75
71
  img (torch.Tensor): The image after preprocessing.
76
- orig_imgs (List[np.ndarray]): List of original images before preprocessing.
77
- protos (List[torch.Tensor]): List of prototype masks.
72
+ orig_imgs (list[np.ndarray]): List of original images before preprocessing.
73
+ protos (list[torch.Tensor]): List of prototype masks.
78
74
 
79
75
  Returns:
80
- (List[Results]): List of result objects containing the original images, image paths, class names,
81
- bounding boxes, and masks.
76
+ (list[Results]): List of result objects containing the original images, image paths, class names, bounding
77
+ boxes, and masks.
82
78
  """
83
79
  return [
84
80
  self.construct_result(pred, img, orig_img, img_path, proto)
@@ -86,11 +82,10 @@ class SegmentationPredictor(DetectionPredictor):
86
82
  ]
87
83
 
88
84
  def construct_result(self, pred, img, orig_img, img_path, proto):
89
- """
90
- Construct a single result object from the prediction.
85
+ """Construct a single result object from the prediction.
91
86
 
92
87
  Args:
93
- pred (np.ndarray): The predicted bounding boxes, scores, and masks.
88
+ pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
94
89
  img (torch.Tensor): The image after preprocessing.
95
90
  orig_img (np.ndarray): The original image before preprocessing.
96
91
  img_path (str): The path to the original image.
@@ -99,7 +94,7 @@ class SegmentationPredictor(DetectionPredictor):
99
94
  Returns:
100
95
  (Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
101
96
  """
102
- if not len(pred): # save empty boxes
97
+ if pred.shape[0] == 0: # save empty boxes
103
98
  masks = None
104
99
  elif self.args.retina_masks:
105
100
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
@@ -108,6 +103,7 @@ class SegmentationPredictor(DetectionPredictor):
108
103
  masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
109
104
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
110
105
  if masks is not None:
111
- keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
112
- pred, masks = pred[keep], masks[keep]
106
+ keep = masks.amax((-2, -1)) > 0 # only keep predictions with masks
107
+ if not all(keep): # most predictions have masks
108
+ pred, masks = pred[keep], masks[keep] # indexing is slow
113
109
  return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)