dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
  2. dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +39 -39
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +187 -157
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +1 -1
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +1 -1
  95. ultralytics/solutions/instance_segmentation.py +6 -3
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -7
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +184 -75
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +42 -28
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
+ from typing import Any, Dict, List, Optional, Union
4
5
 
5
6
  from ultralytics.data.build import load_inference_source
6
7
  from ultralytics.engine.model import Model
@@ -19,9 +20,34 @@ from ultralytics.utils import ROOT, YAML
19
20
 
20
21
 
21
22
  class YOLO(Model):
22
- """YOLO (You Only Look Once) object detection model."""
23
+ """
24
+ YOLO (You Only Look Once) object detection model.
23
25
 
24
- def __init__(self, model="yolo11n.pt", task=None, verbose=False):
26
+ This class provides a unified interface for YOLO models, automatically switching to specialized model types
27
+ (YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
28
+ detection, segmentation, classification, pose estimation, and oriented bounding box detection.
29
+
30
+ Attributes:
31
+ model: The loaded YOLO model instance.
32
+ task: The task type (detect, segment, classify, pose, obb).
33
+ overrides: Configuration overrides for the model.
34
+
35
+ Methods:
36
+ __init__: Initialize a YOLO model with automatic type detection.
37
+ task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
38
+
39
+ Examples:
40
+ Load a pretrained YOLOv11n detection model
41
+ >>> model = YOLO("yolo11n.pt")
42
+
43
+ Load a pretrained YOLO11n segmentation model
44
+ >>> model = YOLO("yolo11n-seg.pt")
45
+
46
+ Initialize from a YAML configuration
47
+ >>> model = YOLO("yolo11n.yaml")
48
+ """
49
+
50
+ def __init__(self, model: Union[str, Path] = "yolo11n.pt", task: Optional[str] = None, verbose: bool = False):
25
51
  """
26
52
  Initialize a YOLO model.
27
53
 
@@ -30,7 +56,7 @@ class YOLO(Model):
30
56
 
31
57
  Args:
32
58
  model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
33
- task (str | None): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
59
+ task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
34
60
  Defaults to auto-detection based on model.
35
61
  verbose (bool): Display model info on load.
36
62
 
@@ -59,7 +85,7 @@ class YOLO(Model):
59
85
  self.__dict__ = new_instance.__dict__
60
86
 
61
87
  @property
62
- def task_map(self):
88
+ def task_map(self) -> Dict[str, Dict[str, Any]]:
63
89
  """Map head to model, trainer, validator, and predictor classes."""
64
90
  return {
65
91
  "classify": {
@@ -96,9 +122,32 @@ class YOLO(Model):
96
122
 
97
123
 
98
124
  class YOLOWorld(Model):
99
- """YOLO-World object detection model."""
125
+ """
126
+ YOLO-World object detection model.
100
127
 
101
- def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
128
+ YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
129
+ without requiring training on specific classes. It extends the YOLO architecture to support real-time
130
+ open-vocabulary detection.
131
+
132
+ Attributes:
133
+ model: The loaded YOLO-World model instance.
134
+ task: Always set to 'detect' for object detection.
135
+ overrides: Configuration overrides for the model.
136
+
137
+ Methods:
138
+ __init__: Initialize YOLOv8-World model with a pre-trained model file.
139
+ task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
140
+ set_classes: Set the model's class names for detection.
141
+
142
+ Examples:
143
+ Load a YOLOv8-World model
144
+ >>> model = YOLOWorld("yolov8s-world.pt")
145
+
146
+ Set custom classes for detection
147
+ >>> model.set_classes(["person", "car", "bicycle"])
148
+ """
149
+
150
+ def __init__(self, model: Union[str, Path] = "yolov8s-world.pt", verbose: bool = False) -> None:
102
151
  """
103
152
  Initialize YOLOv8-World model with a pre-trained model file.
104
153
 
@@ -116,7 +165,7 @@ class YOLOWorld(Model):
116
165
  self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
117
166
 
118
167
  @property
