ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +37 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +111 -41
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +579 -244
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +191 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +226 -82
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +172 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +305 -112
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.63.dist-info/METADATA +370 -0
  235. ultralytics-8.3.63.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
@@ -1,46 +1,62 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ import gc
3
4
  import math
4
5
  import os
5
6
  import random
6
7
  import time
7
8
  from contextlib import contextmanager
8
9
  from copy import deepcopy
10
+ from datetime import datetime
9
11
  from pathlib import Path
10
12
  from typing import Union
11
13
 
12
14
  import numpy as np
15
+ import thop
13
16
  import torch
14
17
  import torch.distributed as dist
15
18
  import torch.nn as nn
16
19
  import torch.nn.functional as F
17
- import torchvision
18
-
19
- from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
20
- from ultralytics.utils.checks import PYTHON_VERSION, check_version
21
-
22
- try:
23
- import thop
24
- except ImportError:
25
- thop = None
26
20
 
21
+ from ultralytics.utils import (
22
+ DEFAULT_CFG_DICT,
23
+ DEFAULT_CFG_KEYS,
24
+ LOGGER,
25
+ NUM_THREADS,
26
+ PYTHON_VERSION,
27
+ TORCHVISION_VERSION,
28
+ WINDOWS,
29
+ __version__,
30
+ colorstr,
31
+ )
32
+ from ultralytics.utils.checks import check_version
33
+
34
+ # Version checks (all default to version>=min_version)
27
35
  TORCH_1_9 = check_version(torch.__version__, "1.9.0")
28
36
  TORCH_1_13 = check_version(torch.__version__, "1.13.0")
29
37
  TORCH_2_0 = check_version(torch.__version__, "2.0.0")
30
- TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
31
- TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
32
- TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
38
+ TORCH_2_4 = check_version(torch.__version__, "2.4.0")
39
+ TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
40
+ TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
41
+ TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
42
+ TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
43
+ if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
44
+ LOGGER.warning(
45
+ "WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
46
+ "https://github.com/ultralytics/ultralytics/issues/15049"
47
+ )
33
48
 
34
49
 
35
50
  @contextmanager
36
51
  def torch_distributed_zero_first(local_rank: int):
37
- """Decorator to make all processes in distributed training wait for each local_master to do something."""
38
- initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
39
- if initialized and local_rank not in (-1, 0):
52
+ """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
53
+ initialized = dist.is_available() and dist.is_initialized()
54
+
55
+ if initialized and local_rank not in {-1, 0}:
40
56
  dist.barrier(device_ids=[local_rank])
41
57
  yield
42
58
  if initialized and local_rank == 0:
43
- dist.barrier(device_ids=[0])
59
+ dist.barrier(device_ids=[local_rank])
44
60
 
45
61
 
46
62
  def smart_inference_mode():
@@ -56,14 +72,58 @@ def smart_inference_mode():
56
72
  return decorate
57
73
 
58
74
 
75
+ def autocast(enabled: bool, device: str = "cuda"):
76
+ """
77
+ Get the appropriate autocast context manager based on PyTorch version and AMP setting.
78
+
79
+ This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
80
+ older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
81
+
82
+ Args:
83
+ enabled (bool): Whether to enable automatic mixed precision.
84
+ device (str, optional): The device to use for autocast. Defaults to 'cuda'.
85
+
86
+ Returns:
87
+ (torch.amp.autocast): The appropriate autocast context manager.
88
+
89
+ Note:
90
+ - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
91
+ - For older versions, it uses `torch.cuda.autocast`.
92
+
93
+ Example:
94
+ ```python
95
+ with autocast(amp=True):
96
+ # Your mixed precision operations here
97
+ pass
98
+ ```
99
+ """
100
+ if TORCH_1_13:
101
+ return torch.amp.autocast(device, enabled=enabled)
102
+ else:
103
+ return torch.cuda.amp.autocast(enabled)
104
+
105
+
59
106
  def get_cpu_info():
60
107
  """Return a string with system CPU information, i.e. 'Apple M2'."""
61
- import cpuinfo # pip install py-cpuinfo
108
+ from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
109
+
110
+ if "cpu_info" not in PERSISTENT_CACHE:
111
+ try:
112
+ import cpuinfo # pip install py-cpuinfo
113
+
114
+ k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
115
+ info = cpuinfo.get_cpu_info() # info dict
116
+ string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
117
+ PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
118
+ except Exception:
119
+ pass
120
+ return PERSISTENT_CACHE.get("cpu_info", "unknown")
121
+
62
122
 
63
- k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
64
- info = cpuinfo.get_cpu_info() # info dict
65
- string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
66
- return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
123
+ def get_gpu_info(index):
124
+ """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
125
+ properties = torch.cuda.get_device_properties(index)
126
+ return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
67
127
 
