dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py CHANGED
@@ -95,11 +95,10 @@ from ultralytics.utils.torch_utils import (
95
95
 
96
96
 
97
97
  class BaseModel(torch.nn.Module):
98
- """
99
- Base class for all YOLO models in the Ultralytics family.
98
+ """Base class for all YOLO models in the Ultralytics family.
100
99
 
101
- This class provides common functionality for YOLO models including forward pass handling, model fusion,
102
- information display, and weight loading capabilities.
100
+ This class provides common functionality for YOLO models including forward pass handling, model fusion, information
101
+ display, and weight loading capabilities.
103
102
 
104
103
  Attributes:
105
104
  model (torch.nn.Module): The neural network model.
@@ -121,8 +120,7 @@ class BaseModel(torch.nn.Module):
121
120
  """
122
121
 
123
122
  def forward(self, x, *args, **kwargs):
124
- """
125
- Perform forward pass of the model for either training or inference.
123
+ """Perform forward pass of the model for either training or inference.
126
124
 
127
125
  If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
128
126
 
@@ -139,8 +137,7 @@ class BaseModel(torch.nn.Module):
139
137
  return self.predict(x, *args, **kwargs)
140
138
 
141
139
  def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
142
- """
143
- Perform a forward pass through the network.
140
+ """Perform a forward pass through the network.
144
141
 
145
142
  Args:
146
143
  x (torch.Tensor): The input tensor to the model.
@@ -157,8 +154,7 @@ class BaseModel(torch.nn.Module):
157
154
  return self._predict_once(x, profile, visualize, embed)
158
155
 
159
156
  def _predict_once(self, x, profile=False, visualize=False, embed=None):
160
- """
161
- Perform a forward pass through the network.
157
+ """Perform a forward pass through the network.
162
158
 
163
159
  Args:
164
160
  x (torch.Tensor): The input tensor to the model.
@@ -196,8 +192,7 @@ class BaseModel(torch.nn.Module):
196
192
  return self._predict_once(x)
197
193
 
198
194
  def _profile_one_layer(self, m, x, dt):
199
- """
200
- Profile the computation time and FLOPs of a single layer of the model on a given input.
195
+ """Profile the computation time and FLOPs of a single layer of the model on a given input.
201
196
 
202
197
  Args:
203
198
  m (torch.nn.Module): The layer to be profiled.
@@ -222,8 +217,7 @@ class BaseModel(torch.nn.Module):
222
217
  LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
223
218
 
224
219
  def fuse(self, verbose=True):
225
- """
226
- Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
220
+ """Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
227
221
  efficiency.
228
222
 
229
223
  Returns:
@@ -254,8 +248,7 @@ class BaseModel(torch.nn.Module):
254
248
  return self
255
249
 
256
250
  def is_fused(self, thresh=10):
257
- """
258
- Check if the model has less than a certain threshold of BatchNorm layers.
251
+ """Check if the model has less than a certain threshold of BatchNorm layers.
259
252
 
260
253
  Args:
261
254
  thresh (int, optional): The threshold number of BatchNorm layers.
@@ -267,8 +260,7 @@ class BaseModel(torch.nn.Module):
267
260
  return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
268
261
 
269
262
  def info(self, detailed=False, verbose=True, imgsz=640):
270
- """
271
- Print model information.
263
+ """Print model information.
272
264
 
273
265
  Args:
274
266
  detailed (bool): If True, prints out detailed information about the model.
@@ -278,8 +270,7 @@ class BaseModel(torch.nn.Module):
278
270
  return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
279
271
 
280
272
  def _apply(self, fn):
281
- """
282
- Apply a function to all tensors in the model that are not parameters or registered buffers.
273
+ """Apply a function to all tensors in the model that are not parameters or registered buffers.
283
274
 
284
275
  Args:
285
276
  fn (function): The function to apply to the model.
@@ -298,8 +289,7 @@ class BaseModel(torch.nn.Module):
298
289
  return self
299
290
 
300
291
  def load(self, weights, verbose=True):
301
- """
302
- Load weights into the model.
292
+ """Load weights into the model.
303
293
 
304
294
  Args:
305
295
  weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
@@ -324,8 +314,7 @@ class BaseModel(torch.nn.Module):
324
314
  LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
325
315
 
326
316
  def loss(self, batch, preds=None):
327
- """
328
- Compute loss.
317
+ """Compute loss.
329
318
 
330
319
  Args:
331
320
  batch (dict): Batch to compute loss on.
@@ -344,11 +333,10 @@ class BaseModel(torch.nn.Module):
344
333
 
345
334
 
346
335
  class DetectionModel(BaseModel):
347
- """
348
- YOLO detection model.
336
+ """YOLO detection model.
349
337
 
350
- This class implements the YOLO detection architecture, handling model initialization, forward pass,
351
- augmented inference, and loss computation for object detection tasks.
338
+ This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
339
+ inference, and loss computation for object detection tasks.
352
340
 
353
341
  Attributes:
354
342
  yaml (dict): Model configuration dictionary.
@@ -373,8 +361,7 @@ class DetectionModel(BaseModel):
373
361
  """
374
362
 
375
363
  def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
376
- """
377
- Initialize the YOLO detection model with the given config and parameters.
364
+ """Initialize the YOLO detection model with the given config and parameters.
378
365
 
379
366
  Args:
380
367
  cfg (str | dict): Model configuration file path or dictionary.
@@ -420,7 +407,7 @@ class DetectionModel(BaseModel):
420
407
  self.model.train() # Set model back to training(default) mode
421
408
  m.bias_init() # only run once
422
409
  else:
423
- self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
410
+ self.stride = torch.Tensor([32]) # default stride, e.g., RTDETR
424
411
 
425
412
  # Init weights, biases
426
413
  initialize_weights(self)
@@ -429,8 +416,7 @@ class DetectionModel(BaseModel):
429
416
  LOGGER.info("")
430
417
 
431
418
  def _predict_augment(self, x):
432
- """
433
- Perform augmentations on input image x and return augmented inference and train outputs.
419
+ """Perform augmentations on input image x and return augmented inference and train outputs.
434
420
 
435
421
  Args:
436
422
  x (torch.Tensor): Input image tensor.
@@ -455,8 +441,7 @@ class DetectionModel(BaseModel):
455
441
 
456
442
  @staticmethod
457
443
  def _descale_pred(p, flips, scale, img_size, dim=1):
458
- """
459
- De-scale predictions following augmented inference (inverse operation).
444
+ """De-scale predictions following augmented inference (inverse operation).
460
445
 
461
446
  Args:
462
447
  p (torch.Tensor): Predictions tensor.
@@ -477,8 +462,7 @@ class DetectionModel(BaseModel):
477
462
  return torch.cat((x, y, wh, cls), dim)
478
463
 
479
464
  def _clip_augmented(self, y):
480
- """
481
- Clip YOLO augmented inference tails.
465
+ """Clip YOLO augmented inference tails.
482
466
 
483
467
  Args:
484
468
  y (list[torch.Tensor]): List of detection tensors.
@@ -501,11 +485,10 @@ class DetectionModel(BaseModel):
501
485
 
502
486
 
503
487
  class OBBModel(DetectionModel):
504
- """
505
- YOLO Oriented Bounding Box (OBB) model.
488
+ """YOLO Oriented Bounding Box (OBB) model.
506
489
 
507
- This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized
508
- loss computation for rotated object detection.
490
+ This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
491
+ computation for rotated object detection.
509
492
 
510
493
  Methods:
511
494
  __init__: Initialize YOLO OBB model.
@@ -518,8 +501,7 @@ class OBBModel(DetectionModel):
518
501
  """
519
502
 
520
503
  def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
521
- """
522
- Initialize YOLO OBB model with given config and parameters.
504
+ """Initialize YOLO OBB model with given config and parameters.
523
505
 
524
506
  Args:
525
507
  cfg (str | dict): Model configuration file path or dictionary.
@@ -535,11 +517,10 @@ class OBBModel(DetectionModel):
535
517
 
536
518
 
537
519
  class SegmentationModel(DetectionModel):
538
- """
539
- YOLO segmentation model.
520
+ """YOLO segmentation model.
540
521
 
541
- This class extends DetectionModel to handle instance segmentation tasks, providing specialized
542
- loss computation for pixel-level object detection and segmentation.
522
+ This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
523
+ pixel-level object detection and segmentation.
543
524
 
544
525
  Methods:
545
526
  __init__: Initialize YOLO segmentation model.
@@ -552,8 +533,7 @@ class SegmentationModel(DetectionModel):
552
533
  """
553
534
 
554
535
  def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
555
- """
556
- Initialize Ultralytics YOLO segmentation model with given config and parameters.
536
+ """Initialize Ultralytics YOLO segmentation model with given config and parameters.
557
537
 
558
538
  Args:
559
539
  cfg (str | dict): Model configuration file path or dictionary.
@@ -569,11 +549,10 @@ class SegmentationModel(DetectionModel):
569
549
 
570
550
 
571
551
  class PoseModel(DetectionModel):
572
- """
573
- YOLO pose model.
552
+ """YOLO pose model.
574
553
 
575
- This class extends DetectionModel to handle human pose estimation tasks, providing specialized
576
- loss computation for keypoint detection and pose estimation.
554
+ This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
555
+ keypoint detection and pose estimation.
577
556
 
578
557
  Attributes:
579
558
  kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).
