dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,1768 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import contextlib
4
+ import pickle
5
+ import re
6
+ import types
7
+ from copy import deepcopy
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from ultralytics.nn.autobackend import check_class_names
14
+ from ultralytics.nn.modules import (
15
+ AIFI,
16
+ C1,
17
+ C2,
18
+ C2PSA,
19
+ C3,
20
+ C3TR,
21
+ ELAN1,
22
+ OBB,
23
+ PSA,
24
+ SPP,
25
+ SPPELAN,
26
+ SPPF,
27
+ A2C2f,
28
+ AConv,
29
+ ADown,
30
+ Bottleneck,
31
+ BottleneckCSP,
32
+ C2f,
33
+ C2fAttn,
34
+ C2fCIB,
35
+ C2fPSA,
36
+ C3Ghost,
37
+ C3k2,
38
+ C3x,
39
+ CBFuse,
40
+ CBLinear,
41
+ Classify,
42
+ Concat,
43
+ Conv,
44
+ Conv2,
45
+ ConvTranspose,
46
+ Detect,
47
+ DWConv,
48
+ DWConvTranspose2d,
49
+ Focus,
50
+ GhostBottleneck,
51
+ GhostConv,
52
+ HGBlock,
53
+ HGStem,
54
+ ImagePoolingAttn,
55
+ Index,
56
+ LRPCHead,
57
+ Pose,
58
+ RepC3,
59
+ RepConv,
60
+ RepNCSPELAN4,
61
+ RepVGGDW,
62
+ ResNetLayer,
63
+ RTDETRDecoder,
64
+ SCDown,
65
+ Segment,
66
+ TorchVision,
67
+ WorldDetect,
68
+ YOLOEDetect,
69
+ YOLOESegment,
70
+ v10Detect,
71
+ )
72
+ from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, YAML, colorstr, emojis
73
+ from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
74
+ from ultralytics.utils.loss import (
75
+ E2EDetectLoss,
76
+ v8ClassificationLoss,
77
+ v8DetectionLoss,
78
+ v8OBBLoss,
79
+ v8PoseLoss,
80
+ v8SegmentationLoss,
81
+ )
82
+ from ultralytics.utils.ops import make_divisible
83
+ from ultralytics.utils.patches import torch_load
84
+ from ultralytics.utils.plotting import feature_visualization
85
+ from ultralytics.utils.torch_utils import (
86
+ fuse_conv_and_bn,
87
+ fuse_deconv_and_bn,
88
+ initialize_weights,
89
+ intersect_dicts,
90
+ model_info,
91
+ scale_img,
92
+ smart_inference_mode,
93
+ time_sync,
94
+ )
95
+
96
+
97
+ class BaseModel(torch.nn.Module):
98
+ """Base class for all YOLO models in the Ultralytics family.
99
+
100
+ This class provides common functionality for YOLO models including forward pass handling, model fusion, information
101
+ display, and weight loading capabilities.
102
+
103
+ Attributes:
104
+ model (torch.nn.Module): The neural network model.
105
+ save (list): List of layer indices to save outputs from.
106
+ stride (torch.Tensor): Model stride values.
107
+
108
+ Methods:
109
+ forward: Perform forward pass for training or inference.
110
+ predict: Perform inference on input tensor.
111
+ fuse: Fuse Conv2d and BatchNorm2d layers for optimization.
112
+ info: Print model information.
113
+ load: Load weights into the model.
114
+ loss: Compute loss for training.
115
+
116
+ Examples:
117
+ Create a BaseModel instance
118
+ >>> model = BaseModel()
119
+ >>> model.info() # Display model information
120
+ """
121
+
122
+ def forward(self, x, *args, **kwargs):
123
+ """Perform forward pass of the model for either training or inference.
124
+
125
+ If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
126
+
127
+ Args:
128
+ x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
129
+ *args (Any): Variable length argument list.
130
+ **kwargs (Any): Arbitrary keyword arguments.
131
+
132
+ Returns:
133
+ (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
134
+ """
135
+ if isinstance(x, dict): # for cases of training and validating while training.
136
+ return self.loss(x, *args, **kwargs)
137
+ return self.predict(x, *args, **kwargs)
138
+
139
+ def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
140
+ """Perform a forward pass through the network.
141
+
142
+ Args:
143
+ x (torch.Tensor): The input tensor to the model.
144
+ profile (bool): Print the computation time of each layer if True.
145
+ visualize (bool): Save the feature maps of the model if True.
146
+ augment (bool): Augment image during prediction.
147
+ embed (list, optional): A list of feature vectors/embeddings to return.
148
+
149
+ Returns:
150
+ (torch.Tensor): The last output of the model.
151
+ """
152
+ if augment:
153
+ return self._predict_augment(x)
154
+ return self._predict_once(x, profile, visualize, embed)
155
+
156
+ def _predict_once(self, x, profile=False, visualize=False, embed=None):
157
+ """Perform a forward pass through the network.
158
+
159
+ Args:
160
+ x (torch.Tensor): The input tensor to the model.
161
+ profile (bool): Print the computation time of each layer if True.
162
+ visualize (bool): Save the feature maps of the model if True.
163
+ embed (list, optional): A list of feature vectors/embeddings to return.
164
+
165
+ Returns:
166
+ (torch.Tensor): The last output of the model.
167
+ """
168
+ y, dt, embeddings = [], [], [] # outputs
169
+ embed = frozenset(embed) if embed is not None else {-1}
170
+ max_idx = max(embed)
171
+ for m in self.model:
172
+ if m.f != -1: # if not from previous layer
173
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
174
+ if profile:
175
+ self._profile_one_layer(m, x, dt)
176
+ x = m(x) # run
177
+ y.append(x if m.i in self.save else None) # save output
178
+ if visualize:
179
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
180
+ if m.i in embed:
181
+ embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
182
+ if m.i == max_idx:
183
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
184
+ return x
185
+
186
+ def _predict_augment(self, x):
187
+ """Perform augmentations on input image x and return augmented inference."""
188
+ LOGGER.warning(
189
+ f"{self.__class__.__name__} does not support 'augment=True' prediction. "
190
+ f"Reverting to single-scale prediction."
191
+ )
192
+ return self._predict_once(x)
193
+
194
+ def _profile_one_layer(self, m, x, dt):
195
+ """Profile the computation time and FLOPs of a single layer of the model on a given input.
196
+
197
+ Args:
198
+ m (torch.nn.Module): The layer to be profiled.
199
+ x (torch.Tensor): The input data to the layer.
200
+ dt (list): A list to store the computation time of the layer.
201
+ """
202
+ try:
203
+ import thop
204
+ except ImportError:
205
+ thop = None # conda support without 'ultralytics-thop' installed
206
+
207
+ c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
208
+ flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
209
+ t = time_sync()
210
+ for _ in range(10):
211
+ m(x.copy() if c else x)
212
+ dt.append((time_sync() - t) * 100)
213
+ if m == self.model[0]:
214
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
215
+ LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
216
+ if c:
217
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
218
+
219
+ def fuse(self, verbose=True):
220
+ """Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
221
+ efficiency.
222
+
223
+ Returns:
224
+ (torch.nn.Module): The fused model is returned.
225
+ """
226
+ if not self.is_fused():
227
+ for m in self.model.modules():
228
+ if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
229
+ if isinstance(m, Conv2):
230
+ m.fuse_convs()
231
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
232
+ delattr(m, "bn") # remove batchnorm
233
+ m.forward = m.forward_fuse # update forward
234
+ if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
235
+ m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
236
+ delattr(m, "bn") # remove batchnorm
237
+ m.forward = m.forward_fuse # update forward
238
+ if isinstance(m, RepConv):
239
+ m.fuse_convs()
240
+ m.forward = m.forward_fuse # update forward
241
+ if isinstance(m, RepVGGDW):
242
+ m.fuse()
243
+ m.forward = m.forward_fuse
244
+ if isinstance(m, v10Detect):
245
+ m.fuse() # remove one2many head
246
+ self.info(verbose=verbose)
247
+
248
+ return self
249
+
250
+ def is_fused(self, thresh=10):
251
+ """Check if the model has less than a certain threshold of BatchNorm layers.
252
+
253
+ Args:
254
+ thresh (int, optional): The threshold number of BatchNorm layers.
255
+
256
+ Returns:
257
+ (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
258
+ """
259
+ bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
260
+ return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
261
+
262
+ def info(self, detailed=False, verbose=True, imgsz=640):
263
+ """Print model information.
264
+
265
+ Args:
266
+ detailed (bool): If True, prints out detailed information about the model.
267
+ verbose (bool): If True, prints out the model information.
268
+ imgsz (int): The size of the image that the model will be trained on.
269
+ """
270
+ return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
271
+
272
+ def _apply(self, fn):
273
+ """Apply a function to all tensors in the model that are not parameters or registered buffers.
274
+
275
+ Args:
276
+ fn (function): The function to apply to the model.
277
+
278
+ Returns:
279
+ (BaseModel): An updated BaseModel object.
280
+ """
281
+ self = super()._apply(fn)
282
+ m = self.model[-1] # Detect()
283
+ if isinstance(
284
+ m, Detect
285
+ ): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect, YOLOESegment
286
+ m.stride = fn(m.stride)
287
+ m.anchors = fn(m.anchors)
288
+ m.strides = fn(m.strides)
289
+ return self
290
+
291
+ def load(self, weights, verbose=True):
292
+ """Load weights into the model.
293
+
294
+ Args:
295
+ weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
296
+ verbose (bool, optional): Whether to log the transfer progress.
297
+ """
298
+ model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
299
+ csd = model.float().state_dict() # checkpoint state_dict as FP32
300
+ updated_csd = intersect_dicts(csd, self.state_dict()) # intersect
301
+ self.load_state_dict(updated_csd, strict=False) # load
302
+ len_updated_csd = len(updated_csd)
303
+ first_conv = "model.0.conv.weight" # hard-coded to yolo models for now
304
+ # mostly used to boost multi-channel training
305
+ state_dict = self.state_dict()
306
+ if first_conv not in updated_csd and first_conv in state_dict:
307
+ c1, c2, h, w = state_dict[first_conv].shape
308
+ cc1, cc2, ch, cw = csd[first_conv].shape
309
+ if ch == h and cw == w:
310
+ c1, c2 = min(c1, cc1), min(c2, cc2)
311
+ state_dict[first_conv][:c1, :c2] = csd[first_conv][:c1, :c2]
312
+ len_updated_csd += 1
313
+ if verbose:
314
+ LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
315
+
316
+ def loss(self, batch, preds=None):
317
+ """Compute loss.
318
+
319
+ Args:
320
+ batch (dict): Batch to compute loss on.
321
+ preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
322
+ """
323
+ if getattr(self, "criterion", None) is None:
324
+ self.criterion = self.init_criterion()
325
+
326
+ if preds is None:
327
+ preds = self.forward(batch["img"])
328
+ return self.criterion(preds, batch)
329
+
330
+ def init_criterion(self):
331
+ """Initialize the loss criterion for the BaseModel."""
332
+ raise NotImplementedError("compute_loss() needs to be implemented by task heads")
333
+
334
+
335
+ class DetectionModel(BaseModel):
336
+ """YOLO detection model.
337
+
338
+ This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
339
+ inference, and loss computation for object detection tasks.
340
+
341
+ Attributes:
342
+ yaml (dict): Model configuration dictionary.
343
+ model (torch.nn.Sequential): The neural network model.
344
+ save (list): List of layer indices to save outputs from.
345
+ names (dict): Class names dictionary.
346
+ inplace (bool): Whether to use inplace operations.
347
+ end2end (bool): Whether the model uses end-to-end detection.
348
+ stride (torch.Tensor): Model stride values.
349
+
350
+ Methods:
351
+ __init__: Initialize the YOLO detection model.
352
+ _predict_augment: Perform augmented inference.
353
+ _descale_pred: De-scale predictions following augmented inference.
354
+ _clip_augmented: Clip YOLO augmented inference tails.
355
+ init_criterion: Initialize the loss criterion.
356
+
357
+ Examples:
358
+ Initialize a detection model
359
+ >>> model = DetectionModel("yolo11n.yaml", ch=3, nc=80)
360
+ >>> results = model.predict(image_tensor)
361
+ """
362
+
363
+ def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
364
+ """Initialize the YOLO detection model with the given config and parameters.
365
+
366
+ Args:
367
+ cfg (str | dict): Model configuration file path or dictionary.
368
+ ch (int): Number of input channels.
369
+ nc (int, optional): Number of classes.
370
+ verbose (bool): Whether to display model information.
371
+ """
372
+ super().__init__()
373
+ self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
374
+ if self.yaml["backbone"][0][2] == "Silence":
375
+ LOGGER.warning(
376
+ "YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. "
377
+ "Please delete local *.pt file and re-download the latest model checkpoint."
378
+ )
379
+ self.yaml["backbone"][0][2] = "nn.Identity"
380
+
381
+ # Define model
382
+ self.yaml["channels"] = ch # save channels
383
+ if nc and nc != self.yaml["nc"]:
384
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
385
+ self.yaml["nc"] = nc # override YAML value
386
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
387
+ self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
388
+ self.inplace = self.yaml.get("inplace", True)
389
+ self.end2end = getattr(self.model[-1], "end2end", False)
390
+
391
+ # Build strides
392
+ m = self.model[-1] # Detect()
393
+ if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, YOLOEDetect, YOLOESegment
394
+ s = 256 # 2x min stride
395
+ m.inplace = self.inplace
396
+
397
+ def _forward(x):
398
+ """Perform a forward pass through the model, handling different Detect subclass types accordingly."""
399
+ if self.end2end:
400
+ return self.forward(x)["one2many"]
401
+ return self.forward(x)[0] if isinstance(m, (Segment, YOLOESegment, Pose, OBB)) else self.forward(x)
402
+
403
+ self.model.eval() # Avoid changing batch statistics until training begins
404
+ m.training = True # Setting it to True to properly return strides
405
+ m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
406
+ self.stride = m.stride
407
+ self.model.train() # Set model back to training(default) mode
408
+ m.bias_init() # only run once
409
+ else:
410
+ self.stride = torch.Tensor([32]) # default stride, e.g., RTDETR
411
+
412
+ # Init weights, biases
413
+ initialize_weights(self)
414
+ if verbose:
415
+ self.info()
416
+ LOGGER.info("")
417
+
418
+ def _predict_augment(self, x):
419
+ """Perform augmentations on input image x and return augmented inference and train outputs.
420
+
421
+ Args:
422
+ x (torch.Tensor): Input image tensor.
423
+
424
+ Returns:
425
+ (torch.Tensor): Augmented inference output.
426
+ """
427
+ if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
428
+ LOGGER.warning("Model does not support 'augment=True', reverting to single-scale prediction.")
429
+ return self._predict_once(x)
430
+ img_size = x.shape[-2:] # height, width
431
+ s = [1, 0.83, 0.67] # scales
432
+ f = [None, 3, None] # flips (2-ud, 3-lr)
433
+ y = [] # outputs
434
+ for si, fi in zip(s, f):
435
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
436
+ yi = super().predict(xi)[0] # forward
437
+ yi = self._descale_pred(yi, fi, si, img_size)
438
+ y.append(yi)
439
+ y = self._clip_augmented(y) # clip augmented tails
440
+ return torch.cat(y, -1), None # augmented inference, train
441
+
442
+ @staticmethod
443
+ def _descale_pred(p, flips, scale, img_size, dim=1):
444
+ """De-scale predictions following augmented inference (inverse operation).
445
+
446
+ Args:
447
+ p (torch.Tensor): Predictions tensor.
448
+ flips (int): Flip type (0=none, 2=ud, 3=lr).
449
+ scale (float): Scale factor.
450
+ img_size (tuple): Original image size (height, width).
451
+ dim (int): Dimension to split at.
452
+
453
+ Returns:
454
+ (torch.Tensor): De-scaled predictions.
455
+ """
456
+ p[:, :4] /= scale # de-scale
457
+ x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
458
+ if flips == 2:
459
+ y = img_size[0] - y # de-flip ud
460
+ elif flips == 3:
461
+ x = img_size[1] - x # de-flip lr
462
+ return torch.cat((x, y, wh, cls), dim)
463
+
464
+ def _clip_augmented(self, y):
465
+ """Clip YOLO augmented inference tails.
466
+
467
+ Args:
468
+ y (list[torch.Tensor]): List of detection tensors.
469
+
470
+ Returns:
471
+ (list[torch.Tensor]): Clipped detection tensors.
472
+ """
473
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
474
+ g = sum(4**x for x in range(nl)) # grid points
475
+ e = 1 # exclude layer count
476
+ i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
477
+ y[0] = y[0][..., :-i] # large
478
+ i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
479
+ y[-1] = y[-1][..., i:] # small
480
+ return y
481
+
482
+ def init_criterion(self):
483
+ """Initialize the loss criterion for the DetectionModel."""
484
+ return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
485
+
486
+
487
+ class OBBModel(DetectionModel):
488
+ """YOLO Oriented Bounding Box (OBB) model.
489
+
490
+ This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
491
+ computation for rotated object detection.
492
+
493
+ Methods:
494
+ __init__: Initialize YOLO OBB model.
495
+ init_criterion: Initialize the loss criterion for OBB detection.
496
+
497
+ Examples:
498
+ Initialize an OBB model
499
+ >>> model = OBBModel("yolo11n-obb.yaml", ch=3, nc=80)
500
+ >>> results = model.predict(image_tensor)
501
+ """
502
+
503
+ def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
504
+ """Initialize YOLO OBB model with given config and parameters.
505
+
506
+ Args:
507
+ cfg (str | dict): Model configuration file path or dictionary.
508
+ ch (int): Number of input channels.
509
+ nc (int, optional): Number of classes.
510
+ verbose (bool): Whether to display model information.
511
+ """
512
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
513
+
514
+ def init_criterion(self):
515
+ """Initialize the loss criterion for the model."""
516
+ return v8OBBLoss(self)
517
+
518
+
519
+ class SegmentationModel(DetectionModel):
520
+ """YOLO segmentation model.
521
+
522
+ This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
523
+ pixel-level object detection and segmentation.
524
+
525
+ Methods:
526
+ __init__: Initialize YOLO segmentation model.
527
+ init_criterion: Initialize the loss criterion for segmentation.
528
+
529
+ Examples:
530
+ Initialize a segmentation model
531
+ >>> model = SegmentationModel("yolo11n-seg.yaml", ch=3, nc=80)
532
+ >>> results = model.predict(image_tensor)
533
+ """
534
+
535
+ def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
536
+ """Initialize Ultralytics YOLO segmentation model with given config and parameters.
537
+
538
+ Args:
539
+ cfg (str | dict): Model configuration file path or dictionary.
540
+ ch (int): Number of input channels.
541
+ nc (int, optional): Number of classes.
542
+ verbose (bool): Whether to display model information.
543
+ """
544
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
545
+
546
+ def init_criterion(self):
547
+ """Initialize the loss criterion for the SegmentationModel."""
548
+ return v8SegmentationLoss(self)
549
+
550
+
551
+ class PoseModel(DetectionModel):
552
+ """YOLO pose model.
553
+
554
+ This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
555
+ keypoint detection and pose estimation.
556
+
557
+ Attributes:
558
+ kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).
559
+
560
+ Methods:
561
+ __init__: Initialize YOLO pose model.
562
+ init_criterion: Initialize the loss criterion for pose estimation.
563
+
564
+ Examples:
565
+ Initialize a pose model
566
+ >>> model = PoseModel("yolo11n-pose.yaml", ch=3, nc=1, data_kpt_shape=(17, 3))
567
+ >>> results = model.predict(image_tensor)
568
+ """
569
+
570
+ def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
571
+ """Initialize Ultralytics YOLO Pose model.
572
+
573
+ Args:
574
+ cfg (str | dict): Model configuration file path or dictionary.
575
+ ch (int): Number of input channels.
576
+ nc (int, optional): Number of classes.
577
+ data_kpt_shape (tuple): Shape of keypoints data.
578
+ verbose (bool): Whether to display model information.
579
+ """
580
+ if not isinstance(cfg, dict):
581
+ cfg = yaml_model_load(cfg) # load model YAML
582
+ if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
583
+ LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
584
+ cfg["kpt_shape"] = data_kpt_shape
585
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
586
+
587
+ def init_criterion(self):
588
+ """Initialize the loss criterion for the PoseModel."""
589
+ return v8PoseLoss(self)
590
+
591
+
592
+ class ClassificationModel(BaseModel):
593
+ """YOLO classification model.
594
+
595
+ This class implements the YOLO classification architecture for image classification tasks, providing model
596
+ initialization, configuration, and output reshaping capabilities.
597
+
598
+ Attributes:
599
+ yaml (dict): Model configuration dictionary.
600
+ model (torch.nn.Sequential): The neural network model.
601
+ stride (torch.Tensor): Model stride values.
602
+ names (dict): Class names dictionary.
603
+
604
+ Methods:
605
+ __init__: Initialize ClassificationModel.
606
+ _from_yaml: Set model configurations and define architecture.
607
+ reshape_outputs: Update model to specified class count.
608
+ init_criterion: Initialize the loss criterion.
609
+
610
+ Examples:
611
+ Initialize a classification model
612
+ >>> model = ClassificationModel("yolo11n-cls.yaml", ch=3, nc=1000)
613
+ >>> results = model.predict(image_tensor)
614
+ """
615
+
616
+ def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
617
+ """Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
618
+
619
+ Args:
620
+ cfg (str | dict): Model configuration file path or dictionary.
621
+ ch (int): Number of input channels.
622
+ nc (int, optional): Number of classes.
623
+ verbose (bool): Whether to display model information.
624
+ """
625
+ super().__init__()
626
+ self._from_yaml(cfg, ch, nc, verbose)
627
+
628
+ def _from_yaml(self, cfg, ch, nc, verbose):
629
+ """Set Ultralytics YOLO model configurations and define the model architecture.
630
+
631
+ Args:
632
+ cfg (str | dict): Model configuration file path or dictionary.
633
+ ch (int): Number of input channels.
634
+ nc (int, optional): Number of classes.
635
+ verbose (bool): Whether to display model information.
636
+ """
637
+ self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
638
+
639
+ # Define model
640
+ ch = self.yaml["channels"] = self.yaml.get("channels", ch) # input channels
641
+ if nc and nc != self.yaml["nc"]:
642
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
643
+ self.yaml["nc"] = nc # override YAML value
644
+ elif not nc and not self.yaml.get("nc", None):
645
+ raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
646
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
647
+ self.stride = torch.Tensor([1]) # no stride constraints
648
+ self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
649
+ self.info()
650
+
651
+ @staticmethod
652
+ def reshape_outputs(model, nc):
653
+ """Update a TorchVision classification model to class count 'n' if required.
654
+
655
+ Args:
656
+ model (torch.nn.Module): Model to update.
657
+ nc (int): New number of classes.
658
+ """
659
+ name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
660
+ if isinstance(m, Classify): # YOLO Classify() head
661
+ if m.linear.out_features != nc:
662
+ m.linear = torch.nn.Linear(m.linear.in_features, nc)
663
+ elif isinstance(m, torch.nn.Linear): # ResNet, EfficientNet
664
+ if m.out_features != nc:
665
+ setattr(model, name, torch.nn.Linear(m.in_features, nc))
666
+ elif isinstance(m, torch.nn.Sequential):
667
+ types = [type(x) for x in m]
668
+ if torch.nn.Linear in types:
669
+ i = len(types) - 1 - types[::-1].index(torch.nn.Linear) # last torch.nn.Linear index
670
+ if m[i].out_features != nc:
671
+ m[i] = torch.nn.Linear(m[i].in_features, nc)
672
+ elif torch.nn.Conv2d in types:
673
+ i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d) # last torch.nn.Conv2d index
674
+ if m[i].out_channels != nc:
675
+ m[i] = torch.nn.Conv2d(
676
+ m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None
677
+ )
678
+
679
+ def init_criterion(self):
680
+ """Initialize the loss criterion for the ClassificationModel."""
681
+ return v8ClassificationLoss()
682
+
683
+
684
+ class RTDETRDetectionModel(DetectionModel):
685
+ """RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
686
+
687
+ This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
688
+ the training and inference processes. RTDETR is an object detection and tracking model that extends from the
689
+ DetectionModel base class.
690
+
691
+ Attributes:
692
+ nc (int): Number of classes for detection.
693
+ criterion (RTDETRDetectionLoss): Loss function for training.
694
+
695
+ Methods:
696
+ __init__: Initialize the RTDETRDetectionModel.
697
+ init_criterion: Initialize the loss criterion.
698
+ loss: Compute loss for training.
699
+ predict: Perform forward pass through the model.
700
+
701
+ Examples:
702
+ Initialize an RTDETR model
703
+ >>> model = RTDETRDetectionModel("rtdetr-l.yaml", ch=3, nc=80)
704
+ >>> results = model.predict(image_tensor)
705
+ """
706
+
707
+ def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
708
+ """Initialize the RTDETRDetectionModel.
709
+
710
+ Args:
711
+ cfg (str | dict): Configuration file name or path.
712
+ ch (int): Number of input channels.
713
+ nc (int, optional): Number of classes.
714
+ verbose (bool): Print additional information during initialization.
715
+ """
716
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
717
+
718
+ def _apply(self, fn):
719
+ """Apply a function to all tensors in the model that are not parameters or registered buffers.
720
+
721
+ Args:
722
+ fn (function): The function to apply to the model.
723
+
724
+ Returns:
725
+ (RTDETRDetectionModel): An updated BaseModel object.
726
+ """
727
+ self = super()._apply(fn)
728
+ m = self.model[-1]
729
+ m.anchors = fn(m.anchors)
730
+ m.valid_mask = fn(m.valid_mask)
731
+ return self
732
+
733
+ def init_criterion(self):
734
+ """Initialize the loss criterion for the RTDETRDetectionModel."""
735
+ from ultralytics.models.utils.loss import RTDETRDetectionLoss
736
+
737
+ return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
738
+
739
+ def loss(self, batch, preds=None):
740
+ """Compute the loss for the given batch of data.
741
+
742
+ Args:
743
+ batch (dict): Dictionary containing image and label data.
744
+ preds (torch.Tensor, optional): Precomputed model predictions.
745
+
746
+ Returns:
747
+ loss_sum (torch.Tensor): Total loss value.
748
+ loss_items (torch.Tensor): Main three losses in a tensor.
749
+ """
750
+ if not hasattr(self, "criterion"):
751
+ self.criterion = self.init_criterion()
752
+
753
+ img = batch["img"]
754
+ # NOTE: preprocess gt_bbox and gt_labels to list.
755
+ bs = img.shape[0]
756
+ batch_idx = batch["batch_idx"]
757
+ gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
758
+ targets = {
759
+ "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
760
+ "bboxes": batch["bboxes"].to(device=img.device),
761
+ "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
762
+ "gt_groups": gt_groups,
763
+ }
764
+
765
+ if preds is None:
766
+ preds = self.predict(img, batch=targets)
767
+ dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
768
+ if dn_meta is None:
769
+ dn_bboxes, dn_scores = None, None
770
+ else:
771
+ dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
772
+ dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
773
+
774
+ dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
775
+ dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
776
+
777
+ loss = self.criterion(
778
+ (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
779
+ )
780
+ # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
781
+ return sum(loss.values()), torch.as_tensor(
782
+ [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
783
+ )
784
+
785
+ def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
786
+ """Perform a forward pass through the model.
787
+
788
+ Args:
789
+ x (torch.Tensor): The input tensor.
790
+ profile (bool): If True, profile the computation time for each layer.
791
+ visualize (bool): If True, save feature maps for visualization.
792
+ batch (dict, optional): Ground truth data for evaluation.
793
+ augment (bool): If True, perform data augmentation during inference.
794
+ embed (list, optional): A list of feature vectors/embeddings to return.
795
+
796
+ Returns:
797
+ (torch.Tensor): Model's output tensor.
798
+ """
799
+ y, dt, embeddings = [], [], [] # outputs
800
+ embed = frozenset(embed) if embed is not None else {-1}
801
+ max_idx = max(embed)
802
+ for m in self.model[:-1]: # except the head part
803
+ if m.f != -1: # if not from previous layer
804
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
805
+ if profile:
806
+ self._profile_one_layer(m, x, dt)
807
+ x = m(x) # run
808
+ y.append(x if m.i in self.save else None) # save output
809
+ if visualize:
810
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
811
+ if m.i in embed:
812
+ embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
813
+ if m.i == max_idx:
814
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
815
+ head = self.model[-1]
816
+ x = head([y[j] for j in head.f], batch) # head inference
817
+ return x
818
+
819
+
820
+ class WorldModel(DetectionModel):
821
+ """YOLOv8 World Model.
822
+
823
+ This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
824
+ specification and CLIP model integration for zero-shot detection capabilities.
825
+
826
+ Attributes:
827
+ txt_feats (torch.Tensor): Text feature embeddings for classes.
828
+ clip_model (torch.nn.Module): CLIP model for text encoding.
829
+
830
+ Methods:
831
+ __init__: Initialize YOLOv8 world model.
832
+ set_classes: Set classes for offline inference.
833
+ get_text_pe: Get text positional embeddings.
834
+ predict: Perform forward pass with text features.
835
+ loss: Compute loss with text features.
836
+
837
+ Examples:
838
+ Initialize a world model
839
+ >>> model = WorldModel("yolov8s-world.yaml", ch=3, nc=80)
840
+ >>> model.set_classes(["person", "car", "bicycle"])
841
+ >>> results = model.predict(image_tensor)
842
+ """
843
+
844
+ def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
845
+ """Initialize YOLOv8 world model with given config and parameters.
846
+
847
+ Args:
848
+ cfg (str | dict): Model configuration file path or dictionary.
849
+ ch (int): Number of input channels.
850
+ nc (int, optional): Number of classes.
851
+ verbose (bool): Whether to display model information.
852
+ """
853
+ self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
854
+ self.clip_model = None # CLIP model placeholder
855
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
856
+
857
+ def set_classes(self, text, batch=80, cache_clip_model=True):
858
+ """Set classes in advance so that model could do offline-inference without clip model.
859
+
860
+ Args:
861
+ text (list[str]): List of class names.
862
+ batch (int): Batch size for processing text tokens.
863
+ cache_clip_model (bool): Whether to cache the CLIP model.
864
+ """
865
+ self.txt_feats = self.get_text_pe(text, batch=batch, cache_clip_model=cache_clip_model)
866
+ self.model[-1].nc = len(text)
867
+
868
+ def get_text_pe(self, text, batch=80, cache_clip_model=True):
869
+ """Get text positional embeddings for offline inference without CLIP model.
870
+
871
+ Args:
872
+ text (list[str]): List of class names.
873
+ batch (int): Batch size for processing text tokens.
874
+ cache_clip_model (bool): Whether to cache the CLIP model.
875
+
876
+ Returns:
877
+ (torch.Tensor): Text positional embeddings.
878
+ """
879
+ from ultralytics.nn.text_model import build_text_model
880
+
881
+ device = next(self.model.parameters()).device
882
+ if not getattr(self, "clip_model", None) and cache_clip_model:
883
+ # For backwards compatibility of models lacking clip_model attribute
884
+ self.clip_model = build_text_model("clip:ViT-B/32", device=device)
885
+ model = self.clip_model if cache_clip_model else build_text_model("clip:ViT-B/32", device=device)
886
+ text_token = model.tokenize(text)
887
+ txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
888
+ txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
889
+ return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
890
+
891
+ def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
892
+ """Perform a forward pass through the model.
893
+
894
+ Args:
895
+ x (torch.Tensor): The input tensor.
896
+ profile (bool): If True, profile the computation time for each layer.
897
+ visualize (bool): If True, save feature maps for visualization.
898
+ txt_feats (torch.Tensor, optional): The text features, use it if it's given.
899
+ augment (bool): If True, perform data augmentation during inference.
900
+ embed (list, optional): A list of feature vectors/embeddings to return.
901
+
902
+ Returns:
903
+ (torch.Tensor): Model's output tensor.
904
+ """
905
+ txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
906
+ if txt_feats.shape[0] != x.shape[0] or self.model[-1].export:
907
+ txt_feats = txt_feats.expand(x.shape[0], -1, -1)
908
+ ori_txt_feats = txt_feats.clone()
909
+ y, dt, embeddings = [], [], [] # outputs
910
+ embed = frozenset(embed) if embed is not None else {-1}
911
+ max_idx = max(embed)
912
+ for m in self.model: # except the head part
913
+ if m.f != -1: # if not from previous layer
914
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
915
+ if profile:
916
+ self._profile_one_layer(m, x, dt)
917
+ if isinstance(m, C2fAttn):
918
+ x = m(x, txt_feats)
919
+ elif isinstance(m, WorldDetect):
920
+ x = m(x, ori_txt_feats)
921
+ elif isinstance(m, ImagePoolingAttn):
922
+ txt_feats = m(x, txt_feats)
923
+ else:
924
+ x = m(x) # run
925
+
926
+ y.append(x if m.i in self.save else None) # save output
927
+ if visualize:
928
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
929
+ if m.i in embed:
930
+ embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
931
+ if m.i == max_idx:
932
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
933
+ return x
934
+
935
+ def loss(self, batch, preds=None):
936
+ """Compute loss.
937
+
938
+ Args:
939
+ batch (dict): Batch to compute loss on.
940
+ preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
941
+ """
942
+ if not hasattr(self, "criterion"):
943
+ self.criterion = self.init_criterion()
944
+
945
+ if preds is None:
946
+ preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
947
+ return self.criterion(preds, batch)
948
+
949
+
950
+ class YOLOEModel(DetectionModel):
951
+ """YOLOE detection model.
952
+
953
+ This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
954
+ both prompt-based and prompt-free inference modes.
955
+
956
+ Attributes:
957
+ pe (torch.Tensor): Prompt embeddings for classes.
958
+ clip_model (torch.nn.Module): CLIP model for text encoding.
959
+
960
+ Methods:
961
+ __init__: Initialize YOLOE model.
962
+ get_text_pe: Get text positional embeddings.
963
+ get_visual_pe: Get visual embeddings.
964
+ set_vocab: Set vocabulary for prompt-free model.
965
+ get_vocab: Get fused vocabulary layer.
966
+ set_classes: Set classes for offline inference.
967
+ get_cls_pe: Get class positional embeddings.
968
+ predict: Perform forward pass with prompts.
969
+ loss: Compute loss with prompts.
970
+
971
+ Examples:
972
+ Initialize a YOLOE model
973
+ >>> model = YOLOEModel("yoloe-v8s.yaml", ch=3, nc=80)
974
+ >>> results = model.predict(image_tensor, tpe=text_embeddings)
975
+ """
976
+
977
+ def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
978
+ """Initialize YOLOE model with given config and parameters.
979
+
980
+ Args:
981
+ cfg (str | dict): Model configuration file path or dictionary.
982
+ ch (int): Number of input channels.
983
+ nc (int, optional): Number of classes.
984
+ verbose (bool): Whether to display model information.
985
+ """
986
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
987
+
988
+ @smart_inference_mode()
989
+ def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
990
+ """Get text positional embeddings for offline inference without CLIP model.
991
+
992
+ Args:
993
+ text (list[str]): List of class names.
994
+ batch (int): Batch size for processing text tokens.
995
+ cache_clip_model (bool): Whether to cache the CLIP model.
996
+ without_reprta (bool): Whether to return text embeddings without reprta module processing.
997
+
998
+ Returns:
999
+ (torch.Tensor): Text positional embeddings.
1000
+ """
1001
+ from ultralytics.nn.text_model import build_text_model
1002
+
1003
+ device = next(self.model.parameters()).device
1004
+ if not getattr(self, "clip_model", None) and cache_clip_model:
1005
+ # For backwards compatibility of models lacking clip_model attribute
1006
+ self.clip_model = build_text_model("mobileclip:blt", device=device)
1007
+
1008
+ model = self.clip_model if cache_clip_model else build_text_model("mobileclip:blt", device=device)
1009
+ text_token = model.tokenize(text)
1010
+ txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
1011
+ txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
1012
+ txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
1013
+ if without_reprta:
1014
+ return txt_feats
1015
+
1016
+ head = self.model[-1]
1017
+ assert isinstance(head, YOLOEDetect)
1018
+ return head.get_tpe(txt_feats) # run auxiliary text head
1019
+
1020
+ @smart_inference_mode()
1021
+ def get_visual_pe(self, img, visual):
1022
+ """Get visual embeddings.
1023
+
1024
+ Args:
1025
+ img (torch.Tensor): Input image tensor.
1026
+ visual (torch.Tensor): Visual features.
1027
+
1028
+ Returns:
1029
+ (torch.Tensor): Visual positional embeddings.
1030
+ """
1031
+ return self(img, vpe=visual, return_vpe=True)
1032
+
1033
+ def set_vocab(self, vocab, names):
1034
+ """Set vocabulary for the prompt-free model.
1035
+
1036
+ Args:
1037
+ vocab (nn.ModuleList): List of vocabulary items.
1038
+ names (list[str]): List of class names.
1039
+ """
1040
+ assert not self.training
1041
+ head = self.model[-1]
1042
+ assert isinstance(head, YOLOEDetect)
1043
+
1044
+ # Cache anchors for head
1045
+ device = next(self.parameters()).device
1046
+ self(torch.empty(1, 3, self.args["imgsz"], self.args["imgsz"]).to(device)) # warmup
1047
+
1048
+ # re-parameterization for prompt-free model
1049
+ self.model[-1].lrpc = nn.ModuleList(
1050
+ LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2)
1051
+ for i, (cls, pf, loc) in enumerate(zip(vocab, head.cv3, head.cv2))
1052
+ )
1053
+ for loc_head, cls_head in zip(head.cv2, head.cv3):
1054
+ assert isinstance(loc_head, nn.Sequential)
1055
+ assert isinstance(cls_head, nn.Sequential)
1056
+ del loc_head[-1]
1057
+ del cls_head[-1]
1058
+ self.model[-1].nc = len(names)
1059
+ self.names = check_class_names(names)
1060
+
1061
+ def get_vocab(self, names):
1062
+ """Get fused vocabulary layer from the model.
1063
+
1064
+ Args:
1065
+ names (list): List of class names.
1066
+
1067
+ Returns:
1068
+ (nn.ModuleList): List of vocabulary modules.
1069
+ """
1070
+ assert not self.training
1071
+ head = self.model[-1]
1072
+ assert isinstance(head, YOLOEDetect)
1073
+ assert not head.is_fused
1074
+
1075
+ tpe = self.get_text_pe(names)
1076
+ self.set_classes(names, tpe)
1077
+ device = next(self.model.parameters()).device
1078
+ head.fuse(self.pe.to(device)) # fuse prompt embeddings to classify head
1079
+
1080
+ vocab = nn.ModuleList()
1081
+ for cls_head in head.cv3:
1082
+ assert isinstance(cls_head, nn.Sequential)
1083
+ vocab.append(cls_head[-1])
1084
+ return vocab
1085
+
1086
+ def set_classes(self, names, embeddings):
1087
+ """Set classes in advance so that model could do offline-inference without clip model.
1088
+
1089
+ Args:
1090
+ names (list[str]): List of class names.
1091
+ embeddings (torch.Tensor): Embeddings tensor.
1092
+ """
1093
+ assert not hasattr(self.model[-1], "lrpc"), (
1094
+ "Prompt-free model does not support setting classes. Please try with Text/Visual prompt models."
1095
+ )
1096
+ assert embeddings.ndim == 3
1097
+ self.pe = embeddings
1098
+ self.model[-1].nc = len(names)
1099
+ self.names = check_class_names(names)
1100
+
1101
+ def get_cls_pe(self, tpe, vpe):
1102
+ """Get class positional embeddings.
1103
+
1104
+ Args:
1105
+ tpe (torch.Tensor, optional): Text positional embeddings.
1106
+ vpe (torch.Tensor, optional): Visual positional embeddings.
1107
+
1108
+ Returns:
1109
+ (torch.Tensor): Class positional embeddings.
1110
+ """
1111
+ all_pe = []
1112
+ if tpe is not None:
1113
+ assert tpe.ndim == 3
1114
+ all_pe.append(tpe)
1115
+ if vpe is not None:
1116
+ assert vpe.ndim == 3
1117
+ all_pe.append(vpe)
1118
+ if not all_pe:
1119
+ all_pe.append(getattr(self, "pe", torch.zeros(1, 80, 512)))
1120
+ return torch.cat(all_pe, dim=1)
1121
+
1122
+ def predict(
1123
+ self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
1124
+ ):
1125
+ """Perform a forward pass through the model.
1126
+
1127
+ Args:
1128
+ x (torch.Tensor): The input tensor.
1129
+ profile (bool): If True, profile the computation time for each layer.
1130
+ visualize (bool): If True, save feature maps for visualization.
1131
+ tpe (torch.Tensor, optional): Text positional embeddings.
1132
+ augment (bool): If True, perform data augmentation during inference.
1133
+ embed (list, optional): A list of feature vectors/embeddings to return.
1134
+ vpe (torch.Tensor, optional): Visual positional embeddings.
1135
+ return_vpe (bool): If True, return visual positional embeddings.
1136
+
1137
+ Returns:
1138
+ (torch.Tensor): Model's output tensor.
1139
+ """
1140
+ y, dt, embeddings = [], [], [] # outputs
1141
+ b = x.shape[0]
1142
+ embed = frozenset(embed) if embed is not None else {-1}
1143
+ max_idx = max(embed)
1144
+ for m in self.model: # except the head part
1145
+ if m.f != -1: # if not from previous layer
1146
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
1147
+ if profile:
1148
+ self._profile_one_layer(m, x, dt)
1149
+ if isinstance(m, YOLOEDetect):
1150
+ vpe = m.get_vpe(x, vpe) if vpe is not None else None
1151
+ if return_vpe:
1152
+ assert vpe is not None
1153
+ assert not self.training
1154
+ return vpe
1155
+ cls_pe = self.get_cls_pe(m.get_tpe(tpe), vpe).to(device=x[0].device, dtype=x[0].dtype)
1156
+ if cls_pe.shape[0] != b or m.export:
1157
+ cls_pe = cls_pe.expand(b, -1, -1)
1158
+ x = m(x, cls_pe)
1159
+ else:
1160
+ x = m(x) # run
1161
+
1162
+ y.append(x if m.i in self.save else None) # save output
1163
+ if visualize:
1164
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
1165
+ if m.i in embed:
1166
+ embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
1167
+ if m.i == max_idx:
1168
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
1169
+ return x
1170
+
1171
+ def loss(self, batch, preds=None):
1172
+ """Compute loss.
1173
+
1174
+ Args:
1175
+ batch (dict): Batch to compute loss on.
1176
+ preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
1177
+ """
1178
+ if not hasattr(self, "criterion"):
1179
+ from ultralytics.utils.loss import TVPDetectLoss
1180
+
1181
+ visual_prompt = batch.get("visuals", None) is not None # TODO
1182
+ self.criterion = TVPDetectLoss(self) if visual_prompt else self.init_criterion()
1183
+
1184
+ if preds is None:
1185
+ preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
1186
+ return self.criterion(preds, batch)
1187
+
1188
+
1189
+ class YOLOESegModel(YOLOEModel, SegmentationModel):
1190
+ """YOLOE segmentation model.
1191
+
1192
+ This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
1193
+ specialized loss computation for pixel-level object detection and segmentation.
1194
+
1195
+ Methods:
1196
+ __init__: Initialize YOLOE segmentation model.
1197
+ loss: Compute loss with prompts for segmentation.
1198
+
1199
+ Examples:
1200
+ Initialize a YOLOE segmentation model
1201
+ >>> model = YOLOESegModel("yoloe-v8s-seg.yaml", ch=3, nc=80)
1202
+ >>> results = model.predict(image_tensor, tpe=text_embeddings)
1203
+ """
1204
+
1205
+ def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
1206
+ """Initialize YOLOE segmentation model with given config and parameters.
1207
+
1208
+ Args:
1209
+ cfg (str | dict): Model configuration file path or dictionary.
1210
+ ch (int): Number of input channels.
1211
+ nc (int, optional): Number of classes.
1212
+ verbose (bool): Whether to display model information.
1213
+ """
1214
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
1215
+
1216
+ def loss(self, batch, preds=None):
1217
+ """Compute loss.
1218
+
1219
+ Args:
1220
+ batch (dict): Batch to compute loss on.
1221
+ preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
1222
+ """
1223
+ if not hasattr(self, "criterion"):
1224
+ from ultralytics.utils.loss import TVPSegmentLoss
1225
+
1226
+ visual_prompt = batch.get("visuals", None) is not None # TODO
1227
+ self.criterion = TVPSegmentLoss(self) if visual_prompt else self.init_criterion()
1228
+
1229
+ if preds is None:
1230
+ preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
1231
+ return self.criterion(preds, batch)
1232
+
1233
+
1234
+ class Ensemble(torch.nn.ModuleList):
1235
+ """Ensemble of models.
1236
+
1237
+ This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
1238
+ or other ensemble techniques.
1239
+
1240
+ Methods:
1241
+ __init__: Initialize an ensemble of models.
1242
+ forward: Generate predictions from all models in the ensemble.
1243
+
1244
+ Examples:
1245
+ Create an ensemble of models
1246
+ >>> ensemble = Ensemble()
1247
+ >>> ensemble.append(model1)
1248
+ >>> ensemble.append(model2)
1249
+ >>> results = ensemble(image_tensor)
1250
+ """
1251
+
1252
+ def __init__(self):
1253
+ """Initialize an ensemble of models."""
1254
+ super().__init__()
1255
+
1256
+ def forward(self, x, augment=False, profile=False, visualize=False):
1257
+ """Generate the YOLO network's final layer.
1258
+
1259
+ Args:
1260
+ x (torch.Tensor): Input tensor.
1261
+ augment (bool): Whether to augment the input.
1262
+ profile (bool): Whether to profile the model.
1263
+ visualize (bool): Whether to visualize the features.
1264
+
1265
+ Returns:
1266
+ y (torch.Tensor): Concatenated predictions from all models.
1267
+ train_out (None): Always None for ensemble inference.
1268
+ """
1269
+ y = [module(x, augment, profile, visualize)[0] for module in self]
1270
+ # y = torch.stack(y).max(0)[0] # max ensemble
1271
+ # y = torch.stack(y).mean(0) # mean ensemble
1272
+ y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
1273
+ return y, None # inference, train output
1274
+
1275
+
1276
+ # Functions ------------------------------------------------------------------------------------------------------------
1277
+
1278
+
1279
+ @contextlib.contextmanager
1280
+ def temporary_modules(modules=None, attributes=None):
1281
+ """Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
1282
+
1283
+ This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
1284
+ moved a module from one location to another, but you still want to support the old import paths for backwards
1285
+ compatibility.
1286
+
1287
+ Args:
1288
+ modules (dict, optional): A dictionary mapping old module paths to new module paths.
1289
+ attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
1290
+
1291
+ Examples:
1292
+ >>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
1293
+ >>> import old.module # this will now import new.module
1294
+ >>> from old.module import attribute # this will now import new.module.attribute
1295
+
1296
+ Notes:
1297
+ The changes are only in effect inside the context manager and are undone once the context manager exits.
1298
+ Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
1299
+ applications or libraries. Use this function with caution.
1300
+ """
1301
+ if modules is None:
1302
+ modules = {}
1303
+ if attributes is None:
1304
+ attributes = {}
1305
+ import sys
1306
+ from importlib import import_module
1307
+
1308
+ try:
1309
+ # Set attributes in sys.modules under their old name
1310
+ for old, new in attributes.items():
1311
+ old_module, old_attr = old.rsplit(".", 1)
1312
+ new_module, new_attr = new.rsplit(".", 1)
1313
+ setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
1314
+
1315
+ # Set modules in sys.modules under their old name
1316
+ for old, new in modules.items():
1317
+ sys.modules[old] = import_module(new)
1318
+
1319
+ yield
1320
+ finally:
1321
+ # Remove the temporary module paths
1322
+ for old in modules:
1323
+ if old in sys.modules:
1324
+ del sys.modules[old]
1325
+
1326
+
1327
+ class SafeClass:
1328
+ """A placeholder class to replace unknown classes during unpickling."""
1329
+
1330
+ def __init__(self, *args, **kwargs):
1331
+ """Initialize SafeClass instance, ignoring all arguments."""
1332
+ pass
1333
+
1334
+ def __call__(self, *args, **kwargs):
1335
+ """Run SafeClass instance, ignoring all arguments."""
1336
+ pass
1337
+
1338
+
1339
+ class SafeUnpickler(pickle.Unpickler):
1340
+ """Custom Unpickler that replaces unknown classes with SafeClass."""
1341
+
1342
+ def find_class(self, module, name):
1343
+ """Attempt to find a class, returning SafeClass if not among safe modules.
1344
+
1345
+ Args:
1346
+ module (str): Module name.
1347
+ name (str): Class name.
1348
+
1349
+ Returns:
1350
+ (type): Found class or SafeClass.
1351
+ """
1352
+ safe_modules = (
1353
+ "torch",
1354
+ "collections",
1355
+ "collections.abc",
1356
+ "builtins",
1357
+ "math",
1358
+ "numpy",
1359
+ # Add other modules considered safe
1360
+ )
1361
+ if module in safe_modules:
1362
+ return super().find_class(module, name)
1363
+ else:
1364
+ return SafeClass
1365
+
1366
+
1367
+ def torch_safe_load(weight, safe_only=False):
1368
+ """Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
1369
+ the error, logs a warning message, and attempts to install the missing module via the check_requirements()
1370
+ function. After installation, the function again attempts to load the model using torch.load().
1371
+
1372
+ Args:
1373
+ weight (str): The file path of the PyTorch model.
1374
+ safe_only (bool): If True, replace unknown classes with SafeClass during loading.
1375
+
1376
+ Returns:
1377
+ ckpt (dict): The loaded model checkpoint.
1378
+ file (str): The loaded filename.
1379
+
1380
+ Examples:
1381
+ >>> from ultralytics.nn.tasks import torch_safe_load
1382
+ >>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
1383
+ """
1384
+ from ultralytics.utils.downloads import attempt_download_asset
1385
+
1386
+ check_suffix(file=weight, suffix=".pt")
1387
+ file = attempt_download_asset(weight) # search online if missing locally
1388
+ try:
1389
+ with temporary_modules(
1390
+ modules={
1391
+ "ultralytics.yolo.utils": "ultralytics.utils",
1392
+ "ultralytics.yolo.v8": "ultralytics.models.yolo",
1393
+ "ultralytics.yolo.data": "ultralytics.data",
1394
+ },
1395
+ attributes={
1396
+ "ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
1397
+ "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
1398
+ "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
1399
+ },
1400
+ ):
1401
+ if safe_only:
1402
+ # Load via custom pickle module
1403
+ safe_pickle = types.ModuleType("safe_pickle")
1404
+ safe_pickle.Unpickler = SafeUnpickler
1405
+ safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
1406
+ with open(file, "rb") as f:
1407
+ ckpt = torch_load(f, pickle_module=safe_pickle)
1408
+ else:
1409
+ ckpt = torch_load(file, map_location="cpu")
1410
+
1411
+ except ModuleNotFoundError as e: # e.name is missing module name
1412
+ if e.name == "models":
1413
+ raise TypeError(
1414
+ emojis(
1415
+ f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
1416
+ f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
1417
+ f"YOLOv8 at https://github.com/ultralytics/ultralytics."
1418
+ f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
1419
+ f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
1420
+ )
1421
+ ) from e
1422
+ elif e.name == "numpy._core":
1423
+ raise ModuleNotFoundError(
1424
+ emojis(
1425
+ f"ERROR ❌️ {weight} requires numpy>=1.26.1, however numpy=={__import__('numpy').__version__} is installed."
1426
+ )
1427
+ ) from e
1428
+ LOGGER.warning(
1429
+ f"{weight} appears to require '{e.name}', which is not in Ultralytics requirements."
1430
+ f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
1431
+ f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
1432
+ f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
1433
+ )
1434
+ check_requirements(e.name) # install missing module
1435
+ ckpt = torch_load(file, map_location="cpu")
1436
+
1437
+ if not isinstance(ckpt, dict):
1438
+ # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")
1439
+ LOGGER.warning(
1440
+ f"The file '{weight}' appears to be improperly saved or formatted. "
1441
+ f"For optimal results, use model.save('filename.pt') to correctly save YOLO models."
1442
+ )
1443
+ ckpt = {"model": ckpt.model}
1444
+
1445
+ return ckpt, file
1446
+
1447
+
1448
+ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
1449
+ """Load a single model weights.
1450
+
1451
+ Args:
1452
+ weight (str | Path): Model weight path.
1453
+ device (torch.device, optional): Device to load model to.
1454
+ inplace (bool): Whether to do inplace operations.
1455
+ fuse (bool): Whether to fuse model.
1456
+
1457
+ Returns:
1458
+ model (torch.nn.Module): Loaded model.
1459
+ ckpt (dict): Model checkpoint dictionary.
1460
+ """
1461
+ ckpt, weight = torch_safe_load(weight) # load ckpt
1462
+ args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
1463
+ model = (ckpt.get("ema") or ckpt["model"]).float() # FP32 model
1464
+
1465
+ # Model compatibility updates
1466
+ model.args = args # attach args to model
1467
+ model.pt_path = weight # attach *.pt file path to model
1468
+ model.task = getattr(model, "task", guess_model_task(model))
1469
+ if not hasattr(model, "stride"):
1470
+ model.stride = torch.tensor([32.0])
1471
+
1472
+ model = (model.fuse() if fuse and hasattr(model, "fuse") else model).eval().to(device) # model in eval mode
1473
+
1474
+ # Module updates
1475
+ for m in model.modules():
1476
+ if hasattr(m, "inplace"):
1477
+ m.inplace = inplace
1478
+ elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
1479
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
1480
+
1481
+ # Return model and ckpt
1482
+ return model, ckpt
1483
+
1484
+
1485
+ def parse_model(d, ch, verbose=True):
1486
+ """Parse a YOLO model.yaml dictionary into a PyTorch model.
1487
+
1488
+ Args:
1489
+ d (dict): Model dictionary.
1490
+ ch (int): Input channels.
1491
+ verbose (bool): Whether to print model details.
1492
+
1493
+ Returns:
1494
+ model (torch.nn.Sequential): PyTorch model.
1495
+ save (list): Sorted list of output layers.
1496
+ """
1497
+ import ast
1498
+
1499
+ # Args
1500
+ legacy = True # backward compatibility for v3/v5/v8/v9 models
1501
+ max_channels = float("inf")
1502
+ nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
1503
+ depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
1504
+ scale = d.get("scale")
1505
+ if scales:
1506
+ if not scale:
1507
+ scale = next(iter(scales.keys()))
1508
+ LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
1509
+ depth, width, max_channels = scales[scale]
1510
+
1511
+ if act:
1512
+ Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()
1513
+ if verbose:
1514
+ LOGGER.info(f"{colorstr('activation:')} {act}") # print
1515
+
1516
+ if verbose:
1517
+ LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
1518
+ ch = [ch]
1519
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
1520
+ base_modules = frozenset(
1521
+ {
1522
+ Classify,
1523
+ Conv,
1524
+ ConvTranspose,
1525
+ GhostConv,
1526
+ Bottleneck,
1527
+ GhostBottleneck,
1528
+ SPP,
1529
+ SPPF,
1530
+ C2fPSA,
1531
+ C2PSA,
1532
+ DWConv,
1533
+ Focus,
1534
+ BottleneckCSP,
1535
+ C1,
1536
+ C2,
1537
+ C2f,
1538
+ C3k2,
1539
+ RepNCSPELAN4,
1540
+ ELAN1,
1541
+ ADown,
1542
+ AConv,
1543
+ SPPELAN,
1544
+ C2fAttn,
1545
+ C3,
1546
+ C3TR,
1547
+ C3Ghost,
1548
+ torch.nn.ConvTranspose2d,
1549
+ DWConvTranspose2d,
1550
+ C3x,
1551
+ RepC3,
1552
+ PSA,
1553
+ SCDown,
1554
+ C2fCIB,
1555
+ A2C2f,
1556
+ }
1557
+ )
1558
+ repeat_modules = frozenset( # modules with 'repeat' arguments
1559
+ {
1560
+ BottleneckCSP,
1561
+ C1,
1562
+ C2,
1563
+ C2f,
1564
+ C3k2,
1565
+ C2fAttn,
1566
+ C3,
1567
+ C3TR,
1568
+ C3Ghost,
1569
+ C3x,
1570
+ RepC3,
1571
+ C2fPSA,
1572
+ C2fCIB,
1573
+ C2PSA,
1574
+ A2C2f,
1575
+ }
1576
+ )
1577
+ for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
1578
+ m = (
1579
+ getattr(torch.nn, m[3:])
1580
+ if "nn." in m
1581
+ else getattr(__import__("torchvision").ops, m[16:])
1582
+ if "torchvision.ops." in m
1583
+ else globals()[m]
1584
+ ) # get module
1585
+ for j, a in enumerate(args):
1586
+ if isinstance(a, str):
1587
+ with contextlib.suppress(ValueError):
1588
+ args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
1589
+ n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
1590
+ if m in base_modules:
1591
+ c1, c2 = ch[f], args[0]
1592
+ if c2 != nc: # if c2 != nc (e.g., Classify() output)
1593
+ c2 = make_divisible(min(c2, max_channels) * width, 8)
1594
+ if m is C2fAttn: # set 1) embed channels and 2) num heads
1595
+ args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
1596
+ args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
1597
+
1598
+ args = [c1, c2, *args[1:]]
1599
+ if m in repeat_modules:
1600
+ args.insert(2, n) # number of repeats
1601
+ n = 1
1602
+ if m is C3k2: # for M/L/X sizes
1603
+ legacy = False
1604
+ if scale in "mlx":
1605
+ args[3] = True
1606
+ if m is A2C2f:
1607
+ legacy = False
1608
+ if scale in "lx": # for L/X sizes
1609
+ args.extend((True, 1.2))
1610
+ if m is C2fCIB:
1611
+ legacy = False
1612
+ elif m is AIFI:
1613
+ args = [ch[f], *args]
1614
+ elif m in frozenset({HGStem, HGBlock}):
1615
+ c1, cm, c2 = ch[f], args[0], args[1]
1616
+ args = [c1, cm, c2, *args[2:]]
1617
+ if m is HGBlock:
1618
+ args.insert(4, n) # number of repeats
1619
+ n = 1
1620
+ elif m is ResNetLayer:
1621
+ c2 = args[1] if args[3] else args[1] * 4
1622
+ elif m is torch.nn.BatchNorm2d:
1623
+ args = [ch[f]]
1624
+ elif m is Concat:
1625
+ c2 = sum(ch[x] for x in f)
1626
+ elif m in frozenset(
1627
+ {Detect, WorldDetect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB, ImagePoolingAttn, v10Detect}
1628
+ ):
1629
+ args.append([ch[x] for x in f])
1630
+ if m is Segment or m is YOLOESegment:
1631
+ args[2] = make_divisible(min(args[2], max_channels) * width, 8)
1632
+ if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:
1633
+ m.legacy = legacy
1634
+ elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
1635
+ args.insert(1, [ch[x] for x in f])
1636
+ elif m is CBLinear:
1637
+ c2 = args[0]
1638
+ c1 = ch[f]
1639
+ args = [c1, c2, *args[1:]]
1640
+ elif m is CBFuse:
1641
+ c2 = ch[f[-1]]
1642
+ elif m in frozenset({TorchVision, Index}):
1643
+ c2 = args[0]
1644
+ c1 = ch[f]
1645
+ args = [*args[1:]]
1646
+ else:
1647
+ c2 = ch[f]
1648
+
1649
+ m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
1650
+ t = str(m)[8:-2].replace("__main__.", "") # module type
1651
+ m_.np = sum(x.numel() for x in m_.parameters()) # number params
1652
+ m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
1653
+ if verbose:
1654
+ LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f} {t:<45}{args!s:<30}") # print
1655
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
1656
+ layers.append(m_)
1657
+ if i == 0:
1658
+ ch = []
1659
+ ch.append(c2)
1660
+ return torch.nn.Sequential(*layers), sorted(save)
1661
+
1662
+
1663
+ def yaml_model_load(path):
1664
+ """Load a YOLOv8 model from a YAML file.
1665
+
1666
+ Args:
1667
+ path (str | Path): Path to the YAML file.
1668
+
1669
+ Returns:
1670
+ (dict): Model dictionary.
1671
+ """
1672
+ path = Path(path)
1673
+ if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
1674
+ new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
1675
+ LOGGER.warning(f"Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
1676
+ path = path.with_name(new_stem + path.suffix)
1677
+
1678
+ unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
1679
+ yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
1680
+ d = YAML.load(yaml_file) # model dict
1681
+ d["scale"] = guess_model_scale(path)
1682
+ d["yaml_file"] = str(path)
1683
+ return d
1684
+
1685
+
1686
+ def guess_model_scale(model_path):
1687
+ """Extract the size character n, s, m, l, or x of the model's scale from the model path.
1688
+
1689
+ Args:
1690
+ model_path (str | Path): The path to the YOLO model's YAML file.
1691
+
1692
+ Returns:
1693
+ (str): The size character of the model's scale (n, s, m, l, or x).
1694
+ """
1695
+ try:
1696
+ return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
1697
+ except AttributeError:
1698
+ return ""
1699
+
1700
+
1701
+ def guess_model_task(model):
1702
+ """Guess the task of a PyTorch model from its architecture or configuration.
1703
+
1704
+ Args:
1705
+ model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
1706
+
1707
+ Returns:
1708
+ (str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').
1709
+ """
1710
+
1711
+ def cfg2task(cfg):
1712
+ """Guess from YAML dictionary."""
1713
+ m = cfg["head"][-1][-2].lower() # output module name
1714
+ if m in {"classify", "classifier", "cls", "fc"}:
1715
+ return "classify"
1716
+ if "detect" in m:
1717
+ return "detect"
1718
+ if "segment" in m:
1719
+ return "segment"
1720
+ if m == "pose":
1721
+ return "pose"
1722
+ if m == "obb":
1723
+ return "obb"
1724
+
1725
+ # Guess from model cfg
1726
+ if isinstance(model, dict):
1727
+ with contextlib.suppress(Exception):
1728
+ return cfg2task(model)
1729
+ # Guess from PyTorch model
1730
+ if isinstance(model, torch.nn.Module): # PyTorch model
1731
+ for x in "model.args", "model.model.args", "model.model.model.args":
1732
+ with contextlib.suppress(Exception):
1733
+ return eval(x)["task"] # nosec B307: safe eval of known attribute paths
1734
+ for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
1735
+ with contextlib.suppress(Exception):
1736
+ return cfg2task(eval(x)) # nosec B307: safe eval of known attribute paths
1737
+ for m in model.modules():
1738
+ if isinstance(m, (Segment, YOLOESegment)):
1739
+ return "segment"
1740
+ elif isinstance(m, Classify):
1741
+ return "classify"
1742
+ elif isinstance(m, Pose):
1743
+ return "pose"
1744
+ elif isinstance(m, OBB):
1745
+ return "obb"
1746
+ elif isinstance(m, (Detect, WorldDetect, YOLOEDetect, v10Detect)):
1747
+ return "detect"
1748
+
1749
+ # Guess from model filename
1750
+ if isinstance(model, (str, Path)):
1751
+ model = Path(model)
1752
+ if "-seg" in model.stem or "segment" in model.parts:
1753
+ return "segment"
1754
+ elif "-cls" in model.stem or "classify" in model.parts:
1755
+ return "classify"
1756
+ elif "-pose" in model.stem or "pose" in model.parts:
1757
+ return "pose"
1758
+ elif "-obb" in model.stem or "obb" in model.parts:
1759
+ return "obb"
1760
+ elif "detect" in model.parts:
1761
+ return "detect"
1762
+
1763
+ # Unable to determine task from model
1764
+ LOGGER.warning(
1765
+ "Unable to automatically guess model task, assuming 'task=detect'. "
1766
+ "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
1767
+ )
1768
+ return "detect" # assume detect