dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py CHANGED
@@ -69,7 +69,7 @@ from ultralytics.nn.modules import (
69
69
  YOLOESegment,
70
70
  v10Detect,
71
71
  )
72
- from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, YAML, colorstr, emojis
72
+ from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, YAML, colorstr, emojis
73
73
  from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
74
74
  from ultralytics.utils.loss import (
75
75
  E2EDetectLoss,
@@ -80,6 +80,7 @@ from ultralytics.utils.loss import (
80
80
  v8SegmentationLoss,
81
81
  )
82
82
  from ultralytics.utils.ops import make_divisible
83
+ from ultralytics.utils.patches import torch_load
83
84
  from ultralytics.utils.plotting import feature_visualization
84
85
  from ultralytics.utils.torch_utils import (
85
86
  fuse_conv_and_bn,
@@ -94,11 +95,32 @@ from ultralytics.utils.torch_utils import (
94
95
 
95
96
 
96
97
  class BaseModel(torch.nn.Module):
97
- """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
98
+ """Base class for all YOLO models in the Ultralytics family.
99
+
100
+ This class provides common functionality for YOLO models including forward pass handling, model fusion, information
101
+ display, and weight loading capabilities.
102
+
103
+ Attributes:
104
+ model (torch.nn.Module): The neural network model.
105
+ save (list): List of layer indices to save outputs from.
106
+ stride (torch.Tensor): Model stride values.
107
+
108
+ Methods:
109
+ forward: Perform forward pass for training or inference.
110
+ predict: Perform inference on input tensor.
111
+ fuse: Fuse Conv2d and BatchNorm2d layers for optimization.
112
+ info: Print model information.
113
+ load: Load weights into the model.
114
+ loss: Compute loss for training.
115
+
116
+ Examples:
117
+ Create a BaseModel instance
118
+ >>> model = BaseModel()
119
+ >>> model.info() # Display model information
120
+ """
98
121
 
99
122
  def forward(self, x, *args, **kwargs):
100
- """
101
- Perform forward pass of the model for either training or inference.
123
+ """Perform forward pass of the model for either training or inference.
102
124
 
103
125
  If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
104
126
 
@@ -115,8 +137,7 @@ class BaseModel(torch.nn.Module):
115
137
  return self.predict(x, *args, **kwargs)
116
138
 
117
139
  def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
118
- """
119
- Perform a forward pass through the network.
140
+ """Perform a forward pass through the network.
120
141
 
121
142
  Args:
122
143
  x (torch.Tensor): The input tensor to the model.
@@ -133,8 +154,7 @@ class BaseModel(torch.nn.Module):
133
154
  return self._predict_once(x, profile, visualize, embed)
134
155
 
135
156
  def _predict_once(self, x, profile=False, visualize=False, embed=None):
136
- """
137
- Perform a forward pass through the network.
157
+ """Perform a forward pass through the network.
138
158
 
139
159
  Args:
140
160
  x (torch.Tensor): The input tensor to the model.
@@ -172,8 +192,7 @@ class BaseModel(torch.nn.Module):
172
192
  return self._predict_once(x)
173
193
 
174
194
  def _profile_one_layer(self, m, x, dt):
175
- """
176
- Profile the computation time and FLOPs of a single layer of the model on a given input.
195
+ """Profile the computation time and FLOPs of a single layer of the model on a given input.
177
196
 
178
197
  Args:
179
198
  m (torch.nn.Module): The layer to be profiled.
@@ -198,8 +217,7 @@ class BaseModel(torch.nn.Module):
198
217
  LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
199
218
 
200
219
  def fuse(self, verbose=True):
201
- """
202
- Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
220
+ """Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
203
221
  efficiency.
204
222
 
205
223
  Returns:
@@ -230,8 +248,7 @@ class BaseModel(torch.nn.Module):
230
248
  return self
231
249
 
232
250
  def is_fused(self, thresh=10):
233
- """
234
- Check if the model has less than a certain threshold of BatchNorm layers.
251
+ """Check if the model has less than a certain threshold of BatchNorm layers.
235
252
 
236
253
  Args:
237
254
  thresh (int, optional): The threshold number of BatchNorm layers.
@@ -243,8 +260,7 @@ class BaseModel(torch.nn.Module):
243
260
  return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
244
261
 
245
262
  def info(self, detailed=False, verbose=True, imgsz=640):
246
- """
247
- Print model information.
263
+ """Print model information.
248
264
 
249
265
  Args:
250
266
  detailed (bool): If True, prints out detailed information about the model.
@@ -254,8 +270,7 @@ class BaseModel(torch.nn.Module):
254
270
  return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
255
271
 
256
272
  def _apply(self, fn):
257
- """
258
- Apply a function to all tensors in the model that are not parameters or registered buffers.
273
+ """Apply a function to all tensors in the model that are not parameters or registered buffers.
259
274
 
260
275
  Args:
261
276
  fn (function): The function to apply to the model.
@@ -274,8 +289,7 @@ class BaseModel(torch.nn.Module):
274
289
  return self
275
290
 
276
291
  def load(self, weights, verbose=True):
277
- """
278
- Load weights into the model.
292
+ """Load weights into the model.
279
293
 
280
294
  Args:
281
295
  weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
@@ -300,17 +314,17 @@ class BaseModel(torch.nn.Module):
300
314
  LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
301
315
 
302
316
  def loss(self, batch, preds=None):
303
- """
304
- Compute loss.
317
+ """Compute loss.
305
318
 
306
319
  Args:
307
320
  batch (dict): Batch to compute loss on.
308
- preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
321
+ preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
309
322
  """
310
323
  if getattr(self, "criterion", None) is None:
311
324
  self.criterion = self.init_criterion()
312
325
 
313
- preds = self.forward(batch["img"]) if preds is None else preds
326
+ if preds is None:
327
+ preds = self.forward(batch["img"])
314
328
  return self.criterion(preds, batch)
315
329
 
316
330
  def init_criterion(self):
@@ -319,11 +333,35 @@ class BaseModel(torch.nn.Module):
319
333
 
320
334
 
321
335
  class DetectionModel(BaseModel):
322
- """YOLO detection model."""
336
+ """YOLO detection model.
337
+
338
+ This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
339
+ inference, and loss computation for object detection tasks.
340
+
341
+ Attributes:
342
+ yaml (dict): Model configuration dictionary.
343
+ model (torch.nn.Sequential): The neural network model.
344
+ save (list): List of layer indices to save outputs from.
345
+ names (dict): Class names dictionary.
346
+ inplace (bool): Whether to use inplace operations.
347
+ end2end (bool): Whether the model uses end-to-end detection.
348
+ stride (torch.Tensor): Model stride values.
349
+
350
+ Methods:
351
+ __init__: Initialize the YOLO detection model.
352
+ _predict_augment: Perform augmented inference.
353
+ _descale_pred: De-scale predictions following augmented inference.
354
+ _clip_augmented: Clip YOLO augmented inference tails.
355
+ init_criterion: Initialize the loss criterion.
356
+
357
+ Examples:
358
+ Initialize a detection model
359
+ >>> model = DetectionModel("yolo11n.yaml", ch=3, nc=80)
360
+ >>> results = model.predict(image_tensor)
361
+ """
323
362
 
324
363
  def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
325
- """
326
- Initialize the YOLO detection model with the given config and parameters.
364
+ """Initialize the YOLO detection model with the given config and parameters.
327
365
 
328
366
  Args:
329
367
  cfg (str | dict): Model configuration file path or dictionary.
@@ -362,8 +400,11 @@ class DetectionModel(BaseModel):
362
400
  return self.forward(x)["one2many"]
363
401
  return self.forward(x)[0] if isinstance(m, (Segment, YOLOESegment, Pose, OBB)) else self.forward(x)
364
402
 
403
+ self.model.eval() # Avoid changing batch statistics until training begins
404
+ m.training = True # Setting it to True to properly return strides
365
405
  m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
366
406
  self.stride = m.stride
407
+ self.model.train() # Set model back to training(default) mode
367
408
  m.bias_init() # only run once
368
409
  else:
369
410
  self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
@@ -375,8 +416,7 @@ class DetectionModel(BaseModel):
375
416
  LOGGER.info("")
376
417
 
377
418
  def _predict_augment(self, x):
378
- """
379
- Perform augmentations on input image x and return augmented inference and train outputs.
419
+ """Perform augmentations on input image x and return augmented inference and train outputs.
380
420
 
381
421
  Args:
382
422
  x (torch.Tensor): Input image tensor.
@@ -401,8 +441,7 @@ class DetectionModel(BaseModel):
401
441
 
402
442
  @staticmethod
403
443
  def _descale_pred(p, flips, scale, img_size, dim=1):
404
- """
405
- De-scale predictions following augmented inference (inverse operation).
444
+ """De-scale predictions following augmented inference (inverse operation).
406
445
 
407
446
  Args:
408
447
  p (torch.Tensor): Predictions tensor.
@@ -423,14 +462,13 @@ class DetectionModel(BaseModel):
423
462
  return torch.cat((x, y, wh, cls), dim)
424
463
 
425
464
  def _clip_augmented(self, y):
426
- """
427
- Clip YOLO augmented inference tails.
465
+ """Clip YOLO augmented inference tails.
428
466
 
429
467
  Args:
430
- y (List[torch.Tensor]): List of detection tensors.
468
+ y (list[torch.Tensor]): List of detection tensors.
431
469
 
432
470
  Returns:
433
- (List[torch.Tensor]): Clipped detection tensors.
471
+ (list[torch.Tensor]): Clipped detection tensors.
434
472
  """
435
473
  nl = self.model[-1].nl # number of detection layers (P3-P5)
436
474
  g = sum(4**x for x in range(nl)) # grid points
@@ -447,11 +485,23 @@ class DetectionModel(BaseModel):
447
485
 
448
486
 
449
487
  class OBBModel(DetectionModel):
450
- """YOLO Oriented Bounding Box (OBB) model."""
488
+ """YOLO Oriented Bounding Box (OBB) model.
489
+
490
+ This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
491
+ computation for rotated object detection.
492
+
493
+ Methods:
494
+ __init__: Initialize YOLO OBB model.
495
+ init_criterion: Initialize the loss criterion for OBB detection.
496
+
497
+ Examples:
498
+ Initialize an OBB model
499
+ >>> model = OBBModel("yolo11n-obb.yaml", ch=3, nc=80)
500
+ >>> results = model.predict(image_tensor)
501
+ """
451
502
 
452
503
  def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
453
- """
454
- Initialize YOLO OBB model with given config and parameters.
504
+ """Initialize YOLO OBB model with given config and parameters.
455
505
 
456
506
  Args:
457
507
  cfg (str | dict): Model configuration file path or dictionary.
@@ -467,11 +517,23 @@ class OBBModel(DetectionModel):
467
517
 
468
518
 
469
519
  class SegmentationModel(DetectionModel):
470
- """YOLO segmentation model."""
520
+ """YOLO segmentation model.
521
+
522
+ This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
523
+ pixel-level object detection and segmentation.
524
+
525
+ Methods:
526
+ __init__: Initialize YOLO segmentation model.
527
+ init_criterion: Initialize the loss criterion for segmentation.
528
+
529
+ Examples:
530
+ Initialize a segmentation model
531
+ >>> model = SegmentationModel("yolo11n-seg.yaml", ch=3, nc=80)
532
+ >>> results = model.predict(image_tensor)
533
+ """
471
534
 
472
535
  def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
473
- """
474
- Initialize Ultralytics YOLO segmentation model with given config and parameters.
536
+ """Initialize Ultralytics YOLO segmentation model with given config and parameters.
475
537
 
476
538
  Args:
477
539
  cfg (str | dict): Model configuration file path or dictionary.
@@ -487,11 +549,26 @@ class SegmentationModel(DetectionModel):
487
549
 
488
550
 
489
551
  class PoseModel(DetectionModel):
490
- """YOLO pose model."""
552
+ """YOLO pose model.
553
+
554
+ This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
555
+ keypoint detection and pose estimation.
556
+
557
+ Attributes:
558
+ kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).
559
+
560
+ Methods:
561
+ __init__: Initialize YOLO pose model.
562
+ init_criterion: Initialize the loss criterion for pose estimation.
563
+
564
+ Examples:
565
+ Initialize a pose model
566
+ >>> model = PoseModel("yolo11n-pose.yaml", ch=3, nc=1, data_kpt_shape=(17, 3))
567
+ >>> results = model.predict(image_tensor)
568
+ """
491
569
 
492
570
  def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
493
- """
494
- Initialize Ultralytics YOLO Pose model.
571
+ """Initialize Ultralytics YOLO Pose model.
495
572
 
496
573
  Args:
497
574
  cfg (str | dict): Model configuration file path or dictionary.
@@ -513,11 +590,31 @@ class PoseModel(DetectionModel):
513
590
 
514
591
 
515
592
  class ClassificationModel(BaseModel):
516
- """YOLO classification model."""
593
+ """YOLO classification model.
594
+
595
+ This class implements the YOLO classification architecture for image classification tasks, providing model
596
+ initialization, configuration, and output reshaping capabilities.
597
+
598
+ Attributes:
599
+ yaml (dict): Model configuration dictionary.
600
+ model (torch.nn.Sequential): The neural network model.
601
+ stride (torch.Tensor): Model stride values.
602
+ names (dict): Class names dictionary.
603
+
604
+ Methods:
605
+ __init__: Initialize ClassificationModel.
606
+ _from_yaml: Set model configurations and define architecture.
607
+ reshape_outputs: Update model to specified class count.
608
+ init_criterion: Initialize the loss criterion.
609
+
610
+ Examples:
611
+ Initialize a classification model
612
+ >>> model = ClassificationModel("yolo11n-cls.yaml", ch=3, nc=1000)
613
+ >>> results = model.predict(image_tensor)
614
+ """
517
615
 
518
616
  def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
519
- """
520
- Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
617
+ """Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
521
618
 
522
619
  Args:
523
620
  cfg (str | dict): Model configuration file path or dictionary.
@@ -529,8 +626,7 @@ class ClassificationModel(BaseModel):
529
626
  self._from_yaml(cfg, ch, nc, verbose)
530
627
 
531
628
  def _from_yaml(self, cfg, ch, nc, verbose):
532
- """
533
- Set Ultralytics YOLO model configurations and define the model architecture.
629
+ """Set Ultralytics YOLO model configurations and define the model architecture.
534
630
 
535
631
  Args:
536
632
  cfg (str | dict): Model configuration file path or dictionary.
@@ -554,8 +650,7 @@ class ClassificationModel(BaseModel):
554
650
 
555
651
  @staticmethod
556
652
  def reshape_outputs(model, nc):
557
- """
558
- Update a TorchVision classification model to class count 'n' if required.
653
+ """Update a TorchVision classification model to class count 'n' if required.
559
654
 
560
655
  Args:
561
656
  model (torch.nn.Module): Model to update.
@@ -587,22 +682,30 @@ class ClassificationModel(BaseModel):
587
682
 
588
683
 
589
684
  class RTDETRDetectionModel(DetectionModel):
590
- """
591
- RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
685
+ """RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
592
686
 
593
687
  This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
594
688
  the training and inference processes. RTDETR is an object detection and tracking model that extends from the
595
689
  DetectionModel base class.
596
690
 
691
+ Attributes:
692
+ nc (int): Number of classes for detection.
693
+ criterion (RTDETRDetectionLoss): Loss function for training.
694
+
597
695
  Methods:
598
- init_criterion: Initializes the criterion used for loss calculation.
599
- loss: Computes and returns the loss during training.
600
- predict: Performs a forward pass through the network and returns the output.
696
+ __init__: Initialize the RTDETRDetectionModel.
697
+ init_criterion: Initialize the loss criterion.
698
+ loss: Compute loss for training.
699
+ predict: Perform forward pass through the model.
700
+
701
+ Examples:
702
+ Initialize an RTDETR model
703
+ >>> model = RTDETRDetectionModel("rtdetr-l.yaml", ch=3, nc=80)
704
+ >>> results = model.predict(image_tensor)
601
705
  """
602
706
 
603
707
  def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
604
- """
605
- Initialize the RTDETRDetectionModel.
708
+ """Initialize the RTDETRDetectionModel.
606
709
 
607
710
  Args:
608
711
  cfg (str | dict): Configuration file name or path.
@@ -612,6 +715,21 @@ class RTDETRDetectionModel(DetectionModel):
612
715
  """
613
716
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
614
717
 
718
+ def _apply(self, fn):
719
+ """Apply a function to all tensors in the model that are not parameters or registered buffers.
720
+
721
+ Args:
722
+ fn (function): The function to apply to the model.
723
+
724
+ Returns:
725
+ (RTDETRDetectionModel): An updated BaseModel object.
726
+ """
727
+ self = super()._apply(fn)
728
+ m = self.model[-1]
729
+ m.anchors = fn(m.anchors)
730
+ m.valid_mask = fn(m.valid_mask)
731
+ return self
732
+
615
733
  def init_criterion(self):
616
734
  """Initialize the loss criterion for the RTDETRDetectionModel."""
617
735
  from ultralytics.models.utils.loss import RTDETRDetectionLoss
@@ -619,22 +737,22 @@ class RTDETRDetectionModel(DetectionModel):
619
737
  return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
620
738
 
621
739
  def loss(self, batch, preds=None):
622
- """
623
- Compute the loss for the given batch of data.
740
+ """Compute the loss for the given batch of data.
624
741
 
625
742
  Args:
626
743
  batch (dict): Dictionary containing image and label data.
627
744
  preds (torch.Tensor, optional): Precomputed model predictions.
628
745
 
629
746
  Returns:
630
- (tuple): A tuple containing the total loss and main three losses in a tensor.
747
+ loss_sum (torch.Tensor): Total loss value.
748
+ loss_items (torch.Tensor): Main three losses in a tensor.
631
749
  """
632
750
  if not hasattr(self, "criterion"):
633
751
  self.criterion = self.init_criterion()
634
752
 
635
753
  img = batch["img"]
636
754
  # NOTE: preprocess gt_bbox and gt_labels to list.
637
- bs = len(img)
755
+ bs = img.shape[0]
638
756
  batch_idx = batch["batch_idx"]
639
757
  gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
640
758
  targets = {
@@ -644,7 +762,8 @@ class RTDETRDetectionModel(DetectionModel):
644
762
  "gt_groups": gt_groups,
645
763
  }
646
764
 
647
- preds = self.predict(img, batch=targets) if preds is None else preds
765
+ if preds is None:
766
+ preds = self.predict(img, batch=targets)
648
767
  dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
649
768
  if dn_meta is None:
650
769
  dn_bboxes, dn_scores = None, None
@@ -664,8 +783,7 @@ class RTDETRDetectionModel(DetectionModel):
664
783
  )
665
784
 
666
785
  def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
667
- """
668
- Perform a forward pass through the model.
786
+ """Perform a forward pass through the model.
669
787
 
670
788
  Args:
671
789
  x (torch.Tensor): The input tensor.
@@ -700,11 +818,31 @@ class RTDETRDetectionModel(DetectionModel):
700
818
 
701
819
 
702
820
  class WorldModel(DetectionModel):
703
- """YOLOv8 World Model."""
821
+ """YOLOv8 World Model.
822
+
823
+ This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
824
+ specification and CLIP model integration for zero-shot detection capabilities.
825
+
826
+ Attributes:
827
+ txt_feats (torch.Tensor): Text feature embeddings for classes.
828
+ clip_model (torch.nn.Module): CLIP model for text encoding.
829
+
830
+ Methods:
831
+ __init__: Initialize YOLOv8 world model.
832
+ set_classes: Set classes for offline inference.
833
+ get_text_pe: Get text positional embeddings.
834
+ predict: Perform forward pass with text features.
835
+ loss: Compute loss with text features.
836
+
837
+ Examples:
838
+ Initialize a world model
839
+ >>> model = WorldModel("yolov8s-world.yaml", ch=3, nc=80)
840
+ >>> model.set_classes(["person", "car", "bicycle"])
841
+ >>> results = model.predict(image_tensor)
842
+ """
704
843
 
705
844
  def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
706
- """
707
- Initialize YOLOv8 world model with given config and parameters.
845
+ """Initialize YOLOv8 world model with given config and parameters.
708
846
 
709
847
  Args:
710
848
  cfg (str | dict): Model configuration file path or dictionary.
@@ -717,24 +855,21 @@ class WorldModel(DetectionModel):
717
855
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
718
856
 
719
857
  def set_classes(self, text, batch=80, cache_clip_model=True):
720
- """
721
- Set classes in advance so that model could do offline-inference without clip model.
858
+ """Set classes in advance so that model could do offline-inference without clip model.
722
859
 
723
860
  Args:
724
- text (List[str]): List of class names.
861
+ text (list[str]): List of class names.
725
862
  batch (int): Batch size for processing text tokens.
726
863
  cache_clip_model (bool): Whether to cache the CLIP model.
727
864
  """
728
865
  self.txt_feats = self.get_text_pe(text, batch=batch, cache_clip_model=cache_clip_model)
729
866
  self.model[-1].nc = len(text)
730
867
 
731
- @smart_inference_mode()
732
868
  def get_text_pe(self, text, batch=80, cache_clip_model=True):
733
- """
734
- Set classes in advance so that model could do offline-inference without clip model.
869
+ """Set classes in advance so that model could do offline-inference without clip model.
735
870
 
736
871
  Args:
737
- text (List[str]): List of class names.
872
+ text (list[str]): List of class names.
738
873
  batch (int): Batch size for processing text tokens.
739
874
  cache_clip_model (bool): Whether to cache the CLIP model.
740
875
 
@@ -754,8 +889,7 @@ class WorldModel(DetectionModel):
754
889
  return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
755
890
 
756
891
  def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
757
- """
758
- Perform a forward pass through the model.
892
+ """Perform a forward pass through the model.
759
893
 
760
894
  Args:
761
895
  x (torch.Tensor): The input tensor.
@@ -769,7 +903,7 @@ class WorldModel(DetectionModel):
769
903
  (torch.Tensor): Model's output tensor.
770
904
  """
771
905
  txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
772
- if len(txt_feats) != len(x) or self.model[-1].export:
906
+ if txt_feats.shape[0] != x.shape[0] or self.model[-1].export:
773
907
  txt_feats = txt_feats.expand(x.shape[0], -1, -1)
774
908
  ori_txt_feats = txt_feats.clone()
775
909
  y, dt, embeddings = [], [], [] # outputs
@@ -799,12 +933,11 @@ class WorldModel(DetectionModel):
799
933
  return x
800
934
 
801
935
  def loss(self, batch, preds=None):
802
- """
803
- Compute loss.
936
+ """Compute loss.
804
937
 
805
938
  Args:
806
939
  batch (dict): Batch to compute loss on.
807
- preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
940
+ preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
808
941
  """
809
942
  if not hasattr(self, "criterion"):
810
943
  self.criterion = self.init_criterion()
@@ -815,11 +948,34 @@ class WorldModel(DetectionModel):
815
948
 
816
949
 
817
950
  class YOLOEModel(DetectionModel):
818
- """YOLOE detection model."""
951
+ """YOLOE detection model.
952
+
953
+ This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
954
+ both prompt-based and prompt-free inference modes.
955
+
956
+ Attributes:
957
+ pe (torch.Tensor): Prompt embeddings for classes.
958
+ clip_model (torch.nn.Module): CLIP model for text encoding.
959
+
960
+ Methods:
961
+ __init__: Initialize YOLOE model.
962
+ get_text_pe: Get text positional embeddings.
963
+ get_visual_pe: Get visual embeddings.
964
+ set_vocab: Set vocabulary for prompt-free model.
965
+ get_vocab: Get fused vocabulary layer.
966
+ set_classes: Set classes for offline inference.
967
+ get_cls_pe: Get class positional embeddings.
968
+ predict: Perform forward pass with prompts.
969
+ loss: Compute loss with prompts.
970
+
971
+ Examples:
972
+ Initialize a YOLOE model
973
+ >>> model = YOLOEModel("yoloe-v8s.yaml", ch=3, nc=80)
974
+ >>> results = model.predict(image_tensor, tpe=text_embeddings)
975
+ """
819
976
 
820
977
  def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
821
- """
822
- Initialize YOLOE model with given config and parameters.
978
+ """Initialize YOLOE model with given config and parameters.
823
979
 
824
980
  Args:
825
981
  cfg (str | dict): Model configuration file path or dictionary.
@@ -831,11 +987,10 @@ class YOLOEModel(DetectionModel):
831
987
 
832
988
  @smart_inference_mode()
833
989
  def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
834
- """
835
- Set classes in advance so that model could do offline-inference without clip model.
990
+ """Set classes in advance so that model could do offline-inference without clip model.
836
991
 
837
992
  Args:
838
- text (List[str]): List of class names.
993
+ text (list[str]): List of class names.
839
994
  batch (int): Batch size for processing text tokens.
840
995
  cache_clip_model (bool): Whether to cache the CLIP model.
841
996
  without_reprta (bool): Whether to return text embeddings cooperated with reprta module.
@@ -858,15 +1013,13 @@ class YOLOEModel(DetectionModel):
858
1013
  if without_reprta:
859
1014
  return txt_feats
860
1015
 
861
- assert not self.training
862
1016
  head = self.model[-1]
863
1017
  assert isinstance(head, YOLOEDetect)
864
- return head.get_tpe(txt_feats) # run axuiliary text head
1018
+ return head.get_tpe(txt_feats) # run auxiliary text head
865
1019
 
866
1020
  @smart_inference_mode()
867
1021
  def get_visual_pe(self, img, visual):
868
- """
869
- Get visual embeddings.
1022
+ """Get visual embeddings.
870
1023
 
871
1024
  Args:
872
1025
  img (torch.Tensor): Input image tensor.
@@ -878,12 +1031,11 @@ class YOLOEModel(DetectionModel):
878
1031
  return self(img, vpe=visual, return_vpe=True)
879
1032
 
880
1033
  def set_vocab(self, vocab, names):
881
- """
882
- Set vocabulary for the prompt-free model.
1034
+ """Set vocabulary for the prompt-free model.
883
1035
 
884
1036
  Args:
885
1037
  vocab (nn.ModuleList): List of vocabulary items.
886
- names (List[str]): List of class names.
1038
+ names (list[str]): List of class names.
887
1039
  """
888
1040
  assert not self.training
889
1041
  head = self.model[-1]
@@ -907,8 +1059,7 @@ class YOLOEModel(DetectionModel):
907
1059
  self.names = check_class_names(names)
908
1060
 
909
1061
  def get_vocab(self, names):
910
- """
911
- Get fused vocabulary layer from the model.
1062
+ """Get fused vocabulary layer from the model.
912
1063
 
913
1064
  Args:
914
1065
  names (list): List of class names.
@@ -933,11 +1084,10 @@ class YOLOEModel(DetectionModel):
933
1084
  return vocab
934
1085
 
935
1086
  def set_classes(self, names, embeddings):
936
- """
937
- Set classes in advance so that model could do offline-inference without clip model.
1087
+ """Set classes in advance so that model could do offline-inference without clip model.
938
1088
 
939
1089
  Args:
940
- names (List[str]): List of class names.
1090
+ names (list[str]): List of class names.
941
1091
  embeddings (torch.Tensor): Embeddings tensor.
942
1092
  """
943
1093
  assert not hasattr(self.model[-1], "lrpc"), (
@@ -949,8 +1099,7 @@ class YOLOEModel(DetectionModel):
949
1099
  self.names = check_class_names(names)
950
1100
 
951
1101
  def get_cls_pe(self, tpe, vpe):
952
- """
953
- Get class positional embeddings.
1102
+ """Get class positional embeddings.
954
1103
 
955
1104
  Args:
956
1105
  tpe (torch.Tensor, optional): Text positional embeddings.
@@ -973,8 +1122,7 @@ class YOLOEModel(DetectionModel):
973
1122
  def predict(
974
1123
  self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
975
1124
  ):
976
- """
977
- Perform a forward pass through the model.
1125
+ """Perform a forward pass through the model.
978
1126
 
979
1127
  Args:
980
1128
  x (torch.Tensor): The input tensor.
@@ -1021,12 +1169,11 @@ class YOLOEModel(DetectionModel):
1021
1169
  return x
1022
1170
 
1023
1171
  def loss(self, batch, preds=None):
1024
- """
1025
- Compute loss.
1172
+ """Compute loss.
1026
1173
 
1027
1174
  Args:
1028
1175
  batch (dict): Batch to compute loss on.
1029
- preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
1176
+ preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
1030
1177
  """
1031
1178
  if not hasattr(self, "criterion"):
1032
1179
  from ultralytics.utils.loss import TVPDetectLoss
@@ -1040,11 +1187,23 @@ class YOLOEModel(DetectionModel):
1040
1187
 
1041
1188
 
1042
1189
  class YOLOESegModel(YOLOEModel, SegmentationModel):
1043
- """YOLOE segmentation model."""
1190
+ """YOLOE segmentation model.
1191
+
1192
+ This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
1193
+ specialized loss computation for pixel-level object detection and segmentation.
1194
+
1195
+ Methods:
1196
+ __init__: Initialize YOLOE segmentation model.
1197
+ loss: Compute loss with prompts for segmentation.
1198
+
1199
+ Examples:
1200
+ Initialize a YOLOE segmentation model
1201
+ >>> model = YOLOESegModel("yoloe-v8s-seg.yaml", ch=3, nc=80)
1202
+ >>> results = model.predict(image_tensor, tpe=text_embeddings)
1203
+ """
1044
1204
 
1045
1205
  def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
1046
- """
1047
- Initialize YOLOE segmentation model with given config and parameters.
1206
+ """Initialize YOLOE segmentation model with given config and parameters.
1048
1207
 
1049
1208
  Args:
1050
1209
  cfg (str | dict): Model configuration file path or dictionary.
@@ -1055,12 +1214,11 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1055
1214
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
1056
1215
 
1057
1216
  def loss(self, batch, preds=None):
1058
- """
1059
- Compute loss.
1217
+ """Compute loss.
1060
1218
 
1061
1219
  Args:
1062
1220
  batch (dict): Batch to compute loss on.
1063
- preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
1221
+ preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
1064
1222
  """
1065
1223
  if not hasattr(self, "criterion"):
1066
1224
  from ultralytics.utils.loss import TVPSegmentLoss
@@ -1074,15 +1232,29 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1074
1232
 
1075
1233
 
1076
1234
  class Ensemble(torch.nn.ModuleList):
1077
- """Ensemble of models."""
1235
+ """Ensemble of models.
1236
+
1237
+ This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
1238
+ or other ensemble techniques.
1239
+
1240
+ Methods:
1241
+ __init__: Initialize an ensemble of models.
1242
+ forward: Generate predictions from all models in the ensemble.
1243
+
1244
+ Examples:
1245
+ Create an ensemble of models
1246
+ >>> ensemble = Ensemble()
1247
+ >>> ensemble.append(model1)
1248
+ >>> ensemble.append(model2)
1249
+ >>> results = ensemble(image_tensor)
1250
+ """
1078
1251
 
1079
1252
  def __init__(self):
1080
1253
  """Initialize an ensemble of models."""
1081
1254
  super().__init__()
1082
1255
 
1083
1256
  def forward(self, x, augment=False, profile=False, visualize=False):
1084
- """
1085
- Generate the YOLO network's final layer.
1257
+ """Generate the YOLO network's final layer.
1086
1258
 
1087
1259
  Args:
1088
1260
  x (torch.Tensor): Input tensor.
@@ -1091,7 +1263,8 @@ class Ensemble(torch.nn.ModuleList):
1091
1263
  visualize (bool): Whether to visualize the features.
1092
1264
 
1093
1265
  Returns:
1094
- (tuple): Tuple containing the concatenated predictions and None.
1266
+ y (torch.Tensor): Concatenated predictions from all models.
1267
+ train_out (None): Always None for ensemble inference.
1095
1268
  """
1096
1269
  y = [module(x, augment, profile, visualize)[0] for module in self]
1097
1270
  # y = torch.stack(y).max(0)[0] # max ensemble
@@ -1105,12 +1278,11 @@ class Ensemble(torch.nn.ModuleList):
1105
1278
 
1106
1279
  @contextlib.contextmanager
1107
1280
  def temporary_modules(modules=None, attributes=None):
1108
- """
1109
- Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
1281
+ """Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
1110
1282
 
1111
- This function can be used to change the module paths during runtime. It's useful when refactoring code,
1112
- where you've moved a module from one location to another, but you still want to support the old import
1113
- paths for backwards compatibility.
1283
+ This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
1284
+ moved a module from one location to another, but you still want to support the old import paths for backwards
1285
+ compatibility.
1114
1286
 
1115
1287
  Args:
1116
1288
  modules (dict, optional): A dictionary mapping old module paths to new module paths.
@@ -1121,7 +1293,7 @@ def temporary_modules(modules=None, attributes=None):
1121
1293
  >>> import old.module # this will now import new.module
1122
1294
  >>> from old.module import attribute # this will now import new.module.attribute
1123
1295
 
1124
- Note:
1296
+ Notes:
1125
1297
  The changes are only in effect inside the context manager and are undone once the context manager exits.
1126
1298
  Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
1127
1299
  applications or libraries. Use this function with caution.
@@ -1168,8 +1340,7 @@ class SafeUnpickler(pickle.Unpickler):
1168
1340
  """Custom Unpickler that replaces unknown classes with SafeClass."""
1169
1341
 
1170
1342
  def find_class(self, module, name):
1171
- """
1172
- Attempt to find a class, returning SafeClass if not among safe modules.
1343
+ """Attempt to find a class, returning SafeClass if not among safe modules.
1173
1344
 
1174
1345
  Args:
1175
1346
  module (str): Module name.
@@ -1194,10 +1365,9 @@ class SafeUnpickler(pickle.Unpickler):
1194
1365
 
1195
1366
 
1196
1367
  def torch_safe_load(weight, safe_only=False):
1197
- """
1198
- Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
1199
- error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
1200
- After installation, the function again attempts to load the model using torch.load().
1368
+ """Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
1369
+ the error, logs a warning message, and attempts to install the missing module via the check_requirements()
1370
+ function. After installation, the function again attempts to load the model using torch.load().
1201
1371
 
1202
1372
  Args:
1203
1373
  weight (str): The file path of the PyTorch model.
@@ -1234,9 +1404,9 @@ def torch_safe_load(weight, safe_only=False):
1234
1404
  safe_pickle.Unpickler = SafeUnpickler
1235
1405
  safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
1236
1406
  with open(file, "rb") as f:
1237
- ckpt = torch.load(f, pickle_module=safe_pickle)
1407
+ ckpt = torch_load(f, pickle_module=safe_pickle)
1238
1408
  else:
1239
- ckpt = torch.load(file, map_location="cpu")
1409
+ ckpt = torch_load(file, map_location="cpu")
1240
1410
 
1241
1411
  except ModuleNotFoundError as e: # e.name is missing module name
1242
1412
  if e.name == "models":
@@ -1249,6 +1419,12 @@ def torch_safe_load(weight, safe_only=False):
1249
1419
  f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
1250
1420
  )
