dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,11 @@ from typing import Any
8
8
 
9
9
  from ultralytics.models import yolo
10
10
  from ultralytics.nn.tasks import PoseModel
11
- from ultralytics.utils import DEFAULT_CFG, LOGGER
11
+ from ultralytics.utils import DEFAULT_CFG
12
12
 
13
13
 
14
14
  class PoseTrainer(yolo.detect.DetectionTrainer):
15
- """
16
- A class extending the DetectionTrainer class for training YOLO pose estimation models.
15
+ """A class extending the DetectionTrainer class for training YOLO pose estimation models.
17
16
 
18
17
  This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
19
18
  of pose keypoints alongside bounding boxes.
@@ -33,14 +32,13 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
33
32
 
34
33
  Examples:
35
34
  >>> from ultralytics.models.yolo.pose import PoseTrainer
36
- >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
35
+ >>> args = dict(model="yolo26n-pose.pt", data="coco8-pose.yaml", epochs=3)
37
36
  >>> trainer = PoseTrainer(overrides=args)
38
37
  >>> trainer.train()
39
38
  """
40
39
 
41
40
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
42
- """
43
- Initialize a PoseTrainer object for training YOLO pose estimation models.
41
+ """Initialize a PoseTrainer object for training YOLO pose estimation models.
44
42
 
45
43
  Args:
46
44
  cfg (dict, optional): Default configuration dictionary containing training parameters.
@@ -56,20 +54,13 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
56
54
  overrides["task"] = "pose"
57
55
  super().__init__(cfg, overrides, _callbacks)
58
56
 
59
- if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
60
- LOGGER.warning(
61
- "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
62
- "See https://github.com/ultralytics/ultralytics/issues/4031."
63
- )
64
-
65
57
  def get_model(
66
58
  self,
67
59
  cfg: str | Path | dict[str, Any] | None = None,
68
60
  weights: str | Path | None = None,
69
61
  verbose: bool = True,
70
62
  ) -> PoseModel:
71
- """
72
- Get pose estimation model with specified configuration and weights.
63
+ """Get pose estimation model with specified configuration and weights.
73
64
 
74
65
  Args:
75
66
  cfg (str | Path | dict, optional): Model configuration file path or dictionary.
@@ -91,17 +82,23 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
91
82
  """Set keypoints shape attribute of PoseModel."""
92
83
  super().set_model_attributes()
93
84
  self.model.kpt_shape = self.data["kpt_shape"]
85
+ kpt_names = self.data.get("kpt_names")
86
+ if not kpt_names:
87
+ names = list(map(str, range(self.model.kpt_shape[0])))
88
+ kpt_names = {i: names for i in range(self.model.nc)}
89
+ self.model.kpt_names = kpt_names
94
90
 
95
91
  def get_validator(self):
96
92
  """Return an instance of the PoseValidator class for validation."""
97
93
  self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
94
+ if getattr(self.model.model[-1], "flow_model", None) is not None:
95
+ self.loss_names += ("rle_loss",)
98
96
  return yolo.pose.PoseValidator(
99
97
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
100
98
  )
101
99
 
102
100
  def get_dataset(self) -> dict[str, Any]:
103
- """
104
- Retrieve the dataset and ensure it contains the required `kpt_shape` key.
101
+ """Retrieve the dataset and ensure it contains the required `kpt_shape` key.
105
102
 
106
103
  Returns:
107
104
  (dict): A dictionary containing the training/validation/test dataset and category names.
@@ -9,16 +9,15 @@ import numpy as np
9
9
  import torch
10
10
 
11
11
  from ultralytics.models.yolo.detect import DetectionValidator
12
- from ultralytics.utils import LOGGER, ops
12
+ from ultralytics.utils import ops
13
13
  from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
14
14
 
15
15
 
16
16
  class PoseValidator(DetectionValidator):
17
- """
18
- A class extending the DetectionValidator class for validation based on a pose model.
17
+ """A class extending the DetectionValidator class for validation based on a pose model.
19
18
 
20
- This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
21
- specialized metrics for pose evaluation.
19
+ This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
20
+ metrics for pose evaluation.
22
21
 
23
22
  Attributes:
24
23
  sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
@@ -33,8 +32,8 @@ class PoseValidator(DetectionValidator):
33
32
  _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
34
33
  dimensions.
35
34
  _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
36
- _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
37
- detections and ground truth.
35
+ _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
36
+ and ground truth.
38
37
  plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
39
38
  plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
40
39
  save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
@@ -43,45 +42,33 @@ class PoseValidator(DetectionValidator):
43
42
 
44
43
  Examples:
45
44
  >>> from ultralytics.models.yolo.pose import PoseValidator
46
- >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
45
+ >>> args = dict(model="yolo26n-pose.pt", data="coco8-pose.yaml")
47
46
  >>> validator = PoseValidator(args=args)
48
47
  >>> validator()
48
+
49
+ Notes:
50
+ This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
51
+ for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
52
+ due to a known bug with pose models.
49
53
  """