119
- def task_map(self):
168
+ def task_map(self) -> Dict[str, Dict[str, Any]]:
120
169
  """Map head to model, validator, and predictor classes."""
121
170
  return {
122
171
  "detect": {
@@ -127,12 +176,12 @@ class YOLOWorld(Model):
127
176
  }
128
177
  }
129
178
 
130
- def set_classes(self, classes):
179
+ def set_classes(self, classes: List[str]) -> None:
131
180
  """
132
181
  Set the model's class names for detection.
133
182
 
134
183
  Args:
135
- classes (list[str]): A list of categories i.e. ["person"].
184
+ classes (List[str]): A list of categories i.e. ["person"].
136
185
  """
137
186
  self.model.set_classes(classes)
138
187
  # Remove background if it's given
@@ -147,9 +196,43 @@ class YOLOWorld(Model):
147
196
 
148
197
 
149
198
  class YOLOE(Model):
150
- """YOLOE object detection and segmentation model."""
151
-
152
- def __init__(self, model="yoloe-11s-seg.pt", task=None, verbose=False) -> None:
199
+ """
200
+ YOLOE object detection and segmentation model.
201
+
202
+ YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
203
+ improved performance and additional features like visual and text positional embeddings.
204
+
205
+ Attributes:
206
+ model: The loaded YOLOE model instance.
207
+ task: The task type (detect or segment).
208
+ overrides: Configuration overrides for the model.
209
+
210
+ Methods:
211
+ __init__: Initialize YOLOE model with a pre-trained model file.
212
+ task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
213
+ get_text_pe: Get text positional embeddings for the given texts.
214
+ get_visual_pe: Get visual positional embeddings for the given image and visual features.
215
+ set_vocab: Set vocabulary and class names for the YOLOE model.
216
+ get_vocab: Get vocabulary for the given class names.
217
+ set_classes: Set the model's class names and embeddings for detection.
218
+ val: Validate the model using text or visual prompts.
219
+ predict: Run prediction on images, videos, directories, streams, etc.
220
+
221
+ Examples:
222
+ Load a YOLOE detection model
223
+ >>> model = YOLOE("yoloe-11s-seg.pt")
224
+
225
+ Set vocabulary and class names
226
+ >>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
227
+
228
+ Predict with visual prompts
229
+ >>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
230
+ >>> results = model.predict("image.jpg", visual_prompts=prompts)
231
+ """
232
+
233
+ def __init__(
234
+ self, model: Union[str, Path] = "yoloe-11s-seg.pt", task: Optional[str] = None, verbose: bool = False
235
+ ) -> None:
153
236
  """
154
237
  Initialize YOLOE model with a pre-trained model file.
155
238
 
@@ -165,7 +248,7 @@ class YOLOE(Model):
165
248
  self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
166
249
 
167
250
  @property
168
- def task_map(self):
251
+ def task_map(self) -> Dict[str, Dict[str, Any]]:
169
252
  """Map head to model, validator, and predictor classes."""
170
253
  return {
171
254
  "detect": {
@@ -210,7 +293,7 @@ class YOLOE(Model):
210
293
  assert isinstance(self.model, YOLOEModel)
211
294
  return self.model.get_visual_pe(img, visual)
212
295
 
213
- def set_vocab(self, vocab, names):
296
+ def set_vocab(self, vocab: List[str], names: List[str]) -> None:
214
297
  """
215
298
  Set vocabulary and class names for the YOLOE model.
216
299
 
@@ -218,8 +301,8 @@ class YOLOE(Model):
218
301
  classification tasks. The model must be an instance of YOLOEModel.
219
302
 
220
303
  Args:
221
- vocab (list): Vocabulary list containing tokens or words used by the model for text processing.
222
- names (list): List of class names that the model can detect or classify.
304
+ vocab (List[str]): Vocabulary list containing tokens or words used by the model for text processing.
305
+ names (List[str]): List of class names that the model can detect or classify.
223
306
 
224
307
  Raises:
225
308
  AssertionError: If the model is not an instance of YOLOEModel.
