dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1028 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1580 -0
- ultralytics/engine/model.py +1125 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +977 -0
- ultralytics/engine/tuner.py +449 -0
- ultralytics/engine/validator.py +387 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +964 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +453 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1020 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +506 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1047 -0
- ultralytics/utils/tal.py +404 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +443 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +168 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
|
|
6
|
+
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def run_ray_tune(
|
|
10
|
+
model,
|
|
11
|
+
space: dict | None = None,
|
|
12
|
+
grace_period: int = 10,
|
|
13
|
+
gpu_per_trial: int | None = None,
|
|
14
|
+
max_samples: int = 10,
|
|
15
|
+
**train_args,
|
|
16
|
+
):
|
|
17
|
+
"""Run hyperparameter tuning using Ray Tune.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model (YOLO): Model to run the tuner on.
|
|
21
|
+
space (dict, optional): The hyperparameter search space. If not provided, uses default space.
|
|
22
|
+
grace_period (int, optional): The grace period in epochs of the ASHA scheduler.
|
|
23
|
+
gpu_per_trial (int, optional): The number of GPUs to allocate per trial.
|
|
24
|
+
max_samples (int, optional): The maximum number of trials to run.
|
|
25
|
+
**train_args (Any): Additional arguments to pass to the `train()` method.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
(ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.
|
|
29
|
+
|
|
30
|
+
Examples:
|
|
31
|
+
>>> from ultralytics import YOLO
|
|
32
|
+
>>> model = YOLO("yolo11n.pt") # Load a YOLO11n model
|
|
33
|
+
|
|
34
|
+
Start tuning hyperparameters for YOLO11n training on the COCO8 dataset
|
|
35
|
+
>>> result_grid = model.tune(data="coco8.yaml", use_ray=True)
|
|
36
|
+
"""
|
|
37
|
+
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
|
|
38
|
+
try:
|
|
39
|
+
checks.check_requirements("ray[tune]")
|
|
40
|
+
|
|
41
|
+
import ray
|
|
42
|
+
from ray import tune
|
|
43
|
+
from ray.air import RunConfig
|
|
44
|
+
from ray.air.integrations.wandb import WandbLoggerCallback
|
|
45
|
+
from ray.tune.schedulers import ASHAScheduler
|
|
46
|
+
except ImportError:
|
|
47
|
+
raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
import wandb
|
|
51
|
+
|
|
52
|
+
assert hasattr(wandb, "__version__")
|
|
53
|
+
except (ImportError, AssertionError):
|
|
54
|
+
wandb = False
|
|
55
|
+
|
|
56
|
+
checks.check_version(ray.__version__, ">=2.0.0", "ray")
|
|
57
|
+
default_space = {
|
|
58
|
+
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
|
59
|
+
"lr0": tune.uniform(1e-5, 1e-1),
|
|
60
|
+
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
|
61
|
+
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
|
62
|
+
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay
|
|
63
|
+
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
|
|
64
|
+
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
|
|
65
|
+
"box": tune.uniform(0.02, 0.2), # box loss gain
|
|
66
|
+
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
|
|
67
|
+
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
|
68
|
+
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
|
69
|
+
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
|
|
70
|
+
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
|
|
71
|
+
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
|
|
72
|
+
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
|
|
73
|
+
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
|
|
74
|
+
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
|
75
|
+
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
|
|
76
|
+
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
|
|
77
|
+
"bgr": tune.uniform(0.0, 1.0), # swap RGB↔BGR channels (probability)
|
|
78
|
+
"mosaic": tune.uniform(0.0, 1.0), # image mosaic (probability)
|
|
79
|
+
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
|
|
80
|
+
"cutmix": tune.uniform(0.0, 1.0), # image cutmix (probability)
|
|
81
|
+
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
# Put the model in ray store
|
|
85
|
+
task = model.task
|
|
86
|
+
model_in_store = ray.put(model)
|
|
87
|
+
base_name = train_args.get("name", "tune")
|
|
88
|
+
|
|
89
|
+
def _tune(config):
|
|
90
|
+
"""Train the YOLO model with the specified hyperparameters and return results."""
|
|
91
|
+
model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
|
|
92
|
+
model_to_train.reset_callbacks()
|
|
93
|
+
config.update(train_args)
|
|
94
|
+
|
|
95
|
+
# Set trial-specific name for W&B logging
|
|
96
|
+
try:
|
|
97
|
+
trial_id = tune.get_trial_id() # Get current trial ID (e.g., "2c2fc_00000")
|
|
98
|
+
trial_suffix = trial_id.split("_")[-1] if "_" in trial_id else trial_id
|
|
99
|
+
config["name"] = f"{base_name}_{trial_suffix}"
|
|
100
|
+
except Exception:
|
|
101
|
+
# Not in Ray Tune context or error getting trial ID, use base name
|
|
102
|
+
config["name"] = base_name
|
|
103
|
+
|
|
104
|
+
results = model_to_train.train(**config)
|
|
105
|
+
return results.results_dict
|
|
106
|
+
|
|
107
|
+
# Get search space
|
|
108
|
+
if not space and not train_args.get("resume"):
|
|
109
|
+
space = default_space
|
|
110
|
+
LOGGER.warning("Search space not provided, using default search space.")
|
|
111
|
+
|
|
112
|
+
# Get dataset
|
|
113
|
+
data = train_args.get("data", TASK2DATA[task])
|
|
114
|
+
space["data"] = data
|
|
115
|
+
if "data" not in train_args:
|
|
116
|
+
LOGGER.warning(f'Data not provided, using default "data={data}".')
|
|
117
|
+
|
|
118
|
+
# Define the trainable function with allocated resources
|
|
119
|
+
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
|
|
120
|
+
|
|
121
|
+
# Define the ASHA scheduler for hyperparameter search
|
|
122
|
+
asha_scheduler = ASHAScheduler(
|
|
123
|
+
time_attr="epoch",
|
|
124
|
+
metric=TASK2METRIC[task],
|
|
125
|
+
mode="max",
|
|
126
|
+
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
|
|
127
|
+
grace_period=grace_period,
|
|
128
|
+
reduction_factor=3,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Define the callbacks for the hyperparameter search
|
|
132
|
+
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
|
|
133
|
+
|
|
134
|
+
# Create the Ray Tune hyperparameter search tuner
|
|
135
|
+
tune_dir = get_save_dir(
|
|
136
|
+
get_cfg(
|
|
137
|
+
DEFAULT_CFG,
|
|
138
|
+
{**train_args, **{"exist_ok": train_args.pop("resume", False)}}, # resume w/ same tune_dir
|
|
139
|
+
),
|
|
140
|
+
name=train_args.pop("name", "tune"), # runs/{task}/{tune_dir}
|
|
141
|
+
) # must be absolute dir
|
|
142
|
+
tune_dir.mkdir(parents=True, exist_ok=True)
|
|
143
|
+
if tune.Tuner.can_restore(tune_dir):
|
|
144
|
+
LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...")
|
|
145
|
+
tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)
|
|
146
|
+
else:
|
|
147
|
+
tuner = tune.Tuner(
|
|
148
|
+
trainable_with_resources,
|
|
149
|
+
param_space=space,
|
|
150
|
+
tune_config=tune.TuneConfig(
|
|
151
|
+
scheduler=asha_scheduler,
|
|
152
|
+
num_samples=max_samples,
|
|
153
|
+
trial_name_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
|
|
154
|
+
trial_dirname_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
|
|
155
|
+
),
|
|
156
|
+
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Run the hyperparameter search
|
|
160
|
+
tuner.fit()
|
|
161
|
+
|
|
162
|
+
# Get the results of the hyperparameter search
|
|
163
|
+
results = tuner.get_results()
|
|
164
|
+
|
|
165
|
+
# Shut down Ray to clean up workers
|
|
166
|
+
ray.shutdown()
|
|
167
|
+
|
|
168
|
+
return results
|