68
128
 
69
129
  def select_device(device="", batch=0, newline=False, verbose=True):
@@ -90,30 +150,31 @@ def select_device(device="", batch=0, newline=False, verbose=True):
90
150
  devices when using multiple GPUs.
91
151
 
92
152
  Examples:
93
- >>> select_device('cuda:0')
153
+ >>> select_device("cuda:0")
94
154
  device(type='cuda', index=0)
95
155
 
96
- >>> select_device('cpu')
156
+ >>> select_device("cpu")
97
157
  device(type='cpu')
98
158
 
99
159
  Note:
100
160
  Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
101
161
  """
102
-
103
- if isinstance(device, torch.device):
162
+ if isinstance(device, torch.device) or str(device).startswith("tpu"):
104
163
  return device
105
164
 
106
- s = f"Ultralytics YOLOv{__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
165
+ s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
107
166
  device = str(device).lower()
108
167
  for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
109
168
  device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
110
169
  cpu = device == "cpu"
111
- mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
170
+ mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
112
171
  if cpu or mps:
113
172
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
114
173
  elif device: # non-cpu device requested
115
174
  if device == "cuda":
116
175
  device = "0"
176
+ if "," in device:
177
+ device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
117
178
  visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
118
179
  os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
119
180
  if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
@@ -135,17 +196,22 @@ def select_device(device="", batch=0, newline=False, verbose=True):
135
196
  )
136
197
 
137
198
  if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
138
- devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
199
+ devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
139
200
  n = len(devices) # device count
140
- if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
141
- raise ValueError(
142
- f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
143
- f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
144
- )
201
+ if n > 1: # multi-GPU
202
+ if batch < 1:
203
+ raise ValueError(
204
+ "AutoBatch with batch<1 not supported for Multi-GPU training, "
205
+ "please specify a valid batch size, i.e. batch=16."
206
+ )
207
+ if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
208
+ raise ValueError(
209
+ f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
210
+ f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
211
+ )
145
212
  space = " " * (len(s) + 1)
146
213
  for i, d in enumerate(devices):
147
- p = torch.cuda.get_device_properties(i)
148
- s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
214
+ s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
149
215
  arg = "cuda:0"
150
216
  elif mps and TORCH_2_0 and torch.backends.mps.is_available():
151
217
  # Prefer MPS if available
@@ -155,6 +221,8 @@ def select_device(device="", batch=0, newline=False, verbose=True):
155
221
  s += f"CPU ({get_cpu_info()})\n"
156
222
  arg = "cpu"
157
223
 
224
+ if arg in {"cpu", "mps"}:
225
+ torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
158
226
  if verbose:
159
227
  LOGGER.info(s if newline else s.rstrip())
160
228
  return torch.device(arg)
@@ -185,7 +253,7 @@ def fuse_conv_and_bn(conv, bn):
185
253
  )
186
254
 
187
255
  # Prepare filters
188
- w_conv = conv.weight.clone().view(conv.out_channels, -1)
256
+ w_conv = conv.weight.view(conv.out_channels, -1)
189
257
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
190
258
  fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
191
259
 
@@ -216,7 +284,7 @@ def fuse_deconv_and_bn(deconv, bn):
216
284
  )
217
285
 
218
286
  # Prepare filters
219
- w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
287
+ w_deconv = deconv.weight.view(deconv.out_channels, -1)
220
288
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
221
289
  fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
222
290
 
@@ -229,33 +297,27 @@ def fuse_deconv_and_bn(deconv, bn):
229
297
 
230
298
 
231
299
  def model_info(model, detailed=False, verbose=True, imgsz=640):
232
- """
233
- Model information.
234
-
235
- imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
236
- """
300
+ """Print and return detailed model information layer by layer."""
237
301
  if not verbose:
238
302
  return
239
303
  n_p = get_num_params(model) # number of parameters
240
304
  n_g = get_num_gradients(model) # number of gradients
241
305
  n_l = len(list(model.modules())) # number of layers
242
306
  if detailed:
243
- LOGGER.info(
244
- f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
245
- )
307
+ LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}")
246
308
  for i, (name, p) in enumerate(model.named_parameters()):
247
309
  name = name.replace("module_list.", "")
248
310
  LOGGER.info(
249
- "%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
250
- % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
311
+ f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}"
312
+ f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}"
251
313
  )
252
314
 
253
- flops = get_flops(model, imgsz)
315
+ flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
254
316
  fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
255
317
  fs = f", {flops:.1f} GFLOPs" if flops else ""
256
318
  yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
257
319
  model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
258
- LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
320
+ LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
259
321
  return n_l, n_p, n_g, flops
260
322
 
261
323
 
@@ -276,11 +338,13 @@ def model_info_for_loggers(trainer):
276
338
  Example:
277
339
  YOLOv8n info for loggers
278
340
  ```python
279
- results = {'model/parameters': 3151904,
280
- 'model/GFLOPs': 8.746,
281
- 'model/speed_ONNX(ms)': 41.244,
282
- 'model/speed_TensorRT(ms)': 3.211,
283
- 'model/speed_PyTorch(ms)': 18.755}
341
+ results = {
342
+ "model/parameters": 3151904,
343
+ "model/GFLOPs": 8.746,
344
+ "model/speed_ONNX(ms)": 41.244,
345
+ "model/speed_TensorRT(ms)": 3.211,
346
+ "model/speed_PyTorch(ms)": 18.755,
347
+ }
284
348
  ```
285
349
  """