@@ -589,8 +568,7 @@ class PoseModel(DetectionModel):
589
568
  """
590
569
 
591
570
  def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
592
- """
593
- Initialize Ultralytics YOLO Pose model.
571
+ """Initialize Ultralytics YOLO Pose model.
594
572
 
595
573
  Args:
596
574
  cfg (str | dict): Model configuration file path or dictionary.
@@ -612,11 +590,10 @@ class PoseModel(DetectionModel):
612
590
 
613
591
 
614
592
  class ClassificationModel(BaseModel):
615
- """
616
- YOLO classification model.
593
+ """YOLO classification model.
617
594
 
618
- This class implements the YOLO classification architecture for image classification tasks,
619
- providing model initialization, configuration, and output reshaping capabilities.
595
+ This class implements the YOLO classification architecture for image classification tasks, providing model
596
+ initialization, configuration, and output reshaping capabilities.
620
597
 
621
598
  Attributes:
622
599
  yaml (dict): Model configuration dictionary.
@@ -637,8 +614,7 @@ class ClassificationModel(BaseModel):
637
614
  """
638
615
 
639
616
  def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
640
- """
641
- Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
617
+ """Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
642
618
 
643
619
  Args:
644
620
  cfg (str | dict): Model configuration file path or dictionary.
@@ -650,8 +626,7 @@ class ClassificationModel(BaseModel):
650
626
  self._from_yaml(cfg, ch, nc, verbose)