1251
1421
  ) from e
1422
+ elif e.name == "numpy._core":
1423
+ raise ModuleNotFoundError(
1424
+ emojis(
1425
+ f"ERROR ❌️ {weight} requires numpy>=1.26.1, however numpy=={__import__('numpy').__version__} is installed."
1426
+ )
1427
+ ) from e
1252
1428
  LOGGER.warning(
1253
1429
  f"{weight} appears to require '{e.name}', which is not in Ultralytics requirements."
1254
1430
  f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
@@ -1256,7 +1432,7 @@ def torch_safe_load(weight, safe_only=False):
1256
1432
  f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
1257
1433
  )
1258
1434
  check_requirements(e.name) # install missing module
1259
- ckpt = torch.load(file, map_location="cpu")
1435
+ ckpt = torch_load(file, map_location="cpu")
1260
1436
 
1261
1437
  if not isinstance(ckpt, dict):
1262
1438
  # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")
@@ -1269,80 +1445,31 @@ def torch_safe_load(weight, safe_only=False):
1269
1445
  return ckpt, file
1270
1446
 
1271
1447
 
1272
- def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
1273
- """
1274
- Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
1275
-
1276
- Args:
1277
- weights (str | List[str]): Model weights path(s).
1278
- device (torch.device, optional): Device to load model to.
1279
- inplace (bool): Whether to do inplace operations.
1280
- fuse (bool): Whether to fuse model.
1281
-
1282
- Returns:
1283
- (torch.nn.Module): Loaded model.
1284
- """
1285
- ensemble = Ensemble()
1286
- for w in weights if isinstance(weights, list) else [weights]:
1287
- ckpt, w = torch_safe_load(w) # load ckpt
1288
- args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
1289
- model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
1290
-
1291
- # Model compatibility updates
1292
- model.args = args # attach args to model
1293
- model.pt_path = w # attach *.pt file path to model
1294
- model.task = guess_model_task(model)
1295
- if not hasattr(model, "stride"):
1296
- model.stride = torch.tensor([32.0])
1297
-
1298
- # Append
1299
- ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
1300
-
1301
- # Module updates
1302
- for m in ensemble.modules():
1303
- if hasattr(m, "inplace"):
1304
- m.inplace = inplace
1305
- elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
1306
- m.recompute_scale_factor = None # torch 1.11.0 compatibility
1307
-
1308
- # Return model
1309
- if len(ensemble) == 1:
1310
- return ensemble[-1]
1311
-
1312
- # Return ensemble
1313
- LOGGER.info(f"Ensemble created with {weights}\n")
1314
- for k in "names", "nc", "yaml":
1315
- setattr(ensemble, k, getattr(ensemble[0], k))
1316
- ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
1317
- assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
1318
- return ensemble
1319
-
1320
-
1321
- def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
1322
- """
1323
- Load a single model weights.
1448
+ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
1449
+ """Load a single model weights.
1324
1450
 
