dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1028 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1580 -0
- ultralytics/engine/model.py +1125 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +977 -0
- ultralytics/engine/tuner.py +449 -0
- ultralytics/engine/validator.py +387 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +964 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +453 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1020 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +506 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1047 -0
- ultralytics/utils/tal.py +404 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +443 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +168 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
"""
|
|
3
|
+
Check a model's accuracy on a test or val split of a dataset.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
$ yolo mode=val model=yolo11n.pt data=coco8.yaml imgsz=640
|
|
7
|
+
|
|
8
|
+
Usage - formats:
|
|
9
|
+
$ yolo mode=val model=yolo11n.pt # PyTorch
|
|
10
|
+
yolo11n.torchscript # TorchScript
|
|
11
|
+
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
|
12
|
+
yolo11n_openvino_model # OpenVINO
|
|
13
|
+
yolo11n.engine # TensorRT
|
|
14
|
+
yolo11n.mlpackage # CoreML (macOS-only)
|
|
15
|
+
yolo11n_saved_model # TensorFlow SavedModel
|
|
16
|
+
yolo11n.pb # TensorFlow GraphDef
|
|
17
|
+
yolo11n.tflite # TensorFlow Lite
|
|
18
|
+
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
|
|
19
|
+
yolo11n_paddle_model # PaddlePaddle
|
|
20
|
+
yolo11n.mnn # MNN
|
|
21
|
+
yolo11n_ncnn_model # NCNN
|
|
22
|
+
yolo11n_imx_model # Sony IMX
|
|
23
|
+
yolo11n_rknn_model # Rockchip RKNN
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import json
|
|
27
|
+
import time
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
import torch
|
|
32
|
+
import torch.distributed as dist
|
|
33
|
+
|
|
34
|
+
from ultralytics.cfg import get_cfg, get_save_dir
|
|
35
|
+
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
36
|
+
from ultralytics.nn.autobackend import AutoBackend
|
|
37
|
+
from ultralytics.utils import LOGGER, RANK, TQDM, callbacks, colorstr, emojis
|
|
38
|
+
from ultralytics.utils.checks import check_imgsz
|
|
39
|
+
from ultralytics.utils.ops import Profile
|
|
40
|
+
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class BaseValidator:
|
|
44
|
+
"""A base class for creating validators.
|
|
45
|
+
|
|
46
|
+
This class provides the foundation for validation processes, including model evaluation, metric computation, and
|
|
47
|
+
result visualization.
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
args (SimpleNamespace): Configuration for the validator.
|
|
51
|
+
dataloader (DataLoader): DataLoader to use for validation.
|
|
52
|
+
model (nn.Module): Model to validate.
|
|
53
|
+
data (dict): Data dictionary containing dataset information.
|
|
54
|
+
device (torch.device): Device to use for validation.
|
|
55
|
+
batch_i (int): Current batch index.
|
|
56
|
+
training (bool): Whether the model is in training mode.
|
|
57
|
+
names (dict): Class names mapping.
|
|
58
|
+
seen (int): Number of images seen so far during validation.
|
|
59
|
+
stats (dict): Statistics collected during validation.
|
|
60
|
+
confusion_matrix: Confusion matrix for classification evaluation.
|
|
61
|
+
nc (int): Number of classes.
|
|
62
|
+
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
|
63
|
+
jdict (list): List to store JSON validation results.
|
|
64
|
+
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
|
|
65
|
+
processing times in milliseconds.
|
|
66
|
+
save_dir (Path): Directory to save results.
|
|
67
|
+
plots (dict): Dictionary to store plots for visualization.
|
|
68
|
+
callbacks (dict): Dictionary to store various callback functions.
|
|
69
|
+
stride (int): Model stride for padding calculations.
|
|
70
|
+
loss (torch.Tensor): Accumulated loss during training validation.
|
|
71
|
+
|
|
72
|
+
Methods:
|
|
73
|
+
__call__: Execute validation process, running inference on dataloader and computing performance metrics.
|
|
74
|
+
match_predictions: Match predictions to ground truth objects using IoU.
|
|
75
|
+
add_callback: Append the given callback to the specified event.
|
|
76
|
+
run_callbacks: Run all callbacks associated with a specified event.
|
|
77
|
+
get_dataloader: Get data loader from dataset path and batch size.
|
|
78
|
+
build_dataset: Build dataset from image path.
|
|
79
|
+
preprocess: Preprocess an input batch.
|
|
80
|
+
postprocess: Postprocess the predictions.
|
|
81
|
+
init_metrics: Initialize performance metrics for the YOLO model.
|
|
82
|
+
update_metrics: Update metrics based on predictions and batch.
|
|
83
|
+
finalize_metrics: Finalize and return all metrics.
|
|
84
|
+
get_stats: Return statistics about the model's performance.
|
|
85
|
+
print_results: Print the results of the model's predictions.
|
|
86
|
+
get_desc: Get description of the YOLO model.
|
|
87
|
+
on_plot: Register plots for visualization.
|
|
88
|
+
plot_val_samples: Plot validation samples during training.
|
|
89
|
+
plot_predictions: Plot YOLO model predictions on batch images.
|
|
90
|
+
pred_to_json: Convert predictions to JSON format.
|
|
91
|
+
eval_json: Evaluate and return JSON format of prediction statistics.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
|
95
|
+
"""Initialize a BaseValidator instance.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
99
|
+
save_dir (Path, optional): Directory to save results.
|
|
100
|
+
args (SimpleNamespace, optional): Configuration for the validator.
|
|
101
|
+
_callbacks (dict, optional): Dictionary to store various callback functions.
|
|
102
|
+
"""
|
|
103
|
+
import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
|
|
104
|
+
|
|
105
|
+
self.args = get_cfg(overrides=args)
|
|
106
|
+
self.dataloader = dataloader
|
|
107
|
+
self.stride = None
|
|
108
|
+
self.data = None
|
|
109
|
+
self.device = None
|
|
110
|
+
self.batch_i = None
|
|
111
|
+
self.training = True
|
|
112
|
+
self.names = None
|
|
113
|
+
self.seen = None
|
|
114
|
+
self.stats = None
|
|
115
|
+
self.confusion_matrix = None
|
|
116
|
+
self.nc = None
|
|
117
|
+
self.iouv = None
|
|
118
|
+
self.jdict = None
|
|
119
|
+
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
120
|
+
|
|
121
|
+
self.save_dir = save_dir or get_save_dir(self.args)
|
|
122
|
+
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
|
123
|
+
if self.args.conf is None:
|
|
124
|
+
self.args.conf = 0.01 if self.args.task == "obb" else 0.001 # reduce OBB val memory usage
|
|
125
|
+
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
|
126
|
+
|
|
127
|
+
self.plots = {}
|
|
128
|
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
|
129
|
+
|
|
130
|
+
@smart_inference_mode()
|
|
131
|
+
def __call__(self, trainer=None, model=None):
|
|
132
|
+
"""Execute validation process, running inference on dataloader and computing performance metrics.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
trainer (object, optional): Trainer object that contains the model to validate.
|
|
136
|
+
model (nn.Module, optional): Model to validate if not using a trainer.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
(dict): Dictionary containing validation statistics.
|
|
140
|
+
"""
|
|
141
|
+
self.training = trainer is not None
|
|
142
|
+
augment = self.args.augment and (not self.training)
|
|
143
|
+
if self.training:
|
|
144
|
+
self.device = trainer.device
|
|
145
|
+
self.data = trainer.data
|
|
146
|
+
# Force FP16 val during training
|
|
147
|
+
self.args.half = self.device.type != "cpu" and trainer.amp
|
|
148
|
+
model = trainer.ema.ema or trainer.model
|
|
149
|
+
if trainer.args.compile and hasattr(model, "_orig_mod"):
|
|
150
|
+
model = model._orig_mod # validate non-compiled original model to avoid issues
|
|
151
|
+
model = model.half() if self.args.half else model.float()
|
|
152
|
+
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
|
153
|
+
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
|
154
|
+
model.eval()
|
|
155
|
+
else:
|
|
156
|
+
if str(self.args.model).endswith(".yaml") and model is None:
|
|
157
|
+
LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
|
|
158
|
+
callbacks.add_integration_callbacks(self)
|
|
159
|
+
model = AutoBackend(
|
|
160
|
+
model=model or self.args.model,
|
|
161
|
+
device=select_device(self.args.device) if RANK == -1 else torch.device("cuda", RANK),
|
|
162
|
+
dnn=self.args.dnn,
|
|
163
|
+
data=self.args.data,
|
|
164
|
+
fp16=self.args.half,
|
|
165
|
+
)
|
|
166
|
+
self.device = model.device # update device
|
|
167
|
+
self.args.half = model.fp16 # update half
|
|
168
|
+
stride, pt, jit = model.stride, model.pt, model.jit
|
|
169
|
+
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
|
170
|
+
if not (pt or jit or getattr(model, "dynamic", False)):
|
|
171
|
+
self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
|
|
172
|
+
LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
|
|
173
|
+
|
|
174
|
+
if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
|
|
175
|
+
self.data = check_det_dataset(self.args.data)
|
|
176
|
+
elif self.args.task == "classify":
|
|
177
|
+
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
|
178
|
+
else:
|
|
179
|
+
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
|
180
|
+
|
|
181
|
+
if self.device.type in {"cpu", "mps"}:
|
|
182
|
+
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
|
183
|
+
if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
|
|
184
|
+
self.args.rect = False
|
|
185
|
+
self.stride = model.stride # used in get_dataloader() for padding
|
|
186
|
+
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
|
187
|
+
|
|
188
|
+
model.eval()
|
|
189
|
+
if self.args.compile:
|
|
190
|
+
model = attempt_compile(model, device=self.device)
|
|
191
|
+
model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
|
|
192
|
+
|
|
193
|
+
self.run_callbacks("on_val_start")
|
|
194
|
+
dt = (
|
|
195
|
+
Profile(device=self.device),
|
|
196
|
+
Profile(device=self.device),
|
|
197
|
+
Profile(device=self.device),
|
|
198
|
+
Profile(device=self.device),
|
|
199
|
+
)
|
|
200
|
+
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
|
201
|
+
self.init_metrics(unwrap_model(model))
|
|
202
|
+
self.jdict = [] # empty before each val
|
|
203
|
+
for batch_i, batch in enumerate(bar):
|
|
204
|
+
self.run_callbacks("on_val_batch_start")
|
|
205
|
+
self.batch_i = batch_i
|
|
206
|
+
# Preprocess
|
|
207
|
+
with dt[0]:
|
|
208
|
+
batch = self.preprocess(batch)
|
|
209
|
+
|
|
210
|
+
# Inference
|
|
211
|
+
with dt[1]:
|
|
212
|
+
preds = model(batch["img"], augment=augment)
|
|
213
|
+
|
|
214
|
+
# Loss
|
|
215
|
+
with dt[2]:
|
|
216
|
+
if self.training:
|
|
217
|
+
self.loss += model.loss(batch, preds)[1]
|
|
218
|
+
|
|
219
|
+
# Postprocess
|
|
220
|
+
with dt[3]:
|
|
221
|
+
preds = self.postprocess(preds)
|
|
222
|
+
|
|
223
|
+
self.update_metrics(preds, batch)
|
|
224
|
+
if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
|
|
225
|
+
self.plot_val_samples(batch, batch_i)
|
|
226
|
+
self.plot_predictions(batch, preds, batch_i)
|
|
227
|
+
|
|
228
|
+
self.run_callbacks("on_val_batch_end")
|
|
229
|
+
|
|
230
|
+
stats = {}
|
|
231
|
+
self.gather_stats()
|
|
232
|
+
if RANK in {-1, 0}:
|
|
233
|
+
stats = self.get_stats()
|
|
234
|
+
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
|
235
|
+
self.finalize_metrics()
|
|
236
|
+
self.print_results()
|
|
237
|
+
self.run_callbacks("on_val_end")
|
|
238
|
+
|
|
239
|
+
if self.training:
|
|
240
|
+
model.float()
|
|
241
|
+
# Reduce loss across all GPUs
|
|
242
|
+
loss = self.loss.clone().detach()
|
|
243
|
+
if trainer.world_size > 1:
|
|
244
|
+
dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
|
|
245
|
+
if RANK > 0:
|
|
246
|
+
return
|
|
247
|
+
results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
|
|
248
|
+
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
|
249
|
+
else:
|
|
250
|
+
if RANK > 0:
|
|
251
|
+
return stats
|
|
252
|
+
LOGGER.info(
|
|
253
|
+
"Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
|
|
254
|
+
*tuple(self.speed.values())
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
if self.args.save_json and self.jdict:
|
|
258
|
+
with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
|
|
259
|
+
LOGGER.info(f"Saving {f.name}...")
|
|
260
|
+
json.dump(self.jdict, f) # flatten and save
|
|
261
|
+
stats = self.eval_json(stats) # update stats
|
|
262
|
+
if self.args.plots or self.args.save_json:
|
|
263
|
+
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
|
264
|
+
return stats
|
|
265
|
+
|
|
266
|
+
def match_predictions(
|
|
267
|
+
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
|
|
268
|
+
) -> torch.Tensor:
|
|
269
|
+
"""Match predictions to ground truth objects using IoU.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
|
273
|
+
true_classes (torch.Tensor): Target class indices of shape (M,).
|
|
274
|
+
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
|
|
275
|
+
use_scipy (bool, optional): Whether to use scipy for matching (more precise).
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
(torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
|
|
279
|
+
"""
|
|
280
|
+
# Dx10 matrix, where D - detections, 10 - IoU thresholds
|
|
281
|
+
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
|
282
|
+
# LxD matrix where L - labels (rows), D - detections (columns)
|
|
283
|
+
correct_class = true_classes[:, None] == pred_classes
|
|
284
|
+
iou = iou * correct_class # zero out the wrong classes
|
|
285
|
+
iou = iou.cpu().numpy()
|
|
286
|
+
for i, threshold in enumerate(self.iouv.cpu().tolist()):
|
|
287
|
+
if use_scipy:
|
|
288
|
+
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
|
289
|
+
import scipy # scope import to avoid importing for all commands
|
|
290
|
+
|
|
291
|
+
cost_matrix = iou * (iou >= threshold)
|
|
292
|
+
if cost_matrix.any():
|
|
293
|
+
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
|
|
294
|
+
valid = cost_matrix[labels_idx, detections_idx] > 0
|
|
295
|
+
if valid.any():
|
|
296
|
+
correct[detections_idx[valid], i] = True
|
|
297
|
+
else:
|
|
298
|
+
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
|
|
299
|
+
matches = np.array(matches).T
|
|
300
|
+
if matches.shape[0]:
|
|
301
|
+
if matches.shape[0] > 1:
|
|
302
|
+
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
|
|
303
|
+
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
304
|
+
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
305
|
+
correct[matches[:, 1].astype(int), i] = True
|
|
306
|
+
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
|
307
|
+
|
|
308
|
+
def add_callback(self, event: str, callback):
|
|
309
|
+
"""Append the given callback to the specified event."""
|
|
310
|
+
self.callbacks[event].append(callback)
|
|
311
|
+
|
|
312
|
+
def run_callbacks(self, event: str):
|
|
313
|
+
"""Run all callbacks associated with a specified event."""
|
|
314
|
+
for callback in self.callbacks.get(event, []):
|
|
315
|
+
callback(self)
|
|
316
|
+
|
|
317
|
+
def get_dataloader(self, dataset_path, batch_size):
|
|
318
|
+
"""Get data loader from dataset path and batch size."""
|
|
319
|
+
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
|
320
|
+
|
|
321
|
+
def build_dataset(self, img_path):
|
|
322
|
+
"""Build dataset from image path."""
|
|
323
|
+
raise NotImplementedError("build_dataset function not implemented in validator")
|
|
324
|
+
|
|
325
|
+
def preprocess(self, batch):
|
|
326
|
+
"""Preprocess an input batch."""
|
|
327
|
+
return batch
|
|
328
|
+
|
|
329
|
+
def postprocess(self, preds):
|
|
330
|
+
"""Postprocess the predictions."""
|
|
331
|
+
return preds
|
|
332
|
+
|
|
333
|
+
def init_metrics(self, model):
|
|
334
|
+
"""Initialize performance metrics for the YOLO model."""
|
|
335
|
+
pass
|
|
336
|
+
|
|
337
|
+
def update_metrics(self, preds, batch):
|
|
338
|
+
"""Update metrics based on predictions and batch."""
|
|
339
|
+
pass
|
|
340
|
+
|
|
341
|
+
def finalize_metrics(self):
|
|
342
|
+
"""Finalize and return all metrics."""
|
|
343
|
+
pass
|
|
344
|
+
|
|
345
|
+
def get_stats(self):
|
|
346
|
+
"""Return statistics about the model's performance."""
|
|
347
|
+
return {}
|
|
348
|
+
|
|
349
|
+
def gather_stats(self):
|
|
350
|
+
"""Gather statistics from all the GPUs during DDP training to GPU 0."""
|
|
351
|
+
pass
|
|
352
|
+
|
|
353
|
+
def print_results(self):
|
|
354
|
+
"""Print the results of the model's predictions."""
|
|
355
|
+
pass
|
|
356
|
+
|
|
357
|
+
def get_desc(self):
|
|
358
|
+
"""Get description of the YOLO model."""
|
|
359
|
+
pass
|
|
360
|
+
|
|
361
|
+
@property
|
|
362
|
+
def metric_keys(self):
|
|
363
|
+
"""Return the metric keys used in YOLO training/validation."""
|
|
364
|
+
return []
|
|
365
|
+
|
|
366
|
+
def on_plot(self, name, data=None):
|
|
367
|
+
"""Register plots for visualization, deduplicating by type."""
|
|
368
|
+
plot_type = data.get("type") if data else None
|
|
369
|
+
if plot_type and any((v.get("data") or {}).get("type") == plot_type for v in self.plots.values()):
|
|
370
|
+
return # Skip duplicate plot types
|
|
371
|
+
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
|
372
|
+
|
|
373
|
+
def plot_val_samples(self, batch, ni):
|
|
374
|
+
"""Plot validation samples during training."""
|
|
375
|
+
pass
|
|
376
|
+
|
|
377
|
+
def plot_predictions(self, batch, preds, ni):
|
|
378
|
+
"""Plot YOLO model predictions on batch images."""
|
|
379
|
+
pass
|
|
380
|
+
|
|
381
|
+
def pred_to_json(self, preds, batch):
|
|
382
|
+
"""Convert predictions to JSON format."""
|
|
383
|
+
pass
|
|
384
|
+
|
|
385
|
+
def eval_json(self, stats):
|
|
386
|
+
"""Evaluate and return JSON format of prediction statistics."""
|
|
387
|
+
pass
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from ultralytics.data.utils import HUBDatasetStats
|
|
6
|
+
from ultralytics.hub.auth import Auth
|
|
7
|
+
from ultralytics.hub.session import HUBTrainingSession
|
|
8
|
+
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
|
9
|
+
from ultralytics.utils import LOGGER, SETTINGS, checks
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
"HUB_WEB_ROOT",
|
|
13
|
+
"PREFIX",
|
|
14
|
+
"HUBTrainingSession",
|
|
15
|
+
"check_dataset",
|
|
16
|
+
"export_fmts_hub",
|
|
17
|
+
"export_model",
|
|
18
|
+
"get_export",
|
|
19
|
+
"login",
|
|
20
|
+
"logout",
|
|
21
|
+
"reset_model",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def login(api_key: str | None = None, save: bool = True) -> bool:
|
|
26
|
+
"""Log in to the Ultralytics HUB API using the provided API key.
|
|
27
|
+
|
|
28
|
+
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
|
|
29
|
+
environment variable if successfully authenticated.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS
|
|
33
|
+
or HUB_API_KEY environment variable.
|
|
34
|
+
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
(bool): True if authentication is successful, False otherwise.
|
|
38
|
+
"""
|
|
39
|
+
checks.check_requirements("hub-sdk>=0.0.12")
|
|
40
|
+
from hub_sdk import HUBClient
|
|
41
|
+
|
|
42
|
+
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
|
|
43
|
+
saved_key = SETTINGS.get("api_key")
|
|
44
|
+
active_key = api_key or saved_key
|
|
45
|
+
credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
|
|
46
|
+
|
|
47
|
+
client = HUBClient(credentials) # initialize HUBClient
|
|
48
|
+
|
|
49
|
+
if client.authenticated:
|
|
50
|
+
# Successfully authenticated with HUB
|
|
51
|
+
|
|
52
|
+
if save and client.api_key != saved_key:
|
|
53
|
+
SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
|
|
54
|
+
|
|
55
|
+
# Set message based on whether key was provided or retrieved from settings
|
|
56
|
+
log_message = (
|
|
57
|
+
"New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
|
|
58
|
+
)
|
|
59
|
+
LOGGER.info(f"{PREFIX}{log_message}")
|
|
60
|
+
|
|
61
|
+
return True
|
|
62
|
+
else:
|
|
63
|
+
# Failed to authenticate with HUB
|
|
64
|
+
LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo login API_KEY'")
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def logout():
|
|
69
|
+
"""Log out of Ultralytics HUB by removing the API key from the settings file."""
|
|
70
|
+
SETTINGS["api_key"] = ""
|
|
71
|
+
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def reset_model(model_id: str = ""):
|
|
75
|
+
"""Reset a trained model to an untrained state."""
|
|
76
|
+
import requests # scoped as slow import
|
|
77
|
+
|
|
78
|
+
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
|
|
79
|
+
if r.status_code == 200:
|
|
80
|
+
LOGGER.info(f"{PREFIX}Model reset successfully")
|
|
81
|
+
return
|
|
82
|
+
LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def export_fmts_hub():
|
|
86
|
+
"""Return a list of HUB-supported export formats."""
|
|
87
|
+
from ultralytics.engine.exporter import export_formats
|
|
88
|
+
|
|
89
|
+
return [*list(export_formats()["Argument"][1:]), "ultralytics_tflite", "ultralytics_coreml"]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def export_model(model_id: str = "", format: str = "torchscript"):
|
|
93
|
+
"""Export a model to a specified format for deployment via the Ultralytics HUB API.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
model_id (str): The ID of the model to export. An empty string will use the default model.
|
|
97
|
+
format (str): The format to export the model to. Must be one of the supported formats returned by
|
|
98
|
+
export_fmts_hub().
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
AssertionError: If the specified format is not supported or if the export request fails.
|
|
102
|
+
|
|
103
|
+
Examples:
|
|
104
|
+
>>> from ultralytics import hub
|
|
105
|
+
>>> hub.export_model(model_id="your_model_id", format="torchscript")
|
|
106
|
+
"""
|
|
107
|
+
import requests # scoped as slow import
|
|
108
|
+
|
|
109
|
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
|
110
|
+
r = requests.post(
|
|
111
|
+
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
|
|
112
|
+
)
|
|
113
|
+
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
|
|
114
|
+
LOGGER.info(f"{PREFIX}{format} export started ✅")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_export(model_id: str = "", format: str = "torchscript"):
|
|
118
|
+
"""Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
model_id (str): The ID of the model to retrieve from Ultralytics HUB.
|
|
122
|
+
format (str): The export format to retrieve. Must be one of the supported formats returned by export_fmts_hub().
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
(dict): JSON response containing the exported model information.
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
AssertionError: If the specified format is not supported or if the API request fails.
|
|
129
|
+
|
|
130
|
+
Examples:
|
|
131
|
+
>>> from ultralytics import hub
|
|
132
|
+
>>> result = hub.get_export(model_id="your_model_id", format="torchscript")
|
|
133
|
+
"""
|
|
134
|
+
import requests # scoped as slow import
|
|
135
|
+
|
|
136
|
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
|
137
|
+
r = requests.post(
|
|
138
|
+
f"{HUB_API_ROOT}/get-export",
|
|
139
|
+
json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
|
|
140
|
+
headers={"x-api-key": Auth().api_key},
|
|
141
|
+
)
|
|
142
|
+
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
|
|
143
|
+
return r.json()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def check_dataset(path: str, task: str) -> None:
|
|
147
|
+
"""Check HUB dataset Zip file for errors before upload.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
path (str): Path to data.zip (with data.yaml inside data.zip).
|
|
151
|
+
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.
|
|
152
|
+
|
|
153
|
+
Examples:
|
|
154
|
+
>>> from ultralytics.hub import check_dataset
|
|
155
|
+
>>> check_dataset("path/to/coco8.zip", task="detect") # detect dataset
|
|
156
|
+
>>> check_dataset("path/to/coco8-seg.zip", task="segment") # segment dataset
|
|
157
|
+
>>> check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset
|
|
158
|
+
>>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
|
|
159
|
+
>>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
|
|
160
|
+
|
|
161
|
+
Notes:
|
|
162
|
+
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
|
163
|
+
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
|
164
|
+
"""
|
|
165
|
+
HUBDatasetStats(path=path, task=task).get_json()
|
|
166
|
+
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
|