651
627
 
652
628
  def _from_yaml(self, cfg, ch, nc, verbose):
653
- """
654
- Set Ultralytics YOLO model configurations and define the model architecture.
629
+ """Set Ultralytics YOLO model configurations and define the model architecture.
655
630
 
656
631
  Args:
657
632
  cfg (str | dict): Model configuration file path or dictionary.
@@ -675,8 +650,7 @@ class ClassificationModel(BaseModel):
675
650
 
676
651
  @staticmethod
677
652
  def reshape_outputs(model, nc):
678
- """
679
- Update a TorchVision classification model to class count 'n' if required.
653
+ """Update a TorchVision classification model to class count 'n' if required.
680
654
 
681
655
  Args:
682
656
  model (torch.nn.Module): Model to update.
@@ -708,8 +682,7 @@ class ClassificationModel(BaseModel):
708
682
 
709
683
 
710
684
  class RTDETRDetectionModel(DetectionModel):
711
- """
712
- RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
685
+ """RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
713
686
 
714
687
  This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
715
688
  the training and inference processes. RTDETR is an object detection and tracking model that extends from the
@@ -732,8 +705,7 @@ class RTDETRDetectionModel(DetectionModel):
732
705
  """
733
706
 
734
707
  def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
735
- """
736
- Initialize the RTDETRDetectionModel.
708
+ """Initialize the RTDETRDetectionModel.
737
709
 
738
710
  Args:
739
711
  cfg (str | dict): Configuration file name or path.
@@ -743,6 +715,21 @@ class RTDETRDetectionModel(DetectionModel):
743
715
  """
