ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +37 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +111 -41
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +579 -244
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +191 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +226 -82
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +172 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +305 -112
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.63.dist-info/METADATA +370 -0
  235. ultralytics-8.3.63.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py CHANGED
@@ -1,9 +1,13 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import contextlib
4
+ import pickle
5
+ import re
6
+ import types
4
7
  from copy import deepcopy
5
8
  from pathlib import Path
6
9
 
10
+ import thop
7
11
  import torch
8
12
  import torch.nn as nn
9
13
 
@@ -11,18 +15,28 @@ from ultralytics.nn.modules import (
11
15
  AIFI,
12
16
  C1,
13
17
  C2,
18
+ C2PSA,
14
19
  C3,
15
20
  C3TR,
21
+ ELAN1,
16
22
  OBB,
23
+ PSA,
17
24
  SPP,
25
+ SPPELAN,
18
26
  SPPF,
27
+ AConv,
28
+ ADown,
19
29
  Bottleneck,
20
30
  BottleneckCSP,
21
31
  C2f,
22
32
  C2fAttn,
23
- ImagePoolingAttn,
33
+ C2fCIB,
34
+ C2fPSA,
24
35
  C3Ghost,
36
+ C3k2,
25
37
  C3x,
38
+ CBFuse,
39
+ CBLinear,
26
40
  Classify,
27
41
  Concat,
28
42
  Conv,
@@ -36,53 +50,60 @@ from ultralytics.nn.modules import (
36
50
  GhostConv,
37
51
  HGBlock,
38
52
  HGStem,
53
+ ImagePoolingAttn,
54
+ Index,
39
55
  Pose,
40
56
  RepC3,
41
57
  RepConv,
58
+ RepNCSPELAN4,
59
+ RepVGGDW,
42
60
  ResNetLayer,
43
61
  RTDETRDecoder,
62
+ SCDown,
44
63
  Segment,
64
+ TorchVision,
45
65
  WorldDetect,
46
- RepNCSPELAN4,
47
- ADown,
48
- SPPELAN,
49
- CBFuse,
50
- CBLinear,
51
- Silence,
66
+ v10Detect,
52
67
  )
53
68
  from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
54
69
  from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
55
- from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
70
+ from ultralytics.utils.loss import (
71
+ E2EDetectLoss,
72
+ v8ClassificationLoss,
73
+ v8DetectionLoss,
74
+ v8OBBLoss,
75
+ v8PoseLoss,
76
+ v8SegmentationLoss,
77
+ )
78
+ from ultralytics.utils.ops import make_divisible
56
79
  from ultralytics.utils.plotting import feature_visualization
57
80
  from ultralytics.utils.torch_utils import (
58
81
  fuse_conv_and_bn,
59
82
  fuse_deconv_and_bn,
60
83
  initialize_weights,
61
84
  intersect_dicts,
62
- make_divisible,
63
85
  model_info,
64
86
  scale_img,
65
87
  time_sync,
66
88
  )
67
89
 
68
- try:
69
- import thop
70
- except ImportError:
71
- thop = None
72
-
73
90
 
74
91
  class BaseModel(nn.Module):
75
92
  """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
76
93
 
77
94
  def forward(self, x, *args, **kwargs):
78
95
  """
79
- Forward pass of the model on a single scale. Wrapper for `_forward_once` method.
96
+ Perform forward pass of the model for either training or inference.
97
+
98
+ If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
80
99
 
81
100
  Args:
82
- x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
101
+ x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
102
+ *args (Any): Variable length argument list.
103
+ **kwargs (Any): Arbitrary keyword arguments.
83
104
 
84
105
  Returns:
85
- (torch.Tensor): The output of the network.
106
+ (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
86
107
  """
87
108
  if isinstance(x, dict): # for cases of training and validating while training.
88
109
  return self.loss(x, *args, **kwargs)
@@ -138,8 +159,8 @@ class BaseModel(nn.Module):
138
159
  def _predict_augment(self, x):
139
160
  """Perform augmentations on input image x and return augmented inference."""
140
161
  LOGGER.warning(
141
- f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
142
- f"Reverting to single-scale inference instead."
162
+ f"WARNING ⚠️ {self.__class__.__name__} does not support 'augment=True' prediction. "
163
+ f"Reverting to single-scale prediction."
143
164
  )
144
165
  return self._predict_once(x)
145
166
 
@@ -157,7 +178,7 @@ class BaseModel(nn.Module):
157
178
  None
158
179
  """
159
180
  c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
160
- flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
181
+ flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
161
182
  t = time_sync()
162
183
  for _ in range(10):
163
184
  m(x.copy() if c else x)
@@ -191,6 +212,9 @@ class BaseModel(nn.Module):
191
212
  if isinstance(m, RepConv):
192
213
  m.fuse_convs()
193
214
  m.forward = m.forward_fuse # update forward
215
+ if isinstance(m, RepVGGDW):
216
+ m.fuse()
217
+ m.forward = m.forward_fuse
194
218
  self.info(verbose=verbose)
195
219
 
196
220
  return self
@@ -260,7 +284,7 @@ class BaseModel(nn.Module):
260
284
  batch (dict): Batch to compute loss on
261
285
  preds (torch.Tensor | List[torch.Tensor]): Predictions.
262
286
  """
263
- if not hasattr(self, "criterion"):
287
+ if getattr(self, "criterion", None) is None:
264
288
  self.criterion = self.init_criterion()
265
289
 
266
290
  preds = self.forward(batch["img"]) if preds is None else preds
@@ -278,6 +302,12 @@ class DetectionModel(BaseModel):
278
302
  """Initialize the YOLOv8 detection model with the given config and parameters."""
279
303
  super().__init__()
280
304
  self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
305
+ if self.yaml["backbone"][0][2] == "Silence":
306
+ LOGGER.warning(
307
+ "WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. "
308
+ "Please delete local *.pt file and re-download the latest model checkpoint."
309
+ )
310
+ self.yaml["backbone"][0][2] = "nn.Identity"
281
311
 
282
312
  # Define model
283
313
  ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
@@ -287,14 +317,21 @@ class DetectionModel(BaseModel):
287
317
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
288
318
  self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
289
319
  self.inplace = self.yaml.get("inplace", True)
320
+ self.end2end = getattr(self.model[-1], "end2end", False)
290
321
 
291
322
  # Build strides
292
323
  m = self.model[-1] # Detect()
293
324
  if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
294
325
  s = 256 # 2x min stride
295
326
  m.inplace = self.inplace
296
- forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
297
- m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
327
+
328
+ def _forward(x):
329
+ """Performs a forward pass through the model, handling different Detect subclass types accordingly."""
330
+ if self.end2end:
331
+ return self.forward(x)["one2many"]
332
+ return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
333
+
334
+ m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
298
335
  self.stride = m.stride
299
336
  m.bias_init() # only run once
300
337
  else:
@@ -308,6 +345,9 @@ class DetectionModel(BaseModel):
308
345
 
309
346
  def _predict_augment(self, x):
310
347
  """Perform augmentations on input image x and return augmented inference and train outputs."""
348
+ if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
349
+ LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.")
350
+ return self._predict_once(x)
311
351
  img_size = x.shape[-2:] # height, width
312
352
  s = [1, 0.83, 0.67] # scales
313
353
  f = [None, 3, None] # flips (2-ud, 3-lr)
@@ -344,7 +384,7 @@ class DetectionModel(BaseModel):
344
384
 
345
385
  def init_criterion(self):
346
386
  """Initialize the loss criterion for the DetectionModel."""
347
- return v8DetectionLoss(self)
387
+ return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
348
388
 
349
389
 
350
390
  class OBBModel(DetectionModel):
@@ -425,11 +465,11 @@ class ClassificationModel(BaseModel):
425
465
  elif isinstance(m, nn.Sequential):
426
466
  types = [type(x) for x in m]
427
467
  if nn.Linear in types:
428
- i = types.index(nn.Linear) # nn.Linear index
468
+ i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
429
469
  if m[i].out_features != nc:
430
470
  m[i] = nn.Linear(m[i].in_features, nc)
431
471
  elif nn.Conv2d in types:
432
- i = types.index(nn.Conv2d) # nn.Conv2d index
472
+ i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
433
473
  if m[i].out_channels != nc:
434
474
  m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
435
475
 
@@ -560,30 +600,32 @@ class WorldModel(DetectionModel):
560
600
 
561
601
  def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
562
602
  """Initialize YOLOv8 world model with given config and parameters."""
563
- self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder
603
+ self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
604
+ self.clip_model = None # CLIP model placeholder
564
605
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
565
606
 
566
- def set_classes(self, text):
567
- """Perform a forward pass with optional profiling, visualization, and embedding extraction."""
607
+ def set_classes(self, text, batch=80, cache_clip_model=True):
608
+ """Set classes in advance so that model could do offline-inference without clip model."""
568
609
  try:
569
610
  import clip
570
611
  except ImportError:
571
- check_requirements("git+https://github.com/openai/CLIP.git")
612
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
572
613
  import clip
573
614
 
574
- model, _ = clip.load("ViT-B/32")
615
+ if (
616
+ not getattr(self, "clip_model", None) and cache_clip_model
617
+ ): # for backwards compatibility of models lacking clip_model attribute
618
+ self.clip_model = clip.load("ViT-B/32")[0]
619
+ model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
575
620
  device = next(model.parameters()).device
576
621
  text_token = clip.tokenize(text).to(device)
577
- txt_feats = model.encode_text(text_token).to(dtype=torch.float32)
622
+ txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
623
+ txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
578
624
  txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
579
- self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
625
+ self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
580
626
  self.model[-1].nc = len(text)
581
627
 
582
- def init_criterion(self):
583
- """Initialize the loss criterion for the model."""
584
- raise NotImplementedError
585
-
586
- def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
628
+ def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
587
629
  """
588
630
  Perform a forward pass through the model.
589
631
 
@@ -591,13 +633,14 @@ class WorldModel(DetectionModel):
591
633
  x (torch.Tensor): The input tensor.
592
634
  profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
593
635
  visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
636
+ txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
594
637
  augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
595
638
  embed (list, optional): A list of feature vectors/embeddings to return.
596
639
 
597
640
  Returns:
598
641
  (torch.Tensor): Model's output tensor.
599
642
  """
600
- txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
643
+ txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
601
644
  if len(txt_feats) != len(x):
602
645
  txt_feats = txt_feats.repeat(len(x), 1, 1)
603
646
  ori_txt_feats = txt_feats.clone()
@@ -625,6 +668,21 @@ class WorldModel(DetectionModel):
625
668
  return torch.unbind(torch.cat(embeddings, 1), dim=0)
626
669
  return x
627
670
 
671
+ def loss(self, batch, preds=None):
672
+ """
673
+ Compute loss.
674
+
675
+ Args:
676
+ batch (dict): Batch to compute loss on.
677
+ preds (torch.Tensor | List[torch.Tensor]): Predictions.
678
+ """
679
+ if not hasattr(self, "criterion"):
680
+ self.criterion = self.init_criterion()
681
+
682
+ if preds is None:
683
+ preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
684
+ return self.criterion(preds, batch)
685
+
628
686
 
629
687
  class Ensemble(nn.ModuleList):
630
688
  """Ensemble of models."""
@@ -646,7 +704,7 @@ class Ensemble(nn.ModuleList):
646
704
 
647
705
 
648
706
  @contextlib.contextmanager
649
- def temporary_modules(modules=None):
707
+ def temporary_modules(modules=None, attributes=None):
650
708
  """
651
709
  Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
652
710
 
@@ -656,11 +714,13 @@ def temporary_modules(modules=None):
656
714
 
657
715
  Args:
658
716
  modules (dict, optional): A dictionary mapping old module paths to new module paths.
717
+ attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
659
718
 
660
719
  Example:
661
720
  ```python
662
- with temporary_modules({'old.module.path': 'new.module.path'}):
663
- import old.module.path # this will now import new.module.path
721
+ with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
722
+ import old.module # this will now import new.module
723
+ from old.module import attribute # this will now import new.module.attribute
664
724
  ```
665
725
 
666
726
  Note:
@@ -668,16 +728,23 @@ def temporary_modules(modules=None):
668
728
  Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
669
729
  applications or libraries. Use this function with caution.
670
730
  """
671
- if not modules:
731
+ if modules is None:
672
732
  modules = {}
673
-
674
- import importlib
733
+ if attributes is None:
734
+ attributes = {}
675
735
  import sys
736
+ from importlib import import_module
676
737
 
677
738
  try:
739
+ # Set attributes in sys.modules under their old name
740
+ for old, new in attributes.items():
741
+ old_module, old_attr = old.rsplit(".", 1)
742
+ new_module, new_attr = new.rsplit(".", 1)
743
+ setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
744
+
678
745
  # Set modules in sys.modules under their old name
679
746
  for old, new in modules.items():
680
- sys.modules[old] = importlib.import_module(new)
747
+ sys.modules[old] = import_module(new)
681
748
 
682
749
  yield
683
750
  finally:
@@ -687,17 +754,58 @@ def temporary_modules(modules=None):
687
754
  del sys.modules[old]
688
755
 
689
756
 
690
- def torch_safe_load(weight):
757
+ class SafeClass:
758
+ """A placeholder class to replace unknown classes during unpickling."""
759
+
760
+ def __init__(self, *args, **kwargs):
761
+ """Initialize SafeClass instance, ignoring all arguments."""
762
+ pass
763
+
764
+ def __call__(self, *args, **kwargs):
765
+ """Run SafeClass instance, ignoring all arguments."""
766
+ pass
767
+
768
+
769
+ class SafeUnpickler(pickle.Unpickler):
770
+ """Custom Unpickler that replaces unknown classes with SafeClass."""
771
+
772
+ def find_class(self, module, name):
773
+ """Attempt to find a class, returning SafeClass if not among safe modules."""
774
+ safe_modules = (
775
+ "torch",
776
+ "collections",
777
+ "collections.abc",
778
+ "builtins",
779
+ "math",
780
+ "numpy",
781
+ # Add other modules considered safe
782
+ )
783
+ if module in safe_modules:
784
+ return super().find_class(module, name)
785
+ else:
786
+ return SafeClass
787
+
788
+
789
+ def torch_safe_load(weight, safe_only=False):
691
790
  """
692
- This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
693
- it catches the error, logs a warning message, and attempts to install the missing module via the
694
- check_requirements() function. After installation, the function again attempts to load the model using torch.load().
791
+ Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
792
+ error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
793
+ After installation, the function again attempts to load the model using torch.load().
695
794
 
696
795
  Args:
697
796
  weight (str): The file path of the PyTorch model.
797
+ safe_only (bool): If True, replace unknown classes with SafeClass during loading.
798
+
799
+ Example:
800
+ ```python
801
+ from ultralytics.nn.tasks import torch_safe_load
802
+
803
+ ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
804
+ ```
698
805
 
699
806
  Returns:
700
- (dict): The loaded PyTorch model.
807
+ ckpt (dict): The loaded model checkpoint.
808
+ file (str): The loaded filename
701
809
  """
702
810
  from ultralytics.utils.downloads import attempt_download_asset
703
811
 
@@ -705,13 +813,26 @@ def torch_safe_load(weight):
705
813
  file = attempt_download_asset(weight) # search online if missing locally
706
814
  try:
707
815
  with temporary_modules(
708
- {
816
+ modules={
709
817
  "ultralytics.yolo.utils": "ultralytics.utils",
710
818
  "ultralytics.yolo.v8": "ultralytics.models.yolo",
711
819
  "ultralytics.yolo.data": "ultralytics.data",
712
- }
713
- ): # for legacy 8.0 Classify and Pose models
714
- ckpt = torch.load(file, map_location="cpu")
820
+ },
821
+ attributes={
822
+ "ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
823
+ "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
824
+ "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
825
+ },
826
+ ):
827
+ if safe_only:
828
+ # Load via custom pickle module
829
+ safe_pickle = types.ModuleType("safe_pickle")
830
+ safe_pickle.Unpickler = SafeUnpickler
831
+ safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
832
+ with open(file, "rb") as f:
833
+ ckpt = torch.load(f, pickle_module=safe_pickle)
834
+ else:
835
+ ckpt = torch.load(file, map_location="cpu")
715
836
 
716
837
  except ModuleNotFoundError as e: # e.name is missing module name
717
838
  if e.name == "models":
@@ -721,14 +842,14 @@ def torch_safe_load(weight):
721
842
  f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
722
843
  f"YOLOv8 at https://github.com/ultralytics/ultralytics."
723
844
  f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
724
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
845
+ f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
725
846
  )
726
847
  ) from e
727
848
  LOGGER.warning(
728
- f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
849
+ f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
729
850
  f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
730
851
  f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
731
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
852
+ f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
732
853
  )
733
854
  check_requirements(e.name) # install missing module
734
855
  ckpt = torch.load(file, map_location="cpu")
@@ -741,12 +862,11 @@ def torch_safe_load(weight):
741
862
  )
742
863
  ckpt = {"model": ckpt.model}
743
864
 
744
- return ckpt, file # load
865
+ return ckpt, file
745
866
 
746
867
 
747
868
  def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
748
869
  """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
749
-
750
870
  ensemble = Ensemble()
751
871
  for w in weights if isinstance(weights, list) else [weights]:
752
872
  ckpt, w = torch_safe_load(w) # load ckpt
@@ -814,6 +934,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
814
934
  import ast
815
935
 
816
936
  # Args
937
+ legacy = True # backward compatibility for v3/v5/v8/v9 models
817
938
  max_channels = float("inf")
818
939
  nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
819
940
  depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
@@ -839,9 +960,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
839
960
  if isinstance(a, str):
840
961
  with contextlib.suppress(ValueError):
841
962
  args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
842
-
843
963
  n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
844
- if m in (
964
+ if m in {
845
965
  Classify,
846
966
  Conv,
847
967
  ConvTranspose,
@@ -850,14 +970,19 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
850
970
  GhostBottleneck,
851
971
  SPP,
852
972
  SPPF,
973
+ C2fPSA,
974
+ C2PSA,
853
975
  DWConv,
854
976
  Focus,
855
977
  BottleneckCSP,
856
978
  C1,
857
979
  C2,
858
980
  C2f,
981
+ C3k2,
859
982
  RepNCSPELAN4,
983
+ ELAN1,
860
984
  ADown,
985
+ AConv,
861
986
  SPPELAN,
862
987
  C2fAttn,
863
988
  C3,
@@ -867,7 +992,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
867
992
  DWConvTranspose2d,
868
993
  C3x,
869
994
  RepC3,
870
- ):
995
+ PSA,
996
+ SCDown,
997
+ C2fCIB,
998
+ }:
871
999
  c1, c2 = ch[f], args[0]
872
1000
  if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
873
1001
  c2 = make_divisible(min(c2, max_channels) * width, 8)
@@ -878,12 +1006,31 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
878
1006
  ) # num heads
879
1007
 
880
1008
  args = [c1, c2, *args[1:]]
881
- if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3):
1009
+ if m in {
1010
+ BottleneckCSP,
1011
+ C1,
1012
+ C2,
1013
+ C2f,
1014
+ C3k2,
1015
+ C2fAttn,
1016
+ C3,
1017
+ C3TR,
1018
+ C3Ghost,
1019
+ C3x,
1020
+ RepC3,
1021
+ C2fPSA,
1022
+ C2fCIB,
1023
+ C2PSA,
1024
+ }:
882
1025
  args.insert(2, n) # number of repeats
883
1026
  n = 1
1027
+ if m is C3k2: # for M/L/X sizes
1028
+ legacy = False
1029
+ if scale in "mlx":
1030
+ args[3] = True
884
1031
  elif m is AIFI:
885
1032
  args = [ch[f], *args]
886
- elif m in (HGStem, HGBlock):
1033
+ elif m in {HGStem, HGBlock}:
887
1034
  c1, cm, c2 = ch[f], args[0], args[1]
888
1035
  args = [c1, cm, c2, *args[2:]]
889
1036
  if m is HGBlock:
@@ -895,13 +1042,15 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
895
1042
  args = [ch[f]]
896
1043
  elif m is Concat:
897
1044
  c2 = sum(ch[x] for x in f)
898
- elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn):
1045
+ elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
899
1046
  args.append([ch[x] for x in f])
900
1047
  if m is Segment:
901
1048
  args[2] = make_divisible(min(args[2], max_channels) * width, 8)
1049
+ if m in {Detect, Segment, Pose, OBB}:
1050
+ m.legacy = legacy
902
1051
  elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
903
1052
  args.insert(1, [ch[x] for x in f])
904
- elif m is CBLinear:
1053
+ elif m in {CBLinear, TorchVision, Index}:
905
1054
  c2 = args[0]
906
1055
  c1 = ch[f]
