ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,13 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ import gc
3
4
  import math
4
5
  import os
5
6
  import random
6
7
  import time
7
8
  from contextlib import contextmanager
8
9
  from copy import deepcopy
10
+ from datetime import datetime
9
11
  from pathlib import Path
10
12
  from typing import Union
11
13
 
@@ -14,33 +16,51 @@ import torch
14
16
  import torch.distributed as dist
15
17
  import torch.nn as nn
16
18
  import torch.nn.functional as F
17
- import torchvision
18
19
 
19
- from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
20
- from ultralytics.utils.checks import PYTHON_VERSION, check_version
20
+ from ultralytics.utils import (
21
+ DEFAULT_CFG_DICT,
22
+ DEFAULT_CFG_KEYS,
23
+ LOGGER,
24
+ NUM_THREADS,
25
+ PYTHON_VERSION,
26
+ TORCHVISION_VERSION,
27
+ WINDOWS,
28
+ __version__,
29
+ colorstr,
30
+ )
31
+ from ultralytics.utils.checks import check_version
21
32
 
22
33
  try:
23
34
  import thop
24
35
  except ImportError:
25
36
  thop = None
26
37
 
38
+ # Version checks (all default to version>=min_version)
27
39
  TORCH_1_9 = check_version(torch.__version__, "1.9.0")
28
40
  TORCH_1_13 = check_version(torch.__version__, "1.13.0")
29
41
  TORCH_2_0 = check_version(torch.__version__, "2.0.0")
30
- TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
31
- TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
32
- TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
42
+ TORCH_2_4 = check_version(torch.__version__, "2.4.0")
43
+ TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
44
+ TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
45
+ TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
46
+ TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
47
+ if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
48
+ LOGGER.warning(
49
+ "WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
50
+ "https://github.com/ultralytics/ultralytics/issues/15049"
51
+ )
33
52
 
34
53
 
35
54
  @contextmanager
36
55
  def torch_distributed_zero_first(local_rank: int):
37
- """Decorator to make all processes in distributed training wait for each local_master to do something."""
38
- initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
39
- if initialized and local_rank not in (-1, 0):
56
+ """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
57
+ initialized = dist.is_available() and dist.is_initialized()
58
+
59
+ if initialized and local_rank not in {-1, 0}:
40
60
  dist.barrier(device_ids=[local_rank])
41
61
  yield
42
62
  if initialized and local_rank == 0:
43
- dist.barrier(device_ids=[0])
63
+ dist.barrier(device_ids=[local_rank])
44
64
 
45
65
 
46
66
  def smart_inference_mode():
@@ -56,14 +76,58 @@ def smart_inference_mode():
56
76
  return decorate
57
77
 
58
78
 
79
+ def autocast(enabled: bool, device: str = "cuda"):
80
+ """
81
+ Get the appropriate autocast context manager based on PyTorch version and AMP setting.
82
+
83
+ This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
84
+ older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
85
+
86
+ Args:
87
+ enabled (bool): Whether to enable automatic mixed precision.
88
+ device (str, optional): The device to use for autocast. Defaults to 'cuda'.
89
+
90
+ Returns:
91
+ (torch.amp.autocast): The appropriate autocast context manager.
92
+
93
+ Note:
94
+ - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
95
+ - For older versions, it uses `torch.cuda.autocast`.
96
+
97
+ Example:
98
+ ```python
99
+ with autocast(amp=True):
100
+ # Your mixed precision operations here
101
+ pass
102
+ ```
103
+ """
104
+ if TORCH_1_13:
105
+ return torch.amp.autocast(device, enabled=enabled)
106
+ else:
107
+ return torch.cuda.amp.autocast(enabled)
108
+
109
+
59
110
  def get_cpu_info():
60
111
  """Return a string with system CPU information, i.e. 'Apple M2'."""
61
- import cpuinfo # pip install py-cpuinfo
112
+ from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
62
113
 
63
- k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
64
- info = cpuinfo.get_cpu_info() # info dict
65
- string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
66
- return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
114
+ if "cpu_info" not in PERSISTENT_CACHE:
115
+ try:
116
+ import cpuinfo # pip install py-cpuinfo
117
+
118
+ k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
119
+ info = cpuinfo.get_cpu_info() # info dict
120
+ string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
121
+ PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
122
+ except Exception:
123
+ pass
124
+ return PERSISTENT_CACHE.get("cpu_info", "unknown")
125
+
126
+
127
+ def get_gpu_info(index):
128
+ """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
129
+ properties = torch.cuda.get_device_properties(index)
130
+ return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
67
131
 
