dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,11 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from pathlib import Path
6
+ from typing import Any
7
+
8
+ import torch
4
9
 
5
10
  from ultralytics.data.build import load_inference_source
6
11
  from ultralytics.engine.model import Model
@@ -19,19 +24,42 @@ from ultralytics.utils import ROOT, YAML
19
24
 
20
25
 
21
26
  class YOLO(Model):
22
- """YOLO (You Only Look Once) object detection model."""
27
+ """YOLO (You Only Look Once) object detection model.
23
28
 
24
- def __init__(self, model="yolo11n.pt", task=None, verbose=False):
25
- """
26
- Initialize a YOLO model.
29
+ This class provides a unified interface for YOLO models, automatically switching to specialized model types
30
+ (YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
31
+ detection, segmentation, classification, pose estimation, and oriented bounding box detection.
32
+
33
+ Attributes:
34
+ model: The loaded YOLO model instance.
35
+ task: The task type (detect, segment, classify, pose, obb).
36
+ overrides: Configuration overrides for the model.
37
+
38
+ Methods:
39
+ __init__: Initialize a YOLO model with automatic type detection.
40
+ task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
41
+
42
+ Examples:
43
+ Load a pretrained YOLOv11n detection model
44
+ >>> model = YOLO("yolo11n.pt")
45
+
46
+ Load a pretrained YOLO11n segmentation model
47
+ >>> model = YOLO("yolo11n-seg.pt")
48
+
49
+ Initialize from a YAML configuration
50
+ >>> model = YOLO("yolo11n.yaml")
51
+ """
27
52
 
28
- This constructor initializes a YOLO model, automatically switching to specialized model types
29
- (YOLOWorld or YOLOE) based on the model filename.
53
+ def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
54
+ """Initialize a YOLO model.
55
+
56
+ This constructor initializes a YOLO model, automatically switching to specialized model types (YOLOWorld or
57
+ YOLOE) based on the model filename.
30
58
 
31
59
  Args:
32
60
  model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
33
- task (str | None): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
34
- Defaults to auto-detection based on model.
61
+ task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'. Defaults
62
+ to auto-detection based on model.
35
63
  verbose (bool): Display model info on load.
36
64
 
37
65
  Examples:
@@ -39,7 +67,7 @@ class YOLO(Model):
39
67
  >>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
40
68
  >>> model = YOLO("yolo11n-seg.pt") # load a pretrained YOLO11n segmentation model
41
69
  """
42
- path = Path(model)
70
+ path = Path(model if isinstance(model, (str, Path)) else "")
43
71
  if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
44
72
  new_instance = YOLOWorld(path, verbose=verbose)
45
73
  self.__class__ = type(new_instance)
@@ -51,9 +79,15 @@ class YOLO(Model):
51
79
  else:
52
80
  # Continue with default YOLO initialization
53
81
  super().__init__(model=model, task=task, verbose=verbose)
82
+ if hasattr(self.model, "model") and "RTDETR" in self.model.model[-1]._get_name(): # if RTDETR head
83
+ from ultralytics import RTDETR
84
+
85
+ new_instance = RTDETR(self)
86
+ self.__class__ = type(new_instance)
87
+ self.__dict__ = new_instance.__dict__
54
88
 
55
89
  @property
56
- def task_map(self):
90
+ def task_map(self) -> dict[str, dict[str, Any]]:
57
91
  """Map head to model, trainer, validator, and predictor classes."""
58
92
  return {
59
93
  "classify": {
@@ -90,14 +124,35 @@ class YOLO(Model):
90
124
 
91
125
 
92
126
  class YOLOWorld(Model):
93
- """YOLO-World object detection model."""
127
+ """YOLO-World object detection model.
94
128
 
95
- def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
96
- """
97
- Initialize YOLOv8-World model with a pre-trained model file.
129
+ YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions without
130
+ requiring training on specific classes. It extends the YOLO architecture to support real-time open-vocabulary
131
+ detection.
132
+
133
+ Attributes:
134
+ model: The loaded YOLO-World model instance.
135
+ task: Always set to 'detect' for object detection.
136
+ overrides: Configuration overrides for the model.
137
+
138
+ Methods:
139
+ __init__: Initialize YOLOv8-World model with a pre-trained model file.
140
+ task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
141
+ set_classes: Set the model's class names for detection.
142
+
143
+ Examples:
144
+ Load a YOLOv8-World model
145
+ >>> model = YOLOWorld("yolov8s-world.pt")
146
+
147
+ Set custom classes for detection
148
+ >>> model.set_classes(["person", "car", "bicycle"])
149
+ """
150
+
151
+ def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
152
+ """Initialize YOLOv8-World model with a pre-trained model file.
98
153
 
99
- Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
100
- COCO class names.
154
+ Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default COCO
155
+ class names.
101
156
 
102
157
  Args:
103
158
  model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
@@ -110,7 +165,7 @@ class YOLOWorld(Model):
110
165
  self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
111
166
 
112
167
  @property
113
- def task_map(self):
168
+ def task_map(self) -> dict[str, dict[str, Any]]:
114
169
  """Map head to model, validator, and predictor classes."""
115
170
  return {
116
171
  "detect": {
@@ -121,9 +176,8 @@ class YOLOWorld(Model):
121
176
  }
122
177
  }
123
178
 
124
- def set_classes(self, classes):
125
- """
126
- Set the model's class names for detection.
179
+ def set_classes(self, classes: list[str]) -> None:
180
+ """Set the model's class names for detection.
127
181
 
128
182
  Args:
129
183
  classes (list[str]): A list of categories i.e. ["person"].
@@ -141,11 +195,41 @@ class YOLOWorld(Model):
141
195
 
142
196
 
143
197
  class YOLOE(Model):
144
- """YOLOE object detection and segmentation model."""
145
-
146
- def __init__(self, model="yoloe-11s-seg.pt", task=None, verbose=False) -> None:
147
- """
148
- Initialize YOLOE model with a pre-trained model file.
198
+ """YOLOE object detection and segmentation model.
199
+
200
+ YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with improved
201
+ performance and additional features like visual and text positional embeddings.
202
+
203
+ Attributes:
204
+ model: The loaded YOLOE model instance.
205
+ task: The task type (detect or segment).
206
+ overrides: Configuration overrides for the model.
207
+
208
+ Methods:
209
+ __init__: Initialize YOLOE model with a pre-trained model file.
210
+ task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
211
+ get_text_pe: Get text positional embeddings for the given texts.
212
+ get_visual_pe: Get visual positional embeddings for the given image and visual features.
213
+ set_vocab: Set vocabulary and class names for the YOLOE model.
214
+ get_vocab: Get vocabulary for the given class names.
215
+ set_classes: Set the model's class names and embeddings for detection.
216
+ val: Validate the model using text or visual prompts.
217
+ predict: Run prediction on images, videos, directories, streams, etc.
218
+
219
+ Examples:
220
+ Load a YOLOE detection model
221
+ >>> model = YOLOE("yoloe-11s-seg.pt")
222
+
223
+ Set vocabulary and class names
224
+ >>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
225
+
226
+ Predict with visual prompts
227
+ >>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
228
+ >>> results = model.predict("image.jpg", visual_prompts=prompts)
229
+ """
230
+
231
+ def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
232
+ """Initialize YOLOE model with a pre-trained model file.
149
233
 
150
234
  Args:
151
235
  model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
@@ -154,12 +238,8 @@ class YOLOE(Model):
154
238
  """
155
239
  super().__init__(model=model, task=task, verbose=verbose)
156
240
 
157
- # Assign default COCO class names when there are no custom names
158
- if not hasattr(self.model, "names"):
159
- self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
160
-
161
241
  @property
162
- def task_map(self):
242
+ def task_map(self) -> dict[str, dict[str, Any]]:
163
243
  """Map head to model, validator, and predictor classes."""
164
244
  return {
165
245
  "detect": {
@@ -182,11 +262,10 @@ class YOLOE(Model):
182
262
  return self.model.get_text_pe(texts)
183
263
 
184
264
  def get_visual_pe(self, img, visual):
185
- """
186
- Get visual positional embeddings for the given image and visual features.
265
+ """Get visual positional embeddings for the given image and visual features.
187
266
 
188
- This method extracts positional embeddings from visual features based on the input image. It requires
189
- that the model is an instance of YOLOEModel.
267
+ This method extracts positional embeddings from visual features based on the input image. It requires that the
268
+ model is an instance of YOLOEModel.
190
269
 
191
270
  Args:
192
271
  img (torch.Tensor): Input image tensor.
@@ -198,22 +277,21 @@ class YOLOE(Model):
198
277
  Examples:
199
278
  >>> model = YOLOE("yoloe-11s-seg.pt")
200
279
  >>> img = torch.rand(1, 3, 640, 640)
201
- >>> visual_features = model.model.backbone(img)
280
+ >>> visual_features = torch.rand(1, 1, 80, 80)
202
281
  >>> pe = model.get_visual_pe(img, visual_features)
203
282
  """
204
283
  assert isinstance(self.model, YOLOEModel)
205
284
  return self.model.get_visual_pe(img, visual)
206
285
 
207
- def set_vocab(self, vocab, names):
208
- """
209
- Set vocabulary and class names for the YOLOE model.
286
+ def set_vocab(self, vocab: list[str], names: list[str]) -> None:
287
+ """Set vocabulary and class names for the YOLOE model.
210
288
 
211
- This method configures the vocabulary and class names used by the model for text processing and
212
- classification tasks. The model must be an instance of YOLOEModel.
289
+ This method configures the vocabulary and class names used by the model for text processing and classification
290
+ tasks. The model must be an instance of YOLOEModel.
213
291
 
214
292
  Args:
215
- vocab (list): Vocabulary list containing tokens or words used by the model for text processing.
216
- names (list): List of class names that the model can detect or classify.
293
+ vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
294
+ names (list[str]): List of class names that the model can detect or classify.
217
295
 
218
296
  Raises:
219
297
  AssertionError: If the model is not an instance of YOLOEModel.
@@ -230,15 +308,16 @@ class YOLOE(Model):
230
308
  assert isinstance(self.model, YOLOEModel)
231
309
  return self.model.get_vocab(names)
232
310
 
233
- def set_classes(self, classes, embeddings):
234
- """
235
- Set the model's class names and embeddings for detection.
311
+ def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
312
+ """Set the model's class names and embeddings for detection.
236
313
 
237
314
  Args:
238
315
  classes (list[str]): A list of categories i.e. ["person"].
239
316
  embeddings (torch.Tensor): Embeddings corresponding to the classes.
240
317
  """
241
318
  assert isinstance(self.model, YOLOEModel)
319
+ if embeddings is None:
320
+ embeddings = self.get_text_pe(classes) # generate text embeddings if not provided
242
321
  self.model.set_classes(classes, embeddings)
243
322
  # Verify no background class is present
244
323
  assert " " not in classes
@@ -251,12 +330,11 @@ class YOLOE(Model):
251
330
  def val(
252
331
  self,
253
332
  validator=None,
254
- load_vp=False,
255
- refer_data=None,
333
+ load_vp: bool = False,
334
+ refer_data: str | None = None,
256
335
  **kwargs,
257
336
  ):
258
- """
259
- Validate the model using text or visual prompts.
337
+ """Validate the model using text or visual prompts.
260
338
 
261
339
  Args:
262
340
  validator (callable, optional): A callable validator function. If None, a default validator is loaded.
@@ -279,28 +357,27 @@ class YOLOE(Model):
279
357
  self,
280
358
  source=None,
281
359
  stream: bool = False,
282
- visual_prompts: dict = {},
360
+ visual_prompts: dict[str, list] = {},
283
361
  refer_image=None,
284
- predictor=None,
362
+ predictor=yolo.yoloe.YOLOEVPDetectPredictor,
285
363
  **kwargs,
286
364
  ):
287
- """
288
- Run prediction on images, videos, directories, streams, etc.
365
+ """Run prediction on images, videos, directories, streams, etc.
289
366
 
290
367
  Args:
291
- source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
292
- directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
293
- stream (bool): Whether to stream the prediction results. If True, results are yielded as a
294
- generator as they are computed.
295
- visual_prompts (dict): Dictionary containing visual prompts for the model. Must include 'bboxes' and
296
- 'cls' keys when non-empty.
368
+ source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths, directory
369
+ paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
370
+ stream (bool): Whether to stream the prediction results. If True, results are yielded as a generator as they
371
+ are computed.
372
+ visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include 'bboxes'
373
+ and 'cls' keys when non-empty.
297
374
  refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
298
- predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
299
- loaded based on the task.
375
+ predictor (callable, optional): Custom predictor function. If None, a predictor is automatically loaded
376
+ based on the task.
300
377
  **kwargs (Any): Additional keyword arguments passed to the predictor.
301
378
 
302
379
  Returns:
303
- (List | generator): List of Results objects or generator of Results objects if stream=True.
380
+ (list | generator): List of Results objects or generator of Results objects if stream=True.
304
381
 
305
382
  Examples:
306
383
  >>> model = YOLOE("yoloe-11s-seg.pt")
@@ -317,18 +394,21 @@ class YOLOE(Model):
317
394
  f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
318
395
  f"{len(visual_prompts['cls'])} respectively"
319
396
  )
320
- self.predictor = (predictor or self._smart_load("predictor"))(
321
- overrides={
322
- "task": self.model.task,
323
- "mode": "predict",
324
- "save": False,
325
- "verbose": refer_image is None,
326
- "batch": 1,
327
- },
328
- _callbacks=self.callbacks,
329
- )
397
+ if type(self.predictor) is not predictor:
398
+ self.predictor = predictor(
399
+ overrides={
400
+ "task": self.model.task,
401
+ "mode": "predict",
402
+ "save": False,
403
+ "verbose": refer_image is None,
404
+ "batch": 1,
405
+ "device": kwargs.get("device", None),
406
+ "half": kwargs.get("half", False),
407
+ "imgsz": kwargs.get("imgsz", self.overrides["imgsz"]),
408
+ },
409
+ _callbacks=self.callbacks,
410
+ )
330
411
 
331
- if len(visual_prompts):
332
412
  num_cls = (
333
413
  max(len(set(c)) for c in visual_prompts["cls"])
334
414
  if isinstance(source, list) and refer_image is None # means multiple images
@@ -337,18 +417,19 @@ class YOLOE(Model):
337
417
  self.model.model[-1].nc = num_cls
338
418
  self.model.names = [f"object{i}" for i in range(num_cls)]
339
419
  self.predictor.set_prompts(visual_prompts.copy())
340
-
341
- self.predictor.setup_model(model=self.model)
342
-
343
- if refer_image is None and source is not None:
344
- dataset = load_inference_source(source)
345
- if dataset.mode in {"video", "stream"}:
346
- # NOTE: set the first frame as refer image for videos/streams inference
347
- refer_image = next(iter(dataset))[1][0]
348
- if refer_image is not None and len(visual_prompts):
349
- vpe = self.predictor.get_vpe(refer_image)
350
- self.model.set_classes(self.model.names, vpe)
351
- self.task = "segment" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else "detect"
352
- self.predictor = None # reset predictor
420
+ self.predictor.setup_model(model=self.model)
421
+
422
+ if refer_image is None and source is not None:
423
+ dataset = load_inference_source(source)
424
+ if dataset.mode in {"video", "stream"}:
425
+ # NOTE: set the first frame as refer image for videos/streams inference
426
+ refer_image = next(iter(dataset))[1][0]
427
+ if refer_image is not None:
428
+ vpe = self.predictor.get_vpe(refer_image)
429
+ self.model.set_classes(self.model.names, vpe)
430
+ self.task = "segment" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else "detect"
431
+ self.predictor = None # reset predictor
432
+ elif isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):
433
+ self.predictor = None # reset predictor if no visual prompts
353
434
 
354
435
  return super().predict(source, stream, **kwargs)
@@ -8,8 +8,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
8
8
 
9
9
 
10
10
  class OBBPredictor(DetectionPredictor):
11
- """
12
- A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
11
+ """A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
13
12
 
14
13
  This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
15
14
  bounding boxes.
@@ -27,10 +26,7 @@ class OBBPredictor(DetectionPredictor):
27
26
  """
28
27
 
29
28
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
30
- """
31
- Initialize OBBPredictor with optional model and data configuration overrides.
32
-
33
- This constructor sets up an OBBPredictor instance for oriented bounding box detection tasks.
29
+ """Initialize OBBPredictor with optional model and data configuration overrides.
34
30
 
35
31
  Args:
36
32
  cfg (dict, optional): Default configuration for the predictor.
@@ -47,18 +43,18 @@ class OBBPredictor(DetectionPredictor):
47
43
  self.args.task = "obb"
48
44
 
49
45
  def construct_result(self, pred, img, orig_img, img_path):
50
- """
51
- Construct the result object from the prediction.
46
+ """Construct the result object from the prediction.
52
47
 
53
48
  Args:
54
- pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 6) where
55
- the last dimension contains [x, y, w, h, confidence, class_id, angle].
49
+ pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where the
50
+ last dimension contains [x, y, w, h, confidence, class_id, angle].
56
51
  img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
57
52
  orig_img (np.ndarray): The original image before preprocessing.
58
53
  img_path (str): The path to the original image.
59
54
 
60
55
  Returns:
61
- (Results): The result object containing the original image, image path, class names, and oriented bounding boxes.
56
+ (Results): The result object containing the original image, image path, class names, and oriented bounding
57
+ boxes.
62
58
  """
63
59
  rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
64
60
  rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
@@ -1,6 +1,10 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from copy import copy
6
+ from pathlib import Path
7
+ from typing import Any
4
8
 
5
9
  from ultralytics.models import yolo
6
10
  from ultralytics.nn.tasks import OBBModel
@@ -8,11 +12,14 @@ from ultralytics.utils import DEFAULT_CFG, RANK
8
12
 
9
13
 
10
14
  class OBBTrainer(yolo.detect.DetectionTrainer):
11
- """
12
- A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
15
+ """A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
16
+
17
+ This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for detecting
18
+ objects at arbitrary angles rather than just axis-aligned rectangles.
13
19
 
14
20
  Attributes:
15
- loss_names (Tuple[str]): Names of the loss components used during training.
21
+ loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss, and
22
+ dfl_loss.
16
23
 
17
24
  Methods:
18
25
  get_model: Return OBBModel initialized with specified config and weights.
@@ -25,39 +32,30 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
25
32
  >>> trainer.train()
26
33
  """
27
34
 
28
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
29
- """
30
- Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
31
-
32
- This trainer extends the DetectionTrainer class to specialize in training models that detect oriented
33
- bounding boxes. It automatically sets the task to 'obb' in the configuration.
35
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
36
+ """Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
34
37
 
35
38
  Args:
36
- cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
37
- model configuration.
38
- overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
39
- will take precedence over those in cfg.
40
- _callbacks (list, optional): List of callback functions to be invoked during training.
41
-
42
- Examples:
43
- >>> from ultralytics.models.yolo.obb import OBBTrainer
44
- >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
45
- >>> trainer = OBBTrainer(overrides=args)
46
- >>> trainer.train()
39
+ cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and model
40
+ configuration.
41
+ overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here will
42
+ take precedence over those in cfg.
43
+ _callbacks (list[Any], optional): List of callback functions to be invoked during training.
47
44
  """
48
45
  if overrides is None:
49
46
  overrides = {}
50
47
  overrides["task"] = "obb"
51
48
  super().__init__(cfg, overrides, _callbacks)
52
49
 
53
- def get_model(self, cfg=None, weights=None, verbose=True):
54
- """
55
- Return OBBModel initialized with specified config and weights.
50
+ def get_model(
51
+ self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
52
+ ) -> OBBModel:
53
+ """Return OBBModel initialized with specified config and weights.
56
54
 
57
55
  Args:
58
- cfg (str | dict | None): Model configuration. Can be a path to a YAML config file, a dictionary
56
+ cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
59
57
  containing configuration parameters, or None to use default configuration.
60
- weights (str | Path | None): Path to pretrained weights file. If None, random initialization is used.
58
+ weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
61
59
  verbose (bool): Whether to display model information during initialization.
62
60
 
63
61
  Returns: