dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/utils/benchmarks.py
CHANGED
|
@@ -23,9 +23,13 @@ TensorFlow.js | `tfjs` | yolo11n_web_model/
|
|
|
23
23
|
PaddlePaddle | `paddle` | yolo11n_paddle_model/
|
|
24
24
|
MNN | `mnn` | yolo11n.mnn
|
|
25
25
|
NCNN | `ncnn` | yolo11n_ncnn_model/
|
|
26
|
+
IMX | `imx` | yolo11n_imx_model/
|
|
26
27
|
RKNN | `rknn` | yolo11n_rknn_model/
|
|
28
|
+
ExecuTorch | `executorch` | yolo11n_executorch_model/
|
|
27
29
|
"""
|
|
28
30
|
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
29
33
|
import glob
|
|
30
34
|
import os
|
|
31
35
|
import platform
|
|
@@ -40,7 +44,7 @@ import torch.cuda
|
|
|
40
44
|
from ultralytics import YOLO, YOLOWorld
|
|
41
45
|
from ultralytics.cfg import TASK2DATA, TASK2METRIC
|
|
42
46
|
from ultralytics.engine.exporter import export_formats
|
|
43
|
-
from ultralytics.utils import ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML
|
|
47
|
+
from ultralytics.utils import ARM64, ASSETS, ASSETS_URL, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML
|
|
44
48
|
from ultralytics.utils.checks import IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip
|
|
45
49
|
from ultralytics.utils.downloads import safe_download
|
|
46
50
|
from ultralytics.utils.files import file_size
|
|
@@ -57,9 +61,9 @@ def benchmark(
|
|
|
57
61
|
verbose=False,
|
|
58
62
|
eps=1e-3,
|
|
59
63
|
format="",
|
|
64
|
+
**kwargs,
|
|
60
65
|
):
|
|
61
|
-
"""
|
|
62
|
-
Benchmark a YOLO model across different formats for speed and accuracy.
|
|
66
|
+
"""Benchmark a YOLO model across different formats for speed and accuracy.
|
|
63
67
|
|
|
64
68
|
Args:
|
|
65
69
|
model (str | Path): Path to the model file or directory.
|
|
@@ -71,10 +75,11 @@ def benchmark(
|
|
|
71
75
|
verbose (bool | float): If True or a float, assert benchmarks pass with given metric.
|
|
72
76
|
eps (float): Epsilon value for divide by zero prevention.
|
|
73
77
|
format (str): Export format for benchmarking. If not supplied all formats are benchmarked.
|
|
78
|
+
**kwargs (Any): Additional keyword arguments for exporter.
|
|
74
79
|
|
|
75
80
|
Returns:
|
|
76
|
-
(
|
|
77
|
-
|
|
81
|
+
(polars.DataFrame): A polars DataFrame with benchmark results for each format, including file size, metric, and
|
|
82
|
+
inference time.
|
|
78
83
|
|
|
79
84
|
Examples:
|
|
80
85
|
Benchmark a YOLO model with default settings:
|
|
@@ -84,10 +89,15 @@ def benchmark(
|
|
|
84
89
|
imgsz = check_imgsz(imgsz)
|
|
85
90
|
assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, "benchmark() only supports square imgsz."
|
|
86
91
|
|
|
87
|
-
import
|
|
92
|
+
import polars as pl # scope for faster 'import ultralytics'
|
|
93
|
+
|
|
94
|
+
pl.Config.set_tbl_cols(-1) # Show all columns
|
|
95
|
+
pl.Config.set_tbl_rows(-1) # Show all rows
|
|
96
|
+
pl.Config.set_tbl_width_chars(-1) # No width limit
|
|
97
|
+
pl.Config.set_tbl_hide_column_data_types(True) # Hide data types
|
|
98
|
+
pl.Config.set_tbl_hide_dataframe_shape(True) # Hide shape info
|
|
99
|
+
pl.Config.set_tbl_formatting("ASCII_BORDERS_ONLY_CONDENSED")
|
|
88
100
|
|
|
89
|
-
pd.options.display.max_columns = 10
|
|
90
|
-
pd.options.display.width = 120
|
|
91
101
|
device = select_device(device, verbose=False)
|
|
92
102
|
if isinstance(model, (str, Path)):
|
|
93
103
|
model = YOLO(model)
|
|
@@ -102,45 +112,50 @@ def benchmark(
|
|
|
102
112
|
if format_arg:
|
|
103
113
|
formats = frozenset(export_formats()["Argument"])
|
|
104
114
|
assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'."
|
|
105
|
-
for
|
|
115
|
+
for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):
|
|
106
116
|
emoji, filename = "❌", None # export defaults
|
|
107
117
|
try:
|
|
108
118
|
if format_arg and format_arg != format:
|
|
109
119
|
continue
|
|
110
120
|
|
|
111
121
|
# Checks
|
|
112
|
-
if
|
|
122
|
+
if format == "pb":
|
|
113
123
|
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
|
|
114
|
-
elif
|
|
124
|
+
elif format == "edgetpu":
|
|
115
125
|
assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
|
|
116
|
-
elif
|
|
126
|
+
elif format in {"coreml", "tfjs"}:
|
|
117
127
|
assert MACOS or (LINUX and not ARM64), (
|
|
118
128
|
"CoreML and TF.js export only supported on macOS and non-aarch64 Linux"
|
|
119
129
|
)
|
|
120
|
-
if
|
|
130
|
+
if format == "coreml":
|
|
121
131
|
assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13"
|
|
122
|
-
if
|
|
132
|
+
if format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:
|
|
123
133
|
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
|
|
124
134
|
# assert not IS_PYTHON_MINIMUM_3_12, "TFLite exports not supported on Python>=3.12 yet"
|
|
125
|
-
if
|
|
135
|
+
if format == "paddle":
|
|
126
136
|
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
|
|
127
137
|
assert model.task != "obb", "Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024"
|
|
128
138
|
assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
|
|
129
139
|
assert (LINUX and not IS_JETSON) or MACOS, "Windows and Jetson Paddle exports not supported yet"
|
|
130
|
-
if
|
|
140
|
+
if format == "mnn":
|
|
131
141
|
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
|
|
132
|
-
if
|
|
142
|
+
if format == "ncnn":
|
|
133
143
|
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
|
|
134
|
-
if
|
|
144
|
+
if format == "imx":
|
|
135
145
|
assert not is_end2end
|
|
136
146
|
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
|
|
137
|
-
assert model.task
|
|
138
|
-
|
|
139
|
-
|
|
147
|
+
assert model.task in {"detect", "classify", "pose"}, (
|
|
148
|
+
"IMX export is only supported for detection, classification and pose estimation tasks"
|
|
149
|
+
)
|
|
150
|
+
assert "C2f" in model.__str__(), "IMX only supported for YOLOv8n and YOLO11n"
|
|
151
|
+
if format == "rknn":
|
|
140
152
|
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"
|
|
141
153
|
assert not is_end2end, "End-to-end models not supported by RKNN yet"
|
|
142
154
|
assert LINUX, "RKNN only supported on Linux"
|
|
143
155
|
assert not is_rockchip(), "RKNN Inference only supported on Rockchip devices"
|
|
156
|
+
if format == "executorch":
|
|
157
|
+
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 ExecuTorch exports not supported yet"
|
|
158
|
+
assert not is_end2end, "End-to-end models not supported by ExecuTorch yet"
|
|
144
159
|
if "cpu" in device.type:
|
|
145
160
|
assert cpu, "inference not supported on CPU"
|
|
146
161
|
if "cuda" in device.type:
|
|
@@ -152,23 +167,32 @@ def benchmark(
|
|
|
152
167
|
exported_model = model # PyTorch format
|
|
153
168
|
else:
|
|
154
169
|
filename = model.export(
|
|
155
|
-
imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False
|
|
170
|
+
imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False, **kwargs
|
|
156
171
|
)
|
|
157
172
|
exported_model = YOLO(filename, task=model.task)
|
|
158
173
|
assert suffix in str(filename), "export failed"
|
|
159
174
|
emoji = "❎" # indicates export succeeded
|
|
160
175
|
|
|
161
176
|
# Predict
|
|
162
|
-
assert model.task != "pose" or
|
|
163
|
-
assert
|
|
164
|
-
assert
|
|
165
|
-
|
|
177
|
+
assert model.task != "pose" or format != "pb", "GraphDef Pose inference is not supported"
|
|
178
|
+
assert model.task != "pose" or format != "executorch", "ExecuTorch Pose inference is not supported"
|
|
179
|
+
assert format not in {"edgetpu", "tfjs"}, "inference not supported"
|
|
180
|
+
assert format != "coreml" or platform.system() == "Darwin", "inference only supported on macOS>=10.13"
|
|
181
|
+
if format == "ncnn":
|
|
166
182
|
assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
|
|
167
183
|
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False)
|
|
168
184
|
|
|
169
185
|
# Validate
|
|
170
186
|
results = exported_model.val(
|
|
171
|
-
data=data,
|
|
187
|
+
data=data,
|
|
188
|
+
batch=1,
|
|
189
|
+
imgsz=imgsz,
|
|
190
|
+
plots=False,
|
|
191
|
+
device=device,
|
|
192
|
+
half=half,
|
|
193
|
+
int8=int8,
|
|
194
|
+
verbose=False,
|
|
195
|
+
conf=0.001, # all the pre-set benchmark mAP values are based on conf=0.001
|
|
172
196
|
)
|
|
173
197
|
metric, speed = results.results_dict[key], results.speed["inference"]
|
|
174
198
|
fps = round(1000 / (speed + eps), 2) # frames per second
|
|
@@ -181,35 +205,36 @@ def benchmark(
|
|
|
181
205
|
|
|
182
206
|
# Print results
|
|
183
207
|
check_yolo(device=device) # print system info
|
|
184
|
-
df =
|
|
208
|
+
df = pl.DataFrame(y, schema=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"], orient="row")
|
|
209
|
+
df = df.with_row_index(" ", offset=1) # add index info
|
|
210
|
+
df_display = df.with_columns(pl.all().cast(pl.String).fill_null("-"))
|
|
185
211
|
|
|
186
212
|
name = model.model_name
|
|
187
213
|
dt = time.time() - t0
|
|
188
214
|
legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed"
|
|
189
|
-
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{
|
|
215
|
+
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df_display}\n"
|
|
190
216
|
LOGGER.info(s)
|
|
191
217
|
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
|
|
192
218
|
f.write(s)
|
|
193
219
|
|
|
194
220
|
if verbose and isinstance(verbose, float):
|
|
195
|
-
metrics = df[key].
|
|
221
|
+
metrics = df[key].to_numpy() # values to compare to floor
|
|
196
222
|
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
|
197
|
-
assert all(x > floor for x in metrics if
|
|
223
|
+
assert all(x > floor for x in metrics if not np.isnan(x)), f"Benchmark failure: metric(s) < floor {floor}"
|
|
198
224
|
|
|
199
|
-
return
|
|
225
|
+
return df_display
|
|
200
226
|
|
|
201
227
|
|
|
202
228
|
class RF100Benchmark:
|
|
203
|
-
"""
|
|
204
|
-
Benchmark YOLO model performance across various formats for speed and accuracy.
|
|
229
|
+
"""Benchmark YOLO model performance across various formats for speed and accuracy.
|
|
205
230
|
|
|
206
231
|
This class provides functionality to benchmark YOLO models on the RF100 dataset collection.
|
|
207
232
|
|
|
208
233
|
Attributes:
|
|
209
|
-
ds_names (
|
|
210
|
-
ds_cfg_list (
|
|
234
|
+
ds_names (list[str]): Names of datasets used for benchmarking.
|
|
235
|
+
ds_cfg_list (list[Path]): List of paths to dataset configuration files.
|
|
211
236
|
rf (Roboflow): Roboflow instance for accessing datasets.
|
|
212
|
-
val_metrics (
|
|
237
|
+
val_metrics (list[str]): Metrics used for validation.
|
|
213
238
|
|
|
214
239
|
Methods:
|
|
215
240
|
set_key: Set Roboflow API key for accessing datasets.
|
|
@@ -225,9 +250,8 @@ class RF100Benchmark:
|
|
|
225
250
|
self.rf = None
|
|
226
251
|
self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"]
|
|
227
252
|
|
|
228
|
-
def set_key(self, api_key):
|
|
229
|
-
"""
|
|
230
|
-
Set Roboflow API key for processing.
|
|
253
|
+
def set_key(self, api_key: str):
|
|
254
|
+
"""Set Roboflow API key for processing.
|
|
231
255
|
|
|
232
256
|
Args:
|
|
233
257
|
api_key (str): The API key.
|
|
@@ -242,16 +266,15 @@ class RF100Benchmark:
|
|
|
242
266
|
|
|
243
267
|
self.rf = Roboflow(api_key=api_key)
|
|
244
268
|
|
|
245
|
-
def parse_dataset(self, ds_link_txt="datasets_links.txt"):
|
|
246
|
-
"""
|
|
247
|
-
Parse dataset links and download datasets.
|
|
269
|
+
def parse_dataset(self, ds_link_txt: str = "datasets_links.txt"):
|
|
270
|
+
"""Parse dataset links and download datasets.
|
|
248
271
|
|
|
249
272
|
Args:
|
|
250
273
|
ds_link_txt (str): Path to the file containing dataset links.
|
|
251
274
|
|
|
252
275
|
Returns:
|
|
253
|
-
ds_names (
|
|
254
|
-
ds_cfg_list (
|
|
276
|
+
ds_names (list[str]): List of dataset names.
|
|
277
|
+
ds_cfg_list (list[Path]): List of paths to dataset configuration files.
|
|
255
278
|
|
|
256
279
|
Examples:
|
|
257
280
|
>>> benchmark = RF100Benchmark()
|
|
@@ -261,12 +284,12 @@ class RF100Benchmark:
|
|
|
261
284
|
(shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
|
|
262
285
|
os.chdir("rf-100")
|
|
263
286
|
os.mkdir("ultralytics-benchmarks")
|
|
264
|
-
safe_download("
|
|
287
|
+
safe_download(f"{ASSETS_URL}/datasets_links.txt")
|
|
265
288
|
|
|
266
289
|
with open(ds_link_txt, encoding="utf-8") as file:
|
|
267
290
|
for line in file:
|
|
268
291
|
try:
|
|
269
|
-
_,
|
|
292
|
+
_, _url, workspace, project, version = re.split("/+", line.strip())
|
|
270
293
|
self.ds_names.append(project)
|
|
271
294
|
proj_version = f"{project}-{version}"
|
|
272
295
|
if not Path(proj_version).exists():
|
|
@@ -280,16 +303,15 @@ class RF100Benchmark:
|
|
|
280
303
|
return self.ds_names, self.ds_cfg_list
|
|
281
304
|
|
|
282
305
|
@staticmethod
|
|
283
|
-
def fix_yaml(path):
|
|
306
|
+
def fix_yaml(path: Path):
|
|
284
307
|
"""Fix the train and validation paths in a given YAML file."""
|
|
285
308
|
yaml_data = YAML.load(path)
|
|
286
309
|
yaml_data["train"] = "train/images"
|
|
287
310
|
yaml_data["val"] = "valid/images"
|
|
288
311
|
YAML.dump(yaml_data, path)
|
|
289
312
|
|
|
290
|
-
def evaluate(self, yaml_path, val_log_file, eval_log_file, list_ind):
|
|
291
|
-
"""
|
|
292
|
-
Evaluate model performance on validation results.
|
|
313
|
+
def evaluate(self, yaml_path: str, val_log_file: str, eval_log_file: str, list_ind: int):
|
|
314
|
+
"""Evaluate model performance on validation results.
|
|
293
315
|
|
|
294
316
|
Args:
|
|
295
317
|
yaml_path (str): Path to the YAML configuration file.
|
|
@@ -337,20 +359,21 @@ class RF100Benchmark:
|
|
|
337
359
|
map_val = lst["map50"]
|
|
338
360
|
else:
|
|
339
361
|
LOGGER.info("Single dict found")
|
|
340
|
-
map_val =
|
|
362
|
+
map_val = next(res["map50"] for res in eval_lines)
|
|
341
363
|
|
|
342
364
|
with open(eval_log_file, "a", encoding="utf-8") as f:
|
|
343
365
|
f.write(f"{self.ds_names[list_ind]}: {map_val}\n")
|
|
344
366
|
|
|
367
|
+
return float(map_val)
|
|
368
|
+
|
|
345
369
|
|
|
346
370
|
class ProfileModels:
|
|
347
|
-
"""
|
|
348
|
-
ProfileModels class for profiling different models on ONNX and TensorRT.
|
|
371
|
+
"""ProfileModels class for profiling different models on ONNX and TensorRT.
|
|
349
372
|
|
|
350
373
|
This class profiles the performance of different models, returning results such as model speed and FLOPs.
|
|
351
374
|
|
|
352
375
|
Attributes:
|
|
353
|
-
paths (
|
|
376
|
+
paths (list[str]): Paths of the models to profile.
|
|
354
377
|
num_timed_runs (int): Number of timed runs for the profiling.
|
|
355
378
|
num_warmup_runs (int): Number of warmup runs before profiling.
|
|
356
379
|
min_time (float): Minimum number of seconds to profile for.
|
|
@@ -360,15 +383,15 @@ class ProfileModels:
|
|
|
360
383
|
device (torch.device): Device used for profiling.
|
|
361
384
|
|
|
362
385
|
Methods:
|
|
363
|
-
|
|
364
|
-
get_files:
|
|
365
|
-
get_onnx_model_info:
|
|
366
|
-
iterative_sigma_clipping:
|
|
367
|
-
profile_tensorrt_model:
|
|
368
|
-
profile_onnx_model:
|
|
369
|
-
generate_table_row:
|
|
370
|
-
generate_results_dict:
|
|
371
|
-
print_table:
|
|
386
|
+
run: Profile YOLO models for speed and accuracy across various formats.
|
|
387
|
+
get_files: Get all relevant model files.
|
|
388
|
+
get_onnx_model_info: Extract metadata from an ONNX model.
|
|
389
|
+
iterative_sigma_clipping: Apply sigma clipping to remove outliers.
|
|
390
|
+
profile_tensorrt_model: Profile a TensorRT model.
|
|
391
|
+
profile_onnx_model: Profile an ONNX model.
|
|
392
|
+
generate_table_row: Generate a table row with model metrics.
|
|
393
|
+
generate_results_dict: Generate a dictionary of profiling results.
|
|
394
|
+
print_table: Print a formatted table of results.
|
|
372
395
|
|
|
373
396
|
Examples:
|
|
374
397
|
Profile models and print results
|
|
@@ -379,20 +402,19 @@ class ProfileModels:
|
|
|
379
402
|
|
|
380
403
|
def __init__(
|
|
381
404
|
self,
|
|
382
|
-
paths: list,
|
|
383
|
-
num_timed_runs=100,
|
|
384
|
-
num_warmup_runs=10,
|
|
385
|
-
min_time=60,
|
|
386
|
-
imgsz=640,
|
|
387
|
-
half=True,
|
|
388
|
-
trt=True,
|
|
389
|
-
device=None,
|
|
405
|
+
paths: list[str],
|
|
406
|
+
num_timed_runs: int = 100,
|
|
407
|
+
num_warmup_runs: int = 10,
|
|
408
|
+
min_time: float = 60,
|
|
409
|
+
imgsz: int = 640,
|
|
410
|
+
half: bool = True,
|
|
411
|
+
trt: bool = True,
|
|
412
|
+
device: torch.device | str | None = None,
|
|
390
413
|
):
|
|
391
|
-
"""
|
|
392
|
-
Initialize the ProfileModels class for profiling models.
|
|
414
|
+
"""Initialize the ProfileModels class for profiling models.
|
|
393
415
|
|
|
394
416
|
Args:
|
|
395
|
-
paths (
|
|
417
|
+
paths (list[str]): List of paths of the models to be profiled.
|
|
396
418
|
num_timed_runs (int): Number of timed runs for the profiling.
|
|
397
419
|
num_warmup_runs (int): Number of warmup runs before the actual profiling starts.
|
|
398
420
|
min_time (float): Minimum time in seconds for profiling a model.
|
|
@@ -401,14 +423,14 @@ class ProfileModels:
|
|
|
401
423
|
trt (bool): Flag to indicate whether to profile using TensorRT.
|
|
402
424
|
device (torch.device | str | None): Device used for profiling. If None, it is determined automatically.
|
|
403
425
|
|
|
404
|
-
Notes:
|
|
405
|
-
FP16 'half' argument option removed for ONNX as slower on CPU than FP32.
|
|
406
|
-
|
|
407
426
|
Examples:
|
|
408
427
|
Initialize and profile models
|
|
409
428
|
>>> from ultralytics.utils.benchmarks import ProfileModels
|
|
410
429
|
>>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640)
|
|
411
430
|
>>> profiler.run()
|
|
431
|
+
|
|
432
|
+
Notes:
|
|
433
|
+
FP16 'half' argument option removed for ONNX as slower on CPU than FP32.
|
|
412
434
|
"""
|
|
413
435
|
self.paths = paths
|
|
414
436
|
self.num_timed_runs = num_timed_runs
|
|
@@ -420,11 +442,10 @@ class ProfileModels:
|
|
|
420
442
|
self.device = device if isinstance(device, torch.device) else select_device(device)
|
|
421
443
|
|
|
422
444
|
def run(self):
|
|
423
|
-
"""
|
|
424
|
-
Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT.
|
|
445
|
+
"""Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT.
|
|
425
446
|
|
|
426
447
|
Returns:
|
|
427
|
-
(
|
|
448
|
+
(list[dict]): List of dictionaries containing profiling results for each model.
|
|
428
449
|
|
|
429
450
|
Examples:
|
|
430
451
|
Profile models and print results
|
|
@@ -436,7 +457,7 @@ class ProfileModels:
|
|
|
436
457
|
|
|
437
458
|
if not files:
|
|
438
459
|
LOGGER.warning("No matching *.pt or *.onnx files found.")
|
|
439
|
-
return
|
|
460
|
+
return []
|
|
440
461
|
|
|
441
462
|
table_rows = []
|
|
442
463
|
output = []
|
|
@@ -475,11 +496,10 @@ class ProfileModels:
|
|
|
475
496
|
return output
|
|
476
497
|
|
|
477
498
|
def get_files(self):
|
|
478
|
-
"""
|
|
479
|
-
Return a list of paths for all relevant model files given by the user.
|
|
499
|
+
"""Return a list of paths for all relevant model files given by the user.
|
|
480
500
|
|
|
481
501
|
Returns:
|
|
482
|
-
(
|
|
502
|
+
(list[Path]): List of Path objects for the model files.
|
|
483
503
|
"""
|
|
484
504
|
files = []
|
|
485
505
|
for path in self.paths:
|
|
@@ -497,21 +517,20 @@ class ProfileModels:
|
|
|
497
517
|
|
|
498
518
|
@staticmethod
|
|
499
519
|
def get_onnx_model_info(onnx_file: str):
|
|
500
|
-
"""
|
|
520
|
+
"""Extract metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
|
|
501
521
|
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
|
|
502
522
|
|
|
503
523
|
@staticmethod
|
|
504
|
-
def iterative_sigma_clipping(data, sigma=2, max_iters=3):
|
|
505
|
-
"""
|
|
506
|
-
Apply iterative sigma clipping to data to remove outliers.
|
|
524
|
+
def iterative_sigma_clipping(data: np.ndarray, sigma: float = 2, max_iters: int = 3):
|
|
525
|
+
"""Apply iterative sigma clipping to data to remove outliers.
|
|
507
526
|
|
|
508
527
|
Args:
|
|
509
|
-
data (
|
|
528
|
+
data (np.ndarray): Input data array.
|
|
510
529
|
sigma (float): Number of standard deviations to use for clipping.
|
|
511
530
|
max_iters (int): Maximum number of iterations for the clipping process.
|
|
512
531
|
|
|
513
532
|
Returns:
|
|
514
|
-
(
|
|
533
|
+
(np.ndarray): Clipped data array with outliers removed.
|
|
515
534
|
"""
|
|
516
535
|
data = np.array(data)
|
|
517
536
|
for _ in range(max_iters):
|
|
@@ -523,8 +542,7 @@ class ProfileModels:
|
|
|
523
542
|
return data
|
|
524
543
|
|
|
525
544
|
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
|
|
526
|
-
"""
|
|
527
|
-
Profile YOLO model performance with TensorRT, measuring average run time and standard deviation.
|
|
545
|
+
"""Profile YOLO model performance with TensorRT, measuring average run time and standard deviation.
|
|
528
546
|
|
|
529
547
|
Args:
|
|
530
548
|
engine_file (str): Path to the TensorRT engine file.
|
|
@@ -561,9 +579,13 @@ class ProfileModels:
|
|
|
561
579
|
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
|
562
580
|
return np.mean(run_times), np.std(run_times)
|
|
563
581
|
|
|
582
|
+
@staticmethod
|
|
583
|
+
def check_dynamic(tensor_shape):
|
|
584
|
+
"""Check whether the tensor shape in the ONNX model is dynamic."""
|
|
585
|
+
return not all(isinstance(dim, int) and dim >= 0 for dim in tensor_shape)
|
|
586
|
+
|
|
564
587
|
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
|
|
565
|
-
"""
|
|
566
|
-
Profile an ONNX model, measuring average inference time and standard deviation across multiple runs.
|
|
588
|
+
"""Profile an ONNX model, measuring average inference time and standard deviation across multiple runs.
|
|
567
589
|
|
|
568
590
|
Args:
|
|
569
591
|
onnx_file (str): Path to the ONNX model file.
|
|
@@ -573,7 +595,7 @@ class ProfileModels:
|
|
|
573
595
|
mean_time (float): Mean inference time in milliseconds.
|
|
574
596
|
std_time (float): Standard deviation of inference time in milliseconds.
|
|
575
597
|
"""
|
|
576
|
-
check_requirements("onnxruntime")
|
|
598
|
+
check_requirements([("onnxruntime", "onnxruntime-gpu")]) # either package meets requirements
|
|
577
599
|
import onnxruntime as ort
|
|
578
600
|
|
|
579
601
|
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
|
@@ -582,27 +604,36 @@ class ProfileModels:
|
|
|
582
604
|
sess_options.intra_op_num_threads = 8 # Limit the number of threads
|
|
583
605
|
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
|
|
584
606
|
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
607
|
+
input_data_dict = dict()
|
|
608
|
+
for input_tensor in sess.get_inputs():
|
|
609
|
+
input_type = input_tensor.type
|
|
610
|
+
if self.check_dynamic(input_tensor.shape):
|
|
611
|
+
if len(input_tensor.shape) != 4 and self.check_dynamic(input_tensor.shape[1:]):
|
|
612
|
+
raise ValueError(f"Unsupported dynamic shape {input_tensor.shape} of {input_tensor.name}")
|
|
613
|
+
input_shape = (
|
|
614
|
+
(1, 3, self.imgsz, self.imgsz) if len(input_tensor.shape) == 4 else (1, *input_tensor.shape[1:])
|
|
615
|
+
)
|
|
616
|
+
else:
|
|
617
|
+
input_shape = input_tensor.shape
|
|
618
|
+
|
|
619
|
+
# Mapping ONNX datatype to numpy datatype
|
|
620
|
+
if "float16" in input_type:
|
|
621
|
+
input_dtype = np.float16
|
|
622
|
+
elif "float" in input_type:
|
|
623
|
+
input_dtype = np.float32
|
|
624
|
+
elif "double" in input_type:
|
|
625
|
+
input_dtype = np.float64
|
|
626
|
+
elif "int64" in input_type:
|
|
627
|
+
input_dtype = np.int64
|
|
628
|
+
elif "int32" in input_type:
|
|
629
|
+
input_dtype = np.int32
|
|
630
|
+
else:
|
|
631
|
+
raise ValueError(f"Unsupported ONNX datatype {input_type}")
|
|
632
|
+
|
|
633
|
+
input_data = np.random.rand(*input_shape).astype(input_dtype)
|
|
634
|
+
input_name = input_tensor.name
|
|
635
|
+
input_data_dict.update({input_name: input_data})
|
|
603
636
|
|
|
604
|
-
input_data = np.random.rand(*input_shape).astype(input_dtype)
|
|
605
|
-
input_name = input_tensor.name
|
|
606
637
|
output_name = sess.get_outputs()[0].name
|
|
607
638
|
|
|
608
639
|
# Warmup runs
|
|
@@ -610,7 +641,7 @@ class ProfileModels:
|
|
|
610
641
|
for _ in range(3):
|
|
611
642
|
start_time = time.time()
|
|
612
643
|
for _ in range(self.num_warmup_runs):
|
|
613
|
-
sess.run([output_name],
|
|
644
|
+
sess.run([output_name], input_data_dict)
|
|
614
645
|
elapsed = time.time() - start_time
|
|
615
646
|
|
|
616
647
|
# Compute number of runs as higher of min_time or num_timed_runs
|
|
@@ -620,15 +651,20 @@ class ProfileModels:
|
|
|
620
651
|
run_times = []
|
|
621
652
|
for _ in TQDM(range(num_runs), desc=onnx_file):
|
|
622
653
|
start_time = time.time()
|
|
623
|
-
sess.run([output_name],
|
|
654
|
+
sess.run([output_name], input_data_dict)
|
|
624
655
|
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
|
|
625
656
|
|
|
626
657
|
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping
|
|
627
658
|
return np.mean(run_times), np.std(run_times)
|
|
628
659
|
|
|
629
|
-
def generate_table_row(
|
|
630
|
-
|
|
631
|
-
|
|
660
|
+
def generate_table_row(
|
|
661
|
+
self,
|
|
662
|
+
model_name: str,
|
|
663
|
+
t_onnx: tuple[float, float],
|
|
664
|
+
t_engine: tuple[float, float],
|
|
665
|
+
model_info: tuple[float, float, float, float],
|
|
666
|
+
):
|
|
667
|
+
"""Generate a table row string with model performance metrics.
|
|
632
668
|
|
|
633
669
|
Args:
|
|
634
670
|
model_name (str): Name of the model.
|
|
@@ -639,16 +675,20 @@ class ProfileModels:
|
|
|
639
675
|
Returns:
|
|
640
676
|
(str): Formatted table row string with model metrics.
|
|
641
677
|
"""
|
|
642
|
-
|
|
678
|
+
_layers, params, _gradients, flops = model_info
|
|
643
679
|
return (
|
|
644
680
|
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±"
|
|
645
681
|
f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |"
|
|
646
682
|
)
|
|
647
683
|
|
|
648
684
|
@staticmethod
|
|
649
|
-
def generate_results_dict(
|
|
650
|
-
|
|
651
|
-
|
|
685
|
+
def generate_results_dict(
|
|
686
|
+
model_name: str,
|
|
687
|
+
t_onnx: tuple[float, float],
|
|
688
|
+
t_engine: tuple[float, float],
|
|
689
|
+
model_info: tuple[float, float, float, float],
|
|
690
|
+
):
|
|
691
|
+
"""Generate a dictionary of profiling results.
|
|
652
692
|
|
|
653
693
|
Args:
|
|
654
694
|
model_name (str): Name of the model.
|
|
@@ -659,7 +699,7 @@ class ProfileModels:
|
|
|
659
699
|
Returns:
|
|
660
700
|
(dict): Dictionary containing profiling results.
|
|
661
701
|
"""
|
|
662
|
-
|
|
702
|
+
_layers, params, _gradients, flops = model_info
|
|
663
703
|
return {
|
|
664
704
|
"model/name": model_name,
|
|
665
705
|
"model/parameters": params,
|
|
@@ -669,12 +709,11 @@ class ProfileModels:
|
|
|
669
709
|
}
|
|
670
710
|
|
|
671
711
|
@staticmethod
|
|
672
|
-
def print_table(table_rows):
|
|
673
|
-
"""
|
|
674
|
-
Print a formatted table of model profiling results.
|
|
712
|
+
def print_table(table_rows: list[str]):
|
|
713
|
+
"""Print a formatted table of model profiling results.
|
|
675
714
|
|
|
676
715
|
Args:
|
|
677
|
-
table_rows (
|
|
716
|
+
table_rows (list[str]): List of formatted table row strings.
|
|
678
717
|
"""
|
|
679
718
|
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
|
|
680
719
|
headers = [
|