907
1056
  args = [c1, c2, *args[1:]]
@@ -912,10 +1061,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
912
1061
 
913
1062
  m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
914
1063
  t = str(m)[8:-2].replace("__main__.", "") # module type
915
- m.np = sum(x.numel() for x in m_.parameters()) # number params
1064
+ m_.np = sum(x.numel() for x in m_.parameters()) # number params
916
1065
  m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
917
1066
  if verbose:
918
- LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
1067
+ LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
919
1068
  save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
920
1069
  layers.append(m_)
921
1070
  if i == 0:
@@ -926,8 +1075,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
926
1075
 
927
1076
  def yaml_model_load(path):
928
1077
  """Load a YOLOv8 model from a YAML file."""
929
- import re
930
-
931
1078
  path = Path(path)
932
1079
  if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
933
1080
  new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
@@ -954,11 +1101,10 @@ def guess_model_scale(model_path):
954
1101
  Returns:
955
1102
  (str): The size character of the model's scale, which can be n, s, m, l, or x.
956
1103
  """
957
- with contextlib.suppress(AttributeError):
958
- import re
959
-
960
- return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
961
- return ""
1104
+ try:
1105
+ return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # noqa, returns n, s, m, l, or x
1106
+ except AttributeError:
1107
+ return ""
962
1108
 
963
1109
 
964
1110
  def guess_model_task(model):
@@ -978,9 +1124,9 @@ def guess_model_task(model):
978
1124
  def cfg2task(cfg):
979
1125
  """Guess from YAML dictionary."""
980
1126
  m = cfg["head"][-1][-2].lower() # output module name
981
- if m in ("classify", "classifier", "cls", "fc"):
1127
+ if m in {"classify", "classifier", "cls", "fc"}:
982
1128
  return "classify"
983
- if m == "detect":
1129
+ if "detect" in m:
984
1130
  return "detect"
985
1131
  if m == "segment":
986
1132
  return "segment"
@@ -993,7 +1139,6 @@ def guess_model_task(model):
993
1139
  if isinstance(model, dict):
994
1140
  with contextlib.suppress(Exception):
995
1141
  return cfg2task(model)
996
-
997
1142
  # Guess from PyTorch model
998
1143
  if isinstance(model, nn.Module): # PyTorch model
999
1144
  for x in "model.args", "model.model.args", "model.model.model.args":
@@ -1002,7 +1147,6 @@ def guess_model_task(model):
1002
1147
  for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
1003
1148
  with contextlib.suppress(Exception):
1004
1149
  return cfg2task(eval(x))
1005
-
1006
1150
  for m in model.modules():
1007
1151
  if isinstance(m, Segment):
1008
1152
  return "segment"
@@ -1012,7 +1156,7 @@ def guess_model_task(model):
1012
1156
  return "pose"
1013
1157
  elif isinstance(m, OBB):
1014
1158
  return "obb"
1015
- elif isinstance(m, (Detect, WorldDetect)):
1159
+ elif isinstance(m, (Detect, WorldDetect, v10Detect)):
1016
1160
  return "detect"
1017
1161
 
1018
1162
  # Guess from model filename