286
350
  if trainer.args.profile: # profile ONNX and TensorRT times
@@ -299,9 +363,6 @@ def model_info_for_loggers(trainer):
299
363
 
300
364
  def get_flops(model, imgsz=640):
301
365
  """Return a YOLO model's FLOPs."""
302
- if not thop:
303
- return 0.0 # if not installed return 0.0 GFLOPs
304
-
305
366
  try:
306
367
  model = de_parallel(model)
307
368
  p = next(model.parameters())
@@ -322,19 +383,28 @@ def get_flops(model, imgsz=640):
322
383
 
323
384
 
324
385
  def get_flops_with_torch_profiler(model, imgsz=640):
325
- """Compute model FLOPs (thop alternative)."""
326
- if TORCH_2_0:
327
- model = de_parallel(model)
328
- p = next(model.parameters())
386
+ """Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
387
+ if not TORCH_2_0: # torch profiler implemented in torch>=2.0
388
+ return 0.0
389
+ model = de_parallel(model)
390
+ p = next(model.parameters())
391
+ if not isinstance(imgsz, list):
392
+ imgsz = [imgsz, imgsz] # expand if int/float
393
+ try:
394
+ # Use stride size for input tensor
329
395
  stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
330
- im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
396
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
331
397
  with torch.profiler.profile(with_flops=True) as prof:
332
398
  model(im)
333
399
  flops = sum(x.flops for x in prof.key_averages()) / 1e9
334
- imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
335
400
  flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
336
- return flops
337
- return 0
401
+ except Exception:
402
+ # Use actual image size for input tensor (i.e. required for RTDETR models)
403
+ im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
404
+ with torch.profiler.profile(with_flops=True) as prof:
405
+ model(im)
406
+ flops = sum(x.flops for x in prof.key_averages()) / 1e9
407
+ return flops
338
408
 
339
409
 
340
410
  def initialize_weights(model):
@@ -346,14 +416,12 @@ def initialize_weights(model):
346
416
  elif t is nn.BatchNorm2d:
347
417
  m.eps = 1e-3
348
418
  m.momentum = 0.03
349
- elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
419
+ elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
350
420
  m.inplace = True
351
421
 
352
422
 
353
423
  def scale_img(img, ratio=1.0, same_shape=False, gs=32):
354
- """Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
355
- retaining the original shape.
356
- """
424
+ """Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
357
425
  if ratio == 1.0:
358
426
  return img
359
427
  h, w = img.shape[2:]
@@ -364,13 +432,6 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
364
432
  return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
365
433
 
366
434
 
367
- def make_divisible(x, divisor):
368
- """Returns nearest x divisible by divisor."""
369
- if isinstance(divisor, torch.Tensor):
370
- divisor = int(divisor.max()) # to int
371
- return math.ceil(x / divisor) * divisor
372
-
373
-
374
435
  def copy_attr(a, b, include=(), exclude=()):