68
132
 
69
133
  def select_device(device="", batch=0, newline=False, verbose=True):
@@ -90,30 +154,31 @@ def select_device(device="", batch=0, newline=False, verbose=True):
90
154
  devices when using multiple GPUs.
91
155
 
92
156
  Examples:
93
- >>> select_device('cuda:0')
157
+ >>> select_device("cuda:0")
94
158
  device(type='cuda', index=0)
95
159
 
96
- >>> select_device('cpu')
160
+ >>> select_device("cpu")
97
161
  device(type='cpu')
98
162
 
99
163
  Note:
100
164
  Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
101
165
  """
102
-
103
- if isinstance(device, torch.device):
166
+ if isinstance(device, torch.device) or str(device).startswith("tpu"):
104
167
  return device
105
168
 
106
- s = f"Ultralytics YOLOv{__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
169
+ s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
107
170
  device = str(device).lower()
108
171
  for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
109
172
  device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
110
173
  cpu = device == "cpu"
111
- mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
174
+ mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
112
175
  if cpu or mps:
113
176
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
114
177
  elif device: # non-cpu device requested
115
178
  if device == "cuda":
116
179
  device = "0"
180
+ if "," in device:
181
+ device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
117
182
  visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
118
183
  os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
119
184
  if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
@@ -135,17 +200,22 @@ def select_device(device="", batch=0, newline=False, verbose=True):
135
200
  )
136
201
 
137
202
  if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
138
- devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
203
+ devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
139
204
  n = len(devices) # device count
140
- if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
141
- raise ValueError(
142
- f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
143
- f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
144
- )
205
+ if n > 1: # multi-GPU
206
+ if batch < 1:
207
+ raise ValueError(
208
+ "AutoBatch with batch<1 not supported for Multi-GPU training, "
209
+ "please specify a valid batch size, i.e. batch=16."
210
+ )
211
+ if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
212
+ raise ValueError(
213
+ f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
214
+ f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
215
+ )
145
216
  space = " " * (len(s) + 1)
146
217
  for i, d in enumerate(devices):
147
- p = torch.cuda.get_device_properties(i)
148
- s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
218
+ s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
149
219
  arg = "cuda:0"
150
220
  elif mps and TORCH_2_0 and torch.backends.mps.is_available():
151
221
  # Prefer MPS if available
@@ -155,6 +225,8 @@ def select_device(device="", batch=0, newline=False, verbose=True):
155
225
  s += f"CPU ({get_cpu_info()})\n"
156
226
  arg = "cpu"
157
227
 
228
+ if arg in {"cpu", "mps"}:
229
+ torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
158
230
  if verbose:
159
231
  LOGGER.info(s if newline else s.rstrip())
160
232
  return torch.device(arg)
@@ -185,7 +257,7 @@ def fuse_conv_and_bn(conv, bn):
185
257
  )
186
258
 
187
259
  # Prepare filters
188
- w_conv = conv.weight.clone().view(conv.out_channels, -1)
260
+ w_conv = conv.weight.view(conv.out_channels, -1)
189
261
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
190
262
  fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
191
263
 
@@ -216,7 +288,7 @@ def fuse_deconv_and_bn(deconv, bn):
216
288
  )
217
289
 
218
290
  # Prepare filters
219
- w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
291
+ w_deconv = deconv.weight.view(deconv.out_channels, -1)
220
292
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
221
293
  fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
222
294
 
@@ -229,33 +301,27 @@ def fuse_deconv_and_bn(deconv, bn):
229
301
 
230
302
 
231
303
  def model_info(model, detailed=False, verbose=True, imgsz=640):
232
- """
233
- Model information.
234
-
235
- imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
236
- """
304
+ """Print and return detailed model information layer by layer."""
237
305
  if not verbose:
238
306
  return
239
307
  n_p = get_num_params(model) # number of parameters
240
308
  n_g = get_num_gradients(model) # number of gradients
241
309
  n_l = len(list(model.modules())) # number of layers
242
310
  if detailed:
243
- LOGGER.info(
244
- f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
245
- )
311
+ LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}")
246
312
  for i, (name, p) in enumerate(model.named_parameters()):
247
313
  name = name.replace("module_list.", "")
248
314
  LOGGER.info(
249
- "%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
250
- % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
315
+ f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}"
316
+ f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}"
251
317
  )
252
318
 
253
- flops = get_flops(model, imgsz)
319
+ flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
254
320
  fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
255
321
  fs = f", {flops:.1f} GFLOPs" if flops else ""
256
322
  yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
257
323
  model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
258
- LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
324
+ LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
259
325
  return n_l, n_p, n_g, flops
260
326
 
261
327
 
@@ -276,11 +342,13 @@ def model_info_for_loggers(trainer):
276
342
  Example:
277
343
  YOLOv8n info for loggers
278
344
  ```python
279
- results = {'model/parameters': 3151904,
280
- 'model/GFLOPs': 8.746,
281
- 'model/speed_ONNX(ms)': 41.244,
282
- 'model/speed_TensorRT(ms)': 3.211,
283
- 'model/speed_PyTorch(ms)': 18.755}
345
+ results = {
346
+ "model/parameters": 3151904,
347
+ "model/GFLOPs": 8.746,
348
+ "model/speed_ONNX(ms)": 41.244,
349
+ "model/speed_TensorRT(ms)": 3.211,
350
+ "model/speed_PyTorch(ms)": 18.755,
351
+ }
284
352
  ```
285
353
  """
