dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,8 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import functools
3
6
  import gc
4
7
  import math
5
8
  import os
@@ -9,7 +12,7 @@ from contextlib import contextmanager
9
12
  from copy import deepcopy
10
13
  from datetime import datetime
11
14
  from pathlib import Path
12
- from typing import Union
15
+ from typing import Any
13
16
 
14
17
  import numpy as np
15
18
  import torch
@@ -24,22 +27,29 @@ from ultralytics.utils import (
24
27
  LOGGER,
25
28
  NUM_THREADS,
26
29
  PYTHON_VERSION,
30
+ TORCH_VERSION,
27
31
  TORCHVISION_VERSION,
28
32
  WINDOWS,
29
33
  colorstr,
30
34
  )
31
35
  from ultralytics.utils.checks import check_version
36
+ from ultralytics.utils.cpu import CPUInfo
37
+ from ultralytics.utils.patches import torch_load
32
38
 
33
39
  # Version checks (all default to version>=min_version)
34
- TORCH_1_9 = check_version(torch.__version__, "1.9.0")
35
- TORCH_1_13 = check_version(torch.__version__, "1.13.0")
36
- TORCH_2_0 = check_version(torch.__version__, "2.0.0")
37
- TORCH_2_4 = check_version(torch.__version__, "2.4.0")
40
+ TORCH_1_9 = check_version(TORCH_VERSION, "1.9.0")
41
+ TORCH_1_10 = check_version(TORCH_VERSION, "1.10.0")
42
+ TORCH_1_11 = check_version(TORCH_VERSION, "1.11.0")
43
+ TORCH_1_13 = check_version(TORCH_VERSION, "1.13.0")
44
+ TORCH_2_0 = check_version(TORCH_VERSION, "2.0.0")
45
+ TORCH_2_1 = check_version(TORCH_VERSION, "2.1.0")
46
+ TORCH_2_4 = check_version(TORCH_VERSION, "2.4.0")
47
+ TORCH_2_9 = check_version(TORCH_VERSION, "2.9.0")
38
48
  TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
39
49
  TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
40
50
  TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
41
51
  TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