744
716
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
745
717
 
718
+ def _apply(self, fn):
719
+ """Apply a function to all tensors in the model that are not parameters or registered buffers.
720
+
721
+ Args:
722
+ fn (function): The function to apply to the model.
723
+
724
+ Returns:
725
+ (RTDETRDetectionModel): An updated BaseModel object.
726
+ """
727
+ self = super()._apply(fn)
728
+ m = self.model[-1]
729
+ m.anchors = fn(m.anchors)
730
+ m.valid_mask = fn(m.valid_mask)
731
+ return self
732
+
746
733
  def init_criterion(self):
747
734
  """Initialize the loss criterion for the RTDETRDetectionModel."""
748
735
  from ultralytics.models.utils.loss import RTDETRDetectionLoss
@@ -750,8 +737,7 @@ class RTDETRDetectionModel(DetectionModel):
750
737
  return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
751
738
 
752
739
  def loss(self, batch, preds=None):
753
- """
754
- Compute the loss for the given batch of data.
740
+ """Compute the loss for the given batch of data.
755
741
 
756
742
  Args:
757
743
  batch (dict): Dictionary containing image and label data.
@@ -766,7 +752,7 @@ class RTDETRDetectionModel(DetectionModel):
766
752
 
767
753
  img = batch["img"]
768
754
  # NOTE: preprocess gt_bbox and gt_labels to list.
769
- bs = len(img)
755
+ bs = img.shape[0]
770
756
  batch_idx = batch["batch_idx"]
771
757
  gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
772
758
  targets = {
@@ -797,8 +783,7 @@ class RTDETRDetectionModel(DetectionModel):
797
783
  )
798
784
 
799
785
  def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
800
- """
801
- Perform a forward pass through the model.
786
+ """Perform a forward pass through the model.
802
787
 
803
788
  Args:
804
789
  x (torch.Tensor): The input tensor.
@@ -833,11 +818,10 @@ class RTDETRDetectionModel(DetectionModel):
833
818
 
834
819
 
835
820
  class WorldModel(DetectionModel):
836
- """
837
- YOLOv8 World Model.
821
+ """YOLOv8 World Model.
838
822
 
839
- This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based
840
- class specification and CLIP model integration for zero-shot detection capabilities.
823
+ This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
824
+ specification and CLIP model integration for zero-shot detection capabilities.
841
825
 
842
826
  Attributes:
843
827
  txt_feats (torch.Tensor): Text feature embeddings for classes.
@@ -858,8 +842,7 @@ class WorldModel(DetectionModel):
858
842
  """
859
843
 
860
844
  def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
861
- """
862
- Initialize YOLOv8 world model with given config and parameters.
845
+ """Initialize YOLOv8 world model with given config and parameters.
863
846
 
864
847
  Args:
865
848
  cfg (str | dict): Model configuration file path or dictionary.
@@ -872,8 +855,7 @@ class WorldModel(DetectionModel):
872
855
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
873
856
 
874
857
  def set_classes(self, text, batch=80, cache_clip_model=True):
875
- """
876
- Set classes in advance so that model could do offline-inference without clip model.
858
+ """Set classes in advance so that model could do offline-inference without clip model.
877
859
 
878
860
  Args:
879
861
  text (list[str]): List of class names.
@@ -884,8 +866,7 @@ class WorldModel(DetectionModel):
884
866
  self.model[-1].nc = len(text)
885
867
 
886
868
  def get_text_pe(self, text, batch=80, cache_clip_model=True):
887
- """
888
- Set classes in advance so that model could do offline-inference without clip model.
869
+ """Get text positional embeddings for offline inference without CLIP model.
889
870
 
890
871
  Args:
891
872
  text (list[str]): List of class names.
@@ -908,8 +889,7 @@ class WorldModel(DetectionModel):
908
889
  return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
909
890
 
910
891
  def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
911
- """
912
- Perform a forward pass through the model.
892
+ """Perform a forward pass through the model.
913
893
 
914
894
  Args:
915
895
  x (torch.Tensor): The input tensor.
@@ -923,7 +903,7 @@ class WorldModel(DetectionModel):
923
903
  (torch.Tensor): Model's output tensor.
924
904
  """