50
54
 
51
55
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
52
- """
53
- Initialize a PoseValidator object for pose estimation validation.
56
+ """Initialize a PoseValidator object for pose estimation validation.
54
57
 
55
58
  This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
56
59
  specialized metrics for pose evaluation.
57
60
 
58
61
  Args:
59
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
62
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
60
63
  save_dir (Path | str, optional): Directory to save results.
61
64
  args (dict, optional): Arguments for the validator including task set to "pose".
62
65
  _callbacks (list, optional): List of callback functions to be executed during validation.
63
-
64
- Examples:
65
- >>> from ultralytics.models.yolo.pose import PoseValidator
66
- >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
67
- >>> validator = PoseValidator(args=args)
68
- >>> validator()
69
-
70
- Notes:
71
- This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
72
- for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
73
- due to a known bug with pose models.
74
66
  """
75
67
  super().__init__(dataloader, save_dir, args, _callbacks)
76
68
  self.sigma = None
77
69
  self.kpt_shape = None
78
70
  self.args.task = "pose"
79
71
  self.metrics = PoseMetrics()
80
- if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
81
- LOGGER.warning(
82
- "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
83
- "See https://github.com/ultralytics/ultralytics/issues/4031."
84
- )
85
72
 
86
73
  def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
87
74
  """Preprocess batch by converting keypoints data to float and moving it to the device."""
@@ -106,8 +93,7 @@ class PoseValidator(DetectionValidator):
106
93
  )
107
94
 
108
95
  def init_metrics(self, model: torch.nn.Module) -> None:
109
- """
110
- Initialize evaluation metrics for YOLO pose validation.
96
+ """Initialize evaluation metrics for YOLO pose validation.
111
97
 
112
98
  Args:
113
99
  model (torch.nn.Module): Model to validate.
@@ -119,17 +105,15 @@ class PoseValidator(DetectionValidator):
119
105
  self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
120
106
 
121
107
  def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
122
- """
123
- Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
108
+ """Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
124
109
 
125
- This method extends the parent class postprocessing by extracting keypoints from the 'extra'
126
- field of predictions and reshaping them according to the keypoint shape configuration.
127
- The keypoints are reshaped from a flattened format to the proper dimensional structure
128
- (typically [N, 17, 3] for COCO pose format).
110
+ This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
111
+ predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
112
+ flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
129
113
 
130
114
  Args:
131
- preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
132
- bounding boxes, confidence scores, class predictions, and keypoint data.
115
+ preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
116
+ scores, class predictions, and keypoint data.
133
117
 
134
118
  Returns:
135
119
  (dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
@@ -138,10 +122,10 @@ class PoseValidator(DetectionValidator):
138
122
  - 'cls': Class predictions
139
123
  - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
140
124
 
141
- Note:
142
- If no keypoints are present in a prediction (empty keypoints), that prediction
143
- is skipped and continues to the next one. The keypoints are extracted from the
144
- 'extra' field which contains additional task-specific data beyond basic detection.
125
+ Notes:
126
+ If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
127
+ to the next one. The keypoints are extracted from the 'extra' field which contains additional
128
+ task-specific data beyond basic detection.
145
129
  """
146
130
  preds = super().postprocess(preds)
147
131
  for pred in preds:
@@ -149,8 +133,7 @@ class PoseValidator(DetectionValidator):
149
133
  return preds