@@ -236,12 +319,12 @@ class YOLOE(Model):
236
319
  assert isinstance(self.model, YOLOEModel)
237
320
  return self.model.get_vocab(names)
238
321
 
239
- def set_classes(self, classes, embeddings):
322
+ def set_classes(self, classes: List[str], embeddings) -> None:
240
323
  """
241
324
  Set the model's class names and embeddings for detection.
242
325
 
243
326
  Args:
244
- classes (list[str]): A list of categories i.e. ["person"].
327
+ classes (List[str]): A list of categories i.e. ["person"].
245
328
  embeddings (torch.Tensor): Embeddings corresponding to the classes.
246
329
  """
247
330
  assert isinstance(self.model, YOLOEModel)
@@ -257,8 +340,8 @@ class YOLOE(Model):
257
340
  def val(
258
341
  self,
259
342
  validator=None,
260
- load_vp=False,
261
- refer_data=None,
343
+ load_vp: bool = False,
344
+ refer_data: Optional[str] = None,
262
345
  **kwargs,
263
346
  ):
264
347
  """
@@ -285,7 +368,7 @@ class YOLOE(Model):
285
368
  self,
286
369
  source=None,
287
370
  stream: bool = False,
288
- visual_prompts: dict = {},
371
+ visual_prompts: Dict[str, List] = {},
289
372
  refer_image=None,
290
373
  predictor=None,
291
374
  **kwargs,
@@ -298,8 +381,8 @@ class YOLOE(Model):
298
381
  directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
299
382
  stream (bool): Whether to stream the prediction results. If True, results are yielded as a
300
383
  generator as they are computed.
301
- visual_prompts (dict): Dictionary containing visual prompts for the model. Must include 'bboxes' and
302
- 'cls' keys when non-empty.
384
+ visual_prompts (Dict[str, List]): Dictionary containing visual prompts for the model. Must include
385
+ 'bboxes' and 'cls' keys when non-empty.
303
386
  refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
304
387
  predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
305
388
  loaded based on the task.
@@ -30,8 +30,6 @@ class OBBPredictor(DetectionPredictor):
30
30
  """
31
31
  Initialize OBBPredictor with optional model and data configuration overrides.
32
32
 
33
- This constructor sets up an OBBPredictor instance for oriented bounding box detection tasks.
34
-
35
33
  Args:
36
34
  cfg (dict, optional): Default configuration for the predictor.
37
35
  overrides (dict, optional): Configuration overrides that take precedence over the default config.
@@ -51,14 +49,15 @@ class OBBPredictor(DetectionPredictor):
51
49
  Construct the result object from the prediction.
52
50
 
53
51
  Args:
54
- pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 6) where
52
+ pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
55
53
  the last dimension contains [x, y, w, h, confidence, class_id, angle].
56
54
  img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
57
55
  orig_img (np.ndarray): The original image before preprocessing.
58
56
  img_path (str): The path to the original image.
59
57
 
60
58
  Returns:
61
- (Results): The result object containing the original image, image path, class names, and oriented bounding boxes.
59
+ (Results): The result object containing the original image, image path, class names, and oriented bounding
60
+ boxes.
62
61
  """
63
62
  rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
64
63
  rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
@@ -1,6 +1,8 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
+ from pathlib import Path
5
+ from typing import Any, List, Optional, Union
4
6
 
5
7
  from ultralytics.models import yolo
6
8
  from ultralytics.nn.tasks import OBBModel
@@ -11,8 +13,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
11
13
  """
12
14
  A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
13
15
 
16
+ This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
17
+ detecting objects at arbitrary angles rather than just axis-aligned rectangles.
18
+
14
19
  Attributes:
15
- loss_names (Tuple[str]): Names of the loss components used during training.
20
+ loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
21
+ and dfl_loss.
16
22
 
17
23
  Methods:
18
24
  get_model: Return OBBModel initialized with specified config and weights.
@@ -25,7 +31,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
25
31
  >>> trainer.train()
26
32
  """
27
33
 
28
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
34
+ def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[dict] = None, _callbacks: Optional[List[Any]] = None):
29
35
  """