925
905
  txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
926
- if len(txt_feats) != len(x) or self.model[-1].export:
906
+ if txt_feats.shape[0] != x.shape[0] or self.model[-1].export:
927
907
  txt_feats = txt_feats.expand(x.shape[0], -1, -1)
928
908
  ori_txt_feats = txt_feats.clone()
929
909
  y, dt, embeddings = [], [], [] # outputs
@@ -953,8 +933,7 @@ class WorldModel(DetectionModel):
953
933
  return x
954
934
 
955
935
  def loss(self, batch, preds=None):
956
- """
957
- Compute loss.
936
+ """Compute loss.
958
937
 
959
938
  Args:
960
939
  batch (dict): Batch to compute loss on.
@@ -969,11 +948,10 @@ class WorldModel(DetectionModel):
969
948
 
970
949
 
971
950
  class YOLOEModel(DetectionModel):
972
- """
973
- YOLOE detection model.
951
+ """YOLOE detection model.
974
952
 
975
- This class implements the YOLOE architecture for efficient object detection with text and visual prompts,
976
- supporting both prompt-based and prompt-free inference modes.
953
+ This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
954
+ both prompt-based and prompt-free inference modes.
977
955
 
978
956
  Attributes:
979
957
  pe (torch.Tensor): Prompt embeddings for classes.
