dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/utils/checks.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import ast
|
|
5
6
|
import functools
|
|
6
7
|
import glob
|
|
7
8
|
import inspect
|
|
@@ -11,6 +12,7 @@ import platform
|
|
|
11
12
|
import re
|
|
12
13
|
import shutil
|
|
13
14
|
import subprocess
|
|
15
|
+
import sys
|
|
14
16
|
import time
|
|
15
17
|
from importlib import metadata
|
|
16
18
|
from pathlib import Path
|
|
@@ -23,6 +25,7 @@ import torch
|
|
|
23
25
|
from ultralytics.utils import (
|
|
24
26
|
ARM64,
|
|
25
27
|
ASSETS,
|
|
28
|
+
ASSETS_URL,
|
|
26
29
|
AUTOINSTALL,
|
|
27
30
|
GIT,
|
|
28
31
|
IS_COLAB,
|
|
@@ -52,8 +55,7 @@ from ultralytics.utils import (
|
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
|
55
|
-
"""
|
|
56
|
-
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
|
|
58
|
+
"""Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
|
|
57
59
|
|
|
58
60
|
Args:
|
|
59
61
|
file_path (Path): Path to the requirements.txt file.
|
|
@@ -85,8 +87,7 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
|
|
85
87
|
|
|
86
88
|
@functools.lru_cache
|
|
87
89
|
def parse_version(version="0.0.0") -> tuple:
|
|
88
|
-
"""
|
|
89
|
-
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
|
|
90
|
+
"""Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
|
|
90
91
|
|
|
91
92
|
Args:
|
|
92
93
|
version (str): Version string, i.e. '2.0.1+cpu'
|
|
@@ -102,8 +103,7 @@ def parse_version(version="0.0.0") -> tuple:
|
|
|
102
103
|
|
|
103
104
|
|
|
104
105
|
def is_ascii(s) -> bool:
|
|
105
|
-
"""
|
|
106
|
-
Check if a string is composed of only ASCII characters.
|
|
106
|
+
"""Check if a string is composed of only ASCII characters.
|
|
107
107
|
|
|
108
108
|
Args:
|
|
109
109
|
s (str | list | tuple | dict): Input to be checked (all are converted to string for checking).
|
|
@@ -115,8 +115,7 @@ def is_ascii(s) -> bool:
|
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
|
118
|
-
"""
|
|
119
|
-
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
|
|
118
|
+
"""Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
|
|
120
119
|
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
|
|
121
120
|
|
|
122
121
|
Args:
|
|
@@ -138,7 +137,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
|
|
138
137
|
elif isinstance(imgsz, (list, tuple)):
|
|
139
138
|
imgsz = list(imgsz)
|
|
140
139
|
elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
|
|
141
|
-
imgsz = [int(imgsz)] if imgsz.isnumeric() else
|
|
140
|
+
imgsz = [int(imgsz)] if imgsz.isnumeric() else ast.literal_eval(imgsz)
|
|
142
141
|
else:
|
|
143
142
|
raise TypeError(
|
|
144
143
|
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
|
@@ -186,8 +185,7 @@ def check_version(
|
|
|
186
185
|
verbose: bool = False,
|
|
187
186
|
msg: str = "",
|
|
188
187
|
) -> bool:
|
|
189
|
-
"""
|
|
190
|
-
Check current version against the required version or range.
|
|
188
|
+
"""Check current version against the required version or range.
|
|
191
189
|
|
|
192
190
|
Args:
|
|
193
191
|
current (str): Current version or package name to get version from.
|
|
@@ -267,8 +265,7 @@ def check_version(
|
|
|
267
265
|
|
|
268
266
|
|
|
269
267
|
def check_latest_pypi_version(package_name="ultralytics"):
|
|
270
|
-
"""
|
|
271
|
-
Return the latest version of a PyPI package without downloading or installing it.
|
|
268
|
+
"""Return the latest version of a PyPI package without downloading or installing it.
|
|
272
269
|
|
|
273
270
|
Args:
|
|
274
271
|
package_name (str): The name of the package to find the latest version for.
|
|
@@ -288,8 +285,7 @@ def check_latest_pypi_version(package_name="ultralytics"):
|
|
|
288
285
|
|
|
289
286
|
|
|
290
287
|
def check_pip_update_available():
|
|
291
|
-
"""
|
|
292
|
-
Check if a new version of the ultralytics package is available on PyPI.
|
|
288
|
+
"""Check if a new version of the ultralytics package is available on PyPI.
|
|
293
289
|
|
|
294
290
|
Returns:
|
|
295
291
|
(bool): True if an update is available, False otherwise.
|
|
@@ -313,8 +309,7 @@ def check_pip_update_available():
|
|
|
313
309
|
@ThreadingLocked()
|
|
314
310
|
@functools.lru_cache
|
|
315
311
|
def check_font(font="Arial.ttf"):
|
|
316
|
-
"""
|
|
317
|
-
Find font locally or download to user's configuration directory if it does not already exist.
|
|
312
|
+
"""Find font locally or download to user's configuration directory if it does not already exist.
|
|
318
313
|
|
|
319
314
|
Args:
|
|
320
315
|
font (str): Path or name of font.
|
|
@@ -336,15 +331,14 @@ def check_font(font="Arial.ttf"):
|
|
|
336
331
|
return matches[0]
|
|
337
332
|
|
|
338
333
|
# Download to USER_CONFIG_DIR if missing
|
|
339
|
-
url = f"
|
|
334
|
+
url = f"{ASSETS_URL}/{name}"
|
|
340
335
|
if downloads.is_url(url, check=True):
|
|
341
336
|
downloads.safe_download(url=url, file=file)
|
|
342
337
|
return file
|
|
343
338
|
|
|
344
339
|
|
|
345
340
|
def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool:
|
|
346
|
-
"""
|
|
347
|
-
Check current python version against the required minimum version.
|
|
341
|
+
"""Check current python version against the required minimum version.
|
|
348
342
|
|
|
349
343
|
Args:
|
|
350
344
|
minimum (str): Required minimum version of python.
|
|
@@ -358,13 +352,53 @@ def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = Fals
|
|
|
358
352
|
|
|
359
353
|
|
|
360
354
|
@TryExcept()
|
|
361
|
-
def
|
|
355
|
+
def check_apt_requirements(requirements):
|
|
356
|
+
"""Check if apt packages are installed and install missing ones.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
requirements: List of apt package names to check and install
|
|
362
360
|
"""
|
|
363
|
-
|
|
361
|
+
prefix = colorstr("red", "bold", "apt requirements:")
|
|
362
|
+
# Check which packages are missing
|
|
363
|
+
missing_packages = []
|
|
364
|
+
for package in requirements:
|
|
365
|
+
try:
|
|
366
|
+
# Use dpkg -l to check if package is installed
|
|
367
|
+
result = subprocess.run(["dpkg", "-l", package], capture_output=True, text=True, check=False)
|
|
368
|
+
# Check if package is installed (look for "ii" status)
|
|
369
|
+
if result.returncode != 0 or not any(
|
|
370
|
+
line.startswith("ii") and package in line for line in result.stdout.splitlines()
|
|
371
|
+
):
|
|
372
|
+
missing_packages.append(package)
|
|
373
|
+
except Exception:
|
|
374
|
+
# If check fails, assume package is not installed
|
|
375
|
+
missing_packages.append(package)
|
|
376
|
+
|
|
377
|
+
# Install missing packages if any
|
|
378
|
+
if missing_packages:
|
|
379
|
+
LOGGER.info(
|
|
380
|
+
f"{prefix} Ultralytics requirement{'s' * (len(missing_packages) > 1)} {missing_packages} not found, attempting AutoUpdate..."
|
|
381
|
+
)
|
|
382
|
+
# Optionally update package list first
|
|
383
|
+
cmd = (["sudo"] if is_sudo_available() else []) + ["apt", "update"]
|
|
384
|
+
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
385
|
+
|
|
386
|
+
# Build and run the install command
|
|
387
|
+
cmd = (["sudo"] if is_sudo_available() else []) + ["apt", "install", "-y"] + missing_packages
|
|
388
|
+
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
389
|
+
|
|
390
|
+
LOGGER.info(f"{prefix} AutoUpdate success ✅")
|
|
391
|
+
LOGGER.warning(f"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@TryExcept()
|
|
395
|
+
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
|
|
396
|
+
"""Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.
|
|
364
397
|
|
|
365
398
|
Args:
|
|
366
|
-
requirements (Path | str | list[str]): Path to a requirements.txt file, a single package
|
|
367
|
-
string,
|
|
399
|
+
requirements (Path | str | list[str|tuple] | tuple[str]): Path to a requirements.txt file, a single package
|
|
400
|
+
requirement as a string, a list of package requirements as strings, or a list containing strings and tuples
|
|
401
|
+
of interchangeable packages.
|
|
368
402
|
exclude (tuple): Tuple of package names to exclude from checking.
|
|
369
403
|
install (bool): If True, attempt to auto-update packages that don't meet requirements.
|
|
370
404
|
cmds (str): Additional commands to pass to the pip install command when auto-updating.
|
|
@@ -376,12 +410,20 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|
|
376
410
|
>>> check_requirements("path/to/requirements.txt")
|
|
377
411
|
|
|
378
412
|
Check a single package
|
|
379
|
-
>>> check_requirements("ultralytics>=8.
|
|
413
|
+
>>> check_requirements("ultralytics>=8.3.200", cmds="--index-url https://download.pytorch.org/whl/cpu")
|
|
380
414
|
|
|
381
415
|
Check multiple packages
|
|
382
|
-
>>> check_requirements(["numpy", "ultralytics
|
|
416
|
+
>>> check_requirements(["numpy", "ultralytics"])
|
|
417
|
+
|
|
418
|
+
Check with interchangeable packages
|
|
419
|
+
>>> check_requirements([("onnxruntime", "onnxruntime-gpu"), "numpy"])
|
|
383
420
|
"""
|
|
384
421
|
prefix = colorstr("red", "bold", "requirements:")
|
|
422
|
+
|
|
423
|
+
if os.environ.get("ULTRALYTICS_SKIP_REQUIREMENTS_CHECKS", "0") == "1":
|
|
424
|
+
LOGGER.info(f"{prefix} ULTRALYTICS_SKIP_REQUIREMENTS_CHECKS=1 detected, skipping requirements check.")
|
|
425
|
+
return True
|
|
426
|
+
|
|
385
427
|
if isinstance(requirements, Path): # requirements.txt file
|
|
386
428
|
file = requirements.resolve()
|
|
387
429
|
assert file.exists(), f"{prefix} {file} not found, check failed."
|
|
@@ -391,34 +433,39 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|
|
391
433
|
|
|
392
434
|
pkgs = []
|
|
393
435
|
for r in requirements:
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
436
|
+
candidates = r if isinstance(r, (list, tuple)) else [r]
|
|
437
|
+
satisfied = False
|
|
438
|
+
|
|
439
|
+
for candidate in candidates:
|
|
440
|
+
r_stripped = candidate.rpartition("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
|
|
441
|
+
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
|
|
442
|
+
name, required = match[1], match[2].strip() if match[2] else ""
|
|
443
|
+
try:
|
|
444
|
+
if check_version(metadata.version(name), required):
|
|
445
|
+
satisfied = True
|
|
446
|
+
break
|
|
447
|
+
except (AssertionError, metadata.PackageNotFoundError):
|
|
448
|
+
continue
|
|
449
|
+
|
|
450
|
+
if not satisfied:
|
|
451
|
+
pkgs.append(candidates[0])
|
|
401
452
|
|
|
402
453
|
@Retry(times=2, delay=1)
|
|
403
454
|
def attempt_install(packages, commands, use_uv):
|
|
404
455
|
"""Attempt package installation with uv if available, falling back to pip."""
|
|
405
456
|
if use_uv:
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
457
|
+
# Use --python to explicitly target current interpreter (venv or system)
|
|
458
|
+
# This ensures correct installation when VIRTUAL_ENV env var isn't set
|
|
459
|
+
return subprocess.check_output(
|
|
460
|
+
f'uv pip install --no-cache-dir --python "{sys.executable}" {packages} {commands} '
|
|
461
|
+
f"--index-strategy=unsafe-best-match --break-system-packages",
|
|
462
|
+
shell=True,
|
|
463
|
+
stderr=subprocess.STDOUT,
|
|
464
|
+
text=True,
|
|
409
465
|
)
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
if e.stderr and "No virtual environment found" in e.stderr:
|
|
414
|
-
return subprocess.check_output(
|
|
415
|
-
base.replace("uv pip install", "uv pip install --system"),
|
|
416
|
-
shell=True,
|
|
417
|
-
stderr=subprocess.PIPE,
|
|
418
|
-
text=True,
|
|
419
|
-
)
|
|
420
|
-
raise
|
|
421
|
-
return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True, text=True)
|
|
466
|
+
return subprocess.check_output(
|
|
467
|
+
f"pip install --no-cache-dir {packages} {commands}", shell=True, stderr=subprocess.STDOUT, text=True
|
|
468
|
+
)
|
|
422
469
|
|
|
423
470
|
s = " ".join(f'"{x}"' for x in pkgs) # console string
|
|
424
471
|
if s:
|
|
@@ -429,14 +476,18 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|
|
429
476
|
try:
|
|
430
477
|
t = time.time()
|
|
431
478
|
assert ONLINE, "AutoUpdate skipped (offline)"
|
|
432
|
-
|
|
479
|
+
use_uv = not ARM64 and check_uv() # uv fails on ARM64
|
|
480
|
+
LOGGER.info(attempt_install(s, cmds, use_uv=use_uv))
|
|
433
481
|
dt = time.time() - t
|
|
434
482
|
LOGGER.info(f"{prefix} AutoUpdate success ✅ {dt:.1f}s")
|
|
435
483
|
LOGGER.warning(
|
|
436
484
|
f"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
|
437
485
|
)
|
|
438
486
|
except Exception as e:
|
|
439
|
-
|
|
487
|
+
msg = f"{prefix} ❌ {e}"
|
|
488
|
+
if hasattr(e, "output") and e.output:
|
|
489
|
+
msg += f"\n{e.output}"
|
|
490
|
+
LOGGER.warning(msg)
|
|
440
491
|
return False
|
|
441
492
|
else:
|
|
442
493
|
return False
|
|
@@ -445,8 +496,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|
|
445
496
|
|
|
446
497
|
|
|
447
498
|
def check_torchvision():
|
|
448
|
-
"""
|
|
449
|
-
Check the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
|
499
|
+
"""Check the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
|
450
500
|
|
|
451
501
|
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
|
452
502
|
to the compatibility table based on: https://github.com/pytorch/vision#installation.
|
|
@@ -481,8 +531,7 @@ def check_torchvision():
|
|
|
481
531
|
|
|
482
532
|
|
|
483
533
|
def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
|
|
484
|
-
"""
|
|
485
|
-
Check file(s) for acceptable suffix.
|
|
534
|
+
"""Check file(s) for acceptable suffix.
|
|
486
535
|
|
|
487
536
|
Args:
|
|
488
537
|
file (str | list[str]): File or list of files to check.
|
|
@@ -498,8 +547,7 @@ def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
|
|
|
498
547
|
|
|
499
548
|
|
|
500
549
|
def check_yolov5u_filename(file: str, verbose: bool = True):
|
|
501
|
-
"""
|
|
502
|
-
Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.
|
|
550
|
+
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.
|
|
503
551
|
|
|
504
552
|
Args:
|
|
505
553
|
file (str): Filename to check and potentially update.
|
|
@@ -526,8 +574,7 @@ def check_yolov5u_filename(file: str, verbose: bool = True):
|
|
|
526
574
|
|
|
527
575
|
|
|
528
576
|
def check_model_file_from_stem(model="yolo11n"):
|
|
529
|
-
"""
|
|
530
|
-
Return a model filename from a valid model stem.
|
|
577
|
+
"""Return a model filename from a valid model stem.
|
|
531
578
|
|
|
532
579
|
Args:
|
|
533
580
|
model (str): Model stem to check.
|
|
@@ -542,8 +589,7 @@ def check_model_file_from_stem(model="yolo11n"):
|
|
|
542
589
|
|
|
543
590
|
|
|
544
591
|
def check_file(file, suffix="", download=True, download_dir=".", hard=True):
|
|
545
|
-
"""
|
|
546
|
-
Search/download file (if necessary), check suffix (if provided), and return path.
|
|
592
|
+
"""Search/download file (if necessary), check suffix (if provided), and return path.
|
|
547
593
|
|
|
548
594
|
Args:
|
|
549
595
|
file (str): File name or path.
|
|
@@ -582,8 +628,7 @@ def check_file(file, suffix="", download=True, download_dir=".", hard=True):
|
|
|
582
628
|
|
|
583
629
|
|
|
584
630
|
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
|
585
|
-
"""
|
|
586
|
-
Search/download YAML file (if necessary) and return path, checking suffix.
|
|
631
|
+
"""Search/download YAML file (if necessary) and return path, checking suffix.
|
|
587
632
|
|
|
588
633
|
Args:
|
|
589
634
|
file (str | Path): File name or path.
|
|
@@ -597,8 +642,7 @@ def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
|
|
597
642
|
|
|
598
643
|
|
|
599
644
|
def check_is_path_safe(basedir, path):
|
|
600
|
-
"""
|
|
601
|
-
Check if the resolved path is under the intended directory to prevent path traversal.
|
|
645
|
+
"""Check if the resolved path is under the intended directory to prevent path traversal.
|
|
602
646
|
|
|
603
647
|
Args:
|
|
604
648
|
basedir (Path | str): The intended directory.
|
|
@@ -615,8 +659,7 @@ def check_is_path_safe(basedir, path):
|
|
|
615
659
|
|
|
616
660
|
@functools.lru_cache
|
|
617
661
|
def check_imshow(warn=False):
|
|
618
|
-
"""
|
|
619
|
-
Check if environment supports image displays.
|
|
662
|
+
"""Check if environment supports image displays.
|
|
620
663
|
|
|
621
664
|
Args:
|
|
622
665
|
warn (bool): Whether to warn if environment doesn't support image displays.
|
|
@@ -640,8 +683,7 @@ def check_imshow(warn=False):
|
|
|
640
683
|
|
|
641
684
|
|
|
642
685
|
def check_yolo(verbose=True, device=""):
|
|
643
|
-
"""
|
|
644
|
-
Return a human-readable YOLO software and hardware summary.
|
|
686
|
+
"""Return a human-readable YOLO software and hardware summary.
|
|
645
687
|
|
|
646
688
|
Args:
|
|
647
689
|
verbose (bool): Whether to print verbose information.
|
|
@@ -658,7 +700,7 @@ def check_yolo(verbose=True, device=""):
|
|
|
658
700
|
# System info
|
|
659
701
|
gib = 1 << 30 # bytes per GiB
|
|
660
702
|
ram = psutil.virtual_memory().total
|
|
661
|
-
total,
|
|
703
|
+
total, _used, free = shutil.disk_usage("/")
|
|
662
704
|
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
|
|
663
705
|
try:
|
|
664
706
|
from IPython import display
|
|
@@ -669,13 +711,15 @@ def check_yolo(verbose=True, device=""):
|
|
|
669
711
|
else:
|
|
670
712
|
s = ""
|
|
671
713
|
|
|
714
|
+
if GIT.is_repo:
|
|
715
|
+
check_multiple_install() # check conflicting installation if using local clone
|
|
716
|
+
|
|
672
717
|
select_device(device=device, newline=False)
|
|
673
718
|
LOGGER.info(f"Setup complete ✅ {s}")
|
|
674
719
|
|
|
675
720
|
|
|
676
721
|
def collect_system_info():
|
|
677
|
-
"""
|
|
678
|
-
Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.
|
|
722
|
+
"""Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.
|
|
679
723
|
|
|
680
724
|
Returns:
|
|
681
725
|
(dict): Dictionary containing system information.
|
|
@@ -688,7 +732,7 @@ def collect_system_info():
|
|
|
688
732
|
gib = 1 << 30 # bytes per GiB
|
|
689
733
|
cuda = torch.cuda.is_available()
|
|
690
734
|
check_yolo()
|
|
691
|
-
total,
|
|
735
|
+
total, _used, free = shutil.disk_usage("/")
|
|
692
736
|
|
|
693
737
|
info_dict = {
|
|
694
738
|
"OS": platform.platform(),
|
|
@@ -735,8 +779,7 @@ def collect_system_info():
|
|
|
735
779
|
|
|
736
780
|
|
|
737
781
|
def check_amp(model):
|
|
738
|
-
"""
|
|
739
|
-
Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.
|
|
782
|
+
"""Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.
|
|
740
783
|
|
|
741
784
|
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
|
|
742
785
|
results, so AMP will be disabled during training.
|
|
@@ -807,9 +850,32 @@ def check_amp(model):
|
|
|
807
850
|
return True
|
|
808
851
|
|
|
809
852
|
|
|
853
|
+
def check_multiple_install():
|
|
854
|
+
"""Check if there are multiple Ultralytics installations."""
|
|
855
|
+
import sys
|
|
856
|
+
|
|
857
|
+
try:
|
|
858
|
+
result = subprocess.run([sys.executable, "-m", "pip", "show", "ultralytics"], capture_output=True, text=True)
|
|
859
|
+
install_msg = (
|
|
860
|
+
f"Install your local copy in editable mode with 'pip install -e {ROOT.parent}' to avoid "
|
|
861
|
+
"issues. See https://docs.ultralytics.com/quickstart/"
|
|
862
|
+
)
|
|
863
|
+
if result.returncode != 0:
|
|
864
|
+
if "not found" in result.stderr.lower(): # Package not pip-installed but locally imported
|
|
865
|
+
LOGGER.warning(f"Ultralytics not found via pip but importing from: {ROOT}. {install_msg}")
|
|
866
|
+
return
|
|
867
|
+
yolo_path = (Path(re.findall(r"location:\s+(.+)", result.stdout, flags=re.I)[-1]) / "ultralytics").resolve()
|
|
868
|
+
if not yolo_path.samefile(ROOT.resolve()):
|
|
869
|
+
LOGGER.warning(
|
|
870
|
+
f"Multiple Ultralytics installations detected. The `yolo` command uses: {yolo_path}, "
|
|
871
|
+
f"but current session imports from: {ROOT}. This may cause version conflicts. {install_msg}"
|
|
872
|
+
)
|
|
873
|
+
except Exception:
|
|
874
|
+
return
|
|
875
|
+
|
|
876
|
+
|
|
810
877
|
def print_args(args: dict | None = None, show_file=True, show_func=False):
|
|
811
|
-
"""
|
|
812
|
-
Print function arguments (optional args dict).
|
|
878
|
+
"""Print function arguments (optional args dict).
|
|
813
879
|
|
|
814
880
|
Args:
|
|
815
881
|
args (dict, optional): Arguments to print.
|
|
@@ -835,8 +901,7 @@ def print_args(args: dict | None = None, show_file=True, show_func=False):
|
|
|
835
901
|
|
|
836
902
|
|
|
837
903
|
def cuda_device_count() -> int:
|
|
838
|
-
"""
|
|
839
|
-
Get the number of NVIDIA GPUs available in the environment.
|
|
904
|
+
"""Get the number of NVIDIA GPUs available in the environment.
|
|
840
905
|
|
|
841
906
|
Returns:
|
|
842
907
|
(int): The number of NVIDIA GPUs available.
|
|
@@ -861,8 +926,7 @@ def cuda_device_count() -> int:
|
|
|
861
926
|
|
|
862
927
|
|
|
863
928
|
def cuda_is_available() -> bool:
|
|
864
|
-
"""
|
|
865
|
-
Check if CUDA is available in the environment.
|
|
929
|
+
"""Check if CUDA is available in the environment.
|
|
866
930
|
|
|
867
931
|
Returns:
|
|
868
932
|
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
|
|
@@ -871,8 +935,7 @@ def cuda_is_available() -> bool:
|
|
|
871
935
|
|
|
872
936
|
|
|
873
937
|
def is_rockchip():
|
|
874
|
-
"""
|
|
875
|
-
Check if the current environment is running on a Rockchip SoC.
|
|
938
|
+
"""Check if the current environment is running on a Rockchip SoC.
|
|
876
939
|
|
|
877
940
|
Returns:
|
|
878
941
|
(bool): True if running on a Rockchip SoC, False otherwise.
|
|
@@ -891,8 +954,7 @@ def is_rockchip():
|
|
|
891
954
|
|
|
892
955
|
|
|
893
956
|
def is_intel():
|
|
894
|
-
"""
|
|
895
|
-
Check if the system has Intel hardware (CPU or GPU).
|
|
957
|
+
"""Check if the system has Intel hardware (CPU or GPU).
|
|
896
958
|
|
|
897
959
|
Returns:
|
|
898
960
|
(bool): True if Intel hardware is detected, False otherwise.
|
|
@@ -907,13 +969,12 @@ def is_intel():
|
|
|
907
969
|
try:
|
|
908
970
|
result = subprocess.run(["xpu-smi", "discovery"], capture_output=True, text=True, timeout=5)
|
|
909
971
|
return "intel" in result.stdout.lower()
|
|
910
|
-
except
|
|
972
|
+
except Exception: # broad clause to capture all Intel GPU exception types
|
|
911
973
|
return False
|
|
912
974
|
|
|
913
975
|
|
|
914
976
|
def is_sudo_available() -> bool:
|
|
915
|
-
"""
|
|
916
|
-
Check if the sudo command is available in the environment.
|
|
977
|
+
"""Check if the sudo command is available in the environment.
|
|
917
978
|
|
|
918
979
|
Returns:
|
|
919
980
|
(bool): True if the sudo command is available, False otherwise.
|
|
@@ -930,8 +991,11 @@ check_torchvision() # check torch-torchvision compatibility
|
|
|
930
991
|
|
|
931
992
|
# Define constants
|
|
932
993
|
IS_PYTHON_3_8 = PYTHON_VERSION.startswith("3.8")
|
|
994
|
+
IS_PYTHON_3_9 = PYTHON_VERSION.startswith("3.9")
|
|
995
|
+
IS_PYTHON_3_10 = PYTHON_VERSION.startswith("3.10")
|
|
933
996
|
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")
|
|
934
997
|
IS_PYTHON_3_13 = PYTHON_VERSION.startswith("3.13")
|
|
935
998
|
|
|
999
|
+
IS_PYTHON_MINIMUM_3_9 = check_python("3.9", hard=False)
|
|
936
1000
|
IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
|
|
937
1001
|
IS_PYTHON_MINIMUM_3_12 = check_python("3.12", hard=False)
|
ultralytics/utils/cpu.py
CHANGED
|
@@ -10,8 +10,7 @@ from pathlib import Path
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class CPUInfo:
|
|
13
|
-
"""
|
|
14
|
-
Provide cross-platform CPU brand and model information.
|
|
13
|
+
"""Provide cross-platform CPU brand and model information.
|
|
15
14
|
|
|
16
15
|
Query platform-specific sources to retrieve a human-readable CPU descriptor and normalize it for consistent
|
|
17
16
|
presentation across macOS, Linux, and Windows. If platform-specific probing fails, generic platform identifiers are
|
|
@@ -71,13 +70,9 @@ class CPUInfo:
|
|
|
71
70
|
"""Normalize and prettify a raw CPU descriptor string."""
|
|
72
71
|
s = re.sub(r"\s+", " ", s.strip())
|
|
73
72
|
s = s.replace("(TM)", "").replace("(tm)", "").replace("(R)", "").replace("(r)", "").strip()
|
|
74
|
-
|
|
75
|
-
m = re.search(r"(Intel.*?i\d[\w-]*) CPU @ ([\d.]+GHz)", s, re.I)
|
|
76
|
-
if m:
|
|
73
|
+
if m := re.search(r"(Intel.*?i\d[\w-]*) CPU @ ([\d.]+GHz)", s, re.I):
|
|
77
74
|
return f"{m.group(1)} {m.group(2)}"
|
|
78
|
-
|
|
79
|
-
m = re.search(r"(AMD.*?Ryzen.*?[\w-]*) CPU @ ([\d.]+GHz)", s, re.I)
|
|
80
|
-
if m:
|
|
75
|
+
if m := re.search(r"(AMD.*?Ryzen.*?[\w-]*) CPU @ ([\d.]+GHz)", s, re.I):
|
|
81
76
|
return f"{m.group(1)} {m.group(2)}"
|
|
82
77
|
return s
|
|
83
78
|
|
ultralytics/utils/dist.py
CHANGED
|
@@ -10,8 +10,7 @@ from .torch_utils import TORCH_1_9
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def find_free_network_port() -> int:
|
|
13
|
-
"""
|
|
14
|
-
Find a free port on localhost.
|
|
13
|
+
"""Find a free port on localhost.
|
|
15
14
|
|
|
16
15
|
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
|
17
16
|
`MASTER_PORT` environment variable.
|
|
@@ -27,11 +26,10 @@ def find_free_network_port() -> int:
|
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
def generate_ddp_file(trainer):
|
|
30
|
-
"""
|
|
31
|
-
Generate a DDP (Distributed Data Parallel) file for multi-GPU training.
|
|
29
|
+
"""Generate a DDP (Distributed Data Parallel) file for multi-GPU training.
|
|
32
30
|
|
|
33
|
-
This function creates a temporary Python file that enables distributed training across multiple GPUs.
|
|
34
|
-
|
|
31
|
+
This function creates a temporary Python file that enables distributed training across multiple GPUs. The file
|
|
32
|
+
contains the necessary configuration to initialize the trainer in a distributed environment.
|
|
35
33
|
|
|
36
34
|
Args:
|
|
37
35
|
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.
|
|
@@ -76,12 +74,10 @@ if __name__ == "__main__":
|
|
|
76
74
|
return file.name
|
|
77
75
|
|
|
78
76
|
|
|
79
|
-
def generate_ddp_command(
|
|
80
|
-
"""
|
|
81
|
-
Generate command for distributed training.
|
|
77
|
+
def generate_ddp_command(trainer):
|
|
78
|
+
"""Generate command for distributed training.
|
|
82
79
|
|
|
83
80
|
Args:
|
|
84
|
-
world_size (int): Number of processes to spawn for distributed training.
|
|
85
81
|
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.
|
|
86
82
|
|
|
87
83
|
Returns:
|
|
@@ -95,16 +91,24 @@ def generate_ddp_command(world_size: int, trainer):
|
|
|
95
91
|
file = generate_ddp_file(trainer)
|
|
96
92
|
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
|
97
93
|
port = find_free_network_port()
|
|
98
|
-
cmd = [
|
|
94
|
+
cmd = [
|
|
95
|
+
sys.executable,
|
|
96
|
+
"-m",
|
|
97
|
+
dist_cmd,
|
|
98
|
+
"--nproc_per_node",
|
|
99
|
+
f"{trainer.world_size}",
|
|
100
|
+
"--master_port",
|
|
101
|
+
f"{port}",
|
|
102
|
+
file,
|
|
103
|
+
]
|
|
99
104
|
return cmd, file
|
|
100
105
|
|
|
101
106
|
|
|
102
107
|
def ddp_cleanup(trainer, file):
|
|
103
|
-
"""
|
|
104
|
-
Delete temporary file if created during distributed data parallel (DDP) training.
|
|
108
|
+
"""Delete temporary file if created during distributed data parallel (DDP) training.
|
|
105
109
|
|
|
106
|
-
This function checks if the provided file contains the trainer's ID in its name, indicating it was created
|
|
107
|
-
|
|
110
|
+
This function checks if the provided file contains the trainer's ID in its name, indicating it was created as a
|
|
111
|
+
temporary file for DDP training, and deletes it if so.
|
|
108
112
|
|
|
109
113
|
Args:
|
|
110
114
|
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.
|