286
354
  if trainer.args.profile: # profile ONNX and TensorRT times
@@ -322,19 +390,28 @@ def get_flops(model, imgsz=640):
322
390
 
323
391
 
324
392
  def get_flops_with_torch_profiler(model, imgsz=640):
325
- """Compute model FLOPs (thop alternative)."""
326
- if TORCH_2_0:
327
- model = de_parallel(model)
328
- p = next(model.parameters())
393
+ """Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
394
+ if not TORCH_2_0: # torch profiler implemented in torch>=2.0
395
+ return 0.0
396
+ model = de_parallel(model)
397
+ p = next(model.parameters())
398
+ if not isinstance(imgsz, list):
399
+ imgsz = [imgsz, imgsz] # expand if int/float
400
+ try:
401
+ # Use stride size for input tensor
329
402
  stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
330
- im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
403
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
331
404
  with torch.profiler.profile(with_flops=True) as prof:
332
405
  model(im)
333
406
  flops = sum(x.flops for x in prof.key_averages()) / 1e9
334
- imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
335
407
  flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
336
- return flops
337
- return 0
408
+ except Exception:
409
+ # Use actual image size for input tensor (i.e. required for RTDETR models)
410
+ im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
411
+ with torch.profiler.profile(with_flops=True) as prof:
412
+ model(im)
413
+ flops = sum(x.flops for x in prof.key_averages()) / 1e9
414
+ return flops
338
415
 
339
416
 
340
417
  def initialize_weights(model):
@@ -346,14 +423,12 @@ def initialize_weights(model):
346
423
  elif t is nn.BatchNorm2d:
347
424
  m.eps = 1e-3
348
425
  m.momentum = 0.03
349
- elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
426
+ elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
350
427
  m.inplace = True
351
428
 
352
429
 
353
430
  def scale_img(img, ratio=1.0, same_shape=False, gs=32):
354
- """Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
355
- retaining the original shape.
356
- """
431
+ """Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
357
432
  if ratio == 1.0:
358
433
  return img
359
434
  h, w = img.shape[2:]
@@ -364,13 +439,6 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
364
439
  return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
365
440
 
366
441
 
367
- def make_divisible(x, divisor):
368
- """Returns nearest x divisible by divisor."""
369
- if isinstance(divisor, torch.Tensor):
370
- divisor = int(divisor.max()) # to int
371
- return math.ceil(x / divisor) * divisor
372
-
373
-
374
442
  def copy_attr(a, b, include=(), exclude=()):
375
443
  """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
376
444
  for k, v in b.__dict__.items():
@@ -381,8 +449,13 @@ def copy_attr(a, b, include=(), exclude=()):
381
449
 
382
450
 
383
451
  def get_latest_opset():
384
- """Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
385
- return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
452
+ """Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
453
+ if TORCH_1_13:
454
+ # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
455
+ return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
456
+ # Otherwise for PyTorch<=1.12 return the corresponding predefined opset
457
+ version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
458
+ return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
386
459
 
387
460
 
388
461
  def intersect_dicts(da, db, exclude=()):
@@ -427,14 +500,17 @@ def init_seeds(seed=0, deterministic=False):
427
500
 
428
501
 
429
502
  class ModelEMA:
430
- """Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
431
- Keeps a moving average of everything in the model state_dict (parameters and buffers)
503
+ """
504
+ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
505
+ average of everything in the model state_dict (parameters and buffers).
506
+
432
507
  For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
508
+
433
509
  To disable EMA set the `enabled` attribute to `False`.
434
510
  """