1325
1451
  Args:
1326
- weight (str): Model weight path.
1452
+ weight (str | Path): Model weight path.
1327
1453
  device (torch.device, optional): Device to load model to.
1328
1454
  inplace (bool): Whether to do inplace operations.
1329
1455
  fuse (bool): Whether to fuse model.
1330
1456
 
1331
1457
  Returns:
1332
- (tuple): Tuple containing the model and checkpoint.
1458
+ model (torch.nn.Module): Loaded model.
1459
+ ckpt (dict): Model checkpoint dictionary.
1333
1460
  """
1334
1461
  ckpt, weight = torch_safe_load(weight) # load ckpt
1335
1462
  args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
1336
- model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
1463
+ model = (ckpt.get("ema") or ckpt["model"]).float() # FP32 model
1337
1464
 
1338
1465
  # Model compatibility updates
1339
- model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
1466
+ model.args = args # attach args to model
1340
1467
  model.pt_path = weight # attach *.pt file path to model
1341
- model.task = guess_model_task(model)
1468
+ model.task = getattr(model, "task", guess_model_task(model))
1342
1469
  if not hasattr(model, "stride"):
1343
1470
  model.stride = torch.tensor([32.0])
1344
1471
 
1345
- model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
1472
+ model = (model.fuse() if fuse and hasattr(model, "fuse") else model).eval().to(device) # model in eval mode
1346
1473
 
1347
1474
  # Module updates
1348
1475
  for m in model.modules():
@@ -1355,9 +1482,8 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
1355
1482
  return model, ckpt
1356
1483
 
1357
1484
 
1358
- def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1359
- """
1360
- Parse a YOLO model.yaml dictionary into a PyTorch model.
1485
+ def parse_model(d, ch, verbose=True):
1486
+ """Parse a YOLO model.yaml dictionary into a PyTorch model.
1361
1487
 
1362
1488
  Args:
1363
1489
  d (dict): Model dictionary.
@@ -1365,7 +1491,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1365
1491
  verbose (bool): Whether to print model details.
1366
1492
 
1367
1493
  Returns:
1368
- (tuple): Tuple containing the PyTorch model and sorted list of output layers.
1494
+ model (torch.nn.Sequential): PyTorch model.
1495
+ save (list): Sorted list of output layers.
1369
1496
  """
