dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,1011 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import math
|
4
|
+
import warnings
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Callable, Dict, List, Optional, Union
|
7
|
+
|
8
|
+
import cv2
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
from PIL import Image, ImageDraw, ImageFont
|
12
|
+
from PIL import __version__ as pil_version
|
13
|
+
|
14
|
+
from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded
|
15
|
+
from ultralytics.utils.checks import check_font, check_version, is_ascii
|
16
|
+
from ultralytics.utils.files import increment_path
|
17
|
+
|
18
|
+
|
19
|
+
class Colors:
|
20
|
+
"""
|
21
|
+
Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
|
22
|
+
|
23
|
+
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
|
24
|
+
RGB values.
|
25
|
+
|
26
|
+
Attributes:
|
27
|
+
palette (List[Tuple]): List of RGB color values.
|
28
|
+
n (int): The number of colors in the palette.
|
29
|
+
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
30
|
+
|
31
|
+
Examples:
|
32
|
+
>>> from ultralytics.utils.plotting import Colors
|
33
|
+
>>> colors = Colors()
|
34
|
+
>>> colors(5, True) # ff6fdd or (255, 111, 221)
|
35
|
+
|
36
|
+
## Ultralytics Color Palette
|
37
|
+
|
38
|
+
| Index | Color | HEX | RGB |
|
39
|
+
|-------|-------------------------------------------------------------------|-----------|-------------------|
|
40
|
+
| 0 | <i class="fa-solid fa-square fa-2xl" style="color: #042aff;"></i> | `#042aff` | (4, 42, 255) |
|
41
|
+
| 1 | <i class="fa-solid fa-square fa-2xl" style="color: #0bdbeb;"></i> | `#0bdbeb` | (11, 219, 235) |
|
42
|
+
| 2 | <i class="fa-solid fa-square fa-2xl" style="color: #f3f3f3;"></i> | `#f3f3f3` | (243, 243, 243) |
|
43
|
+
| 3 | <i class="fa-solid fa-square fa-2xl" style="color: #00dfb7;"></i> | `#00dfb7` | (0, 223, 183) |
|
44
|
+
| 4 | <i class="fa-solid fa-square fa-2xl" style="color: #111f68;"></i> | `#111f68` | (17, 31, 104) |
|
45
|
+
| 5 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6fdd;"></i> | `#ff6fdd` | (255, 111, 221) |
|
46
|
+
| 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff444f;"></i> | `#ff444f` | (255, 68, 79) |
|
47
|
+
| 7 | <i class="fa-solid fa-square fa-2xl" style="color: #cced00;"></i> | `#cced00` | (204, 237, 0) |
|
48
|
+
| 8 | <i class="fa-solid fa-square fa-2xl" style="color: #00f344;"></i> | `#00f344` | (0, 243, 68) |
|
49
|
+
| 9 | <i class="fa-solid fa-square fa-2xl" style="color: #bd00ff;"></i> | `#bd00ff` | (189, 0, 255) |
|
50
|
+
| 10 | <i class="fa-solid fa-square fa-2xl" style="color: #00b4ff;"></i> | `#00b4ff` | (0, 180, 255) |
|
51
|
+
| 11 | <i class="fa-solid fa-square fa-2xl" style="color: #dd00ba;"></i> | `#dd00ba` | (221, 0, 186) |
|
52
|
+
| 12 | <i class="fa-solid fa-square fa-2xl" style="color: #00ffff;"></i> | `#00ffff` | (0, 255, 255) |
|
53
|
+
| 13 | <i class="fa-solid fa-square fa-2xl" style="color: #26c000;"></i> | `#26c000` | (38, 192, 0) |
|
54
|
+
| 14 | <i class="fa-solid fa-square fa-2xl" style="color: #01ffb3;"></i> | `#01ffb3` | (1, 255, 179) |
|
55
|
+
| 15 | <i class="fa-solid fa-square fa-2xl" style="color: #7d24ff;"></i> | `#7d24ff` | (125, 36, 255) |
|
56
|
+
| 16 | <i class="fa-solid fa-square fa-2xl" style="color: #7b0068;"></i> | `#7b0068` | (123, 0, 104) |
|
57
|
+
| 17 | <i class="fa-solid fa-square fa-2xl" style="color: #ff1b6c;"></i> | `#ff1b6c` | (255, 27, 108) |
|
58
|
+
| 18 | <i class="fa-solid fa-square fa-2xl" style="color: #fc6d2f;"></i> | `#fc6d2f` | (252, 109, 47) |
|
59
|
+
| 19 | <i class="fa-solid fa-square fa-2xl" style="color: #a2ff0b;"></i> | `#a2ff0b` | (162, 255, 11) |
|
60
|
+
|
61
|
+
## Pose Color Palette
|
62
|
+
|
63
|
+
| Index | Color | HEX | RGB |
|
64
|
+
|-------|-------------------------------------------------------------------|-----------|-------------------|
|
65
|
+
| 0 | <i class="fa-solid fa-square fa-2xl" style="color: #ff8000;"></i> | `#ff8000` | (255, 128, 0) |
|
66
|
+
| 1 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9933;"></i> | `#ff9933` | (255, 153, 51) |
|
67
|
+
| 2 | <i class="fa-solid fa-square fa-2xl" style="color: #ffb266;"></i> | `#ffb266` | (255, 178, 102) |
|
68
|
+
| 3 | <i class="fa-solid fa-square fa-2xl" style="color: #e6e600;"></i> | `#e6e600` | (230, 230, 0) |
|
69
|
+
| 4 | <i class="fa-solid fa-square fa-2xl" style="color: #ff99ff;"></i> | `#ff99ff` | (255, 153, 255) |
|
70
|
+
| 5 | <i class="fa-solid fa-square fa-2xl" style="color: #99ccff;"></i> | `#99ccff` | (153, 204, 255) |
|
71
|
+
| 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff66ff;"></i> | `#ff66ff` | (255, 102, 255) |
|
72
|
+
| 7 | <i class="fa-solid fa-square fa-2xl" style="color: #ff33ff;"></i> | `#ff33ff` | (255, 51, 255) |
|
73
|
+
| 8 | <i class="fa-solid fa-square fa-2xl" style="color: #66b2ff;"></i> | `#66b2ff` | (102, 178, 255) |
|
74
|
+
| 9 | <i class="fa-solid fa-square fa-2xl" style="color: #3399ff;"></i> | `#3399ff` | (51, 153, 255) |
|
75
|
+
| 10 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9999;"></i> | `#ff9999` | (255, 153, 153) |
|
76
|
+
| 11 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6666;"></i> | `#ff6666` | (255, 102, 102) |
|
77
|
+
| 12 | <i class="fa-solid fa-square fa-2xl" style="color: #ff3333;"></i> | `#ff3333` | (255, 51, 51) |
|
78
|
+
| 13 | <i class="fa-solid fa-square fa-2xl" style="color: #99ff99;"></i> | `#99ff99` | (153, 255, 153) |
|
79
|
+
| 14 | <i class="fa-solid fa-square fa-2xl" style="color: #66ff66;"></i> | `#66ff66` | (102, 255, 102) |
|
80
|
+
| 15 | <i class="fa-solid fa-square fa-2xl" style="color: #33ff33;"></i> | `#33ff33` | (51, 255, 51) |
|
81
|
+
| 16 | <i class="fa-solid fa-square fa-2xl" style="color: #00ff00;"></i> | `#00ff00` | (0, 255, 0) |
|
82
|
+
| 17 | <i class="fa-solid fa-square fa-2xl" style="color: #0000ff;"></i> | `#0000ff` | (0, 0, 255) |
|
83
|
+
| 18 | <i class="fa-solid fa-square fa-2xl" style="color: #ff0000;"></i> | `#ff0000` | (255, 0, 0) |
|
84
|
+
| 19 | <i class="fa-solid fa-square fa-2xl" style="color: #ffffff;"></i> | `#ffffff` | (255, 255, 255) |
|
85
|
+
|
86
|
+
!!! note "Ultralytics Brand Colors"
|
87
|
+
|
88
|
+
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
|
89
|
+
"""
|
90
|
+
|
91
|
+
def __init__(self):
|
92
|
+
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
93
|
+
hexs = (
|
94
|
+
"042AFF",
|
95
|
+
"0BDBEB",
|
96
|
+
"F3F3F3",
|
97
|
+
"00DFB7",
|
98
|
+
"111F68",
|
99
|
+
"FF6FDD",
|
100
|
+
"FF444F",
|
101
|
+
"CCED00",
|
102
|
+
"00F344",
|
103
|
+
"BD00FF",
|
104
|
+
"00B4FF",
|
105
|
+
"DD00BA",
|
106
|
+
"00FFFF",
|
107
|
+
"26C000",
|
108
|
+
"01FFB3",
|
109
|
+
"7D24FF",
|
110
|
+
"7B0068",
|
111
|
+
"FF1B6C",
|
112
|
+
"FC6D2F",
|
113
|
+
"A2FF0B",
|
114
|
+
)
|
115
|
+
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
|
116
|
+
self.n = len(self.palette)
|
117
|
+
self.pose_palette = np.array(
|
118
|
+
[
|
119
|
+
[255, 128, 0],
|
120
|
+
[255, 153, 51],
|
121
|
+
[255, 178, 102],
|
122
|
+
[230, 230, 0],
|
123
|
+
[255, 153, 255],
|
124
|
+
[153, 204, 255],
|
125
|
+
[255, 102, 255],
|
126
|
+
[255, 51, 255],
|
127
|
+
[102, 178, 255],
|
128
|
+
[51, 153, 255],
|
129
|
+
[255, 153, 153],
|
130
|
+
[255, 102, 102],
|
131
|
+
[255, 51, 51],
|
132
|
+
[153, 255, 153],
|
133
|
+
[102, 255, 102],
|
134
|
+
[51, 255, 51],
|
135
|
+
[0, 255, 0],
|
136
|
+
[0, 0, 255],
|
137
|
+
[255, 0, 0],
|
138
|
+
[255, 255, 255],
|
139
|
+
],
|
140
|
+
dtype=np.uint8,
|
141
|
+
)
|
142
|
+
|
143
|
+
def __call__(self, i, bgr=False):
|
144
|
+
"""Convert hex color codes to RGB values."""
|
145
|
+
c = self.palette[int(i) % self.n]
|
146
|
+
return (c[2], c[1], c[0]) if bgr else c
|
147
|
+
|
148
|
+
@staticmethod
|
149
|
+
def hex2rgb(h):
|
150
|
+
"""Convert hex color codes to RGB values (i.e. default PIL order)."""
|
151
|
+
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
|
152
|
+
|
153
|
+
|
154
|
+
colors = Colors() # create instance for 'from utils.plots import colors'
|
155
|
+
|
156
|
+
|
157
|
+
class Annotator:
|
158
|
+
"""
|
159
|
+
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
160
|
+
|
161
|
+
Attributes:
|
162
|
+
im (Image.Image or np.ndarray): The image to annotate.
|
163
|
+
pil (bool): Whether to use PIL or cv2 for drawing annotations.
|
164
|
+
font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
|
165
|
+
lw (float): Line width for drawing.
|
166
|
+
skeleton (List[List[int]]): Skeleton structure for keypoints.
|
167
|
+
limb_color (List[int]): Color palette for limbs.
|
168
|
+
kpt_color (List[int]): Color palette for keypoints.
|
169
|
+
dark_colors (set): Set of colors considered dark for text contrast.
|
170
|
+
light_colors (set): Set of colors considered light for text contrast.
|
171
|
+
|
172
|
+
Examples:
|
173
|
+
>>> from ultralytics.utils.plotting import Annotator
|
174
|
+
>>> im0 = cv2.imread("test.png")
|
175
|
+
>>> annotator = Annotator(im0, line_width=10)
|
176
|
+
"""
|
177
|
+
|
178
|
+
def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
|
179
|
+
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
180
|
+
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
181
|
+
input_is_pil = isinstance(im, Image.Image)
|
182
|
+
self.pil = pil or non_ascii or input_is_pil
|
183
|
+
self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
|
184
|
+
if self.pil: # use PIL
|
185
|
+
self.im = im if input_is_pil else Image.fromarray(im)
|
186
|
+
if self.im.mode not in {"RGB", "RGBA"}: # multispectral
|
187
|
+
self.im = self.im.convert("RGB")
|
188
|
+
self.draw = ImageDraw.Draw(self.im, "RGBA")
|
189
|
+
try:
|
190
|
+
font = check_font("Arial.Unicode.ttf" if non_ascii else font)
|
191
|
+
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
|
192
|
+
self.font = ImageFont.truetype(str(font), size)
|
193
|
+
except Exception:
|
194
|
+
self.font = ImageFont.load_default()
|
195
|
+
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
|
196
|
+
if check_version(pil_version, "9.2.0"):
|
197
|
+
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
|
198
|
+
else: # use cv2
|
199
|
+
if im.shape[2] == 1: # handle grayscale
|
200
|
+
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
|
201
|
+
elif im.shape[2] > 3: # multispectral
|
202
|
+
im = np.ascontiguousarray(im[..., :3])
|
203
|
+
assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
|
204
|
+
self.im = im if im.flags.writeable else im.copy()
|
205
|
+
self.tf = max(self.lw - 1, 1) # font thickness
|
206
|
+
self.sf = self.lw / 3 # font scale
|
207
|
+
# Pose
|
208
|
+
self.skeleton = [
|
209
|
+
[16, 14],
|
210
|
+
[14, 12],
|
211
|
+
[17, 15],
|
212
|
+
[15, 13],
|
213
|
+
[12, 13],
|
214
|
+
[6, 12],
|
215
|
+
[7, 13],
|
216
|
+
[6, 7],
|
217
|
+
[6, 8],
|
218
|
+
[7, 9],
|
219
|
+
[8, 10],
|
220
|
+
[9, 11],
|
221
|
+
[2, 3],
|
222
|
+
[1, 2],
|
223
|
+
[1, 3],
|
224
|
+
[2, 4],
|
225
|
+
[3, 5],
|
226
|
+
[4, 6],
|
227
|
+
[5, 7],
|
228
|
+
]
|
229
|
+
|
230
|
+
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
|
231
|
+
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
|
232
|
+
self.dark_colors = {
|
233
|
+
(235, 219, 11),
|
234
|
+
(243, 243, 243),
|
235
|
+
(183, 223, 0),
|
236
|
+
(221, 111, 255),
|
237
|
+
(0, 237, 204),
|
238
|
+
(68, 243, 0),
|
239
|
+
(255, 255, 0),
|
240
|
+
(179, 255, 1),
|
241
|
+
(11, 255, 162),
|
242
|
+
}
|
243
|
+
self.light_colors = {
|
244
|
+
(255, 42, 4),
|
245
|
+
(79, 68, 255),
|
246
|
+
(255, 0, 189),
|
247
|
+
(255, 180, 0),
|
248
|
+
(186, 0, 221),
|
249
|
+
(0, 192, 38),
|
250
|
+
(255, 36, 125),
|
251
|
+
(104, 0, 123),
|
252
|
+
(108, 27, 255),
|
253
|
+
(47, 109, 252),
|
254
|
+
(104, 31, 17),
|
255
|
+
}
|
256
|
+
|
257
|
+
def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
|
258
|
+
"""
|
259
|
+
Assign text color based on background color.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
color (tuple, optional): The background color of the rectangle for text (B, G, R).
|
263
|
+
txt_color (tuple, optional): The color of the text (R, G, B).
|
264
|
+
|
265
|
+
Returns:
|
266
|
+
(tuple): Text color for label.
|
267
|
+
|
268
|
+
Examples:
|
269
|
+
>>> from ultralytics.utils.plotting import Annotator
|
270
|
+
>>> im0 = cv2.imread("test.png")
|
271
|
+
>>> annotator = Annotator(im0, line_width=10)
|
272
|
+
>>> annotator.get_txt_color(color=(104, 31, 17)) # return (255, 255, 255)
|
273
|
+
"""
|
274
|
+
if color in self.dark_colors:
|
275
|
+
return 104, 31, 17
|
276
|
+
elif color in self.light_colors:
|
277
|
+
return 255, 255, 255
|
278
|
+
else:
|
279
|
+
return txt_color
|
280
|
+
|
281
|
+
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
|
282
|
+
"""
|
283
|
+
Draw a bounding box on an image with a given label.
|
284
|
+
|
285
|
+
Args:
|
286
|
+
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
|
287
|
+
label (str, optional): The text label to be displayed.
|
288
|
+
color (tuple, optional): The background color of the rectangle (B, G, R).
|
289
|
+
txt_color (tuple, optional): The color of the text (R, G, B).
|
290
|
+
rotated (bool, optional): Whether the task is oriented bounding box detection.
|
291
|
+
|
292
|
+
Examples:
|
293
|
+
>>> from ultralytics.utils.plotting import Annotator
|
294
|
+
>>> im0 = cv2.imread("test.png")
|
295
|
+
>>> annotator = Annotator(im0, line_width=10)
|
296
|
+
>>> annotator.box_label(box=[10, 20, 30, 40], label="person")
|
297
|
+
"""
|
298
|
+
txt_color = self.get_txt_color(color, txt_color)
|
299
|
+
if isinstance(box, torch.Tensor):
|
300
|
+
box = box.tolist()
|
301
|
+
if self.pil or not is_ascii(label):
|
302
|
+
if rotated:
|
303
|
+
p1 = box[0]
|
304
|
+
self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box
|
305
|
+
else:
|
306
|
+
p1 = (box[0], box[1])
|
307
|
+
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
308
|
+
if label:
|
309
|
+
w, h = self.font.getsize(label) # text width, height
|
310
|
+
outside = p1[1] >= h # label fits outside box
|
311
|
+
if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image
|
312
|
+
p1 = self.im.size[0] - w, p1[1]
|
313
|
+
self.draw.rectangle(
|
314
|
+
(p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
|
315
|
+
fill=color,
|
316
|
+
)
|
317
|
+
# self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
318
|
+
self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
|
319
|
+
else: # cv2
|
320
|
+
if rotated:
|
321
|
+
p1 = [int(b) for b in box[0]]
|
322
|
+
cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box
|
323
|
+
else:
|
324
|
+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
325
|
+
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
326
|
+
if label:
|
327
|
+
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
|
328
|
+
h += 3 # add pixels to pad text
|
329
|
+
outside = p1[1] >= h # label fits outside box
|
330
|
+
if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image
|
331
|
+
p1 = self.im.shape[1] - w, p1[1]
|
332
|
+
p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h
|
333
|
+
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
334
|
+
cv2.putText(
|
335
|
+
self.im,
|
336
|
+
label,
|
337
|
+
(p1[0], p1[1] - 2 if outside else p1[1] + h - 1),
|
338
|
+
0,
|
339
|
+
self.sf,
|
340
|
+
txt_color,
|
341
|
+
thickness=self.tf,
|
342
|
+
lineType=cv2.LINE_AA,
|
343
|
+
)
|
344
|
+
|
345
|
+
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
346
|
+
"""
|
347
|
+
Plot masks on image.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
masks (torch.Tensor): Predicted masks on cuda, shape: [n, h, w]
|
351
|
+
colors (List[List[int]]): Colors for predicted masks, [[r, g, b] * n]
|
352
|
+
im_gpu (torch.Tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
|
353
|
+
alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque.
|
354
|
+
retina_masks (bool, optional): Whether to use high resolution masks or not.
|
355
|
+
"""
|
356
|
+
if self.pil:
|
357
|
+
# Convert to numpy first
|
358
|
+
self.im = np.asarray(self.im).copy()
|
359
|
+
if len(masks) == 0:
|
360
|
+
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
361
|
+
if im_gpu.device != masks.device:
|
362
|
+
im_gpu = im_gpu.to(masks.device)
|
363
|
+
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
|
364
|
+
colors = colors[:, None, None] # shape(n,1,1,3)
|
365
|
+
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
366
|
+
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
367
|
+
|
368
|
+
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
369
|
+
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
|
370
|
+
|
371
|
+
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
372
|
+
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
373
|
+
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
|
374
|
+
im_mask = im_gpu * 255
|
375
|
+
im_mask_np = im_mask.byte().cpu().numpy()
|
376
|
+
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
|
377
|
+
if self.pil:
|
378
|
+
# Convert im back to PIL and update draw
|
379
|
+
self.fromarray(self.im)
|
380
|
+
|
381
|
+
def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
|
382
|
+
"""
|
383
|
+
Plot keypoints on the image.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
|
387
|
+
shape (tuple, optional): Image shape (h, w).
|
388
|
+
radius (int, optional): Keypoint radius.
|
389
|
+
kpt_line (bool, optional): Draw lines between keypoints.
|
390
|
+
conf_thres (float, optional): Confidence threshold.
|
391
|
+
kpt_color (tuple, optional): Keypoint color (B, G, R).
|
392
|
+
|
393
|
+
Note:
|
394
|
+
- `kpt_line=True` currently only supports human pose plotting.
|
395
|
+
- Modifies self.im in-place.
|
396
|
+
- If self.pil is True, converts image to numpy array and back to PIL.
|
397
|
+
"""
|
398
|
+
radius = radius if radius is not None else self.lw
|
399
|
+
if self.pil:
|
400
|
+
# Convert to numpy first
|
401
|
+
self.im = np.asarray(self.im).copy()
|
402
|
+
nkpt, ndim = kpts.shape
|
403
|
+
is_pose = nkpt == 17 and ndim in {2, 3}
|
404
|
+
kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
|
405
|
+
for i, k in enumerate(kpts):
|
406
|
+
color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
|
407
|
+
x_coord, y_coord = k[0], k[1]
|
408
|
+
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
|
409
|
+
if len(k) == 3:
|
410
|
+
conf = k[2]
|
411
|
+
if conf < conf_thres:
|
412
|
+
continue
|
413
|
+
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
|
414
|
+
|
415
|
+
if kpt_line:
|
416
|
+
ndim = kpts.shape[-1]
|
417
|
+
for i, sk in enumerate(self.skeleton):
|
418
|
+
pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
|
419
|
+
pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
|
420
|
+
if ndim == 3:
|
421
|
+
conf1 = kpts[(sk[0] - 1), 2]
|
422
|
+
conf2 = kpts[(sk[1] - 1), 2]
|
423
|
+
if conf1 < conf_thres or conf2 < conf_thres:
|
424
|
+
continue
|
425
|
+
if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
|
426
|
+
continue
|
427
|
+
if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
|
428
|
+
continue
|
429
|
+
cv2.line(
|
430
|
+
self.im,
|
431
|
+
pos1,
|
432
|
+
pos2,
|
433
|
+
kpt_color or self.limb_color[i].tolist(),
|
434
|
+
thickness=int(np.ceil(self.lw / 2)),
|
435
|
+
lineType=cv2.LINE_AA,
|
436
|
+
)
|
437
|
+
if self.pil:
|
438
|
+
# Convert im back to PIL and update draw
|
439
|
+
self.fromarray(self.im)
|
440
|
+
|
441
|
+
def rectangle(self, xy, fill=None, outline=None, width=1):
|
442
|
+
"""Add rectangle to image (PIL-only)."""
|
443
|
+
self.draw.rectangle(xy, fill, outline, width)
|
444
|
+
|
445
|
+
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_color=()):
|
446
|
+
"""
|
447
|
+
Add text to an image using PIL or cv2.
|
448
|
+
|
449
|
+
Args:
|
450
|
+
xy (List[int]): Top-left coordinates for text placement.
|
451
|
+
text (str): Text to be drawn.
|
452
|
+
txt_color (tuple, optional): Text color (R, G, B).
|
453
|
+
anchor (str, optional): Text anchor position ('top' or 'bottom').
|
454
|
+
box_color (tuple, optional): Box color (R, G, B, A) with optional alpha.
|
455
|
+
"""
|
456
|
+
if self.pil:
|
457
|
+
w, h = self.font.getsize(text)
|
458
|
+
if anchor == "bottom": # start y from font bottom
|
459
|
+
xy[1] += 1 - h
|
460
|
+
for line in text.split("\n"):
|
461
|
+
if box_color:
|
462
|
+
# Draw rectangle for each line
|
463
|
+
w, h = self.font.getsize(line)
|
464
|
+
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=box_color)
|
465
|
+
self.draw.text(xy, line, fill=txt_color, font=self.font)
|
466
|
+
xy[1] += h
|
467
|
+
else:
|
468
|
+
if box_color:
|
469
|
+
w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
|
470
|
+
h += 3 # add pixels to pad text
|
471
|
+
outside = xy[1] >= h # label fits outside box
|
472
|
+
p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
|
473
|
+
cv2.rectangle(self.im, xy, p2, box_color, -1, cv2.LINE_AA) # filled
|
474
|
+
cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
|
475
|
+
|
476
|
+
def fromarray(self, im):
|
477
|
+
"""Update self.im from a numpy array."""
|
478
|
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
479
|
+
self.draw = ImageDraw.Draw(self.im)
|
480
|
+
|
481
|
+
def result(self):
|
482
|
+
"""Return annotated image as array."""
|
483
|
+
return np.asarray(self.im)
|
484
|
+
|
485
|
+
def show(self, title=None):
|
486
|
+
"""Show the annotated image."""
|
487
|
+
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
|
488
|
+
if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
|
489
|
+
try:
|
490
|
+
display(im) # noqa - display() function only available in ipython environments
|
491
|
+
except ImportError as e:
|
492
|
+
LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}")
|
493
|
+
else:
|
494
|
+
im.show(title=title)
|
495
|
+
|
496
|
+
def save(self, filename="image.jpg"):
|
497
|
+
"""Save the annotated image to 'filename'."""
|
498
|
+
cv2.imwrite(filename, np.asarray(self.im))
|
499
|
+
|
500
|
+
@staticmethod
|
501
|
+
def get_bbox_dimension(bbox=None):
|
502
|
+
"""
|
503
|
+
Calculate the dimensions and area of a bounding box.
|
504
|
+
|
505
|
+
Args:
|
506
|
+
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
507
|
+
|
508
|
+
Returns:
|
509
|
+
width (float): Width of the bounding box.
|
510
|
+
height (float): Height of the bounding box.
|
511
|
+
area (float): Area enclosed by the bounding box.
|
512
|
+
|
513
|
+
Examples:
|
514
|
+
>>> from ultralytics.utils.plotting import Annotator
|
515
|
+
>>> im0 = cv2.imread("test.png")
|
516
|
+
>>> annotator = Annotator(im0, line_width=10)
|
517
|
+
>>> annotator.get_bbox_dimension(bbox=[10, 20, 30, 40])
|
518
|
+
"""
|
519
|
+
x_min, y_min, x_max, y_max = bbox
|
520
|
+
width = x_max - x_min
|
521
|
+
height = y_max - y_min
|
522
|
+
return width, height, width * height
|
523
|
+
|
524
|
+
|
525
|
+
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
526
|
+
@plt_settings()
|
527
|
+
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
528
|
+
"""
|
529
|
+
Plot training labels including class histograms and box statistics.
|
530
|
+
|
531
|
+
Args:
|
532
|
+
boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
|
533
|
+
cls (np.ndarray): Class indices.
|
534
|
+
names (dict, optional): Dictionary mapping class indices to class names.
|
535
|
+
save_dir (Path, optional): Directory to save the plot.
|
536
|
+
on_plot (Callable, optional): Function to call after plot is saved.
|
537
|
+
"""
|
538
|
+
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
539
|
+
import pandas
|
540
|
+
import seaborn
|
541
|
+
|
542
|
+
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
|
543
|
+
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
544
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
545
|
+
|
546
|
+
# Plot dataset labels
|
547
|
+
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
548
|
+
nc = int(cls.max() + 1) # number of classes
|
549
|
+
boxes = boxes[:1000000] # limit to 1M boxes
|
550
|
+
x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
|
551
|
+
|
552
|
+
# Seaborn correlogram
|
553
|
+
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
554
|
+
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
|
555
|
+
plt.close()
|
556
|
+
|
557
|
+
# Matplotlib labels
|
558
|
+
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
559
|
+
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
560
|
+
for i in range(nc):
|
561
|
+
y[2].patches[i].set_color([x / 255 for x in colors(i)])
|
562
|
+
ax[0].set_ylabel("instances")
|
563
|
+
if 0 < len(names) < 30:
|
564
|
+
ax[0].set_xticks(range(len(names)))
|
565
|
+
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
566
|
+
else:
|
567
|
+
ax[0].set_xlabel("classes")
|
568
|
+
seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
|
569
|
+
seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
|
570
|
+
|
571
|
+
# Rectangles
|
572
|
+
boxes[:, 0:2] = 0.5 # center
|
573
|
+
boxes = ops.xywh2xyxy(boxes) * 1000
|
574
|
+
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
|
575
|
+
for cls, box in zip(cls[:500], boxes[:500]):
|
576
|
+
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
577
|
+
ax[1].imshow(img)
|
578
|
+
ax[1].axis("off")
|
579
|
+
|
580
|
+
for a in [0, 1, 2, 3]:
|
581
|
+
for s in ["top", "right", "left", "bottom"]:
|
582
|
+
ax[a].spines[s].set_visible(False)
|
583
|
+
|
584
|
+
fname = save_dir / "labels.jpg"
|
585
|
+
plt.savefig(fname, dpi=200)
|
586
|
+
plt.close()
|
587
|
+
if on_plot:
|
588
|
+
on_plot(fname)
|
589
|
+
|
590
|
+
|
591
|
+
def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
592
|
+
"""
|
593
|
+
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
594
|
+
|
595
|
+
This function takes a bounding box and an image, and then saves a cropped portion of the image according
|
596
|
+
to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
|
597
|
+
adjustments to the bounding box.
|
598
|
+
|
599
|
+
Args:
|
600
|
+
xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
|
601
|
+
im (np.ndarray): The input image.
|
602
|
+
file (Path, optional): The path where the cropped image will be saved.
|
603
|
+
gain (float, optional): A multiplicative factor to increase the size of the bounding box.
|
604
|
+
pad (int, optional): The number of pixels to add to the width and height of the bounding box.
|
605
|
+
square (bool, optional): If True, the bounding box will be transformed into a square.
|
606
|
+
BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB.
|
607
|
+
save (bool, optional): If True, the cropped image will be saved to disk.
|
608
|
+
|
609
|
+
Returns:
|
610
|
+
(np.ndarray): The cropped image.
|
611
|
+
|
612
|
+
Examples:
|
613
|
+
>>> from ultralytics.utils.plotting import save_one_box
|
614
|
+
>>> xyxy = [50, 50, 150, 150]
|
615
|
+
>>> im = cv2.imread("image.jpg")
|
616
|
+
>>> cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
|
617
|
+
"""
|
618
|
+
if not isinstance(xyxy, torch.Tensor): # may be list
|
619
|
+
xyxy = torch.stack(xyxy)
|
620
|
+
b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
|
621
|
+
if square:
|
622
|
+
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
623
|
+
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
624
|
+
xyxy = ops.xywh2xyxy(b).long()
|
625
|
+
xyxy = ops.clip_boxes(xyxy, im.shape)
|
626
|
+
crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
|
627
|
+
if save:
|
628
|
+
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
629
|
+
f = str(increment_path(file).with_suffix(".jpg"))
|
630
|
+
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
631
|
+
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
632
|
+
return crop
|
633
|
+
|
634
|
+
|
635
|
+
@threaded
|
636
|
+
def plot_images(
|
637
|
+
images: Union[torch.Tensor, np.ndarray],
|
638
|
+
batch_idx: Union[torch.Tensor, np.ndarray],
|
639
|
+
cls: Union[torch.Tensor, np.ndarray],
|
640
|
+
bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
|
641
|
+
confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
|
642
|
+
masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
|
643
|
+
kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
|
644
|
+
paths: Optional[List[str]] = None,
|
645
|
+
fname: str = "images.jpg",
|
646
|
+
names: Optional[Dict[int, str]] = None,
|
647
|
+
on_plot: Optional[Callable] = None,
|
648
|
+
max_size: int = 1920,
|
649
|
+
max_subplots: int = 16,
|
650
|
+
save: bool = True,
|
651
|
+
conf_thres: float = 0.25,
|
652
|
+
) -> Optional[np.ndarray]:
|
653
|
+
"""
|
654
|
+
Plot image grid with labels, bounding boxes, masks, and keypoints.
|
655
|
+
|
656
|
+
Args:
|
657
|
+
images: Batch of images to plot. Shape: (batch_size, channels, height, width).
|
658
|
+
batch_idx: Batch indices for each detection. Shape: (num_detections,).
|
659
|
+
cls: Class labels for each detection. Shape: (num_detections,).
|
660
|
+
bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
|
661
|
+
confs: Confidence scores for each detection. Shape: (num_detections,).
|
662
|
+
masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
|
663
|
+
kpts: Keypoints for each detection. Shape: (num_detections, 51).
|
664
|
+
paths: List of file paths for each image in the batch.
|
665
|
+
fname: Output filename for the plotted image grid.
|
666
|
+
names: Dictionary mapping class indices to class names.
|
667
|
+
on_plot: Optional callback function to be called after saving the plot.
|
668
|
+
max_size: Maximum size of the output image grid.
|
669
|
+
max_subplots: Maximum number of subplots in the image grid.
|
670
|
+
save: Whether to save the plotted image grid to a file.
|
671
|
+
conf_thres: Confidence threshold for displaying detections.
|
672
|
+
|
673
|
+
Returns:
|
674
|
+
(np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
|
675
|
+
|
676
|
+
Note:
|
677
|
+
This function supports both tensor and numpy array inputs. It will automatically
|
678
|
+
convert tensor inputs to numpy arrays for processing.
|
679
|
+
"""
|
680
|
+
if isinstance(images, torch.Tensor):
|
681
|
+
images = images.cpu().float().numpy()
|
682
|
+
if isinstance(cls, torch.Tensor):
|
683
|
+
cls = cls.cpu().numpy()
|
684
|
+
if isinstance(bboxes, torch.Tensor):
|
685
|
+
bboxes = bboxes.cpu().numpy()
|
686
|
+
if isinstance(masks, torch.Tensor):
|
687
|
+
masks = masks.cpu().numpy().astype(int)
|
688
|
+
if isinstance(kpts, torch.Tensor):
|
689
|
+
kpts = kpts.cpu().numpy()
|
690
|
+
if isinstance(batch_idx, torch.Tensor):
|
691
|
+
batch_idx = batch_idx.cpu().numpy()
|
692
|
+
if images.shape[1] > 3:
|
693
|
+
images = images[:, :3] # crop multispectral images to first 3 channels
|
694
|
+
|
695
|
+
bs, _, h, w = images.shape # batch size, _, height, width
|
696
|
+
bs = min(bs, max_subplots) # limit plot images
|
697
|
+
ns = np.ceil(bs**0.5) # number of subplots (square)
|
698
|
+
if np.max(images[0]) <= 1:
|
699
|
+
images *= 255 # de-normalise (optional)
|
700
|
+
|
701
|
+
# Build Image
|
702
|
+
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
703
|
+
for i in range(bs):
|
704
|
+
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
705
|
+
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
|
706
|
+
|
707
|
+
# Resize (optional)
|
708
|
+
scale = max_size / ns / max(h, w)
|
709
|
+
if scale < 1:
|
710
|
+
h = math.ceil(scale * h)
|
711
|
+
w = math.ceil(scale * w)
|
712
|
+
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
713
|
+
|
714
|
+
# Annotate
|
715
|
+
fs = int((h + w) * ns * 0.01) # font size
|
716
|
+
fs = max(fs, 18) # ensure that the font size is large enough to be easily readable.
|
717
|
+
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=str(names))
|
718
|
+
for i in range(bs):
|
719
|
+
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
720
|
+
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
721
|
+
if paths:
|
722
|
+
annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
723
|
+
if len(cls) > 0:
|
724
|
+
idx = batch_idx == i
|
725
|
+
classes = cls[idx].astype("int")
|
726
|
+
labels = confs is None
|
727
|
+
|
728
|
+
if len(bboxes):
|
729
|
+
boxes = bboxes[idx]
|
730
|
+
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
731
|
+
if len(boxes):
|
732
|
+
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
733
|
+
boxes[..., [0, 2]] *= w # scale to pixels
|
734
|
+
boxes[..., [1, 3]] *= h
|
735
|
+
elif scale < 1: # absolute coords need scale if image scales
|
736
|
+
boxes[..., :4] *= scale
|
737
|
+
boxes[..., 0] += x
|
738
|
+
boxes[..., 1] += y
|
739
|
+
is_obb = boxes.shape[-1] == 5 # xywhr
|
740
|
+
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
741
|
+
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
742
|
+
c = classes[j]
|
743
|
+
color = colors(c)
|
744
|
+
c = names.get(c, c) if names else c
|
745
|
+
if labels or conf[j] > conf_thres:
|
746
|
+
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
747
|
+
annotator.box_label(box, label, color=color, rotated=is_obb)
|
748
|
+
|
749
|
+
elif len(classes):
|
750
|
+
for c in classes:
|
751
|
+
color = colors(c)
|
752
|
+
c = names.get(c, c) if names else c
|
753
|
+
annotator.text([x, y], f"{c}", txt_color=color, box_color=(64, 64, 64, 128))
|
754
|
+
|
755
|
+
# Plot keypoints
|
756
|
+
if len(kpts):
|
757
|
+
kpts_ = kpts[idx].copy()
|
758
|
+
if len(kpts_):
|
759
|
+
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
|
760
|
+
kpts_[..., 0] *= w # scale to pixels
|
761
|
+
kpts_[..., 1] *= h
|
762
|
+
elif scale < 1: # absolute coords need scale if image scales
|
763
|
+
kpts_ *= scale
|
764
|
+
kpts_[..., 0] += x
|
765
|
+
kpts_[..., 1] += y
|
766
|
+
for j in range(len(kpts_)):
|
767
|
+
if labels or conf[j] > conf_thres:
|
768
|
+
annotator.kpts(kpts_[j], conf_thres=conf_thres)
|
769
|
+
|
770
|
+
# Plot masks
|
771
|
+
if len(masks):
|
772
|
+
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
|
773
|
+
image_masks = masks[idx]
|
774
|
+
else: # overlap_masks=True
|
775
|
+
image_masks = masks[[i]] # (1, 640, 640)
|
776
|
+
nl = idx.sum()
|
777
|
+
index = np.arange(nl).reshape((nl, 1, 1)) + 1
|
778
|
+
image_masks = np.repeat(image_masks, nl, axis=0)
|
779
|
+
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
780
|
+
|
781
|
+
im = np.asarray(annotator.im).copy()
|
782
|
+
for j in range(len(image_masks)):
|
783
|
+
if labels or conf[j] > conf_thres:
|
784
|
+
color = colors(classes[j])
|
785
|
+
mh, mw = image_masks[j].shape
|
786
|
+
if mh != h or mw != w:
|
787
|
+
mask = image_masks[j].astype(np.uint8)
|
788
|
+
mask = cv2.resize(mask, (w, h))
|
789
|
+
mask = mask.astype(bool)
|
790
|
+
else:
|
791
|
+
mask = image_masks[j].astype(bool)
|
792
|
+
try:
|
793
|
+
im[y : y + h, x : x + w, :][mask] = (
|
794
|
+
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
795
|
+
)
|
796
|
+
except Exception:
|
797
|
+
pass
|
798
|
+
annotator.fromarray(im)
|
799
|
+
if not save:
|
800
|
+
return np.asarray(annotator.im)
|
801
|
+
annotator.im.save(fname) # save
|
802
|
+
if on_plot:
|
803
|
+
on_plot(fname)
|
804
|
+
|
805
|
+
|
806
|
+
@plt_settings()
|
807
|
+
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
|
808
|
+
"""
|
809
|
+
Plot training results from a results CSV file. The function supports various types of data including segmentation,
|
810
|
+
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
|
811
|
+
|
812
|
+
Args:
|
813
|
+
file (str, optional): Path to the CSV file containing the training results.
|
814
|
+
dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
|
815
|
+
segment (bool, optional): Flag to indicate if the data is for segmentation.
|
816
|
+
pose (bool, optional): Flag to indicate if the data is for pose estimation.
|
817
|
+
classify (bool, optional): Flag to indicate if the data is for classification.
|
818
|
+
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
|
819
|
+
|
820
|
+
Examples:
|
821
|
+
>>> from ultralytics.utils.plotting import plot_results
|
822
|
+
>>> plot_results("path/to/results.csv", segment=True)
|
823
|
+
"""
|
824
|
+
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
825
|
+
import pandas as pd
|
826
|
+
from scipy.ndimage import gaussian_filter1d
|
827
|
+
|
828
|
+
save_dir = Path(file).parent if file else Path(dir)
|
829
|
+
if classify:
|
830
|
+
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
|
831
|
+
index = [2, 5, 3, 4]
|
832
|
+
elif segment:
|
833
|
+
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
834
|
+
index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
|
835
|
+
elif pose:
|
836
|
+
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
|
837
|
+
index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
|
838
|
+
else:
|
839
|
+
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
840
|
+
index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
|
841
|
+
ax = ax.ravel()
|
842
|
+
files = list(save_dir.glob("results*.csv"))
|
843
|
+
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
844
|
+
for f in files:
|
845
|
+
try:
|
846
|
+
data = pd.read_csv(f)
|
847
|
+
s = [x.strip() for x in data.columns]
|
848
|
+
x = data.values[:, 0]
|
849
|
+
for i, j in enumerate(index):
|
850
|
+
y = data.values[:, j].astype("float")
|
851
|
+
# y[y == 0] = np.nan # don't show zero values
|
852
|
+
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
853
|
+
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
854
|
+
ax[i].set_title(s[j], fontsize=12)
|
855
|
+
# if j in {8, 9, 10}: # share train and val loss y axes
|
856
|
+
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
857
|
+
except Exception as e:
|
858
|
+
LOGGER.error(f"Plotting error for {f}: {e}")
|
859
|
+
ax[1].legend()
|
860
|
+
fname = save_dir / "results.png"
|
861
|
+
fig.savefig(fname, dpi=200)
|
862
|
+
plt.close()
|
863
|
+
if on_plot:
|
864
|
+
on_plot(fname)
|
865
|
+
|
866
|
+
|
867
|
+
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
|
868
|
+
"""
|
869
|
+
Plot a scatter plot with points colored based on a 2D histogram.
|
870
|
+
|
871
|
+
Args:
|
872
|
+
v (array-like): Values for the x-axis.
|
873
|
+
f (array-like): Values for the y-axis.
|
874
|
+
bins (int, optional): Number of bins for the histogram.
|
875
|
+
cmap (str, optional): Colormap for the scatter plot.
|
876
|
+
alpha (float, optional): Alpha for the scatter plot.
|
877
|
+
edgecolors (str, optional): Edge colors for the scatter plot.
|
878
|
+
|
879
|
+
Examples:
|
880
|
+
>>> v = np.random.rand(100)
|
881
|
+
>>> f = np.random.rand(100)
|
882
|
+
>>> plt_color_scatter(v, f)
|
883
|
+
"""
|
884
|
+
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
885
|
+
|
886
|
+
# Calculate 2D histogram and corresponding colors
|
887
|
+
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
|
888
|
+
colors = [
|
889
|
+
hist[
|
890
|
+
min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
|
891
|
+
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
|
892
|
+
]
|
893
|
+
for i in range(len(v))
|
894
|
+
]
|
895
|
+
|
896
|
+
# Scatter plot
|
897
|
+
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
|
898
|
+
|
899
|
+
|
900
|
+
def plot_tune_results(csv_file="tune_results.csv"):
|
901
|
+
"""
|
902
|
+
Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
|
903
|
+
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
|
904
|
+
|
905
|
+
Args:
|
906
|
+
csv_file (str, optional): Path to the CSV file containing the tuning results.
|
907
|
+
|
908
|
+
Examples:
|
909
|
+
>>> plot_tune_results("path/to/tune_results.csv")
|
910
|
+
"""
|
911
|
+
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
912
|
+
import pandas as pd
|
913
|
+
from scipy.ndimage import gaussian_filter1d
|
914
|
+
|
915
|
+
def _save_one_file(file):
|
916
|
+
"""Save one matplotlib plot to 'file'."""
|
917
|
+
plt.savefig(file, dpi=200)
|
918
|
+
plt.close()
|
919
|
+
LOGGER.info(f"Saved {file}")
|
920
|
+
|
921
|
+
# Scatter plots for each hyperparameter
|
922
|
+
csv_file = Path(csv_file)
|
923
|
+
data = pd.read_csv(csv_file)
|
924
|
+
num_metrics_columns = 1
|
925
|
+
keys = [x.strip() for x in data.columns][num_metrics_columns:]
|
926
|
+
x = data.values
|
927
|
+
fitness = x[:, 0] # fitness
|
928
|
+
j = np.argmax(fitness) # max fitness index
|
929
|
+
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
|
930
|
+
plt.figure(figsize=(10, 10), tight_layout=True)
|
931
|
+
for i, k in enumerate(keys):
|
932
|
+
v = x[:, i + num_metrics_columns]
|
933
|
+
mu = v[j] # best single result
|
934
|
+
plt.subplot(n, n, i + 1)
|
935
|
+
plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
|
936
|
+
plt.plot(mu, fitness.max(), "k+", markersize=15)
|
937
|
+
plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
|
938
|
+
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
|
939
|
+
if i % n != 0:
|
940
|
+
plt.yticks([])
|
941
|
+
_save_one_file(csv_file.with_name("tune_scatter_plots.png"))
|
942
|
+
|
943
|
+
# Fitness vs iteration
|
944
|
+
x = range(1, len(fitness) + 1)
|
945
|
+
plt.figure(figsize=(10, 6), tight_layout=True)
|
946
|
+
plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
|
947
|
+
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
|
948
|
+
plt.title("Fitness vs Iteration")
|
949
|
+
plt.xlabel("Iteration")
|
950
|
+
plt.ylabel("Fitness")
|
951
|
+
plt.grid(True)
|
952
|
+
plt.legend()
|
953
|
+
_save_one_file(csv_file.with_name("tune_fitness.png"))
|
954
|
+
|
955
|
+
|
956
|
+
def output_to_target(output, max_det=300):
|
957
|
+
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
958
|
+
targets = []
|
959
|
+
for i, o in enumerate(output):
|
960
|
+
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
961
|
+
j = torch.full((conf.shape[0], 1), i)
|
962
|
+
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
|
963
|
+
targets = torch.cat(targets, 0).numpy()
|
964
|
+
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
965
|
+
|
966
|
+
|
967
|
+
def output_to_rotated_target(output, max_det=300):
|
968
|
+
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
969
|
+
targets = []
|
970
|
+
for i, o in enumerate(output):
|
971
|
+
box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
|
972
|
+
j = torch.full((conf.shape[0], 1), i)
|
973
|
+
targets.append(torch.cat((j, cls, box, angle, conf), 1))
|
974
|
+
targets = torch.cat(targets, 0).numpy()
|
975
|
+
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
976
|
+
|
977
|
+
|
978
|
+
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
|
979
|
+
"""
|
980
|
+
Visualize feature maps of a given model module during inference.
|
981
|
+
|
982
|
+
Args:
|
983
|
+
x (torch.Tensor): Features to be visualized.
|
984
|
+
module_type (str): Module type.
|
985
|
+
stage (int): Module stage within the model.
|
986
|
+
n (int, optional): Maximum number of feature maps to plot.
|
987
|
+
save_dir (Path, optional): Directory to save results.
|
988
|
+
"""
|
989
|
+
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
990
|
+
|
991
|
+
for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads
|
992
|
+
if m in module_type:
|
993
|
+
return
|
994
|
+
if isinstance(x, torch.Tensor):
|
995
|
+
_, channels, height, width = x.shape # batch, channels, height, width
|
996
|
+
if height > 1 and width > 1:
|
997
|
+
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
998
|
+
|
999
|
+
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
1000
|
+
n = min(n, channels) # number of plots
|
1001
|
+
_, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
1002
|
+
ax = ax.ravel()
|
1003
|
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
1004
|
+
for i in range(n):
|
1005
|
+
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
1006
|
+
ax[i].axis("off")
|
1007
|
+
|
1008
|
+
LOGGER.info(f"Saving {f}... ({n}/{channels})")
|
1009
|
+
plt.savefig(f, dpi=300, bbox_inches="tight")
|
1010
|
+
plt.close()
|
1011
|
+
np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save
|