150
134
 
151
135
  def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
152
- """
153
- Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
136
+ """Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
154
137
 
155
138
  Args:
156
139
  si (int): Batch index.
@@ -173,18 +156,18 @@ class PoseValidator(DetectionValidator):
173
156
  return pbatch
174
157
 
175
158
  def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
176
- """
177
- Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
159
+ """Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
160
+ truth.
178
161
 
179
162
  Args:
180
163
  preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
181
164
  and 'keypoints' for keypoint predictions.
182
- batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
183
- 'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
165
+ batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
166
+ for bounding boxes, and 'keypoints' for keypoint annotations.
184
167
 
185
168
  Returns:
186
- (dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
187
- true positives across 10 IoU levels.
169
+ (dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
170
+ positives across 10 IoU levels.
188
171
 
189
172
  Notes:
190
173
  `0.53` scale factor used in area computation is referenced from
@@ -203,11 +186,10 @@ class PoseValidator(DetectionValidator):
203
186
  return tp
204
187
 
205
188
  def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
206
- """
207
- Save YOLO pose detections to a text file in normalized coordinates.
189
+ """Save YOLO pose detections to a text file in normalized coordinates.
208
190
 
209
191
  Args:
210
- predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
192
+ predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
211
193
  save_conf (bool): Whether to save confidence scores.
212
194
  shape (tuple[int, int]): Shape of the original image (height, width).
213
195
  file (Path): Output file path to save detections.
@@ -227,15 +209,14 @@ class PoseValidator(DetectionValidator):
227
209
  ).save_txt(file, save_conf=save_conf)
228
210
 
229
211
  def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
230
- """
231
- Convert YOLO predictions to COCO JSON format.
212
+ """Convert YOLO predictions to COCO JSON format.
232
213
 
233
- This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
234
- to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
214
+ This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
215
+ format, and appends the results to the internal JSON dictionary (self.jdict).
235
216
 
236
217
  Args:
237
- predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
238
- and 'keypoints' tensors.
218
+ predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
219
+ tensors.
239
220
  pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
240
221
 
241
222
  Notes:
@@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
6
6
 
7
7
 
8
8
  class SegmentationPredictor(DetectionPredictor):
9
- """
10
- A class extending the DetectionPredictor class for prediction based on a segmentation model.
9
+ """A class extending the DetectionPredictor class for prediction based on a segmentation model.
11
10
 
12
11
  This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
13
12
  prediction results.
@@ -25,14 +24,13 @@ class SegmentationPredictor(DetectionPredictor):
25
24
  Examples:
26
25
  >>> from ultralytics.utils import ASSETS
27
26
  >>> from ultralytics.models.yolo.segment import SegmentationPredictor
28
- >>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
27
+ >>> args = dict(model="yolo26n-seg.pt", source=ASSETS)
29
28
  >>> predictor = SegmentationPredictor(overrides=args)
30
29
  >>> predictor.predict_cli()
31
30
  """
32
31
 
33
32
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
34
- """
35
- Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
33
+ """Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
36
34
 
37
35
  This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
38
36
  prediction results.
@@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
46
44
  self.args.task = "segment"
47
45
 
48
46
  def postprocess(self, preds, img, orig_imgs):
49
- """
50
- Apply non-max suppression and process segmentation detections for each image in the input batch.
47
+ """Apply non-max suppression and process segmentation detections for each image in the input batch.
51
48
 
52
49
  Args:
53
50
  preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
@@ -55,20 +52,19 @@ class SegmentationPredictor(DetectionPredictor):
55
52
  orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
56
53
 
57
54
  Returns:
58
- (list): List of Results objects containing the segmentation predictions for each image in the batch.
59
- Each Results object includes both bounding boxes and segmentation masks.
55
+ (list): List of Results objects containing the segmentation predictions for each image in the batch. Each
56
+ Results object includes both bounding boxes and segmentation masks.
60
57
 
61
58
  Examples:
62
- >>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
59
+ >>> predictor = SegmentationPredictor(overrides=dict(model="yolo26n-seg.pt"))
63
60
  >>> results = predictor.postprocess(preds, img, orig_img)
64
61
  """
65
62
  # Extract protos - tuple if PyTorch model or array if exported
66
- protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
63
+ protos = preds[0][1] if isinstance(preds[0], tuple) else preds[1]
67
64
  return super().postprocess(preds[0], img, orig_imgs, protos=protos)
68
65
 
69
66
  def construct_results(self, preds, img, orig_imgs, protos):
70
- """
71
- Construct a list of result objects from the predictions.
67
+ """Construct a list of result objects from the predictions.
72
68
 
73
69
  Args:
74
70
  preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
@@ -77,8 +73,8 @@ class SegmentationPredictor(DetectionPredictor):
77
73
  protos (list[torch.Tensor]): List of prototype masks.
78
74
 
79
75
  Returns:
80
- (list[Results]): List of result objects containing the original images, image paths, class names,
81
- bounding boxes, and masks.
76
+ (list[Results]): List of result objects containing the original images, image paths, class names, bounding
77
+ boxes, and masks.
82
78
  """
83
79
  return [
84
80
  self.construct_result(pred, img, orig_img, img_path, proto)
@@ -86,8 +82,7 @@ class SegmentationPredictor(DetectionPredictor):
86
82
  ]
87
83
 
88
84
  def construct_result(self, pred, img, orig_img, img_path, proto):
89
- """
90
- Construct a single result object from the prediction.
85
+ """Construct a single result object from the prediction.
91
86
 
92
87
  Args:
93
88
  pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
@@ -103,11 +98,12 @@ class SegmentationPredictor(DetectionPredictor):
103
98
  masks = None
104
99
  elif self.args.retina_masks:
105
100
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
106
- masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
101
+ masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # NHW
107
102
  else:
108
- masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
103
+ masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # NHW
109
104
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
110
105
  if masks is not None:
111
- keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
112
- pred, masks = pred[keep], masks[keep]
106
+ keep = masks.amax((-2, -1)) > 0 # only keep predictions with masks
107
+ if not all(keep): # most predictions have masks
108
+ pred, masks = pred[keep], masks[keep] # indexing is slow
113
109
  return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
@@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, RANK
11
11
 
12
12
 
13
13
  class SegmentationTrainer(yolo.detect.DetectionTrainer):
14
- """
15
- A class extending the DetectionTrainer class for training based on a segmentation model.
14
+ """A class extending the DetectionTrainer class for training based on a segmentation model.
16
15
 
17
16
  This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
18
17
  functionality including model initialization, validation, and visualization.
@@ -22,14 +21,13 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
22
21
 
23
22
  Examples:
24
23
  >>> from ultralytics.models.yolo.segment import SegmentationTrainer
25
- >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
24
+ >>> args = dict(model="yolo26n-seg.pt", data="coco8-seg.yaml", epochs=3)
26
25
  >>> trainer = SegmentationTrainer(overrides=args)
27
26
  >>> trainer.train()
28
27
  """
29
28
 
30
29
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
31
- """
32
- Initialize a SegmentationTrainer object.
30
+ """Initialize a SegmentationTrainer object.
33
31
 
34
32
  Args:
35
33
  cfg (dict): Configuration dictionary with default training settings.
@@ -42,8 +40,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
42
40
  super().__init__(cfg, overrides, _callbacks)
43
41
 
44
42
  def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
45
- """
46
- Initialize and return a SegmentationModel with specified configuration and weights.
43
+ """Initialize and return a SegmentationModel with specified configuration and weights.
47
44
 
48
45
  Args:
49
46
  cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
@@ -55,8 +52,8 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
55
52
 
56
53
  Examples:
57
54
  >>> trainer = SegmentationTrainer()
58
- >>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
59
- >>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
55
+ >>> model = trainer.get_model(cfg="yolo26n-seg.yaml")
56
+ >>> model = trainer.get_model(weights="yolo26n-seg.pt", verbose=False)
60
57
  """
61
58
  model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
62
59
  if weights:
@@ -66,7 +63,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
66
63
 
67
64
  def get_validator(self):
68
65
  """Return an instance of SegmentationValidator for validation of YOLO model."""
69
- self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
66
+ self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss", "sem_loss"
70
67
  return yolo.segment.SegmentationValidator(
71
68
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
72
69
  )