1370
1497
  import ast
1371
1498
 
@@ -1374,10 +1501,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1374
1501
  max_channels = float("inf")
1375
1502
  nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
1376
1503
  depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
1504
+ scale = d.get("scale")
1377
1505
  if scales:
1378
- scale = d.get("scale")
1379
1506
  if not scale:
1380
- scale = tuple(scales.keys())[0]
1507
+ scale = next(iter(scales.keys()))
1381
1508
  LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
1382
1509
  depth, width, max_channels = scales[scale]
1383
1510
 
@@ -1524,7 +1651,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1524
1651
  m_.np = sum(x.numel() for x in m_.parameters()) # number params
1525
1652
  m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
1526
1653
  if verbose:
1527
- LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
1654
+ LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f} {t:<45}{args!s:<30}") # print
1528
1655
  save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
1529
1656
  layers.append(m_)
1530
1657
  if i == 0:
@@ -1534,8 +1661,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1534
1661
 
1535
1662
 
1536
1663
  def yaml_model_load(path):
1537
- """
1538
- Load a YOLOv8 model from a YAML file.
1664
+ """Load a YOLOv8 model from a YAML file.
1539
1665
 
1540
1666
  Args:
1541
1667
  path (str | Path): Path to the YAML file.
@@ -1558,8 +1684,7 @@ def yaml_model_load(path):
1558
1684
 
1559
1685
 
1560
1686
  def guess_model_scale(model_path):
1561
- """
1562
- Extract the size character n, s, m, l, or x of the model's scale from the model path.
1687
+ """Extract the size character n, s, m, l, or x of the model's scale from the model path.
1563
1688
 
1564
1689
  Args:
1565
1690
  model_path (str | Path): The path to the YOLO model's YAML file.
@@ -1568,14 +1693,13 @@ def guess_model_scale(model_path):
1568
1693
  (str): The size character of the model's scale (n, s, m, l, or x).
1569
1694
  """
1570
1695
  try:
1571
- return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2) # noqa
1696
+ return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
1572
1697
  except AttributeError:
1573
1698
  return ""
1574
1699
 
1575
1700
 
1576
1701
  def guess_model_task(model):
1577
- """
1578
- Guess the task of a PyTorch model from its architecture or configuration.
1702
+ """Guess the task of a PyTorch model from its architecture or configuration.
1579
1703
 
1580
1704
  Args:
1581
1705
  model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.