375
436
  """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
376
437
  for k, v in b.__dict__.items():
@@ -381,8 +442,13 @@ def copy_attr(a, b, include=(), exclude=()):
381
442
 
382
443
 
383
444
  def get_latest_opset():
384
- """Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
385
- return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
445
+ """Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
446
+ if TORCH_1_13:
447
+ # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
448
+ return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
449
+ # Otherwise for PyTorch<=1.12 return the corresponding predefined opset
450
+ version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
451
+ return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
386
452
 
387
453
 
388
454
  def intersect_dicts(da, db, exclude=()):
@@ -427,14 +493,17 @@ def init_seeds(seed=0, deterministic=False):
427
493
 
428
494
 
429
495
  class ModelEMA:
430
- """Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
431
- Keeps a moving average of everything in the model state_dict (parameters and buffers)
496
+ """
497
+ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
498
+ average of everything in the model state_dict (parameters and buffers).
499
+
432
500
  For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
501
+
433
502
  To disable EMA set the `enabled` attribute to `False`.
434
503
  """
435
504
 
436
505
  def __init__(self, model, decay=0.9999, tau=2000, updates=0):
437
- """Create EMA."""
506
+ """Initialize EMA for 'model' with given arguments."""
438
507
  self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
439
508
  self.updates = updates # number of EMA updates
440
509
  self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
@@ -461,50 +530,113 @@ class ModelEMA:
461
530
  copy_attr(self.ema, model, include, exclude)
462
531
 
463
532
 
464
- def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
533
+ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
465
534
  """
466
535
  Strip optimizer from 'f' to finalize training, optionally save as 's'.
467
536
 
468
537
  Args:
469
538
  f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
470
539
  s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
540
+ updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
471
541
 
472
542
  Returns:
473
- None
543
+ (dict): The combined checkpoint dictionary.
474
544
 
475
545
  Example:
476
546
  ```python
477
547
  from pathlib import Path
478
548
  from ultralytics.utils.torch_utils import strip_optimizer
479
549
 
480
- for f in Path('path/to/weights').rglob('*.pt'):
550
+ for f in Path("path/to/model/checkpoints").rglob("*.pt"):
481
551
  strip_optimizer(f)
482
552
  ```
483
- """
484
- x = torch.load(f, map_location=torch.device("cpu"))
485
- if "model" not in x:
486
- LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
487
- return
488
553
 
554
+ Note:
555
+ Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
556
+ """
557
+ try:
558
+ x = torch.load(f, map_location=torch.device("cpu"))
559
+ assert isinstance(x, dict), "checkpoint is not a Python dictionary"
560
+ assert "model" in x, "'model' missing from checkpoint"
561
+ except Exception as e:
562
+ LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
563
+ return {}
564
+
565
+ metadata = {
566
+ "date": datetime.now().isoformat(),
567
+ "version": __version__,
568
+ "license": "AGPL-3.0 License (https://ultralytics.com/license)",
569
+ "docs": "https://docs.ultralytics.com",
570
+ }
571
+
572
+ # Update model
573
+ if x.get("ema"):
574
+ x["model"] = x["ema"] # replace model with EMA
489
575
  if hasattr(x["model"], "args"):
490
576
  x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
491
- args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
492
- if x.get("ema"):
493
- x["model"] = x["ema"] # replace model with ema
494
- for k in "optimizer", "best_fitness", "ema", "updates": # keys
495
- x[k] = None
496
- x["epoch"] = -1
577
+ if hasattr(x["model"], "criterion"):
578
+ x["model"].criterion = None # strip loss criterion
497
579
  x["model"].half() # to FP16
498
580
  for p in x["model"].parameters():
499
581
  p.requires_grad = False
582
+
583
+ # Update other keys
584
+ args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
585
+ for k in "optimizer", "best_fitness", "ema", "updates": # keys
586
+ x[k] = None
587
+ x["epoch"] = -1
500
588
  x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
501
589
  # x['model'].args = x['train_args']
502
- torch.save(x, s or f)
590
+
591
+ # Save
592
+ combined = {**metadata, **x, **(updates or {})}
593
+ torch.save(combined, s or f) # combine dicts (prefer to the right)
503
594
  mb = os.path.getsize(s or f) / 1e6 # file size
504
595
  LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
596
+ return combined
597
+
598
+
599
+ def convert_optimizer_state_dict_to_fp16(state_dict):
600
+ """
601
+ Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
602
+
603
+ This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
604
+ """
605
+ for state in state_dict["state"].values():
606
+ for k, v in state.items():
607
+ if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
608
+ state[k] = v.half()
609
+
610
+ return state_dict
611
+
612
+
613
+ @contextmanager
614
+ def cuda_memory_usage(device=None):
615
+ """
616
+ Monitor and manage CUDA memory usage.
617
+
618
+ This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
619
+ It then yields a dictionary containing memory usage information, which can be updated by the caller.
620
+ Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
621
+
622
+ Args:
623
+ device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
505
624
 