@@ -997,8 +975,7 @@ class YOLOEModel(DetectionModel):
997
975
  """
998
976
 
999
977
  def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
1000
- """
1001
- Initialize YOLOE model with given config and parameters.
978
+ """Initialize YOLOE model with given config and parameters.
1002
979
 
1003
980
  Args:
1004
981
  cfg (str | dict): Model configuration file path or dictionary.
@@ -1010,14 +987,13 @@ class YOLOEModel(DetectionModel):
1010
987
 
1011
988
  @smart_inference_mode()
1012
989
  def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
1013
- """
1014
- Set classes in advance so that model could do offline-inference without clip model.
990
+ """Get text positional embeddings for offline inference without CLIP model.
1015
991
 
1016
992
  Args:
1017
993
  text (list[str]): List of class names.
1018
994
  batch (int): Batch size for processing text tokens.
1019
995
  cache_clip_model (bool): Whether to cache the CLIP model.
1020
- without_reprta (bool): Whether to return text embeddings cooperated with reprta module.
996
+ without_reprta (bool): Whether to return text embeddings without reprta module processing.
1021
997
 
1022
998
  Returns:
1023
999
  (torch.Tensor): Text positional embeddings.
@@ -1037,15 +1013,13 @@ class YOLOEModel(DetectionModel):
1037
1013
  if without_reprta:
1038
1014
  return txt_feats
1039
1015
 
1040
- assert not self.training
1041
1016
  head = self.model[-1]
1042
1017
  assert isinstance(head, YOLOEDetect)
1043
1018
  return head.get_tpe(txt_feats) # run auxiliary text head
1044
1019
 
1045
1020
  @smart_inference_mode()
1046
1021
  def get_visual_pe(self, img, visual):
1047
- """
1048
- Get visual embeddings.
1022
+ """Get visual embeddings.
1049
1023
 
1050
1024
  Args:
1051
1025
  img (torch.Tensor): Input image tensor.
@@ -1057,8 +1031,7 @@ class YOLOEModel(DetectionModel):
1057
1031
  return self(img, vpe=visual, return_vpe=True)
1058
1032
 
1059
1033
  def set_vocab(self, vocab, names):
1060
- """
1061
- Set vocabulary for the prompt-free model.
1034
+ """Set vocabulary for the prompt-free model.
1062
1035
 
1063
1036
  Args:
1064
1037
  vocab (nn.ModuleList): List of vocabulary items.
@@ -1086,8 +1059,7 @@ class YOLOEModel(DetectionModel):
1086
1059
  self.names = check_class_names(names)
1087
1060
 
1088
1061
  def get_vocab(self, names):
1089
- """
1090
- Get fused vocabulary layer from the model.
1062
+ """Get fused vocabulary layer from the model.
1091
1063
 
1092
1064
  Args:
1093
1065
  names (list): List of class names.
@@ -1112,8 +1084,7 @@ class YOLOEModel(DetectionModel):
1112
1084
  return vocab
1113
1085
 
1114
1086
  def set_classes(self, names, embeddings):
1115
- """
1116
- Set classes in advance so that model could do offline-inference without clip model.
1087
+ """Set classes in advance so that model could do offline-inference without clip model.
1117
1088
 
1118
1089
  Args:
1119
1090
  names (list[str]): List of class names.
@@ -1128,8 +1099,7 @@ class YOLOEModel(DetectionModel):
1128
1099
  self.names = check_class_names(names)
1129
1100
 
1130
1101
  def get_cls_pe(self, tpe, vpe):
1131
- """
1132
- Get class positional embeddings.
1102
+ """Get class positional embeddings.
1133
1103
 
1134
1104
  Args:
1135
1105
  tpe (torch.Tensor, optional): Text positional embeddings.
@@ -1152,8 +1122,7 @@ class YOLOEModel(DetectionModel):
1152
1122
  def predict(
1153
1123
  self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
1154
1124
  ):
1155
- """
1156
- Perform a forward pass through the model.
1125
+ """Perform a forward pass through the model.
1157
1126
 
1158
1127
  Args:
1159
1128
  x (torch.Tensor): The input tensor.
@@ -1200,8 +1169,7 @@ class YOLOEModel(DetectionModel):
1200
1169
  return x
1201
1170
 
1202
1171
  def loss(self, batch, preds=None):
1203
- """
1204
- Compute loss.
1172
+ """Compute loss.
1205
1173
 
1206
1174
  Args:
1207
1175
  batch (dict): Batch to compute loss on.
@@ -1219,11 +1187,10 @@ class YOLOEModel(DetectionModel):
1219
1187
 
1220
1188
 
1221
1189
  class YOLOESegModel(YOLOEModel, SegmentationModel):
1222
- """
1223
- YOLOE segmentation model.
1190
+ """YOLOE segmentation model.
1224
1191
 
1225
- This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts,
1226
- providing specialized loss computation for pixel-level object detection and segmentation.
1192
+ This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
1193
+ specialized loss computation for pixel-level object detection and segmentation.
1227
1194
 
1228
1195
  Methods:
1229
1196
  __init__: Initialize YOLOE segmentation model.
@@ -1236,8 +1203,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1236
1203
  """
1237
1204
 
1238
1205
  def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
1239
- """
1240
- Initialize YOLOE segmentation model with given config and parameters.
1206
+ """Initialize YOLOE segmentation model with given config and parameters.
1241
1207
 
1242
1208
  Args:
1243
1209
  cfg (str | dict): Model configuration file path or dictionary.
@@ -1248,8 +1214,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1248
1214
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
1249
1215
 
1250
1216
  def loss(self, batch, preds=None):
1251
- """
1252
- Compute loss.
1217
+ """Compute loss.
1253
1218
 
1254
1219
  Args:
1255
1220
  batch (dict): Batch to compute loss on.