42
- if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
52
+ if WINDOWS and check_version(TORCH_VERSION, "==2.4.0"): # reject version 2.4.0 on Windows
43
53
  LOGGER.warning(
44
54
  "Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
45
55
  "https://github.com/ultralytics/ultralytics/issues/15049"
@@ -48,7 +58,7 @@ if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.
48
58
 
49
59
  @contextmanager
50
60
  def torch_distributed_zero_first(local_rank: int):
51
- """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
61
+ """Ensure all processes in distributed training wait for the local master (rank 0) to complete a task first."""
52
62
  initialized = dist.is_available() and dist.is_initialized()
53
63
  use_ids = initialized and dist.get_backend() == "nccl"
54
64
 
@@ -60,10 +70,10 @@ def torch_distributed_zero_first(local_rank: int):
60
70
 
61
71
 
62
72
  def smart_inference_mode():
63
- """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
73
+ """Apply torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
64
74
 
65
75
  def decorate(fn):
66
- """Applies appropriate torch decorator for inference mode based on torch version."""
76
+ """Apply appropriate torch decorator for inference mode based on torch version."""
67
77
  if TORCH_1_9 and torch.is_inference_mode_enabled():
68
78
  return fn # already in inference_mode, act as a pass-through
69
79
  else:
@@ -73,27 +83,26 @@ def smart_inference_mode():
73
83
 
74
84
 
75
85
  def autocast(enabled: bool, device: str = "cuda"):
76
- """
77
- Get the appropriate autocast context manager based on PyTorch version and AMP setting.
86
+ """Get the appropriate autocast context manager based on PyTorch version and AMP setting.
78
87
 
79
88
  This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
80
89
  older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
81
90
 
82
91
  Args:
83
92
  enabled (bool): Whether to enable automatic mixed precision.
84
- device (str, optional): The device to use for autocast. Defaults to 'cuda'.
93
+ device (str, optional): The device to use for autocast.
85
94
 
86
95
  Returns:
87
96
  (torch.amp.autocast): The appropriate autocast context manager.
88
97
 
89
- Notes:
90
- - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
91
- - For older versions, it uses `torch.cuda.autocast`.
92
-
93
98
  Examples:
94
99
  >>> with autocast(enabled=True):
95
100
  ... # Your mixed precision operations here
96
101
  ... pass
102
+
103
+ Notes:
104
+ - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
105
+ - For older versions, it uses `torch.cuda.autocast`.
97
106
  """
98
107
  if TORCH_1_13:
99
108
  return torch.amp.autocast(device, enabled=enabled)
@@ -101,52 +110,42 @@ def autocast(enabled: bool, device: str = "cuda"):
101
110
  return torch.cuda.amp.autocast(enabled)
102
111
 
103
112
 
113
+ @functools.lru_cache
104
114
  def get_cpu_info():
105
115
  """Return a string with system CPU information, i.e. 'Apple M2'."""
106
116
  from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
107
117
 
108
118
  if "cpu_info" not in PERSISTENT_CACHE:
109
119
  try:
110
- import cpuinfo # pip install py-cpuinfo
111
-
112
- k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
113
- info = cpuinfo.get_cpu_info() # info dict
114
- string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
115
- PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
120
+ PERSISTENT_CACHE["cpu_info"] = CPUInfo.name()
116
121
  except Exception:
117
122
  pass
118
123
  return PERSISTENT_CACHE.get("cpu_info", "unknown")
119
124
 
120
125
 
126
+ @functools.lru_cache
121
127
  def get_gpu_info(index):
122
128
  """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
123
129
  properties = torch.cuda.get_device_properties(index)
124
130
  return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
125
131
 
126
132
 
127
- def select_device(device="", batch=0, newline=False, verbose=True):
128
- """
129
- Select the appropriate PyTorch device based on the provided arguments.
133
+ def select_device(device="", newline=False, verbose=True):
134
+ """Select the appropriate PyTorch device based on the provided arguments.
130
135
 
131
136
  The function takes a string specifying the device or a torch.device object and returns a torch.device object
132
137
  representing the selected device. The function also validates the number of available devices and raises an
133
138
  exception if the requested device(s) are not available.
134
139
 
135
140
  Args:
136
- device (str | torch.device, optional): Device string or torch.device object.
137
- Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
138
- the first available GPU, or CPU if no GPU is available.
139
- batch (int, optional): Batch size being used in your model.
141
+ device (str | torch.device, optional): Device string or torch.device object. Options are 'None', 'cpu', or
142
+ 'cuda', or '0' or '0,1,2,3'. Auto-selects the first available GPU, or CPU if no GPU is available.
140
143
  newline (bool, optional): If True, adds a newline at the end of the log string.
141
144
  verbose (bool, optional): If True, logs the device information.
142
145
 
143
146
  Returns:
144
147
  (torch.device): Selected device.
145
148
 
146
- Raises:
147
- ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
148
- devices when using multiple GPUs.
149
-
150
149
  Examples:
151
150
  >>> select_device("cuda:0")
152
151
  device(type='cuda', index=0)
@@ -154,13 +153,13 @@ def select_device(device="", batch=0, newline=False, verbose=True):
154
153
  >>> select_device("cpu")
155
154
  device(type='cpu')
156
155
 
157
- Note:
156
+ Notes:
158
157
  Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
159
158
  """
160
159
  if isinstance(device, torch.device) or str(device).startswith(("tpu", "intel")):
161
160
  return device
162
161
 
163
- s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
162
+ s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{TORCH_VERSION} "
164
163
  device = str(device).lower()
165
164
  for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
166
165
  device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
@@ -171,7 +170,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
171
170
 
172
171
  # Replace each -1 with a selected GPU or remove it
173
172
  parts = device.split(",")
174
- selected = GPUInfo().select_idle_gpu(count=parts.count("-1"), min_memory_mb=2048)
173
+ selected = GPUInfo().select_idle_gpu(count=parts.count("-1"), min_memory_fraction=0.2)
175
174
  for i in range(len(parts)):
176
175
  if parts[i] == "-1":
177
176
  parts[i] = str(selected.pop(0)) if selected else ""
@@ -208,19 +207,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
208
207
 
209
208
  if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
210
209
  devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
211
- n = len(devices) # device count
212
- if n > 1: # multi-GPU
213
- if batch < 1:
214
- raise ValueError(
215
- "AutoBatch with batch<1 not supported for Multi-GPU training, "
216
- f"please specify a valid batch size multiple of GPU count {n}, i.e. batch={n * 8}."
217
- )
218
- if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
219
- raise ValueError(
220
- f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
221
- f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
222
- )
223
- space = " " * (len(s) + 1)
210
+ space = " " * len(s)
224
211
  for i, d in enumerate(devices):
225
212
  s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
226
213
  arg = "cuda:0"
@@ -240,89 +227,92 @@ def select_device(device="", batch=0, newline=False, verbose=True):
240
227
 
241
228
 
242
229
  def time_sync():
243
- """PyTorch-accurate time."""
230
+ """Return PyTorch-accurate time."""
244
231
  if torch.cuda.is_available():
245
232
  torch.cuda.synchronize()
246
233
  return time.time()
247
234
 
248
235
 
249
236
  def fuse_conv_and_bn(conv, bn):
250
- """Fuse Conv2d() and BatchNorm2d() layers."""
251
- fusedconv = (
252
- nn.Conv2d(
253
- conv.in_channels,
254
- conv.out_channels,
255
- kernel_size=conv.kernel_size,
256
- stride=conv.stride,
257
- padding=conv.padding,
258
- dilation=conv.dilation,
259
- groups=conv.groups,
260
- bias=True,
261
- )
262
- .requires_grad_(False)
263
- .to(conv.weight.device)
264
- )
237
+ """Fuse Conv2d and BatchNorm2d layers for inference optimization.
238
+
239
+ Args:
240
+ conv (nn.Conv2d): Convolutional layer to fuse.
241
+ bn (nn.BatchNorm2d): Batch normalization layer to fuse.
242
+
243
+ Returns:
244
+ (nn.Conv2d): The fused convolutional layer with gradients disabled.
265
245
 
266
- # Prepare filters
246
+ Examples:
247
+ >>> conv = nn.Conv2d(3, 16, 3)
248
+ >>> bn = nn.BatchNorm2d(16)
249
+ >>> fused_conv = fuse_conv_and_bn(conv, bn)
250
+ """
251
+ # Compute fused weights
267
252
  w_conv = conv.weight.view(conv.out_channels, -1)
268
253
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
269
- fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
254
+ conv.weight.data = torch.mm(w_bn, w_conv).view(conv.weight.shape)
270
255
 
271
- # Prepare spatial bias
272
- b_conv = (
273
- torch.zeros(conv.weight.shape[0], dtype=conv.weight.dtype, device=conv.weight.device)
274
- if conv.bias is None
275
- else conv.bias
276
- )
256
+ # Compute fused bias
257
+ b_conv = torch.zeros(conv.out_channels, device=conv.weight.device) if conv.bias is None else conv.bias
277
258
  b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
278
- fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
259
+ fused_bias = torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn
260
+
261
+ if conv.bias is None:
262
+ conv.register_parameter("bias", nn.Parameter(fused_bias))
263
+ else:
264
+ conv.bias.data = fused_bias
279
265
 
280
- return fusedconv
266
+ return conv.requires_grad_(False)
281
267
 
282
268
 
283
269
  def fuse_deconv_and_bn(deconv, bn):
284
- """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
285
- fuseddconv = (
286
- nn.ConvTranspose2d(
287
- deconv.in_channels,
288
- deconv.out_channels,
289
- kernel_size=deconv.kernel_size,
290
- stride=deconv.stride,
291
- padding=deconv.padding,
292
- output_padding=deconv.output_padding,
293
- dilation=deconv.dilation,
294
- groups=deconv.groups,
295
- bias=True,
296
- )
297
- .requires_grad_(False)
298
- .to(deconv.weight.device)
299
- )
270
+ """Fuse ConvTranspose2d and BatchNorm2d layers for inference optimization.
271
+
272
+ Args:
273
+ deconv (nn.ConvTranspose2d): Transposed convolutional layer to fuse.
274
+ bn (nn.BatchNorm2d): Batch normalization layer to fuse.
300
275
 
301
- # Prepare filters
276
+ Returns:
277
+ (nn.ConvTranspose2d): The fused transposed convolutional layer with gradients disabled.
278
+
279
+ Examples:
280
+ >>> deconv = nn.ConvTranspose2d(16, 3, 3)
281
+ >>> bn = nn.BatchNorm2d(3)
282
+ >>> fused_deconv = fuse_deconv_and_bn(deconv, bn)
283
+ """
284
+ # Compute fused weights
302
285
  w_deconv = deconv.weight.view(deconv.out_channels, -1)
303
286
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
304
- fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
287
+ deconv.weight.data = torch.mm(w_bn, w_deconv).view(deconv.weight.shape)
305
288
 
306
- # Prepare spatial bias
307
- b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
289
+ # Compute fused bias
290
+ b_conv = torch.zeros(deconv.out_channels, device=deconv.weight.device) if deconv.bias is None else deconv.bias
308
291
  b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
309
- fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
292
+ fused_bias = torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn
310
293
 
311
- return fuseddconv
294
+ if deconv.bias is None:
295
+ deconv.register_parameter("bias", nn.Parameter(fused_bias))
296
+ else:
297
+ deconv.bias.data = fused_bias
298
+
299
+ return deconv.requires_grad_(False)
312
300
 
313
301
 
314
302
  def model_info(model, detailed=False, verbose=True, imgsz=640):
315
- """
316
- Print and return detailed model information layer by layer.
303
+ """Print and return detailed model information layer by layer.
317
304
 
318
305
  Args:
319
306
  model (nn.Module): Model to analyze.
320
- detailed (bool, optional): Whether to print detailed layer information. Defaults to False.
321
- verbose (bool, optional): Whether to print model information. Defaults to True.
322
- imgsz (int | List, optional): Input image size. Defaults to 640.
307
+ detailed (bool, optional): Whether to print detailed layer information.
308
+ verbose (bool, optional): Whether to print model information.
309
+ imgsz (int | list, optional): Input image size.
323
310
 
324
311
  Returns:
325
- (Tuple[int, int, int, float]): Number of layers, parameters, gradients, and GFLOPs.
312
+ n_l (int): Number of layers.
313
+ n_p (int): Number of parameters.
314
+ n_g (int): Number of gradients.
315
+ flops (float): GFLOPs.
326
316
  """
327
317
  if not verbose:
328
318
  return
@@ -339,10 +329,10 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
339
329
  if len(m._parameters):
340
330
  for pn, p in m.named_parameters():
341
331
  LOGGER.info(
342
- f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
332
+ f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{list(p.shape)!s:>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
343
333
  )
344
334
  else: # layers with no learnable params
345
- LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}")
335
+ LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{[]!s:>20}{'-':>10}{'-':>10}{'-':>15}")
346
336
 
347
337
  flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
348
338
  fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
@@ -364,8 +354,7 @@ def get_num_gradients(model):
364
354
 
365
355
 
366
356
  def model_info_for_loggers(trainer):
367
- """
368
- Return model info dict with useful model information.
357
+ """Return model info dict with useful model information.
369
358
 
370
359
  Args:
371
360
  trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.
@@ -398,16 +387,14 @@ def model_info_for_loggers(trainer):
398
387
 
399
388
 
400
389
  def get_flops(model, imgsz=640):
401
- """
402
- Calculate FLOPs (floating point operations) for a model in billions.
390
+ """Calculate FLOPs (floating point operations) for a model in billions.
403
391
 
404
- Attempts two calculation methods: first with a stride-based tensor for efficiency,
405
- then falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0
406
- if thop library is unavailable or calculation fails.
392
+ Attempts two calculation methods: first with a stride-based tensor for efficiency, then falls back to full image
393
+ size if needed (e.g., for RTDETR models). Returns 0.0 if thop library is unavailable or calculation fails.
407
394
 
408
395
  Args:
409
396
  model (nn.Module): The model to calculate FLOPs for.
410
- imgsz (int | List[int], optional): Input image size. Defaults to 640.
397
+ imgsz (int | list, optional): Input image size.
411
398
 
412
399
  Returns:
413
400
  (float): The model FLOPs in billions.
@@ -421,7 +408,7 @@ def get_flops(model, imgsz=640):
421
408
  return 0.0 # if not installed return 0.0 GFLOPs
422
409
 
423
410
  try:
424
- model = de_parallel(model)
411
+ model = unwrap_model(model)
425
412
  p = next(model.parameters())
426
413
  if not isinstance(imgsz, list):
427
414
  imgsz = [imgsz, imgsz] # expand if int/float
@@ -440,19 +427,18 @@ def get_flops(model, imgsz=640):
440
427
 
441
428
 
442
429
  def get_flops_with_torch_profiler(model, imgsz=640):
443
- """
444
- Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
430
+ """Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
445
431
 
446
432
  Args:
447
433
  model (nn.Module): The model to calculate FLOPs for.
448
- imgsz (int | List[int], optional): Input image size. Defaults to 640.
434
+ imgsz (int | list, optional): Input image size.
449
435
 
450
436
  Returns:
451
437
  (float): The model's FLOPs in billions.
452
438
  """
453
439
  if not TORCH_2_0: # torch profiler implemented in torch>=2.0
454
440
  return 0.0
455
- model = de_parallel(model)
441
+ model = unwrap_model(model)
456
442
  p = next(model.parameters())
457
443
  if not isinstance(imgsz, list):
458
444
  imgsz = [imgsz, imgsz] # expand if int/float
@@ -487,14 +473,13 @@ def initialize_weights(model):
487
473
 
488
474
 
489
475
  def scale_img(img, ratio=1.0, same_shape=False, gs=32):
490
- """
491
- Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
476
+ """Scale and pad an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
492
477
 
493
478
  Args:
494
479
  img (torch.Tensor): Input image tensor.
495
- ratio (float, optional): Scaling ratio. Defaults to 1.0.
496
- same_shape (bool, optional): Whether to maintain the same shape. Defaults to False.
497
- gs (int, optional): Grid size for padding. Defaults to 32.
480
+ ratio (float, optional): Scaling ratio.
481
+ same_shape (bool, optional): Whether to maintain the same shape.
482
+ gs (int, optional): Grid size for padding.
498
483
 
499
484
  Returns:
500
485
  (torch.Tensor): Scaled and padded image tensor.
@@ -510,14 +495,13 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
510
495
 
511
496
 
512
497
  def copy_attr(a, b, include=(), exclude=()):
513
- """
514
- Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
498
+ """Copy attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
515
499
 
516
500
  Args:
517
- a (object): Destination object to copy attributes to.
518
- b (object): Source object to copy attributes from.
519
- include (tuple, optional): Attributes to include. If empty, all attributes are included. Defaults to ().
520
- exclude (tuple, optional): Attributes to exclude. Defaults to ().
501
+ a (Any): Destination object to copy attributes to.
502
+ b (Any): Source object to copy attributes from.
503
+ include (tuple, optional): Attributes to include. If empty, all attributes are included.
504
+ exclude (tuple, optional): Attributes to exclude.
521
505
  """
522
506
  for k, v in b.__dict__.items():
523
507
  if (len(include) and k not in include) or k.startswith("_") or k in exclude:
@@ -526,29 +510,13 @@ def copy_attr(a, b, include=(), exclude=()):
526
510
  setattr(a, k, v)
527
511
 
528
512
 
529
- def get_latest_opset():
530
- """
531
- Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.
532
-
533
- Returns:
534
- (int): The ONNX opset version.
535
- """
536
- if TORCH_1_13:
537
- # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
538
- return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
539
- # Otherwise for PyTorch<=1.12 return the corresponding predefined opset
540
- version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
541
- return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
542
-
543
-
544
513
  def intersect_dicts(da, db, exclude=()):
545
- """
546
- Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
514
+ """Return a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
547
515
 
548
516
  Args:
549
517
  da (dict): First dictionary.
550
518
  db (dict): Second dictionary.
551
- exclude (tuple, optional): Keys to exclude. Defaults to ().
519
+ exclude (tuple, optional): Keys to exclude.
552
520
 
553
521
  Returns:
554
522
  (dict): Dictionary of intersecting keys with matching shapes.
@@ -557,8 +525,7 @@ def intersect_dicts(da, db, exclude=()):
557
525
 
558
526
 
559
527
  def is_parallel(model):
560
- """
561
- Returns True if model is of type DP or DDP.
528
+ """Return True if model is of type DP or DDP.
562
529
 
563
530
  Args:
564
531
  model (nn.Module): Model to check.
@@ -569,27 +536,32 @@ def is_parallel(model):
569
536
  return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
570
537
 
571
538
 
572
- def de_parallel(model):
573
- """
574
- De-parallelize a model: returns single-GPU model if model is of type DP or DDP.
539
+ def unwrap_model(m: nn.Module) -> nn.Module:
540
+ """Unwrap compiled and parallel models to get the base model.
575
541
 
576
542
  Args:
577
- model (nn.Module): Model to de-parallelize.
543
+ m (nn.Module): A model that may be wrapped by torch.compile (._orig_mod) or parallel wrappers such as
544
+ DataParallel/DistributedDataParallel (.module).
578
545
 
579
546
  Returns:
580
- (nn.Module): De-parallelized model.
547
+ m (nn.Module): The unwrapped base model without compile or parallel wrappers.
581
548
  """
582
- return model.module if is_parallel(model) else model
549
+ while True:
550
+ if hasattr(m, "_orig_mod") and isinstance(m._orig_mod, nn.Module):
551
+ m = m._orig_mod
552
+ elif hasattr(m, "module") and isinstance(m.module, nn.Module):
553
+ m = m.module
554
+ else:
555
+ return m
583
556
 
584
557
 
585
558
  def one_cycle(y1=0.0, y2=1.0, steps=100):
586
- """
587
- Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
559
+ """Return a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
588
560
 
589
561
  Args:
590
- y1 (float, optional): Initial value. Defaults to 0.0.
591
- y2 (float, optional): Final value. Defaults to 1.0.
592
- steps (int, optional): Number of steps. Defaults to 100.
562
+ y1 (float, optional): Initial value.
563
+ y2 (float, optional): Final value.
564
+ steps (int, optional): Number of steps.
593
565
 
594
566
  Returns:
595
567
  (function): Lambda function for computing the sinusoidal ramp.
@@ -598,12 +570,11 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
598
570
 
599
571
 
600
572
  def init_seeds(seed=0, deterministic=False):
601
- """
602
- Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
573
+ """Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
603
574
 
604
575
  Args:
605
- seed (int, optional): Random seed. Defaults to 0.
606
- deterministic (bool, optional): Whether to set deterministic algorithms. Defaults to False.
576
+ seed (int, optional): Random seed.
577
+ deterministic (bool, optional): Whether to set deterministic algorithms.
607
578
  """
608
579
  random.seed(seed)
609
580
  np.random.seed(seed)
@@ -624,7 +595,7 @@ def init_seeds(seed=0, deterministic=False):
624
595
 
625
596
 
626
597
  def unset_deterministic():
627
- """Unsets all the configurations applied for deterministic training."""
598
+ """Unset all the configurations applied for deterministic training."""
628
599
  torch.use_deterministic_algorithms(False)
629
600
  torch.backends.cudnn.deterministic = False
630
601
  os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
@@ -632,11 +603,10 @@ def unset_deterministic():
632
603
 
633
604
 
634
605
  class ModelEMA:
635
- """
636
- Updated Exponential Moving Average (EMA) implementation.
606
+ """Updated Exponential Moving Average (EMA) implementation.
637
607
 
638
- Keeps a moving average of everything in the model state_dict (parameters and buffers).
639
- For EMA details see References.
608
+ Keeps a moving average of everything in the model state_dict (parameters and buffers). For EMA details see
609
+ References.
640
610
 
641
611
  To disable EMA set the `enabled` attribute to `False`.
642
612
 
@@ -652,16 +622,15 @@ class ModelEMA:
652
622
  """
653
623
 
654
624
  def __init__(self, model, decay=0.9999, tau=2000, updates=0):
655
- """
656
- Initialize EMA for 'model' with given arguments.
625
+ """Initialize EMA for 'model' with given arguments.
657
626
 
658
627
  Args:
659
628
  model (nn.Module): Model to create EMA for.
660
- decay (float, optional): Maximum EMA decay rate. Defaults to 0.9999.
661
- tau (int, optional): EMA decay time constant. Defaults to 2000.
662
- updates (int, optional): Initial number of updates. Defaults to 0.
629
+ decay (float, optional): Maximum EMA decay rate.
630
+ tau (int, optional): EMA decay time constant.
631
+ updates (int, optional): Initial number of updates.
663
632
  """
664
- self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
633
+ self.ema = deepcopy(unwrap_model(model)).eval() # FP32 EMA
665
634
  self.updates = updates # number of EMA updates
666
635
  self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
667
636
  for p in self.ema.parameters():
@@ -669,8 +638,7 @@ class ModelEMA:
669
638
  self.enabled = True
670
639
 
671
640
  def update(self, model):
672
- """
673
- Update EMA parameters.
641
+ """Update EMA parameters.
674
642
 
675
643
  Args:
676
644
  model (nn.Module): Model to update EMA from.
@@ -679,7 +647,7 @@ class ModelEMA:
679
647
  self.updates += 1
680
648
  d = self.decay(self.updates)
681
649
 
682
- msd = de_parallel(model).state_dict() # model state_dict
650
+ msd = unwrap_model(model).state_dict() # model state_dict
683
651
  for k, v in self.ema.state_dict().items():
684
652
  if v.dtype.is_floating_point: # true for FP16 and FP32
685
653
  v *= d
@@ -687,25 +655,24 @@ class ModelEMA:
687
655
  # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
688
656
 
689
657
  def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
690
- """
691
- Updates attributes and saves stripped model with optimizer removed.
658
+ """Update attributes and save stripped model with optimizer removed.
692
659
 
693
660
  Args:
694
661
  model (nn.Module): Model to update attributes from.
695
- include (tuple, optional): Attributes to include. Defaults to ().
696
- exclude (tuple, optional): Attributes to exclude. Defaults to ("process_group", "reducer").
662
+ include (tuple, optional): Attributes to include.
663
+ exclude (tuple, optional): Attributes to exclude.
697
664
  """
698
665
  if self.enabled:
699
666
  copy_attr(self.ema, model, include, exclude)
700
667
 
701
668
 
702
- def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
703
- """
704
- Strip optimizer from 'f' to finalize training, optionally save as 's'.
669
+ def strip_optimizer(f: str | Path = "best.pt", s: str = "", updates: dict[str, Any] | None = None) -> dict[str, Any]:
670
+ """Strip optimizer from 'f' to finalize training, optionally save as 's'.
705
671
 
706
672
  Args:
707
- f (str | Path): File path to model to strip the optimizer from. Defaults to 'best.pt'.
708
- s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
673
+ f (str | Path): File path to model to strip the optimizer from.
674
+ s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be
675
+ overwritten.
709
676
  updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.
710
677
 
711
678
  Returns:
@@ -718,7 +685,7 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
718
685
  >>> strip_optimizer(f)
719
686
  """
720
687
  try:
721
- x = torch.load(f, map_location=torch.device("cpu"))
688
+ x = torch_load(f, map_location=torch.device("cpu"))
722
689
  assert isinstance(x, dict), "checkpoint is not a Python dictionary"
723
690
  assert "model" in x, "'model' missing from checkpoint"
724
691
  except Exception as e:
@@ -745,7 +712,7 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
745
712
 
746
713
  # Update other keys
747
714
  args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
748
- for k in "optimizer", "best_fitness", "ema", "updates": # keys
715
+ for k in "optimizer", "best_fitness", "ema", "updates", "scaler": # keys
749
716
  x[k] = None
750
717
  x["epoch"] = -1
751
718
  x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
@@ -760,8 +727,7 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
760
727
 
761
728
 
762
729
  def convert_optimizer_state_dict_to_fp16(state_dict):
763
- """
764
- Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
730
+ """Convert the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
765
731
 
766
732
  Args:
767
733
  state_dict (dict): Optimizer state dictionary.
@@ -779,15 +745,14 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
779
745
 
780
746
  @contextmanager
781
747
  def cuda_memory_usage(device=None):
782
- """
783
- Monitor and manage CUDA memory usage.
748
+ """Monitor and manage CUDA memory usage.
784
749
 
785
- This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
786
- It then yields a dictionary containing memory usage information, which can be updated by the caller.
787
- Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
750
+ This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory. It then
751
+ yields a dictionary containing memory usage information, which can be updated by the caller. Finally, it updates the
752
+ dictionary with the amount of memory reserved by CUDA on the specified device.
788
753
 
789
754
  Args:
790
- device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
755
+ device (torch.device, optional): The CUDA device to query memory usage for.
791
756
 
792
757
  Yields:
793
758
  (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
@@ -804,15 +769,14 @@ def cuda_memory_usage(device=None):
804
769
 
805
770
 
806
771
  def profile_ops(input, ops, n=10, device=None, max_num_obj=0):
807
- """
808
- Ultralytics speed, memory and FLOPs profiler.
772
+ """Ultralytics speed, memory and FLOPs profiler.
809
773
 
810
774
  Args:
811
- input (torch.Tensor | List[torch.Tensor]): Input tensor(s) to profile.
812
- ops (nn.Module | List[nn.Module]): Model or list of operations to profile.
813
- n (int, optional): Number of iterations to average. Defaults to 10.
814
- device (str | torch.device, optional): Device to profile on. Defaults to None.
815
- max_num_obj (int, optional): Maximum number of objects for simulation. Defaults to 0.
775
+ input (torch.Tensor | list): Input tensor(s) to profile.
776
+ ops (nn.Module | list): Model or list of operations to profile.
777
+ n (int, optional): Number of iterations to average.
778
+ device (str | torch.device, optional): Device to profile on.
779
+ max_num_obj (int, optional): Maximum number of objects for simulation.
816
780
 
817
781
  Returns:
818
782
  (list): Profile results for each operation.
@@ -878,7 +842,7 @@ def profile_ops(input, ops, n=10, device=None, max_num_obj=0):
878
842
  mem += cuda_info["memory"] / 1e9 # (GB)
879
843
  s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
880
844
  p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
881
- LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
845
+ LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{s_in!s:>24s}{s_out!s:>24s}")
882
846
  results.append([p, flops, mem, tf, tb, s_in, s_out])
883
847
  except Exception as e:
884
848
  LOGGER.info(e)
@@ -890,8 +854,7 @@ def profile_ops(input, ops, n=10, device=None, max_num_obj=0):
890
854
 
891
855
 
892
856
  class EarlyStopping:
893
- """
894
- Early stopping class that stops training when a specified number of epochs have passed without improvement.
857
+ """Early stopping class that stops training when a specified number of epochs have passed without improvement.
895
858
 
896
859
  Attributes:
897
860
  best_fitness (float): Best fitness value observed.
@@ -901,8 +864,7 @@ class EarlyStopping:
901
864
  """
902
865
 
903
866
  def __init__(self, patience=50):
904
- """
905
- Initialize early stopping object.
867
+ """Initialize early stopping object.
906
868
 
907
869
  Args:
908
870
  patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
@@ -913,8 +875,7 @@ class EarlyStopping:
913
875
  self.possible_stop = False # possible stop may occur next epoch
914
876
 
915
877
  def __call__(self, epoch, fitness):
916
- """
917
- Check whether to stop training.
878
+ """Check whether to stop training.
918
879
 
919
880
  Args:
920
881
  epoch (int): Current epoch of training
@@ -943,48 +904,80 @@ class EarlyStopping:
943
904
  return stop
944
905
 
945
906
 
946
- class FXModel(nn.Module):
947
- """
948
- A custom model class for torch.fx compatibility.
949
-
950
- This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
951
- manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
952
- copying.
953
-
954
- Attributes:
955
- model (nn.Module): The original model's layers.
956
- """
907
+ def attempt_compile(
908
+ model: torch.nn.Module,
909
+ device: torch.device,
910
+ imgsz: int = 640,
911
+ use_autocast: bool = False,
912
+ warmup: bool = False,
913
+ mode: bool | str = "default",
914
+ ) -> torch.nn.Module:
915
+ """Compile a model with torch.compile and optionally warm up the graph to reduce first-iteration latency.
957
916
 
958
- def __init__(self, model):
959
- """
960
- Initialize the FXModel.
917
+ This utility attempts to compile the provided model using the inductor backend with dynamic shapes enabled and an
918
+ autotuning mode. If compilation is unavailable or fails, the original model is returned unchanged. An optional
919
+ warmup performs a single forward pass on a dummy input to prime the compiled graph and measure compile/warmup time.
961
920
 
962
- Args:
963
- model (nn.Module): The original model to wrap for torch.fx compatibility.
964
- """
965
- super().__init__()
966
- copy_attr(self, model)
967
- # Explicitly set `model` since `copy_attr` somehow does not copy it.
968
- self.model = model.model
969
-
970
- def forward(self, x):
971
- """
972
- Forward pass through the model.
921
+ Args:
922
+ model (torch.nn.Module): Model to compile.
923
+ device (torch.device): Inference device used for warmup and autocast decisions.
924
+ imgsz (int, optional): Square input size to create a dummy tensor with shape (1, 3, imgsz, imgsz) for warmup.
925
+ use_autocast (bool, optional): Whether to run warmup under autocast on CUDA or MPS devices.
926
+ warmup (bool, optional): Whether to execute a single dummy forward pass to warm up the compiled model.
927
+ mode (bool | str, optional): torch.compile mode. True → "default", False → no compile, or a string like
928
+ "default", "reduce-overhead", "max-autotune-no-cudagraphs".
973
929
 
974
- This method performs the forward pass through the model, handling the dependencies between layers and saving
975
- intermediate outputs.
930
+ Returns:
931
+ model (torch.nn.Module): Compiled model if compilation succeeds, otherwise the original unmodified model.
976
932
 
977
- Args:
978
- x (torch.Tensor): The input tensor to the model.
933
+ Examples:
934
+ >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
935
+ >>> # Try to compile and warm up a model with a 640x640 input
936
+ >>> model = attempt_compile(model, device=device, imgsz=640, use_autocast=True, warmup=True)
979
937
 
980
- Returns:
981
- (torch.Tensor): The output tensor from the model.
982
- """
983
- y = [] # outputs
984
- for m in self.model:
985
- if m.f != -1: # if not from previous layer
986
- # from earlier layers
987
- x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
988
- x = m(x) # run
989
- y.append(x) # save output
990
- return x
938
+ Notes:
939
+ - If the current PyTorch build does not provide torch.compile, the function returns the input model immediately.
940
+ - Warmup runs under torch.inference_mode and may use torch.autocast for CUDA/MPS to align compute precision.
941
+ - CUDA devices are synchronized after warmup to account for asynchronous kernel execution.
942
+ """
943
+ if not hasattr(torch, "compile") or not mode:
944
+ return model
945
+
946
+ if mode is True:
947
+ mode = "default"
948
+ prefix = colorstr("compile:")
949
+ LOGGER.info(f"{prefix} starting torch.compile with '{mode}' mode...")
950
+ if mode == "max-autotune":
951
+ LOGGER.warning(f"{prefix} mode='{mode}' not recommended, using mode='max-autotune-no-cudagraphs' instead")
952
+ mode = "max-autotune-no-cudagraphs"
953
+ t0 = time.perf_counter()
954
+ try:
955
+ model = torch.compile(model, mode=mode, backend="inductor")
956
+ except Exception as e:
957
+ LOGGER.warning(f"{prefix} torch.compile failed, continuing uncompiled: {e}")
958
+ return model
959
+ t_compile = time.perf_counter() - t0
960
+
961
+ t_warm = 0.0
962
+ if warmup:
963
+ # Use a single dummy tensor to build the graph shape state and reduce first-iteration latency
964
+ dummy = torch.zeros(1, 3, imgsz, imgsz, device=device)
965
+ if use_autocast and device.type == "cuda":
966
+ dummy = dummy.half()
967
+ t1 = time.perf_counter()
968
+ with torch.inference_mode():
969
+ if use_autocast and device.type in {"cuda", "mps"}:
970
+ with torch.autocast(device.type):
971
+ _ = model(dummy)
972
+ else:
973
+ _ = model(dummy)
974
+ if device.type == "cuda":
975
+ torch.cuda.synchronize(device)
976
+ t_warm = time.perf_counter() - t1
977
+
978
+ total = t_compile + t_warm
979
+ if warmup:
980
+ LOGGER.info(f"{prefix} complete in {total:.1f}s (compile {t_compile:.1f}s + warmup {t_warm:.1f}s)")
981
+ else:
982
+ LOGGER.info(f"{prefix} compile complete in {t_compile:.1f}s (no warmup)")
983
+ return model