ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/utils/torch_utils.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
import gc
|
3
4
|
import math
|
4
5
|
import os
|
5
6
|
import random
|
6
7
|
import time
|
7
8
|
from contextlib import contextmanager
|
8
9
|
from copy import deepcopy
|
10
|
+
from datetime import datetime
|
9
11
|
from pathlib import Path
|
10
12
|
from typing import Union
|
11
13
|
|
@@ -14,33 +16,51 @@ import torch
|
|
14
16
|
import torch.distributed as dist
|
15
17
|
import torch.nn as nn
|
16
18
|
import torch.nn.functional as F
|
17
|
-
import torchvision
|
18
19
|
|
19
|
-
from ultralytics.utils import
|
20
|
-
|
20
|
+
from ultralytics.utils import (
|
21
|
+
DEFAULT_CFG_DICT,
|
22
|
+
DEFAULT_CFG_KEYS,
|
23
|
+
LOGGER,
|
24
|
+
NUM_THREADS,
|
25
|
+
PYTHON_VERSION,
|
26
|
+
TORCHVISION_VERSION,
|
27
|
+
WINDOWS,
|
28
|
+
__version__,
|
29
|
+
colorstr,
|
30
|
+
)
|
31
|
+
from ultralytics.utils.checks import check_version
|
21
32
|
|
22
33
|
try:
|
23
34
|
import thop
|
24
35
|
except ImportError:
|
25
36
|
thop = None
|
26
37
|
|
38
|
+
# Version checks (all default to version>=min_version)
|
27
39
|
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
28
40
|
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
|
29
41
|
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
30
|
-
|
31
|
-
|
32
|
-
|
42
|
+
TORCH_2_4 = check_version(torch.__version__, "2.4.0")
|
43
|
+
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
|
44
|
+
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
|
45
|
+
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
|
46
|
+
TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
|
47
|
+
if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
|
48
|
+
LOGGER.warning(
|
49
|
+
"WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
|
50
|
+
"https://github.com/ultralytics/ultralytics/issues/15049"
|
51
|
+
)
|
33
52
|
|
34
53
|
|
35
54
|
@contextmanager
|
36
55
|
def torch_distributed_zero_first(local_rank: int):
|
37
|
-
"""
|
38
|
-
initialized =
|
39
|
-
|
56
|
+
"""Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
|
57
|
+
initialized = dist.is_available() and dist.is_initialized()
|
58
|
+
|
59
|
+
if initialized and local_rank not in {-1, 0}:
|
40
60
|
dist.barrier(device_ids=[local_rank])
|
41
61
|
yield
|
42
62
|
if initialized and local_rank == 0:
|
43
|
-
dist.barrier(device_ids=[
|
63
|
+
dist.barrier(device_ids=[local_rank])
|
44
64
|
|
45
65
|
|
46
66
|
def smart_inference_mode():
|
@@ -56,14 +76,58 @@ def smart_inference_mode():
|
|
56
76
|
return decorate
|
57
77
|
|
58
78
|
|
79
|
+
def autocast(enabled: bool, device: str = "cuda"):
|
80
|
+
"""
|
81
|
+
Get the appropriate autocast context manager based on PyTorch version and AMP setting.
|
82
|
+
|
83
|
+
This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
|
84
|
+
older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
enabled (bool): Whether to enable automatic mixed precision.
|
88
|
+
device (str, optional): The device to use for autocast. Defaults to 'cuda'.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
(torch.amp.autocast): The appropriate autocast context manager.
|
92
|
+
|
93
|
+
Note:
|
94
|
+
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
|
95
|
+
- For older versions, it uses `torch.cuda.autocast`.
|
96
|
+
|
97
|
+
Example:
|
98
|
+
```python
|
99
|
+
with autocast(amp=True):
|
100
|
+
# Your mixed precision operations here
|
101
|
+
pass
|
102
|
+
```
|
103
|
+
"""
|
104
|
+
if TORCH_1_13:
|
105
|
+
return torch.amp.autocast(device, enabled=enabled)
|
106
|
+
else:
|
107
|
+
return torch.cuda.amp.autocast(enabled)
|
108
|
+
|
109
|
+
|
59
110
|
def get_cpu_info():
|
60
111
|
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
61
|
-
import
|
112
|
+
from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
|
62
113
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
114
|
+
if "cpu_info" not in PERSISTENT_CACHE:
|
115
|
+
try:
|
116
|
+
import cpuinfo # pip install py-cpuinfo
|
117
|
+
|
118
|
+
k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
|
119
|
+
info = cpuinfo.get_cpu_info() # info dict
|
120
|
+
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
|
121
|
+
PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
|
122
|
+
except Exception:
|
123
|
+
pass
|
124
|
+
return PERSISTENT_CACHE.get("cpu_info", "unknown")
|
125
|
+
|
126
|
+
|
127
|
+
def get_gpu_info(index):
|
128
|
+
"""Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
|
129
|
+
properties = torch.cuda.get_device_properties(index)
|
130
|
+
return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
|
67
131
|
|
68
132
|
|
69
133
|
def select_device(device="", batch=0, newline=False, verbose=True):
|
@@ -90,30 +154,31 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|
90
154
|
devices when using multiple GPUs.
|
91
155
|
|
92
156
|
Examples:
|
93
|
-
>>> select_device(
|
157
|
+
>>> select_device("cuda:0")
|
94
158
|
device(type='cuda', index=0)
|
95
159
|
|
96
|
-
>>> select_device(
|
160
|
+
>>> select_device("cpu")
|
97
161
|
device(type='cpu')
|
98
162
|
|
99
163
|
Note:
|
100
164
|
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
|
101
165
|
"""
|
102
|
-
|
103
|
-
if isinstance(device, torch.device):
|
166
|
+
if isinstance(device, torch.device) or str(device).startswith("tpu"):
|
104
167
|
return device
|
105
168
|
|
106
|
-
s = f"Ultralytics
|
169
|
+
s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
|
107
170
|
device = str(device).lower()
|
108
171
|
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
|
109
172
|
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
110
173
|
cpu = device == "cpu"
|
111
|
-
mps = device in
|
174
|
+
mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
|
112
175
|
if cpu or mps:
|
113
176
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
|
114
177
|
elif device: # non-cpu device requested
|
115
178
|
if device == "cuda":
|
116
179
|
device = "0"
|
180
|
+
if "," in device:
|
181
|
+
device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
|
117
182
|
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
118
183
|
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
|
119
184
|
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
|
@@ -135,17 +200,22 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|
135
200
|
)
|
136
201
|
|
137
202
|
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
138
|
-
devices = device.split(",") if device else "0" #
|
203
|
+
devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
|
139
204
|
n = len(devices) # device count
|
140
|
-
if n > 1
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
205
|
+
if n > 1: # multi-GPU
|
206
|
+
if batch < 1:
|
207
|
+
raise ValueError(
|
208
|
+
"AutoBatch with batch<1 not supported for Multi-GPU training, "
|
209
|
+
"please specify a valid batch size, i.e. batch=16."
|
210
|
+
)
|
211
|
+
if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
|
212
|
+
raise ValueError(
|
213
|
+
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
214
|
+
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
|
215
|
+
)
|
145
216
|
space = " " * (len(s) + 1)
|
146
217
|
for i, d in enumerate(devices):
|
147
|
-
|
148
|
-
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
218
|
+
s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
|
149
219
|
arg = "cuda:0"
|
150
220
|
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
|
151
221
|
# Prefer MPS if available
|
@@ -155,6 +225,8 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|
155
225
|
s += f"CPU ({get_cpu_info()})\n"
|
156
226
|
arg = "cpu"
|
157
227
|
|
228
|
+
if arg in {"cpu", "mps"}:
|
229
|
+
torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
|
158
230
|
if verbose:
|
159
231
|
LOGGER.info(s if newline else s.rstrip())
|
160
232
|
return torch.device(arg)
|
@@ -185,7 +257,7 @@ def fuse_conv_and_bn(conv, bn):
|
|
185
257
|
)
|
186
258
|
|
187
259
|
# Prepare filters
|
188
|
-
w_conv = conv.weight.
|
260
|
+
w_conv = conv.weight.view(conv.out_channels, -1)
|
189
261
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
190
262
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
191
263
|
|
@@ -216,7 +288,7 @@ def fuse_deconv_and_bn(deconv, bn):
|
|
216
288
|
)
|
217
289
|
|
218
290
|
# Prepare filters
|
219
|
-
w_deconv = deconv.weight.
|
291
|
+
w_deconv = deconv.weight.view(deconv.out_channels, -1)
|
220
292
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
221
293
|
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
222
294
|
|
@@ -229,33 +301,27 @@ def fuse_deconv_and_bn(deconv, bn):
|
|
229
301
|
|
230
302
|
|
231
303
|
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
232
|
-
"""
|
233
|
-
Model information.
|
234
|
-
|
235
|
-
imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
|
236
|
-
"""
|
304
|
+
"""Print and return detailed model information layer by layer."""
|
237
305
|
if not verbose:
|
238
306
|
return
|
239
307
|
n_p = get_num_params(model) # number of parameters
|
240
308
|
n_g = get_num_gradients(model) # number of gradients
|
241
309
|
n_l = len(list(model.modules())) # number of layers
|
242
310
|
if detailed:
|
243
|
-
LOGGER.info(
|
244
|
-
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
|
245
|
-
)
|
311
|
+
LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}")
|
246
312
|
for i, (name, p) in enumerate(model.named_parameters()):
|
247
313
|
name = name.replace("module_list.", "")
|
248
314
|
LOGGER.info(
|
249
|
-
"
|
250
|
-
|
315
|
+
f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}"
|
316
|
+
f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}"
|
251
317
|
)
|
252
318
|
|
253
|
-
flops = get_flops(model, imgsz)
|
319
|
+
flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
|
254
320
|
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
|
255
321
|
fs = f", {flops:.1f} GFLOPs" if flops else ""
|
256
322
|
yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
|
257
323
|
model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
|
258
|
-
LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
|
324
|
+
LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
|
259
325
|
return n_l, n_p, n_g, flops
|
260
326
|
|
261
327
|
|
@@ -276,11 +342,13 @@ def model_info_for_loggers(trainer):
|
|
276
342
|
Example:
|
277
343
|
YOLOv8n info for loggers
|
278
344
|
```python
|
279
|
-
results = {
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
345
|
+
results = {
|
346
|
+
"model/parameters": 3151904,
|
347
|
+
"model/GFLOPs": 8.746,
|
348
|
+
"model/speed_ONNX(ms)": 41.244,
|
349
|
+
"model/speed_TensorRT(ms)": 3.211,
|
350
|
+
"model/speed_PyTorch(ms)": 18.755,
|
351
|
+
}
|
284
352
|
```
|
285
353
|
"""
|
286
354
|
if trainer.args.profile: # profile ONNX and TensorRT times
|
@@ -322,19 +390,28 @@ def get_flops(model, imgsz=640):
|
|
322
390
|
|
323
391
|
|
324
392
|
def get_flops_with_torch_profiler(model, imgsz=640):
|
325
|
-
"""Compute model FLOPs (thop alternative)."""
|
326
|
-
if TORCH_2_0:
|
327
|
-
|
328
|
-
|
393
|
+
"""Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
|
394
|
+
if not TORCH_2_0: # torch profiler implemented in torch>=2.0
|
395
|
+
return 0.0
|
396
|
+
model = de_parallel(model)
|
397
|
+
p = next(model.parameters())
|
398
|
+
if not isinstance(imgsz, list):
|
399
|
+
imgsz = [imgsz, imgsz] # expand if int/float
|
400
|
+
try:
|
401
|
+
# Use stride size for input tensor
|
329
402
|
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
330
|
-
im = torch.
|
403
|
+
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
331
404
|
with torch.profiler.profile(with_flops=True) as prof:
|
332
405
|
model(im)
|
333
406
|
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
334
|
-
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
335
407
|
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
336
|
-
|
337
|
-
|
408
|
+
except Exception:
|
409
|
+
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
410
|
+
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
411
|
+
with torch.profiler.profile(with_flops=True) as prof:
|
412
|
+
model(im)
|
413
|
+
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
414
|
+
return flops
|
338
415
|
|
339
416
|
|
340
417
|
def initialize_weights(model):
|
@@ -346,14 +423,12 @@ def initialize_weights(model):
|
|
346
423
|
elif t is nn.BatchNorm2d:
|
347
424
|
m.eps = 1e-3
|
348
425
|
m.momentum = 0.03
|
349
|
-
elif t in
|
426
|
+
elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
|
350
427
|
m.inplace = True
|
351
428
|
|
352
429
|
|
353
430
|
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
354
|
-
"""Scales and pads an image tensor
|
355
|
-
retaining the original shape.
|
356
|
-
"""
|
431
|
+
"""Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
|
357
432
|
if ratio == 1.0:
|
358
433
|
return img
|
359
434
|
h, w = img.shape[2:]
|
@@ -364,13 +439,6 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
|
364
439
|
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
365
440
|
|
366
441
|
|
367
|
-
def make_divisible(x, divisor):
|
368
|
-
"""Returns nearest x divisible by divisor."""
|
369
|
-
if isinstance(divisor, torch.Tensor):
|
370
|
-
divisor = int(divisor.max()) # to int
|
371
|
-
return math.ceil(x / divisor) * divisor
|
372
|
-
|
373
|
-
|
374
442
|
def copy_attr(a, b, include=(), exclude=()):
|
375
443
|
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
|
376
444
|
for k, v in b.__dict__.items():
|
@@ -381,8 +449,13 @@ def copy_attr(a, b, include=(), exclude=()):
|
|
381
449
|
|
382
450
|
|
383
451
|
def get_latest_opset():
|
384
|
-
"""Return second-most
|
385
|
-
|
452
|
+
"""Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
|
453
|
+
if TORCH_1_13:
|
454
|
+
# If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
|
455
|
+
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
|
456
|
+
# Otherwise for PyTorch<=1.12 return the corresponding predefined opset
|
457
|
+
version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
|
458
|
+
return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
|
386
459
|
|
387
460
|
|
388
461
|
def intersect_dicts(da, db, exclude=()):
|
@@ -427,14 +500,17 @@ def init_seeds(seed=0, deterministic=False):
|
|
427
500
|
|
428
501
|
|
429
502
|
class ModelEMA:
|
430
|
-
"""
|
431
|
-
|
503
|
+
"""
|
504
|
+
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
|
505
|
+
average of everything in the model state_dict (parameters and buffers).
|
506
|
+
|
432
507
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
508
|
+
|
433
509
|
To disable EMA set the `enabled` attribute to `False`.
|
434
510
|
"""
|
435
511
|
|
436
512
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
437
|
-
"""
|
513
|
+
"""Initialize EMA for 'model' with given arguments."""
|
438
514
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
439
515
|
self.updates = updates # number of EMA updates
|
440
516
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
@@ -461,50 +537,113 @@ class ModelEMA:
|
|
461
537
|
copy_attr(self.ema, model, include, exclude)
|
462
538
|
|
463
539
|
|
464
|
-
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") ->
|
540
|
+
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
|
465
541
|
"""
|
466
542
|
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
467
543
|
|
468
544
|
Args:
|
469
545
|
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
470
546
|
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
547
|
+
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
|
471
548
|
|
472
549
|
Returns:
|
473
|
-
|
550
|
+
(dict): The combined checkpoint dictionary.
|
474
551
|
|
475
552
|
Example:
|
476
553
|
```python
|
477
554
|
from pathlib import Path
|
478
555
|
from ultralytics.utils.torch_utils import strip_optimizer
|
479
556
|
|
480
|
-
for f in Path(
|
557
|
+
for f in Path("path/to/model/checkpoints").rglob("*.pt"):
|
481
558
|
strip_optimizer(f)
|
482
559
|
```
|
483
|
-
"""
|
484
|
-
x = torch.load(f, map_location=torch.device("cpu"))
|
485
|
-
if "model" not in x:
|
486
|
-
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
487
|
-
return
|
488
560
|
|
561
|
+
Note:
|
562
|
+
Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
|
563
|
+
"""
|
564
|
+
try:
|
565
|
+
x = torch.load(f, map_location=torch.device("cpu"))
|
566
|
+
assert isinstance(x, dict), "checkpoint is not a Python dictionary"
|
567
|
+
assert "model" in x, "'model' missing from checkpoint"
|
568
|
+
except Exception as e:
|
569
|
+
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
|
570
|
+
return {}
|
571
|
+
|
572
|
+
metadata = {
|
573
|
+
"date": datetime.now().isoformat(),
|
574
|
+
"version": __version__,
|
575
|
+
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
576
|
+
"docs": "https://docs.ultralytics.com",
|
577
|
+
}
|
578
|
+
|
579
|
+
# Update model
|
580
|
+
if x.get("ema"):
|
581
|
+
x["model"] = x["ema"] # replace model with EMA
|
489
582
|
if hasattr(x["model"], "args"):
|
490
583
|
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
|
491
|
-
|
492
|
-
|
493
|
-
x["model"] = x["ema"] # replace model with ema
|
494
|
-
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
495
|
-
x[k] = None
|
496
|
-
x["epoch"] = -1
|
584
|
+
if hasattr(x["model"], "criterion"):
|
585
|
+
x["model"].criterion = None # strip loss criterion
|
497
586
|
x["model"].half() # to FP16
|
498
587
|
for p in x["model"].parameters():
|
499
588
|
p.requires_grad = False
|
589
|
+
|
590
|
+
# Update other keys
|
591
|
+
args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
|
592
|
+
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
593
|
+
x[k] = None
|
594
|
+
x["epoch"] = -1
|
500
595
|
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
501
596
|
# x['model'].args = x['train_args']
|
502
|
-
|
597
|
+
|
598
|
+
# Save
|
599
|
+
combined = {**metadata, **x, **(updates or {})}
|
600
|
+
torch.save(combined, s or f) # combine dicts (prefer to the right)
|
503
601
|
mb = os.path.getsize(s or f) / 1e6 # file size
|
504
602
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
603
|
+
return combined
|
604
|
+
|
605
|
+
|
606
|
+
def convert_optimizer_state_dict_to_fp16(state_dict):
|
607
|
+
"""
|
608
|
+
Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
|
609
|
+
|
610
|
+
This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
|
611
|
+
"""
|
612
|
+
for state in state_dict["state"].values():
|
613
|
+
for k, v in state.items():
|
614
|
+
if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
|
615
|
+
state[k] = v.half()
|
616
|
+
|
617
|
+
return state_dict
|
618
|
+
|
619
|
+
|
620
|
+
@contextmanager
|
621
|
+
def cuda_memory_usage(device=None):
|
622
|
+
"""
|
623
|
+
Monitor and manage CUDA memory usage.
|
624
|
+
|
625
|
+
This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
|
626
|
+
It then yields a dictionary containing memory usage information, which can be updated by the caller.
|
627
|
+
Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
|
628
|
+
|
629
|
+
Args:
|
630
|
+
device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
|
631
|
+
|
632
|
+
Yields:
|
633
|
+
(dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
|
634
|
+
"""
|
635
|
+
cuda_info = dict(memory=0)
|
636
|
+
if torch.cuda.is_available():
|
637
|
+
torch.cuda.empty_cache()
|
638
|
+
try:
|
639
|
+
yield cuda_info
|
640
|
+
finally:
|
641
|
+
cuda_info["memory"] = torch.cuda.memory_reserved(device)
|
642
|
+
else:
|
643
|
+
yield cuda_info
|
505
644
|
|
506
645
|
|
507
|
-
def profile(input, ops, n=10, device=None):
|
646
|
+
def profile(input, ops, n=10, device=None, max_num_obj=0):
|
508
647
|
"""
|
509
648
|
Ultralytics speed, memory and FLOPs profiler.
|
510
649
|
|
@@ -525,7 +664,8 @@ def profile(input, ops, n=10, device=None):
|
|
525
664
|
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
526
665
|
f"{'input':>24s}{'output':>24s}"
|
527
666
|
)
|
528
|
-
|
667
|
+
gc.collect() # attempt to free unused memory
|
668
|
+
torch.cuda.empty_cache()
|
529
669
|
for x in input if isinstance(input, list) else [input]:
|
530
670
|
x = x.to(device)
|
531
671
|
x.requires_grad = True
|
@@ -539,19 +679,31 @@ def profile(input, ops, n=10, device=None):
|
|
539
679
|
flops = 0
|
540
680
|
|
541
681
|
try:
|
682
|
+
mem = 0
|
542
683
|
for _ in range(n):
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
684
|
+
with cuda_memory_usage(device) as cuda_info:
|
685
|
+
t[0] = time_sync()
|
686
|
+
y = m(x)
|
687
|
+
t[1] = time_sync()
|
688
|
+
try:
|
689
|
+
(sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
690
|
+
t[2] = time_sync()
|
691
|
+
except Exception: # no backward method
|
692
|
+
# print(e) # for debug
|
693
|
+
t[2] = float("nan")
|
694
|
+
mem += cuda_info["memory"] / 1e9 # (GB)
|
552
695
|
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
553
696
|
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
554
|
-
|
697
|
+
if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
|
698
|
+
with cuda_memory_usage(device) as cuda_info:
|
699
|
+
torch.randn(
|
700
|
+
x.shape[0],
|
701
|
+
max_num_obj,
|
702
|
+
int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
|
703
|
+
device=device,
|
704
|
+
dtype=torch.float32,
|
705
|
+
)
|
706
|
+
mem += cuda_info["memory"] / 1e9 # (GB)
|
555
707
|
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
|
556
708
|
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
557
709
|
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
|
@@ -559,7 +711,9 @@ def profile(input, ops, n=10, device=None):
|
|
559
711
|
except Exception as e:
|
560
712
|
LOGGER.info(e)
|
561
713
|
results.append(None)
|
562
|
-
|
714
|
+
finally:
|
715
|
+
gc.collect() # attempt to free unused memory
|
716
|
+
torch.cuda.empty_cache()
|
563
717
|
return results
|
564
718
|
|
565
719
|
|
@@ -599,10 +753,56 @@ class EarlyStopping:
|
|
599
753
|
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
600
754
|
stop = delta >= self.patience # stop training if patience exceeded
|
601
755
|
if stop:
|
756
|
+
prefix = colorstr("EarlyStopping: ")
|
602
757
|
LOGGER.info(
|
603
|
-
f"
|
758
|
+
f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
|
604
759
|
f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
|
605
760
|
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
|
606
761
|
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
|
607
762
|
)
|
608
763
|
return stop
|
764
|
+
|
765
|
+
|
766
|
+
class FXModel(nn.Module):
|
767
|
+
"""
|
768
|
+
A custom model class for torch.fx compatibility.
|
769
|
+
|
770
|
+
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
|
771
|
+
It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
|
772
|
+
|
773
|
+
Args:
|
774
|
+
model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
|
775
|
+
"""
|
776
|
+
|
777
|
+
def __init__(self, model):
|
778
|
+
"""
|
779
|
+
Initialize the FXModel.
|
780
|
+
|
781
|
+
Args:
|
782
|
+
model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
|
783
|
+
"""
|
784
|
+
super().__init__()
|
785
|
+
copy_attr(self, model)
|
786
|
+
# Explicitly set `model` since `copy_attr` somehow does not copy it.
|
787
|
+
self.model = model.model
|
788
|
+
|
789
|
+
def forward(self, x):
|
790
|
+
"""
|
791
|
+
Forward pass through the model.
|
792
|
+
|
793
|
+
This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
|
794
|
+
|
795
|
+
Args:
|
796
|
+
x (torch.Tensor): The input tensor to the model.
|
797
|
+
|
798
|
+
Returns:
|
799
|
+
(torch.Tensor): The output tensor from the model.
|
800
|
+
"""
|
801
|
+
y = [] # outputs
|
802
|
+
for m in self.model:
|
803
|
+
if m.f != -1: # if not from previous layer
|
804
|
+
# from earlier layers
|
805
|
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
|
806
|
+
x = m(x) # run
|
807
|
+
y.append(x) # save output
|
808
|
+
return x
|
ultralytics/utils/triton.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from typing import List
|
4
4
|
from urllib.parse import urlsplit
|
@@ -66,6 +66,7 @@ class TritonRemoteModel:
|
|
66
66
|
self.np_input_formats = [type_map[x] for x in self.input_formats]
|
67
67
|
self.input_names = [x["name"] for x in config["input"]]
|
68
68
|
self.output_names = [x["name"] for x in config["output"]]
|
69
|
+
self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
|
69
70
|
|
70
71
|
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
|
71
72
|
"""
|