@@ -1267,11 +1232,10 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1267
1232
 
1268
1233
 
1269
1234
  class Ensemble(torch.nn.ModuleList):
1270
- """
1271
- Ensemble of models.
1235
+ """Ensemble of models.
1272
1236
 
1273
- This class allows combining multiple YOLO models into an ensemble for improved performance through
1274
- model averaging or other ensemble techniques.
1237
+ This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
1238
+ or other ensemble techniques.
1275
1239
 
1276
1240
  Methods:
1277
1241
  __init__: Initialize an ensemble of models.
@@ -1290,8 +1254,7 @@ class Ensemble(torch.nn.ModuleList):
1290
1254
  super().__init__()
1291
1255
 
1292
1256
  def forward(self, x, augment=False, profile=False, visualize=False):
1293
- """
1294
- Generate the YOLO network's final layer.
1257
+ """Generate the YOLO network's final layer.
1295
1258
 
1296
1259
  Args:
1297
1260
  x (torch.Tensor): Input tensor.
@@ -1315,12 +1278,11 @@ class Ensemble(torch.nn.ModuleList):
1315
1278
 
1316
1279
  @contextlib.contextmanager
1317
1280
  def temporary_modules(modules=None, attributes=None):
1318
- """
1319
- Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
1281
+ """Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
1320
1282
 
1321
- This function can be used to change the module paths during runtime. It's useful when refactoring code,
1322
- where you've moved a module from one location to another, but you still want to support the old import
1323
- paths for backwards compatibility.
1283
+ This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
1284
+ moved a module from one location to another, but you still want to support the old import paths for backwards
1285
+ compatibility.
1324
1286
 
1325
1287
  Args:
1326
1288
  modules (dict, optional): A dictionary mapping old module paths to new module paths.
@@ -1331,7 +1293,7 @@ def temporary_modules(modules=None, attributes=None):
1331
1293
  >>> import old.module # this will now import new.module
1332
1294
  >>> from old.module import attribute # this will now import new.module.attribute
1333
1295
 
1334
- Note:
1296
+ Notes:
1335
1297
  The changes are only in effect inside the context manager and are undone once the context manager exits.
1336
1298
  Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
1337
1299
  applications or libraries. Use this function with caution.
@@ -1378,8 +1340,7 @@ class SafeUnpickler(pickle.Unpickler):
1378
1340
  """Custom Unpickler that replaces unknown classes with SafeClass."""
1379
1341
 
1380
1342
  def find_class(self, module, name):
1381
- """
1382
- Attempt to find a class, returning SafeClass if not among safe modules.
1343
+ """Attempt to find a class, returning SafeClass if not among safe modules.
1383
1344
 
1384
1345
  Args:
1385
1346
  module (str): Module name.
@@ -1404,10 +1365,9 @@ class SafeUnpickler(pickle.Unpickler):
1404
1365
 
1405
1366
 
1406
1367
  def torch_safe_load(weight, safe_only=False):
1407
- """
1408
- Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
1409
- error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
1410
- After installation, the function again attempts to load the model using torch.load().
1368
+ """Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
1369
+ the error, logs a warning message, and attempts to install the missing module via the check_requirements()
1370
+ function. After installation, the function again attempts to load the model using torch.load().
1411
1371
 
1412
1372
  Args:
1413
1373
  weight (str): The file path of the PyTorch model.
@@ -1486,8 +1446,7 @@ def torch_safe_load(weight, safe_only=False):
1486
1446
 
1487
1447
 
1488
1448
  def load_checkpoint(weight, device=None, inplace=True, fuse=False):
1489
- """
1490
- Load a single model weights.
1449
+ """Load a single model weights.
1491
1450
 
1492
1451
  Args:
1493
1452
  weight (str | Path): Model weight path.
@@ -1524,8 +1483,7 @@ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
1524
1483
 
1525
1484
 
1526
1485
  def parse_model(d, ch, verbose=True):