30
36
  Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
31
37
 
@@ -37,7 +43,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
37
43
  model configuration.
38
44
  overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
39
45
  will take precedence over those in cfg.
40
- _callbacks (list, optional): List of callback functions to be invoked during training.
46
+ _callbacks (List[Any], optional): List of callback functions to be invoked during training.
41
47
 
42
48
  Examples:
43
49
  >>> from ultralytics.models.yolo.obb import OBBTrainer
@@ -50,14 +56,16 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
50
56
  overrides["task"] = "obb"
51
57
  super().__init__(cfg, overrides, _callbacks)
52
58
 
53
- def get_model(self, cfg=None, weights=None, verbose=True):
59
+ def get_model(
60
+ self, cfg: Optional[Union[str, dict]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True
61
+ ) -> OBBModel:
54
62
  """
55
63
  Return OBBModel initialized with specified config and weights.
56
64
 
57
65
  Args:
58
- cfg (str | dict | None): Model configuration. Can be a path to a YAML config file, a dictionary
66
+ cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
59
67
  containing configuration parameters, or None to use default configuration.
60
- weights (str | Path | None): Path to pretrained weights file. If None, random initialization is used.
68
+ weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
61
69
  verbose (bool): Whether to display model information during initialization.
62
70
 
63
71
  Returns:
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
+ from typing import Dict, List, Tuple, Union
4
5
 
5
6
  import torch
6
7
 
@@ -63,34 +64,31 @@ class OBBValidator(DetectionValidator):
63
64
  val = self.data.get(self.args.split, "") # validation path
64
65
  self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
65
66
 
66
- def _process_batch(self, detections, gt_bboxes, gt_cls):
67
+ def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
67
68
  """
68
- Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
69
+ Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
69
70
 
70
71
  Args:
71
- detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
72
- data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
73
- gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
74
- represented as (x1, y1, x2, y2, angle).
75
- gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
72
+ detections (torch.Tensor): Detected bounding boxes and associated data with shape (N, 7) where each
73
+ detection is represented as (x1, y1, x2, y2, conf, class, angle).
74
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (M, 5) where each box is represented
75
+ as (x1, y1, x2, y2, angle).
76
+ gt_cls (torch.Tensor): Class labels for the ground truth bounding boxes with shape (M,).
76
77
 
77
78
  Returns:
78
- (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
79
- Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
79
+ (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU levels for each
80
+ detection, indicating the accuracy of predictions compared to the ground truth.
80
81
 
81
82
  Examples:
82
83
  >>> detections = torch.rand(100, 7) # 100 sample detections
83
84
  >>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
84
85
  >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
85
- >>> correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
86
-
87
- Note:
88
- This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
86
+ >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
89
87
  """
90
88
  iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
91
89
  return self.match_predictions(detections[:, 5], gt_cls, iou)
92
90
 
93
- def _prepare_batch(self, si, batch):
91
+ def _prepare_batch(self, si: int, batch: Dict) -> Dict:
94
92
  """
95
93
  Prepare batch data for OBB validation with proper scaling and formatting.
96
94
 
@@ -104,8 +102,8 @@ class OBBValidator(DetectionValidator):
104
102
  - img: Batch of images
105
103
  - ratio_pad: Ratio and padding information
106
104
 
107
- This method filters the batch data for a specific batch index, extracts class labels and bounding boxes,
108
- and scales the bounding boxes to the original image dimensions.
105
+ Returns:
106
+ (dict): Prepared batch data with scaled bounding boxes and metadata.
109
107
  """
110
108
  idx = batch["batch_idx"] == si
111
109
  cls = batch["cls"][idx].squeeze(-1)
@@ -118,7 +116,7 @@ class OBBValidator(DetectionValidator):
118
116
  ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
119
117
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
120
118
 
121
- def _prepare_pred(self, pred, pbatch):
119
+ def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict) -> torch.Tensor:
122
120
  """
123
121
  Prepare predictions by scaling bounding boxes to original image dimensions.
124
122
 
@@ -141,7 +139,7 @@ class OBBValidator(DetectionValidator):
141
139
  ) # native-space pred
142
140
  return predn
143
141
 
144
- def plot_predictions(self, batch, preds, ni):
142
+ def plot_predictions(self, batch: Dict, preds: List[torch.Tensor], ni: int):
145
143
  """
146
144
  Plot predicted bounding boxes on input images and save the result.
147
145
 
@@ -165,7 +163,7 @@ class OBBValidator(DetectionValidator):
165
163
  on_plot=self.on_plot,
166
164
  ) # pred
167
165
 
168
- def pred_to_json(self, predn, filename):
166
+ def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]):
169
167
  """
170
168
  Convert YOLO predictions to COCO JSON format with rotated bounding box information.
171
169
 
@@ -194,9 +192,9 @@ class OBBValidator(DetectionValidator):
194
192
  }
195
193
  )