435
511
 
436
512
  def __init__(self, model, decay=0.9999, tau=2000, updates=0):
437
- """Create EMA."""
513
+ """Initialize EMA for 'model' with given arguments."""
438
514
  self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
439
515
  self.updates = updates # number of EMA updates
440
516
  self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
@@ -461,50 +537,113 @@ class ModelEMA:
461
537
  copy_attr(self.ema, model, include, exclude)
462
538
 
463
539
 
464
- def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
540
+ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
465
541
  """
466
542
  Strip optimizer from 'f' to finalize training, optionally save as 's'.
467
543
 
468
544
  Args:
469
545
  f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
470
546
  s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
547
+ updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
471
548
 
472
549
  Returns:
473
- None
550
+ (dict): The combined checkpoint dictionary.
474
551
 
475
552
  Example:
476
553
  ```python
477
554
  from pathlib import Path
478
555
  from ultralytics.utils.torch_utils import strip_optimizer
479
556
 
480
- for f in Path('path/to/weights').rglob('*.pt'):
557
+ for f in Path("path/to/model/checkpoints").rglob("*.pt"):
481
558
  strip_optimizer(f)
482
559
  ```
483
- """
484
- x = torch.load(f, map_location=torch.device("cpu"))
485
- if "model" not in x:
486
- LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
487
- return
488
560
 
561
+ Note:
562
+ Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
563
+ """
564
+ try:
565
+ x = torch.load(f, map_location=torch.device("cpu"))
566
+ assert isinstance(x, dict), "checkpoint is not a Python dictionary"
567
+ assert "model" in x, "'model' missing from checkpoint"
568
+ except Exception as e:
569
+ LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
570
+ return {}
571
+
572
+ metadata = {
573
+ "date": datetime.now().isoformat(),
574
+ "version": __version__,
575
+ "license": "AGPL-3.0 License (https://ultralytics.com/license)",
576
+ "docs": "https://docs.ultralytics.com",
577
+ }
578
+
579
+ # Update model
580
+ if x.get("ema"):
581
+ x["model"] = x["ema"] # replace model with EMA
489
582
  if hasattr(x["model"], "args"):
490
583
  x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
491
- args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
492
- if x.get("ema"):
493
- x["model"] = x["ema"] # replace model with ema
494
- for k in "optimizer", "best_fitness", "ema", "updates": # keys
495
- x[k] = None
496
- x["epoch"] = -1
584
+ if hasattr(x["model"], "criterion"):
585
+ x["model"].criterion = None # strip loss criterion
497
586
  x["model"].half() # to FP16
498
587
  for p in x["model"].parameters():
499
588
  p.requires_grad = False
589
+
590
+ # Update other keys
591
+ args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
592
+ for k in "optimizer", "best_fitness", "ema", "updates": # keys
593
+ x[k] = None
594
+ x["epoch"] = -1
500
595
  x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
501
596
  # x['model'].args = x['train_args']
502
- torch.save(x, s or f)
597
+
598
+ # Save
599
+ combined = {**metadata, **x, **(updates or {})}
600
+ torch.save(combined, s or f) # combine dicts (prefer to the right)
503
601
  mb = os.path.getsize(s or f) / 1e6 # file size
504
602
  LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
603
+ return combined
604
+
605
+
606
+ def convert_optimizer_state_dict_to_fp16(state_dict):
607
+ """
608
+ Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
609
+
610
+ This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
611
+ """
612
+ for state in state_dict["state"].values():
613
+ for k, v in state.items():
614
+ if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
615
+ state[k] = v.half()
616
+
617
+ return state_dict
618
+
619
+
620
+ @contextmanager
621
+ def cuda_memory_usage(device=None):
622
+ """
623
+ Monitor and manage CUDA memory usage.
624
+
625
+ This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
626
+ It then yields a dictionary containing memory usage information, which can be updated by the caller.
627
+ Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
628
+
629
+ Args:
630
+ device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
631
+
632
+ Yields:
633
+ (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
634
+ """
635
+ cuda_info = dict(memory=0)
636
+ if torch.cuda.is_available():
637
+ torch.cuda.empty_cache()
638
+ try:
639
+ yield cuda_info
640
+ finally:
641
+ cuda_info["memory"] = torch.cuda.memory_reserved(device)
642
+ else:
643
+ yield cuda_info
505
644
 
506
645
 
507
- def profile(input, ops, n=10, device=None):
646
+ def profile(input, ops, n=10, device=None, max_num_obj=0):
508
647
  """
509
648
  Ultralytics speed, memory and FLOPs profiler.
510
649
 
@@ -525,7 +664,8 @@ def profile(input, ops, n=10, device=None):
525
664
  f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
526
665
  f"{'input':>24s}{'output':>24s}"
527
666
  )