1527
- """
1528
- Parse a YOLO model.yaml dictionary into a PyTorch model.
1486
+ """Parse a YOLO model.yaml dictionary into a PyTorch model.
1529
1487
 
1530
1488
  Args:
1531
1489
  d (dict): Model dictionary.
@@ -1543,10 +1501,10 @@ def parse_model(d, ch, verbose=True):
1543
1501
  max_channels = float("inf")
1544
1502
  nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
1545
1503
  depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
1504
+ scale = d.get("scale")
1546
1505
  if scales:
1547
- scale = d.get("scale")
1548
1506
  if not scale:
1549
- scale = tuple(scales.keys())[0]
1507
+ scale = next(iter(scales.keys()))
1550
1508
  LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
1551
1509
  depth, width, max_channels = scales[scale]
1552
1510
 
@@ -1631,7 +1589,7 @@ def parse_model(d, ch, verbose=True):
1631
1589
  n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
1632
1590
  if m in base_modules:
1633
1591
  c1, c2 = ch[f], args[0]
1634
- if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
1592
+ if c2 != nc: # if c2 != nc (e.g., Classify() output)
1635
1593
  c2 = make_divisible(min(c2, max_channels) * width, 8)
1636
1594
  if m is C2fAttn: # set 1) embed channels and 2) num heads
1637
1595
  args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
@@ -1693,7 +1651,7 @@ def parse_model(d, ch, verbose=True):
1693
1651
  m_.np = sum(x.numel() for x in m_.parameters()) # number params
1694
1652
  m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
1695
1653
  if verbose:
1696
- LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
1654
+ LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f} {t:<45}{args!s:<30}") # print
1697
1655
  save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
1698
1656
  layers.append(m_)
1699
1657
  if i == 0:
@@ -1703,8 +1661,7 @@ def parse_model(d, ch, verbose=True):
1703
1661
 
1704
1662
 
1705
1663
  def yaml_model_load(path):
1706
- """
1707
- Load a YOLOv8 model from a YAML file.
1664
+ """Load a YOLOv8 model from a YAML file.
1708
1665
 
1709
1666
  Args:
1710
1667
  path (str | Path): Path to the YAML file.
@@ -1727,8 +1684,7 @@ def yaml_model_load(path):
1727
1684
 
1728
1685
 
1729
1686
  def guess_model_scale(model_path):
1730
- """
1731
- Extract the size character n, s, m, l, or x of the model's scale from the model path.
1687
+ """Extract the size character n, s, m, l, or x of the model's scale from the model path.
1732
1688
 
1733
1689
  Args:
1734
1690
  model_path (str | Path): The path to the YOLO model's YAML file.
@@ -1737,14 +1693,13 @@ def guess_model_scale(model_path):
1737
1693
  (str): The size character of the model's scale (n, s, m, l, or x).
1738
1694
  """
1739
1695
  try:
1740
- return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2) # noqa
1696
+ return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
1741
1697
  except AttributeError:
1742
1698
  return ""
1743
1699
 
1744
1700
 
1745
1701
  def guess_model_task(model):
1746
- """
1747
- Guess the task of a PyTorch model from its architecture or configuration.
1702
+ """Guess the task of a PyTorch model from its architecture or configuration.
1748
1703
 
1749
1704
  Args:
1750
1705
  model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
@@ -1775,10 +1730,10 @@ def guess_model_task(model):
1775
1730
  if isinstance(model, torch.nn.Module): # PyTorch model
1776
1731
  for x in "model.args", "model.model.args", "model.model.model.args":
1777
1732
  with contextlib.suppress(Exception):
1778
- return eval(x)["task"]
1733
+ return eval(x)["task"] # nosec B307: safe eval of known attribute paths
1779
1734
  for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
1780
1735
  with contextlib.suppress(Exception):
1781
- return cfg2task(eval(x))
1736
+ return cfg2task(eval(x)) # nosec B307: safe eval of known attribute paths
1782
1737
  for m in model.modules():
1783
1738
  if isinstance(m, (Segment, YOLOESegment)):
1784
1739
  return "segment"