196
194
 
197
- def save_one_txt(self, predn, save_conf, shape, file):
195
+ def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]):
198
196
  """
199
- Save YOLO OBB (Oriented Bounding Box) detections to a text file in normalized coordinates.
197
+ Save YOLO OBB detections to a text file in normalized coordinates.
200
198
 
201
199
  Args:
202
200
  predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
@@ -224,8 +222,16 @@ class OBBValidator(DetectionValidator):
224
222
  obb=obb,
225
223
  ).save_txt(file, save_conf=save_conf)
226
224
 
227
- def eval_json(self, stats):
228
- """Evaluate YOLO output in JSON format and save predictions in DOTA format."""
225
+ def eval_json(self, stats: Dict) -> Dict:
226
+ """
227
+ Evaluate YOLO output in JSON format and save predictions in DOTA format.
228
+
229
+ Args:
230
+ stats (dict): Performance statistics dictionary.
231
+
232
+ Returns:
233
+ (dict): Updated performance statistics.
234
+ """
229
235
  if self.args.save_json and self.is_dota and len(self.jdict):
230
236
  import json
231
237
  import re
@@ -16,7 +16,7 @@ class PosePredictor(DetectionPredictor):
16
16
  model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
17
17
 
18
18
  Methods:
19
- construct_result: Constructs the result object from the prediction, including keypoints.
19
+ construct_result: Construct the result object from the prediction, including keypoints.
20
20
 
21
21
  Examples:
22
22
  >>> from ultralytics.utils import ASSETS
@@ -28,13 +28,13 @@ class PosePredictor(DetectionPredictor):
28
28
 
29
29
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
30
30
  """
31
- Initialize PosePredictor, a specialized predictor for pose estimation tasks.
31
+ Initialize PosePredictor for pose estimation tasks.
32
32
 
33
- This initializer sets up a PosePredictor instance, configuring it for pose detection tasks and handling
34
- device-specific warnings for Apple MPS.
33
+ Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
34
+ warnings for Apple MPS.
35
35
 
36
36
  Args:
37
- cfg (Any): Configuration for the predictor. Default is DEFAULT_CFG.
37
+ cfg (Any): Configuration for the predictor.
38
38
  overrides (dict, optional): Configuration overrides that take precedence over cfg.
39
39
  _callbacks (list, optional): List of callback functions to be invoked during prediction.
40
40
 
@@ -57,8 +57,8 @@ class PosePredictor(DetectionPredictor):
57
57
  """
58
58
  Construct the result object from the prediction, including keypoints.
59
59
 
60
- This method extends the parent class implementation by extracting keypoint data from predictions
61
- and adding them to the result object.
60
+ Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
61
+ result object.
62
62
 
63
63
  Args:
64
64
  pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
@@ -68,7 +68,8 @@ class PosePredictor(DetectionPredictor):
68
68
  img_path (str): The path to the original image file.
69
69
 
70
70
  Returns:
71
- (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
71
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and
72
+ keypoints.
72
73
  """
73
74
  result = super().construct_result(pred, img, orig_img, img_path)
74
75
  # Extract keypoints from prediction and reshape according to model's keypoint shape
