ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +37 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +111 -41
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +579 -244
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +191 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +226 -82
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +172 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +305 -112
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.63.dist-info/METADATA +370 -0
- ultralytics-8.3.63.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
ultralytics/utils/torch_utils.py
CHANGED
@@ -1,46 +1,62 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
import gc
|
3
4
|
import math
|
4
5
|
import os
|
5
6
|
import random
|
6
7
|
import time
|
7
8
|
from contextlib import contextmanager
|
8
9
|
from copy import deepcopy
|
10
|
+
from datetime import datetime
|
9
11
|
from pathlib import Path
|
10
12
|
from typing import Union
|
11
13
|
|
12
14
|
import numpy as np
|
15
|
+
import thop
|
13
16
|
import torch
|
14
17
|
import torch.distributed as dist
|
15
18
|
import torch.nn as nn
|
16
19
|
import torch.nn.functional as F
|
17
|
-
import torchvision
|
18
|
-
|
19
|
-
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
|
20
|
-
from ultralytics.utils.checks import PYTHON_VERSION, check_version
|
21
|
-
|
22
|
-
try:
|
23
|
-
import thop
|
24
|
-
except ImportError:
|
25
|
-
thop = None
|
26
20
|
|
21
|
+
from ultralytics.utils import (
|
22
|
+
DEFAULT_CFG_DICT,
|
23
|
+
DEFAULT_CFG_KEYS,
|
24
|
+
LOGGER,
|
25
|
+
NUM_THREADS,
|
26
|
+
PYTHON_VERSION,
|
27
|
+
TORCHVISION_VERSION,
|
28
|
+
WINDOWS,
|
29
|
+
__version__,
|
30
|
+
colorstr,
|
31
|
+
)
|
32
|
+
from ultralytics.utils.checks import check_version
|
33
|
+
|
34
|
+
# Version checks (all default to version>=min_version)
|
27
35
|
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
|
28
36
|
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
|
29
37
|
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
|
30
|
-
|
31
|
-
|
32
|
-
|
38
|
+
TORCH_2_4 = check_version(torch.__version__, "2.4.0")
|
39
|
+
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
|
40
|
+
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
|
41
|
+
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
|
42
|
+
TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
|
43
|
+
if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
|
44
|
+
LOGGER.warning(
|
45
|
+
"WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
|
46
|
+
"https://github.com/ultralytics/ultralytics/issues/15049"
|
47
|
+
)
|
33
48
|
|
34
49
|
|
35
50
|
@contextmanager
|
36
51
|
def torch_distributed_zero_first(local_rank: int):
|
37
|
-
"""
|
38
|
-
initialized =
|
39
|
-
|
52
|
+
"""Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
|
53
|
+
initialized = dist.is_available() and dist.is_initialized()
|
54
|
+
|
55
|
+
if initialized and local_rank not in {-1, 0}:
|
40
56
|
dist.barrier(device_ids=[local_rank])
|
41
57
|
yield
|
42
58
|
if initialized and local_rank == 0:
|
43
|
-
dist.barrier(device_ids=[
|
59
|
+
dist.barrier(device_ids=[local_rank])
|
44
60
|
|
45
61
|
|
46
62
|
def smart_inference_mode():
|
@@ -56,14 +72,58 @@ def smart_inference_mode():
|
|
56
72
|
return decorate
|
57
73
|
|
58
74
|
|
75
|
+
def autocast(enabled: bool, device: str = "cuda"):
|
76
|
+
"""
|
77
|
+
Get the appropriate autocast context manager based on PyTorch version and AMP setting.
|
78
|
+
|
79
|
+
This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
|
80
|
+
older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
enabled (bool): Whether to enable automatic mixed precision.
|
84
|
+
device (str, optional): The device to use for autocast. Defaults to 'cuda'.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
(torch.amp.autocast): The appropriate autocast context manager.
|
88
|
+
|
89
|
+
Note:
|
90
|
+
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
|
91
|
+
- For older versions, it uses `torch.cuda.autocast`.
|
92
|
+
|
93
|
+
Example:
|
94
|
+
```python
|
95
|
+
with autocast(amp=True):
|
96
|
+
# Your mixed precision operations here
|
97
|
+
pass
|
98
|
+
```
|
99
|
+
"""
|
100
|
+
if TORCH_1_13:
|
101
|
+
return torch.amp.autocast(device, enabled=enabled)
|
102
|
+
else:
|
103
|
+
return torch.cuda.amp.autocast(enabled)
|
104
|
+
|
105
|
+
|
59
106
|
def get_cpu_info():
|
60
107
|
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
61
|
-
import
|
108
|
+
from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
|
109
|
+
|
110
|
+
if "cpu_info" not in PERSISTENT_CACHE:
|
111
|
+
try:
|
112
|
+
import cpuinfo # pip install py-cpuinfo
|
113
|
+
|
114
|
+
k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
|
115
|
+
info = cpuinfo.get_cpu_info() # info dict
|
116
|
+
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
|
117
|
+
PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
|
118
|
+
except Exception:
|
119
|
+
pass
|
120
|
+
return PERSISTENT_CACHE.get("cpu_info", "unknown")
|
121
|
+
|
62
122
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
return
|
123
|
+
def get_gpu_info(index):
|
124
|
+
"""Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
|
125
|
+
properties = torch.cuda.get_device_properties(index)
|
126
|
+
return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
|
67
127
|
|
68
128
|
|
69
129
|
def select_device(device="", batch=0, newline=False, verbose=True):
|
@@ -90,30 +150,31 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|
90
150
|
devices when using multiple GPUs.
|
91
151
|
|
92
152
|
Examples:
|
93
|
-
>>> select_device(
|
153
|
+
>>> select_device("cuda:0")
|
94
154
|
device(type='cuda', index=0)
|
95
155
|
|
96
|
-
>>> select_device(
|
156
|
+
>>> select_device("cpu")
|
97
157
|
device(type='cpu')
|
98
158
|
|
99
159
|
Note:
|
100
160
|
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
|
101
161
|
"""
|
102
|
-
|
103
|
-
if isinstance(device, torch.device):
|
162
|
+
if isinstance(device, torch.device) or str(device).startswith("tpu"):
|
104
163
|
return device
|
105
164
|
|
106
|
-
s = f"Ultralytics
|
165
|
+
s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
|
107
166
|
device = str(device).lower()
|
108
167
|
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
|
109
168
|
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
110
169
|
cpu = device == "cpu"
|
111
|
-
mps = device in
|
170
|
+
mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
|
112
171
|
if cpu or mps:
|
113
172
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
|
114
173
|
elif device: # non-cpu device requested
|
115
174
|
if device == "cuda":
|
116
175
|
device = "0"
|
176
|
+
if "," in device:
|
177
|
+
device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
|
117
178
|
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
118
179
|
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
|
119
180
|
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
|
@@ -135,17 +196,22 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|
135
196
|
)
|
136
197
|
|
137
198
|
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
138
|
-
devices = device.split(",") if device else "0" #
|
199
|
+
devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
|
139
200
|
n = len(devices) # device count
|
140
|
-
if n > 1
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
201
|
+
if n > 1: # multi-GPU
|
202
|
+
if batch < 1:
|
203
|
+
raise ValueError(
|
204
|
+
"AutoBatch with batch<1 not supported for Multi-GPU training, "
|
205
|
+
"please specify a valid batch size, i.e. batch=16."
|
206
|
+
)
|
207
|
+
if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
|
208
|
+
raise ValueError(
|
209
|
+
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
210
|
+
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
|
211
|
+
)
|
145
212
|
space = " " * (len(s) + 1)
|
146
213
|
for i, d in enumerate(devices):
|
147
|
-
|
148
|
-
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
214
|
+
s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
|
149
215
|
arg = "cuda:0"
|
150
216
|
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
|
151
217
|
# Prefer MPS if available
|
@@ -155,6 +221,8 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|
155
221
|
s += f"CPU ({get_cpu_info()})\n"
|
156
222
|
arg = "cpu"
|
157
223
|
|
224
|
+
if arg in {"cpu", "mps"}:
|
225
|
+
torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
|
158
226
|
if verbose:
|
159
227
|
LOGGER.info(s if newline else s.rstrip())
|
160
228
|
return torch.device(arg)
|
@@ -185,7 +253,7 @@ def fuse_conv_and_bn(conv, bn):
|
|
185
253
|
)
|
186
254
|
|
187
255
|
# Prepare filters
|
188
|
-
w_conv = conv.weight.
|
256
|
+
w_conv = conv.weight.view(conv.out_channels, -1)
|
189
257
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
190
258
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
191
259
|
|
@@ -216,7 +284,7 @@ def fuse_deconv_and_bn(deconv, bn):
|
|
216
284
|
)
|
217
285
|
|
218
286
|
# Prepare filters
|
219
|
-
w_deconv = deconv.weight.
|
287
|
+
w_deconv = deconv.weight.view(deconv.out_channels, -1)
|
220
288
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
221
289
|
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
222
290
|
|
@@ -229,33 +297,27 @@ def fuse_deconv_and_bn(deconv, bn):
|
|
229
297
|
|
230
298
|
|
231
299
|
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
232
|
-
"""
|
233
|
-
Model information.
|
234
|
-
|
235
|
-
imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
|
236
|
-
"""
|
300
|
+
"""Print and return detailed model information layer by layer."""
|
237
301
|
if not verbose:
|
238
302
|
return
|
239
303
|
n_p = get_num_params(model) # number of parameters
|
240
304
|
n_g = get_num_gradients(model) # number of gradients
|
241
305
|
n_l = len(list(model.modules())) # number of layers
|
242
306
|
if detailed:
|
243
|
-
LOGGER.info(
|
244
|
-
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
|
245
|
-
)
|
307
|
+
LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}")
|
246
308
|
for i, (name, p) in enumerate(model.named_parameters()):
|
247
309
|
name = name.replace("module_list.", "")
|
248
310
|
LOGGER.info(
|
249
|
-
"
|
250
|
-
|
311
|
+
f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}"
|
312
|
+
f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}"
|
251
313
|
)
|
252
314
|
|
253
|
-
flops = get_flops(model, imgsz)
|
315
|
+
flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
|
254
316
|
fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
|
255
317
|
fs = f", {flops:.1f} GFLOPs" if flops else ""
|
256
318
|
yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
|
257
319
|
model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
|
258
|
-
LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
|
320
|
+
LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
|
259
321
|
return n_l, n_p, n_g, flops
|
260
322
|
|
261
323
|
|
@@ -276,11 +338,13 @@ def model_info_for_loggers(trainer):
|
|
276
338
|
Example:
|
277
339
|
YOLOv8n info for loggers
|
278
340
|
```python
|
279
|
-
results = {
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
341
|
+
results = {
|
342
|
+
"model/parameters": 3151904,
|
343
|
+
"model/GFLOPs": 8.746,
|
344
|
+
"model/speed_ONNX(ms)": 41.244,
|
345
|
+
"model/speed_TensorRT(ms)": 3.211,
|
346
|
+
"model/speed_PyTorch(ms)": 18.755,
|
347
|
+
}
|
284
348
|
```
|
285
349
|
"""
|
286
350
|
if trainer.args.profile: # profile ONNX and TensorRT times
|
@@ -299,9 +363,6 @@ def model_info_for_loggers(trainer):
|
|
299
363
|
|
300
364
|
def get_flops(model, imgsz=640):
|
301
365
|
"""Return a YOLO model's FLOPs."""
|
302
|
-
if not thop:
|
303
|
-
return 0.0 # if not installed return 0.0 GFLOPs
|
304
|
-
|
305
366
|
try:
|
306
367
|
model = de_parallel(model)
|
307
368
|
p = next(model.parameters())
|
@@ -322,19 +383,28 @@ def get_flops(model, imgsz=640):
|
|
322
383
|
|
323
384
|
|
324
385
|
def get_flops_with_torch_profiler(model, imgsz=640):
|
325
|
-
"""Compute model FLOPs (thop alternative)."""
|
326
|
-
if TORCH_2_0:
|
327
|
-
|
328
|
-
|
386
|
+
"""Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
|
387
|
+
if not TORCH_2_0: # torch profiler implemented in torch>=2.0
|
388
|
+
return 0.0
|
389
|
+
model = de_parallel(model)
|
390
|
+
p = next(model.parameters())
|
391
|
+
if not isinstance(imgsz, list):
|
392
|
+
imgsz = [imgsz, imgsz] # expand if int/float
|
393
|
+
try:
|
394
|
+
# Use stride size for input tensor
|
329
395
|
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
330
|
-
im = torch.
|
396
|
+
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
331
397
|
with torch.profiler.profile(with_flops=True) as prof:
|
332
398
|
model(im)
|
333
399
|
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
334
|
-
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
335
400
|
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
336
|
-
|
337
|
-
|
401
|
+
except Exception:
|
402
|
+
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
403
|
+
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
404
|
+
with torch.profiler.profile(with_flops=True) as prof:
|
405
|
+
model(im)
|
406
|
+
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
407
|
+
return flops
|
338
408
|
|
339
409
|
|
340
410
|
def initialize_weights(model):
|
@@ -346,14 +416,12 @@ def initialize_weights(model):
|
|
346
416
|
elif t is nn.BatchNorm2d:
|
347
417
|
m.eps = 1e-3
|
348
418
|
m.momentum = 0.03
|
349
|
-
elif t in
|
419
|
+
elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
|
350
420
|
m.inplace = True
|
351
421
|
|
352
422
|
|
353
423
|
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
354
|
-
"""Scales and pads an image tensor
|
355
|
-
retaining the original shape.
|
356
|
-
"""
|
424
|
+
"""Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
|
357
425
|
if ratio == 1.0:
|
358
426
|
return img
|
359
427
|
h, w = img.shape[2:]
|
@@ -364,13 +432,6 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
|
364
432
|
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
365
433
|
|
366
434
|
|
367
|
-
def make_divisible(x, divisor):
|
368
|
-
"""Returns nearest x divisible by divisor."""
|
369
|
-
if isinstance(divisor, torch.Tensor):
|
370
|
-
divisor = int(divisor.max()) # to int
|
371
|
-
return math.ceil(x / divisor) * divisor
|
372
|
-
|
373
|
-
|
374
435
|
def copy_attr(a, b, include=(), exclude=()):
|
375
436
|
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
|
376
437
|
for k, v in b.__dict__.items():
|
@@ -381,8 +442,13 @@ def copy_attr(a, b, include=(), exclude=()):
|
|
381
442
|
|
382
443
|
|
383
444
|
def get_latest_opset():
|
384
|
-
"""Return second-most
|
385
|
-
|
445
|
+
"""Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
|
446
|
+
if TORCH_1_13:
|
447
|
+
# If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
|
448
|
+
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
|
449
|
+
# Otherwise for PyTorch<=1.12 return the corresponding predefined opset
|
450
|
+
version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
|
451
|
+
return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
|
386
452
|
|
387
453
|
|
388
454
|
def intersect_dicts(da, db, exclude=()):
|
@@ -427,14 +493,17 @@ def init_seeds(seed=0, deterministic=False):
|
|
427
493
|
|
428
494
|
|
429
495
|
class ModelEMA:
|
430
|
-
"""
|
431
|
-
|
496
|
+
"""
|
497
|
+
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
|
498
|
+
average of everything in the model state_dict (parameters and buffers).
|
499
|
+
|
432
500
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
501
|
+
|
433
502
|
To disable EMA set the `enabled` attribute to `False`.
|
434
503
|
"""
|
435
504
|
|
436
505
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
437
|
-
"""
|
506
|
+
"""Initialize EMA for 'model' with given arguments."""
|
438
507
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
439
508
|
self.updates = updates # number of EMA updates
|
440
509
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
@@ -461,50 +530,113 @@ class ModelEMA:
|
|
461
530
|
copy_attr(self.ema, model, include, exclude)
|
462
531
|
|
463
532
|
|
464
|
-
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") ->
|
533
|
+
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
|
465
534
|
"""
|
466
535
|
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
467
536
|
|
468
537
|
Args:
|
469
538
|
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
470
539
|
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
540
|
+
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
|
471
541
|
|
472
542
|
Returns:
|
473
|
-
|
543
|
+
(dict): The combined checkpoint dictionary.
|
474
544
|
|
475
545
|
Example:
|
476
546
|
```python
|
477
547
|
from pathlib import Path
|
478
548
|
from ultralytics.utils.torch_utils import strip_optimizer
|
479
549
|
|
480
|
-
for f in Path(
|
550
|
+
for f in Path("path/to/model/checkpoints").rglob("*.pt"):
|
481
551
|
strip_optimizer(f)
|
482
552
|
```
|
483
|
-
"""
|
484
|
-
x = torch.load(f, map_location=torch.device("cpu"))
|
485
|
-
if "model" not in x:
|
486
|
-
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
487
|
-
return
|
488
553
|
|
554
|
+
Note:
|
555
|
+
Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
|
556
|
+
"""
|
557
|
+
try:
|
558
|
+
x = torch.load(f, map_location=torch.device("cpu"))
|
559
|
+
assert isinstance(x, dict), "checkpoint is not a Python dictionary"
|
560
|
+
assert "model" in x, "'model' missing from checkpoint"
|
561
|
+
except Exception as e:
|
562
|
+
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
|
563
|
+
return {}
|
564
|
+
|
565
|
+
metadata = {
|
566
|
+
"date": datetime.now().isoformat(),
|
567
|
+
"version": __version__,
|
568
|
+
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
569
|
+
"docs": "https://docs.ultralytics.com",
|
570
|
+
}
|
571
|
+
|
572
|
+
# Update model
|
573
|
+
if x.get("ema"):
|
574
|
+
x["model"] = x["ema"] # replace model with EMA
|
489
575
|
if hasattr(x["model"], "args"):
|
490
576
|
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
|
491
|
-
|
492
|
-
|
493
|
-
x["model"] = x["ema"] # replace model with ema
|
494
|
-
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
495
|
-
x[k] = None
|
496
|
-
x["epoch"] = -1
|
577
|
+
if hasattr(x["model"], "criterion"):
|
578
|
+
x["model"].criterion = None # strip loss criterion
|
497
579
|
x["model"].half() # to FP16
|
498
580
|
for p in x["model"].parameters():
|
499
581
|
p.requires_grad = False
|
582
|
+
|
583
|
+
# Update other keys
|
584
|
+
args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
|
585
|
+
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
586
|
+
x[k] = None
|
587
|
+
x["epoch"] = -1
|
500
588
|
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
501
589
|
# x['model'].args = x['train_args']
|
502
|
-
|
590
|
+
|
591
|
+
# Save
|
592
|
+
combined = {**metadata, **x, **(updates or {})}
|
593
|
+
torch.save(combined, s or f) # combine dicts (prefer to the right)
|
503
594
|
mb = os.path.getsize(s or f) / 1e6 # file size
|
504
595
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
596
|
+
return combined
|
597
|
+
|
598
|
+
|
599
|
+
def convert_optimizer_state_dict_to_fp16(state_dict):
|
600
|
+
"""
|
601
|
+
Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
|
602
|
+
|
603
|
+
This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
|
604
|
+
"""
|
605
|
+
for state in state_dict["state"].values():
|
606
|
+
for k, v in state.items():
|
607
|
+
if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
|
608
|
+
state[k] = v.half()
|
609
|
+
|
610
|
+
return state_dict
|
611
|
+
|
612
|
+
|
613
|
+
@contextmanager
|
614
|
+
def cuda_memory_usage(device=None):
|
615
|
+
"""
|
616
|
+
Monitor and manage CUDA memory usage.
|
617
|
+
|
618
|
+
This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
|
619
|
+
It then yields a dictionary containing memory usage information, which can be updated by the caller.
|
620
|
+
Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
|
621
|
+
|
622
|
+
Args:
|
623
|
+
device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
|
505
624
|
|
625
|
+
Yields:
|
626
|
+
(dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
|
627
|
+
"""
|
628
|
+
cuda_info = dict(memory=0)
|
629
|
+
if torch.cuda.is_available():
|
630
|
+
torch.cuda.empty_cache()
|
631
|
+
try:
|
632
|
+
yield cuda_info
|
633
|
+
finally:
|
634
|
+
cuda_info["memory"] = torch.cuda.memory_reserved(device)
|
635
|
+
else:
|
636
|
+
yield cuda_info
|
506
637
|
|
507
|
-
|
638
|
+
|
639
|
+
def profile(input, ops, n=10, device=None, max_num_obj=0):
|
508
640
|
"""
|
509
641
|
Ultralytics speed, memory and FLOPs profiler.
|
510
642
|
|
@@ -525,7 +657,8 @@ def profile(input, ops, n=10, device=None):
|
|
525
657
|
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
526
658
|
f"{'input':>24s}{'output':>24s}"
|
527
659
|
)
|
528
|
-
|
660
|
+
gc.collect() # attempt to free unused memory
|
661
|
+
torch.cuda.empty_cache()
|
529
662
|
for x in input if isinstance(input, list) else [input]:
|
530
663
|
x = x.to(device)
|
531
664
|
x.requires_grad = True
|
@@ -534,24 +667,36 @@ def profile(input, ops, n=10, device=None):
|
|
534
667
|
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
535
668
|
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
536
669
|
try:
|
537
|
-
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2
|
670
|
+
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs
|
538
671
|
except Exception:
|
539
672
|
flops = 0
|
540
673
|
|
541
674
|
try:
|
675
|
+
mem = 0
|
542
676
|
for _ in range(n):
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
677
|
+
with cuda_memory_usage(device) as cuda_info:
|
678
|
+
t[0] = time_sync()
|
679
|
+
y = m(x)
|
680
|
+
t[1] = time_sync()
|
681
|
+
try:
|
682
|
+
(sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
683
|
+
t[2] = time_sync()
|
684
|
+
except Exception: # no backward method
|
685
|
+
# print(e) # for debug
|
686
|
+
t[2] = float("nan")
|
687
|
+
mem += cuda_info["memory"] / 1e9 # (GB)
|
552
688
|
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
553
689
|
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
554
|
-
|
690
|
+
if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
|
691
|
+
with cuda_memory_usage(device) as cuda_info:
|
692
|
+
torch.randn(
|
693
|
+
x.shape[0],
|
694
|
+
max_num_obj,
|
695
|
+
int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
|
696
|
+
device=device,
|
697
|
+
dtype=torch.float32,
|
698
|
+
)
|
699
|
+
mem += cuda_info["memory"] / 1e9 # (GB)
|
555
700
|
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
|
556
701
|
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
557
702
|
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
|
@@ -559,7 +704,9 @@ def profile(input, ops, n=10, device=None):
|
|
559
704
|
except Exception as e:
|
560
705
|
LOGGER.info(e)
|
561
706
|
results.append(None)
|
562
|
-
|
707
|
+
finally:
|
708
|
+
gc.collect() # attempt to free unused memory
|
709
|
+
torch.cuda.empty_cache()
|
563
710
|
return results
|
564
711
|
|
565
712
|
|
@@ -599,10 +746,56 @@ class EarlyStopping:
|
|
599
746
|
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
600
747
|
stop = delta >= self.patience # stop training if patience exceeded
|
601
748
|
if stop:
|
749
|
+
prefix = colorstr("EarlyStopping: ")
|
602
750
|
LOGGER.info(
|
603
|
-
f"
|
751
|
+
f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
|
604
752
|
f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
|
605
753
|
f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
|
606
754
|
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
|
607
755
|
)
|
608
756
|
return stop
|
757
|
+
|
758
|
+
|
759
|
+
class FXModel(nn.Module):
|
760
|
+
"""
|
761
|
+
A custom model class for torch.fx compatibility.
|
762
|
+
|
763
|
+
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
|
764
|
+
It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
|
765
|
+
|
766
|
+
Args:
|
767
|
+
model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
|
768
|
+
"""
|
769
|
+
|
770
|
+
def __init__(self, model):
|
771
|
+
"""
|
772
|
+
Initialize the FXModel.
|
773
|
+
|
774
|
+
Args:
|
775
|
+
model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
|
776
|
+
"""
|
777
|
+
super().__init__()
|
778
|
+
copy_attr(self, model)
|
779
|
+
# Explicitly set `model` since `copy_attr` somehow does not copy it.
|
780
|
+
self.model = model.model
|
781
|
+
|
782
|
+
def forward(self, x):
|
783
|
+
"""
|
784
|
+
Forward pass through the model.
|
785
|
+
|
786
|
+
This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
|
787
|
+
|
788
|
+
Args:
|
789
|
+
x (torch.Tensor): The input tensor to the model.
|
790
|
+
|
791
|
+
Returns:
|
792
|
+
(torch.Tensor): The output tensor from the model.
|
793
|
+
"""
|
794
|
+
y = [] # outputs
|
795
|
+
for m in self.model:
|
796
|
+
if m.f != -1: # if not from previous layer
|
797
|
+
# from earlier layers
|
798
|
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
|
799
|
+
x = m(x) # run
|
800
|
+
y.append(x) # save output
|
801
|
+
return x
|