625
+ Yields:
626
+ (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
627
+ """
628
+ cuda_info = dict(memory=0)
629
+ if torch.cuda.is_available():
630
+ torch.cuda.empty_cache()
631
+ try:
632
+ yield cuda_info
633
+ finally:
634
+ cuda_info["memory"] = torch.cuda.memory_reserved(device)
635
+ else:
636
+ yield cuda_info
506
637
 
507
- def profile(input, ops, n=10, device=None):
638
+
639
+ def profile(input, ops, n=10, device=None, max_num_obj=0):
508
640
  """
509
641
  Ultralytics speed, memory and FLOPs profiler.
510
642
 
@@ -525,7 +657,8 @@ def profile(input, ops, n=10, device=None):
525
657
  f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
526
658
  f"{'input':>24s}{'output':>24s}"
527
659
  )
528
-
660
+ gc.collect() # attempt to free unused memory
661
+ torch.cuda.empty_cache()
529
662
  for x in input if isinstance(input, list) else [input]:
530
663
  x = x.to(device)
531
664
  x.requires_grad = True
@@ -534,24 +667,36 @@ def profile(input, ops, n=10, device=None):
534
667
  m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
535
668
  tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
536
669
  try:
537
- flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
670
+ flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs
538
671
  except Exception:
539
672
  flops = 0
540
673
 
541
674
  try:
675
+ mem = 0
542
676
  for _ in range(n):
543
- t[0] = time_sync()
544
- y = m(x)
545
- t[1] = time_sync()
546
- try:
547
- (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
548
- t[2] = time_sync()
549
- except Exception: # no backward method
550
- # print(e) # for debug
551
- t[2] = float("nan")
677
+ with cuda_memory_usage(device) as cuda_info:
678
+ t[0] = time_sync()
679
+ y = m(x)
680
+ t[1] = time_sync()
681
+ try:
682
+ (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
683
+ t[2] = time_sync()
684
+ except Exception: # no backward method
685
+ # print(e) # for debug
686
+ t[2] = float("nan")
687
+ mem += cuda_info["memory"] / 1e9 # (GB)
552
688
  tf += (t[1] - t[0]) * 1000 / n # ms per op forward
553
689
  tb += (t[2] - t[1]) * 1000 / n # ms per op backward
554
- mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
690
+ if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
691
+ with cuda_memory_usage(device) as cuda_info:
692
+ torch.randn(
693
+ x.shape[0],
694
+ max_num_obj,
695
+ int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
696
+ device=device,
697
+ dtype=torch.float32,
698
+ )
699
+ mem += cuda_info["memory"] / 1e9 # (GB)
555
700
  s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
556
701
  p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
557
702
  LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
@@ -559,7 +704,9 @@ def profile(input, ops, n=10, device=None):
559
704
  except Exception as e:
560
705
  LOGGER.info(e)
561
706
  results.append(None)
562
- torch.cuda.empty_cache()
707
+ finally:
708
+ gc.collect() # attempt to free unused memory
709
+ torch.cuda.empty_cache()
563
710
  return results
564
711
 
565
712
 
@@ -599,10 +746,56 @@ class EarlyStopping:
599
746
  self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
600
747
  stop = delta >= self.patience # stop training if patience exceeded
601
748
  if stop:
749
+ prefix = colorstr("EarlyStopping: ")
602
750
  LOGGER.info(
603
- f"Stopping training early as no improvement observed in last {self.patience} epochs. "
751
+ f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
604
752
  f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
605
753
  f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
606
754
  f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
607
755
  )
608
756
  return stop
757
+
758
+
759
+ class FXModel(nn.Module):
760
+ """
761
+ A custom model class for torch.fx compatibility.
762
+
763
+ This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
764
+ It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
765
+
766
+ Args:
767
+ model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
768
+ """
769
+
770
+ def __init__(self, model):
771
+ """
772
+ Initialize the FXModel.
773
+
774
+ Args:
775
+ model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
776
+ """
777
+ super().__init__()
778
+ copy_attr(self, model)
779
+ # Explicitly set `model` since `copy_attr` somehow does not copy it.
780
+ self.model = model.model
781
+
782
+ def forward(self, x):
783
+ """
784
+ Forward pass through the model.
785
+
786
+ This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
787
+
788
+ Args:
789
+ x (torch.Tensor): The input tensor to the model.
790
+
791
+ Returns:
792
+ (torch.Tensor): The output tensor from the model.
793
+ """
794
+ y = [] # outputs
795
+ for m in self.model:
796
+ if m.f != -1: # if not from previous layer
797
+ # from earlier layers
798
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
799
+ x = m(x) # run
800
+ y.append(x) # save output
801
+ return x