@@ -1,6 +1,8 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Optional, Union
4
6
 
5
7
  from ultralytics.models import yolo
6
8
  from ultralytics.nn.tasks import PoseModel
@@ -19,14 +21,15 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
19
21
  args (dict): Configuration arguments for training.
20
22
  model (PoseModel): The pose estimation model being trained.
21
23
  data (dict): Dataset configuration including keypoint shape information.
22
- loss_names (Tuple[str]): Names of the loss components used in training.
24
+ loss_names (tuple): Names of the loss components used in training.
23
25
 
24
26
  Methods:
25
- get_model: Retrieves a pose estimation model with specified configuration.
26
- set_model_attributes: Sets keypoints shape attribute on the model.
27
- get_validator: Creates a validator instance for model evaluation.
28
- plot_training_samples: Visualizes training samples with keypoints.
29
- plot_metrics: Generates and saves training/validation metric plots.
27
+ get_model: Retrieve a pose estimation model with specified configuration.
28
+ set_model_attributes: Set keypoints shape attribute on the model.
29
+ get_validator: Create a validator instance for model evaluation.
30
+ plot_training_samples: Visualize training samples with keypoints.
31
+ plot_metrics: Generate and save training/validation metric plots.
32
+ get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
30
33
 
31
34
  Examples:
32
35
  >>> from ultralytics.models.yolo.pose import PoseTrainer
@@ -35,7 +38,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
35
38
  >>> trainer.train()
36
39
  """
37
40
 
38
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
41
+ def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):
39
42
  """
40
43
  Initialize a PoseTrainer object for training YOLO pose estimation models.
41
44
 
@@ -68,13 +71,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
68
71
  "See https://github.com/ultralytics/ultralytics/issues/4031."
69
72
  )
70
73
 
71
- def get_model(self, cfg=None, weights=None, verbose=True):
74
+ def get_model(
75
+ self,
76
+ cfg: Optional[Union[str, Path, Dict[str, Any]]] = None,
77
+ weights: Optional[Union[str, Path]] = None,
78
+ verbose: bool = True,
79
+ ) -> PoseModel:
72
80
  """
73
81
  Get pose estimation model with specified configuration and weights.
74
82
 
75
83
  Args:
76
- cfg (str | Path | dict | None): Model configuration file path or dictionary.
77
- weights (str | Path | None): Path to the model weights file.
84
+ cfg (str | Path | dict, optional): Model configuration file path or dictionary.
85
+ weights (str | Path, optional): Path to the model weights file.
78
86
  verbose (bool): Whether to display model information.
79
87
 
80
88
  Returns:
@@ -89,18 +97,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
89
97
  return model
90
98
 
91
99
  def set_model_attributes(self):
92
- """Sets keypoints shape attribute of PoseModel."""
100
+ """Set keypoints shape attribute of PoseModel."""
93
101
  super().set_model_attributes()
94
102
  self.model.kpt_shape = self.data["kpt_shape"]
95
103
 
96
104
  def get_validator(self):
97
- """Returns an instance of the PoseValidator class for validation."""
105
+ """Return an instance of the PoseValidator class for validation."""
98
106
  self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
99
107
  return yolo.pose.PoseValidator(
100
108
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
101
109
  )
102
110
 
103
- def plot_training_samples(self, batch, ni):
111
+ def plot_training_samples(self, batch: Dict[str, Any], ni: int):
104
112
  """
105
113
  Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
106
114
 
@@ -135,12 +143,12 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
135
143
  )
136
144
 
137
145
  def plot_metrics(self):
138
- """Plots training/val metrics."""
146
+ """Plot training/validation metrics."""
139
147
  plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
140
148
 
141
- def get_dataset(self):
149
+ def get_dataset(self) -> Dict[str, Any]:
142
150
  """
143
- Retrieves the dataset and ensures it contains the required `kpt_shape` key.
151
+ Retrieve the dataset and ensure it contains the required `kpt_shape` key.
144
152
 
145
153
  Returns:
146
154
  (dict): A dictionary containing the training/validation/test dataset and category names.