528
-
667
+ gc.collect() # attempt to free unused memory
668
+ torch.cuda.empty_cache()
529
669
  for x in input if isinstance(input, list) else [input]:
530
670
  x = x.to(device)
531
671
  x.requires_grad = True
@@ -539,19 +679,31 @@ def profile(input, ops, n=10, device=None):
539
679
  flops = 0
540
680
 
541
681
  try:
682
+ mem = 0
542
683
  for _ in range(n):
543
- t[0] = time_sync()
544
- y = m(x)
545
- t[1] = time_sync()
546
- try:
547
- (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
548
- t[2] = time_sync()
549
- except Exception: # no backward method
550
- # print(e) # for debug
551
- t[2] = float("nan")
684
+ with cuda_memory_usage(device) as cuda_info:
685
+ t[0] = time_sync()
686
+ y = m(x)
687
+ t[1] = time_sync()
688
+ try:
689
+ (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
690
+ t[2] = time_sync()
691
+ except Exception: # no backward method
692
+ # print(e) # for debug
693
+ t[2] = float("nan")
694
+ mem += cuda_info["memory"] / 1e9 # (GB)
552
695
  tf += (t[1] - t[0]) * 1000 / n # ms per op forward
553
696
  tb += (t[2] - t[1]) * 1000 / n # ms per op backward
554
- mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
697
+ if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
698
+ with cuda_memory_usage(device) as cuda_info:
699
+ torch.randn(
700
+ x.shape[0],
701
+ max_num_obj,
702
+ int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
703
+ device=device,
704
+ dtype=torch.float32,
705
+ )
706
+ mem += cuda_info["memory"] / 1e9 # (GB)
555
707
  s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
556
708
  p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
557
709
  LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
@@ -559,7 +711,9 @@ def profile(input, ops, n=10, device=None):
559
711
  except Exception as e:
560
712
  LOGGER.info(e)
561
713
  results.append(None)
562
- torch.cuda.empty_cache()
714
+ finally:
715
+ gc.collect() # attempt to free unused memory
716
+ torch.cuda.empty_cache()
563
717
  return results
564
718
 
565
719
 
@@ -599,10 +753,56 @@ class EarlyStopping:
599
753
  self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
600
754
  stop = delta >= self.patience # stop training if patience exceeded
601
755
  if stop:
756
+ prefix = colorstr("EarlyStopping: ")
602
757
  LOGGER.info(
603
- f"Stopping training early as no improvement observed in last {self.patience} epochs. "
758
+ f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
604
759
  f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
605
760
  f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
606
761
  f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
607
762
  )
608
763
  return stop
764
+
765
+
766
+ class FXModel(nn.Module):
767
+ """
768
+ A custom model class for torch.fx compatibility.
769
+
770
+ This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
771
+ It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
772
+
773
+ Args:
774
+ model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
775
+ """
776
+
777
+ def __init__(self, model):
778
+ """
779
+ Initialize the FXModel.
780
+
781
+ Args:
782
+ model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
783
+ """
784
+ super().__init__()
785
+ copy_attr(self, model)
786
+ # Explicitly set `model` since `copy_attr` somehow does not copy it.
787
+ self.model = model.model
788
+
789
+ def forward(self, x):
790
+ """
791
+ Forward pass through the model.
792
+
793
+ This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
794
+
795
+ Args:
796
+ x (torch.Tensor): The input tensor to the model.
797
+
798
+ Returns:
799
+ (torch.Tensor): The output tensor from the model.
800
+ """
801
+ y = [] # outputs
802
+ for m in self.model:
803
+ if m.f != -1: # if not from previous layer
804
+ # from earlier layers
805
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
806
+ x = m(x) # run
807
+ y.append(x) # save output
808
+ return x
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from typing import List
4
4
  from urllib.parse import urlsplit
@@ -66,6 +66,7 @@ class TritonRemoteModel:
66
66
  self.np_input_formats = [type_map[x] for x in self.input_formats]
67
67
  self.input_names = [x["name"] for x in config["input"]]
68
68
  self.output_names = [x["name"] for x in config["output"]]
69
+ self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
69
70
 
70
71
  def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
71
72
  """