dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,1047 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from collections.abc import Callable
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image, ImageDraw, ImageFont
14
+ from PIL import __version__ as pil_version
15
+
16
+ from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded
17
+ from ultralytics.utils.checks import check_font, check_version, is_ascii
18
+ from ultralytics.utils.files import increment_path
19
+
20
+
21
+ class Colors:
22
+ """Ultralytics color palette for visualization and plotting.
23
+
24
+ This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
25
+ values and accessing predefined color schemes for object detection and pose estimation.
26
+
27
+ ## Ultralytics Color Palette
28
+
29
+ | Index | Color | HEX | RGB |
30
+ |-------|-------------------------------------------------------------------|-----------|-------------------|
31
+ | 0 | <i class="fa-solid fa-square fa-2xl" style="color: #042aff;"></i> | `#042aff` | (4, 42, 255) |
32
+ | 1 | <i class="fa-solid fa-square fa-2xl" style="color: #0bdbeb;"></i> | `#0bdbeb` | (11, 219, 235) |
33
+ | 2 | <i class="fa-solid fa-square fa-2xl" style="color: #f3f3f3;"></i> | `#f3f3f3` | (243, 243, 243) |
34
+ | 3 | <i class="fa-solid fa-square fa-2xl" style="color: #00dfb7;"></i> | `#00dfb7` | (0, 223, 183) |
35
+ | 4 | <i class="fa-solid fa-square fa-2xl" style="color: #111f68;"></i> | `#111f68` | (17, 31, 104) |
36
+ | 5 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6fdd;"></i> | `#ff6fdd` | (255, 111, 221) |
37
+ | 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff444f;"></i> | `#ff444f` | (255, 68, 79) |
38
+ | 7 | <i class="fa-solid fa-square fa-2xl" style="color: #cced00;"></i> | `#cced00` | (204, 237, 0) |
39
+ | 8 | <i class="fa-solid fa-square fa-2xl" style="color: #00f344;"></i> | `#00f344` | (0, 243, 68) |
40
+ | 9 | <i class="fa-solid fa-square fa-2xl" style="color: #bd00ff;"></i> | `#bd00ff` | (189, 0, 255) |
41
+ | 10 | <i class="fa-solid fa-square fa-2xl" style="color: #00b4ff;"></i> | `#00b4ff` | (0, 180, 255) |
42
+ | 11 | <i class="fa-solid fa-square fa-2xl" style="color: #dd00ba;"></i> | `#dd00ba` | (221, 0, 186) |
43
+ | 12 | <i class="fa-solid fa-square fa-2xl" style="color: #00ffff;"></i> | `#00ffff` | (0, 255, 255) |
44
+ | 13 | <i class="fa-solid fa-square fa-2xl" style="color: #26c000;"></i> | `#26c000` | (38, 192, 0) |
45
+ | 14 | <i class="fa-solid fa-square fa-2xl" style="color: #01ffb3;"></i> | `#01ffb3` | (1, 255, 179) |
46
+ | 15 | <i class="fa-solid fa-square fa-2xl" style="color: #7d24ff;"></i> | `#7d24ff` | (125, 36, 255) |
47
+ | 16 | <i class="fa-solid fa-square fa-2xl" style="color: #7b0068;"></i> | `#7b0068` | (123, 0, 104) |
48
+ | 17 | <i class="fa-solid fa-square fa-2xl" style="color: #ff1b6c;"></i> | `#ff1b6c` | (255, 27, 108) |
49
+ | 18 | <i class="fa-solid fa-square fa-2xl" style="color: #fc6d2f;"></i> | `#fc6d2f` | (252, 109, 47) |
50
+ | 19 | <i class="fa-solid fa-square fa-2xl" style="color: #a2ff0b;"></i> | `#a2ff0b` | (162, 255, 11) |
51
+
52
+ ## Pose Color Palette
53
+
54
+ | Index | Color | HEX | RGB |
55
+ |-------|-------------------------------------------------------------------|-----------|-------------------|
56
+ | 0 | <i class="fa-solid fa-square fa-2xl" style="color: #ff8000;"></i> | `#ff8000` | (255, 128, 0) |
57
+ | 1 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9933;"></i> | `#ff9933` | (255, 153, 51) |
58
+ | 2 | <i class="fa-solid fa-square fa-2xl" style="color: #ffb266;"></i> | `#ffb266` | (255, 178, 102) |
59
+ | 3 | <i class="fa-solid fa-square fa-2xl" style="color: #e6e600;"></i> | `#e6e600` | (230, 230, 0) |
60
+ | 4 | <i class="fa-solid fa-square fa-2xl" style="color: #ff99ff;"></i> | `#ff99ff` | (255, 153, 255) |
61
+ | 5 | <i class="fa-solid fa-square fa-2xl" style="color: #99ccff;"></i> | `#99ccff` | (153, 204, 255) |
62
+ | 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff66ff;"></i> | `#ff66ff` | (255, 102, 255) |
63
+ | 7 | <i class="fa-solid fa-square fa-2xl" style="color: #ff33ff;"></i> | `#ff33ff` | (255, 51, 255) |
64
+ | 8 | <i class="fa-solid fa-square fa-2xl" style="color: #66b2ff;"></i> | `#66b2ff` | (102, 178, 255) |
65
+ | 9 | <i class="fa-solid fa-square fa-2xl" style="color: #3399ff;"></i> | `#3399ff` | (51, 153, 255) |
66
+ | 10 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9999;"></i> | `#ff9999` | (255, 153, 153) |
67
+ | 11 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6666;"></i> | `#ff6666` | (255, 102, 102) |
68
+ | 12 | <i class="fa-solid fa-square fa-2xl" style="color: #ff3333;"></i> | `#ff3333` | (255, 51, 51) |
69
+ | 13 | <i class="fa-solid fa-square fa-2xl" style="color: #99ff99;"></i> | `#99ff99` | (153, 255, 153) |
70
+ | 14 | <i class="fa-solid fa-square fa-2xl" style="color: #66ff66;"></i> | `#66ff66` | (102, 255, 102) |
71
+ | 15 | <i class="fa-solid fa-square fa-2xl" style="color: #33ff33;"></i> | `#33ff33` | (51, 255, 51) |
72
+ | 16 | <i class="fa-solid fa-square fa-2xl" style="color: #00ff00;"></i> | `#00ff00` | (0, 255, 0) |
73
+ | 17 | <i class="fa-solid fa-square fa-2xl" style="color: #0000ff;"></i> | `#0000ff` | (0, 0, 255) |
74
+ | 18 | <i class="fa-solid fa-square fa-2xl" style="color: #ff0000;"></i> | `#ff0000` | (255, 0, 0) |
75
+ | 19 | <i class="fa-solid fa-square fa-2xl" style="color: #ffffff;"></i> | `#ffffff` | (255, 255, 255) |
76
+
77
+ !!! note "Ultralytics Brand Colors"
78
+
79
+ For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
80
+ Please use the official Ultralytics colors for all marketing materials.
81
+
82
+ Attributes:
83
+ palette (list[tuple]): List of RGB color tuples for general use.
84
+ n (int): The number of colors in the palette.
85
+ pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
86
+
87
+ Examples:
88
+ >>> from ultralytics.utils.plotting import Colors
89
+ >>> colors = Colors()
90
+ >>> colors(5, True) # Returns BGR format: (221, 111, 255)
91
+ >>> colors(5, False) # Returns RGB format: (255, 111, 221)
92
+ """
93
+
94
+ def __init__(self):
95
+ """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
96
+ hexs = (
97
+ "042AFF",
98
+ "0BDBEB",
99
+ "F3F3F3",
100
+ "00DFB7",
101
+ "111F68",
102
+ "FF6FDD",
103
+ "FF444F",
104
+ "CCED00",
105
+ "00F344",
106
+ "BD00FF",
107
+ "00B4FF",
108
+ "DD00BA",
109
+ "00FFFF",
110
+ "26C000",
111
+ "01FFB3",
112
+ "7D24FF",
113
+ "7B0068",
114
+ "FF1B6C",
115
+ "FC6D2F",
116
+ "A2FF0B",
117
+ )
118
+ self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
119
+ self.n = len(self.palette)
120
+ self.pose_palette = np.array(
121
+ [
122
+ [255, 128, 0],
123
+ [255, 153, 51],
124
+ [255, 178, 102],
125
+ [230, 230, 0],
126
+ [255, 153, 255],
127
+ [153, 204, 255],
128
+ [255, 102, 255],
129
+ [255, 51, 255],
130
+ [102, 178, 255],
131
+ [51, 153, 255],
132
+ [255, 153, 153],
133
+ [255, 102, 102],
134
+ [255, 51, 51],
135
+ [153, 255, 153],
136
+ [102, 255, 102],
137
+ [51, 255, 51],
138
+ [0, 255, 0],
139
+ [0, 0, 255],
140
+ [255, 0, 0],
141
+ [255, 255, 255],
142
+ ],
143
+ dtype=np.uint8,
144
+ )
145
+
146
+ def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
147
+ """Convert hex color codes to RGB values.
148
+
149
+ Args:
150
+ i (int | torch.Tensor): Color index.
151
+ bgr (bool, optional): Whether to return BGR format instead of RGB.
152
+
153
+ Returns:
154
+ (tuple): RGB or BGR color tuple.
155
+ """
156
+ c = self.palette[int(i) % self.n]
157
+ return (c[2], c[1], c[0]) if bgr else c
158
+
159
+ @staticmethod
160
+ def hex2rgb(h: str) -> tuple:
161
+ """Convert hex color codes to RGB values (i.e. default PIL order)."""
162
+ return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
163
+
164
+
165
+ colors = Colors() # create instance for 'from utils.plots import colors'
166
+
167
+
168
+ class Annotator:
169
+ """Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
170
+
171
+ Attributes:
172
+ im (Image.Image | np.ndarray): The image to annotate.
173
+ pil (bool): Whether to use PIL or cv2 for drawing annotations.
174
+ font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.
175
+ lw (float): Line width for drawing.
176
+ skeleton (list[list[int]]): Skeleton structure for keypoints.
177
+ limb_color (list[int]): Color palette for limbs.
178
+ kpt_color (list[int]): Color palette for keypoints.
179
+ dark_colors (set): Set of colors considered dark for text contrast.
180
+ light_colors (set): Set of colors considered light for text contrast.
181
+
182
+ Examples:
183
+ >>> from ultralytics.utils.plotting import Annotator
184
+ >>> im0 = cv2.imread("test.png")
185
+ >>> annotator = Annotator(im0, line_width=10)
186
+ >>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0))
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ im,
192
+ line_width: int | None = None,
193
+ font_size: int | None = None,
194
+ font: str = "Arial.ttf",
195
+ pil: bool = False,
196
+ example: str = "abc",
197
+ ):
198
+ """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
199
+ non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
200
+ input_is_pil = isinstance(im, Image.Image)
201
+ self.pil = pil or non_ascii or input_is_pil
202
+ self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
203
+ if not input_is_pil:
204
+ if im.shape[2] == 1: # handle grayscale
205
+ im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
206
+ elif im.shape[2] == 2: # handle 2-channel images
207
+ im = np.ascontiguousarray(np.dstack((im, np.zeros_like(im[..., :1]))))
208
+ elif im.shape[2] > 3: # multispectral
209
+ im = np.ascontiguousarray(im[..., :3])
210
+ if self.pil: # use PIL
211
+ self.im = im if input_is_pil else Image.fromarray(im) # stay in BGR since color palette is in BGR
212
+ if self.im.mode not in {"RGB", "RGBA"}: # multispectral
213
+ self.im = self.im.convert("RGB")
214
+ self.draw = ImageDraw.Draw(self.im, "RGBA")
215
+ try:
216
+ font = check_font("Arial.Unicode.ttf" if non_ascii else font)
217
+ size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
218
+ self.font = ImageFont.truetype(str(font), size)
219
+ except Exception:
220
+ self.font = ImageFont.load_default()
221
+ # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
222
+ if check_version(pil_version, "9.2.0"):
223
+ self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
224
+ else: # use cv2
225
+ assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
226
+ self.im = im if im.flags.writeable else im.copy()
227
+ self.tf = max(self.lw - 1, 1) # font thickness
228
+ self.sf = self.lw / 3 # font scale
229
+ # Pose
230
+ self.skeleton = [
231
+ [16, 14],
232
+ [14, 12],
233
+ [17, 15],
234
+ [15, 13],
235
+ [12, 13],
236
+ [6, 12],
237
+ [7, 13],
238
+ [6, 7],
239
+ [6, 8],
240
+ [7, 9],
241
+ [8, 10],
242
+ [9, 11],
243
+ [2, 3],
244
+ [1, 2],
245
+ [1, 3],
246
+ [2, 4],
247
+ [3, 5],
248
+ [4, 6],
249
+ [5, 7],
250
+ ]
251
+
252
+ self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
253
+ self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
254
+ self.dark_colors = {
255
+ (235, 219, 11),
256
+ (243, 243, 243),
257
+ (183, 223, 0),
258
+ (221, 111, 255),
259
+ (0, 237, 204),
260
+ (68, 243, 0),
261
+ (255, 255, 0),
262
+ (179, 255, 1),
263
+ (11, 255, 162),
264
+ }
265
+ self.light_colors = {
266
+ (255, 42, 4),
267
+ (79, 68, 255),
268
+ (255, 0, 189),
269
+ (255, 180, 0),
270
+ (186, 0, 221),
271
+ (0, 192, 38),
272
+ (255, 36, 125),
273
+ (104, 0, 123),
274
+ (108, 27, 255),
275
+ (47, 109, 252),
276
+ (104, 31, 17),
277
+ }
278
+
279
+ def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
280
+ """Assign text color based on background color.
281
+
282
+ Args:
283
+ color (tuple, optional): The background color of the rectangle for text (B, G, R).
284
+ txt_color (tuple, optional): The color of the text (R, G, B).
285
+
286
+ Returns:
287
+ (tuple): Text color for label.
288
+
289
+ Examples:
290
+ >>> from ultralytics.utils.plotting import Annotator
291
+ >>> im0 = cv2.imread("test.png")
292
+ >>> annotator = Annotator(im0, line_width=10)
293
+ >>> annotator.get_txt_color(color=(104, 31, 17)) # return (255, 255, 255)
294
+ """
295
+ if color in self.dark_colors:
296
+ return 104, 31, 17
297
+ elif color in self.light_colors:
298
+ return 255, 255, 255
299
+ else:
300
+ return txt_color
301
+
302
+ def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
303
+ """Draw a bounding box on an image with a given label.
304
+
305
+ Args:
306
+ box (tuple): The bounding box coordinates (x1, y1, x2, y2).
307
+ label (str, optional): The text label to be displayed.
308
+ color (tuple, optional): The background color of the rectangle (B, G, R).
309
+ txt_color (tuple, optional): The color of the text (R, G, B).
310
+
311
+ Examples:
312
+ >>> from ultralytics.utils.plotting import Annotator
313
+ >>> im0 = cv2.imread("test.png")
314
+ >>> annotator = Annotator(im0, line_width=10)
315
+ >>> annotator.box_label(box=[10, 20, 30, 40], label="person")
316
+ """
317
+ txt_color = self.get_txt_color(color, txt_color)
318
+ if isinstance(box, torch.Tensor):
319
+ box = box.tolist()
320
+
321
+ multi_points = isinstance(box[0], list) # multiple points with shape (n, 2)
322
+ p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1]))
323
+ if self.pil:
324
+ self.draw.polygon(
325
+ [tuple(b) for b in box], width=self.lw, outline=color
326
+ ) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color)
327
+ if label:
328
+ w, h = self.font.getsize(label) # text width, height
329
+ outside = p1[1] >= h # label fits outside box
330
+ if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image
331
+ p1 = self.im.size[0] - w, p1[1]
332
+ self.draw.rectangle(
333
+ (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
334
+ fill=color,
335
+ )
336
+ # self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
337
+ self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
338
+ else: # cv2
339
+ cv2.polylines(
340
+ self.im, [np.asarray(box, dtype=int)], True, color, self.lw
341
+ ) if multi_points else cv2.rectangle(
342
+ self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA
343
+ )
344
+ if label:
345
+ w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
346
+ h += 3 # add pixels to pad text
347
+ outside = p1[1] >= h # label fits outside box
348
+ if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image
349
+ p1 = self.im.shape[1] - w, p1[1]
350
+ p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h
351
+ cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
352
+ cv2.putText(
353
+ self.im,
354
+ label,
355
+ (p1[0], p1[1] - 2 if outside else p1[1] + h - 1),
356
+ 0,
357
+ self.sf,
358
+ txt_color,
359
+ thickness=self.tf,
360
+ lineType=cv2.LINE_AA,
361
+ )
362
+
363
+ def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
364
+ """Plot masks on image.
365
+
366
+ Args:
367
+ masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
368
+ colors (list[list[int]]): Colors for predicted masks, [[r, g, b] * n]
369
+ im_gpu (torch.Tensor | None): Image is in cuda, shape: [3, h, w], range: [0, 1]
370
+ alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque.
371
+ retina_masks (bool, optional): Whether to use high resolution masks or not.
372
+ """
373
+ if self.pil:
374
+ # Convert to numpy first
375
+ self.im = np.asarray(self.im).copy()
376
+ if im_gpu is None:
377
+ assert isinstance(masks, np.ndarray), "`masks` must be a np.ndarray if `im_gpu` is not provided."
378
+ overlay = self.im.copy()
379
+ for i, mask in enumerate(masks):
380
+ overlay[mask.astype(bool)] = colors[i]
381
+ self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
382
+ else:
383
+ assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
384
+ if len(masks) == 0:
385
+ self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
386
+ return
387
+ if im_gpu.device != masks.device:
388
+ im_gpu = im_gpu.to(masks.device)
389
+
390
+ ih, iw = self.im.shape[:2]
391
+ if not retina_masks:
392
+ # Use scale_masks to properly remove padding and upsample, convert bool to float first
393
+ masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
394
+ # Convert original BGR image to RGB tensor
395
+ im_gpu = (
396
+ torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
397
+ )
398
+
399
+ colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
400
+ colors = colors[:, None, None] # shape(n,1,1,3)
401
+ masks = masks.unsqueeze(3) # shape(n,h,w,1)
402
+ masks_color = masks * (colors * alpha) # shape(n,h,w,3)
403
+ inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
404
+ mcs = masks_color.max(dim=0).values # shape(h,w,3)
405
+
406
+ im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
407
+ im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
408
+ self.im[:] = (im_gpu * 255).byte().cpu().numpy()
409
+ if self.pil:
410
+ # Convert im back to PIL and update draw
411
+ self.fromarray(self.im)
412
+
413
+ def kpts(
414
+ self,
415
+ kpts,
416
+ shape: tuple = (640, 640),
417
+ radius: int | None = None,
418
+ kpt_line: bool = True,
419
+ conf_thres: float = 0.25,
420
+ kpt_color: tuple | None = None,
421
+ ):
422
+ """Plot keypoints on the image.
423
+
424
+ Args:
425
+ kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
426
+ shape (tuple, optional): Image shape (h, w).
427
+ radius (int, optional): Keypoint radius.
428
+ kpt_line (bool, optional): Draw lines between keypoints.
429
+ conf_thres (float, optional): Confidence threshold.
430
+ kpt_color (tuple, optional): Keypoint color (B, G, R).
431
+
432
+ Notes:
433
+ - `kpt_line=True` currently only supports human pose plotting.
434
+ - Modifies self.im in-place.
435
+ - If self.pil is True, converts image to numpy array and back to PIL.
436
+ """
437
+ radius = radius if radius is not None else self.lw
438
+ if self.pil:
439
+ # Convert to numpy first
440
+ self.im = np.asarray(self.im).copy()
441
+ nkpt, ndim = kpts.shape
442
+ is_pose = nkpt == 17 and ndim in {2, 3}
443
+ kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
444
+ for i, k in enumerate(kpts):
445
+ color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
446
+ x_coord, y_coord = k[0], k[1]
447
+ if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
448
+ if len(k) == 3:
449
+ conf = k[2]
450
+ if conf < conf_thres:
451
+ continue
452
+ cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
453
+
454
+ if kpt_line:
455
+ ndim = kpts.shape[-1]
456
+ for i, sk in enumerate(self.skeleton):
457
+ pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
458
+ pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
459
+ if ndim == 3:
460
+ conf1 = kpts[(sk[0] - 1), 2]
461
+ conf2 = kpts[(sk[1] - 1), 2]
462
+ if conf1 < conf_thres or conf2 < conf_thres:
463
+ continue
464
+ if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
465
+ continue
466
+ if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
467
+ continue
468
+ cv2.line(
469
+ self.im,
470
+ pos1,
471
+ pos2,
472
+ kpt_color or self.limb_color[i].tolist(),
473
+ thickness=int(np.ceil(self.lw / 2)),
474
+ lineType=cv2.LINE_AA,
475
+ )
476
+ if self.pil:
477
+ # Convert im back to PIL and update draw
478
+ self.fromarray(self.im)
479
+
480
+ def rectangle(self, xy, fill=None, outline=None, width: int = 1):
481
+ """Add rectangle to image (PIL-only)."""
482
+ self.draw.rectangle(xy, fill, outline, width)
483
+
484
+ def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
485
+ """Add text to an image using PIL or cv2.
486
+
487
+ Args:
488
+ xy (list[int]): Top-left coordinates for text placement.
489
+ text (str): Text to be drawn.
490
+ txt_color (tuple, optional): Text color (R, G, B).
491
+ anchor (str, optional): Text anchor position ('top' or 'bottom').
492
+ box_color (tuple, optional): Box color (R, G, B, A) with optional alpha.
493
+ """
494
+ if self.pil:
495
+ w, h = self.font.getsize(text)
496
+ if anchor == "bottom": # start y from font bottom
497
+ xy[1] += 1 - h
498
+ for line in text.split("\n"):
499
+ if box_color:
500
+ # Draw rectangle for each line
501
+ w, h = self.font.getsize(line)
502
+ self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=box_color)
503
+ self.draw.text(xy, line, fill=txt_color, font=self.font)
504
+ xy[1] += h
505
+ else:
506
+ if box_color:
507
+ w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
508
+ h += 3 # add pixels to pad text
509
+ outside = xy[1] >= h # label fits outside box
510
+ p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
511
+ cv2.rectangle(self.im, xy, p2, box_color, -1, cv2.LINE_AA) # filled
512
+ cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
513
+
514
+ def fromarray(self, im):
515
+ """Update `self.im` from a NumPy array or PIL image."""
516
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
517
+ self.draw = ImageDraw.Draw(self.im)
518
+
519
+ def result(self, pil=False):
520
+ """Return annotated image as array or PIL image."""
521
+ im = np.asarray(self.im) # self.im is in BGR
522
+ return Image.fromarray(im[..., ::-1]) if pil else im
523
+
524
+ def show(self, title: str | None = None):
525
+ """Show the annotated image."""
526
+ im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert BGR NumPy array to RGB PIL Image
527
+ if IS_COLAB or IS_KAGGLE: # cannot use IS_JUPYTER as it runs for all IPython environments
528
+ try:
529
+ display(im) # noqa - display() function only available in ipython environments
530
+ except ImportError as e:
531
+ LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}")
532
+ else:
533
+ im.show(title=title)
534
+
535
+ def save(self, filename: str = "image.jpg"):
536
+ """Save the annotated image to 'filename'."""
537
+ cv2.imwrite(filename, np.asarray(self.im))
538
+
539
+ @staticmethod
540
+ def get_bbox_dimension(bbox: tuple | list):
541
+ """Calculate the dimensions and area of a bounding box.
542
+
543
+ Args:
544
+ bbox (tuple | list): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
545
+
546
+ Returns:
547
+ width (float): Width of the bounding box.
548
+ height (float): Height of the bounding box.
549
+ area (float): Area enclosed by the bounding box.
550
+
551
+ Examples:
552
+ >>> from ultralytics.utils.plotting import Annotator
553
+ >>> im0 = cv2.imread("test.png")
554
+ >>> annotator = Annotator(im0, line_width=10)
555
+ >>> annotator.get_bbox_dimension(bbox=[10, 20, 30, 40])
556
+ """
557
+ x_min, y_min, x_max, y_max = bbox
558
+ width = x_max - x_min
559
+ height = y_max - y_min
560
+ return width, height, width * height
561
+
562
+
563
+ @TryExcept()
564
+ @plt_settings()
565
+ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
566
+ """Plot training labels including class histograms and box statistics.
567
+
568
+ Args:
569
+ boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
570
+ cls (np.ndarray): Class indices.
571
+ names (dict, optional): Dictionary mapping class indices to class names.
572
+ save_dir (Path, optional): Directory to save the plot.
573
+ on_plot (Callable, optional): Function to call after plot is saved.
574
+ """
575
+ import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
576
+ import polars
577
+ from matplotlib.colors import LinearSegmentedColormap
578
+
579
+ # Plot dataset labels
580
+ LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
581
+ nc = int(cls.max() + 1) # number of classes
582
+ boxes = boxes[:1000000] # limit to 1M boxes
583
+ x = polars.DataFrame(boxes, schema=["x", "y", "width", "height"])
584
+
585
+ # Matplotlib labels
586
+ subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"])
587
+ ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
588
+ y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
589
+ for i in range(nc):
590
+ y[2].patches[i].set_color([x / 255 for x in colors(i)])
591
+ ax[0].set_ylabel("instances")
592
+ if 0 < len(names) < 30:
593
+ ax[0].set_xticks(range(len(names)))
594
+ ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
595
+ ax[0].bar_label(y[2])
596
+ else:
597
+ ax[0].set_xlabel("classes")
598
+ boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
599
+ img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
600
+ for class_id, box in zip(cls[:500], boxes[:500]):
601
+ ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(class_id)) # plot
602
+ ax[1].imshow(img)
603
+ ax[1].axis("off")
604
+
605
+ ax[2].hist2d(x["x"], x["y"], bins=50, cmap=subplot_3_4_color)
606
+ ax[2].set_xlabel("x")
607
+ ax[2].set_ylabel("y")
608
+ ax[3].hist2d(x["width"], x["height"], bins=50, cmap=subplot_3_4_color)
609
+ ax[3].set_xlabel("width")
610
+ ax[3].set_ylabel("height")
611
+ for a in {0, 1, 2, 3}:
612
+ for s in {"top", "right", "left", "bottom"}:
613
+ ax[a].spines[s].set_visible(False)
614
+
615
+ fname = save_dir / "labels.jpg"
616
+ plt.savefig(fname, dpi=200)
617
+ plt.close()
618
+ if on_plot:
619
+ on_plot(fname)
620
+
621
+
622
+ def save_one_box(
623
+ xyxy,
624
+ im,
625
+ file: Path = Path("im.jpg"),
626
+ gain: float = 1.02,
627
+ pad: int = 10,
628
+ square: bool = False,
629
+ BGR: bool = False,
630
+ save: bool = True,
631
+ ):
632
+ """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
633
+
634
+ This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
635
+ bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
636
+ bounding box.
637
+
638
+ Args:
639
+ xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
640
+ im (np.ndarray): The input image.
641
+ file (Path, optional): The path where the cropped image will be saved.
642
+ gain (float, optional): A multiplicative factor to increase the size of the bounding box.
643
+ pad (int, optional): The number of pixels to add to the width and height of the bounding box.
644
+ square (bool, optional): If True, the bounding box will be transformed into a square.
645
+ BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.
646
+ save (bool, optional): If True, the cropped image will be saved to disk.
647
+
648
+ Returns:
649
+ (np.ndarray): The cropped image.
650
+
651
+ Examples:
652
+ >>> from ultralytics.utils.plotting import save_one_box
653
+ >>> xyxy = [50, 50, 150, 150]
654
+ >>> im = cv2.imread("image.jpg")
655
+ >>> cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
656
+ """
657
+ if not isinstance(xyxy, torch.Tensor): # may be list
658
+ xyxy = torch.stack(xyxy)
659
+ b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
660
+ if square:
661
+ b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
662
+ b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
663
+ xyxy = ops.xywh2xyxy(b).long()
664
+ xyxy = ops.clip_boxes(xyxy, im.shape)
665
+ grayscale = im.shape[2] == 1 # grayscale image
666
+ crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)]
667
+ if save:
668
+ file.parent.mkdir(parents=True, exist_ok=True) # make directory
669
+ f = str(increment_path(file).with_suffix(".jpg"))
670
+ # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
671
+ crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop
672
+ Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB
673
+ return crop
674
+
675
+
676
+ @threaded
677
+ def plot_images(
678
+ labels: dict[str, Any],
679
+ images: torch.Tensor | np.ndarray = np.zeros((0, 3, 640, 640), dtype=np.float32),
680
+ paths: list[str] | None = None,
681
+ fname: str = "images.jpg",
682
+ names: dict[int, str] | None = None,
683
+ on_plot: Callable | None = None,
684
+ max_size: int = 1920,
685
+ max_subplots: int = 16,
686
+ save: bool = True,
687
+ conf_thres: float = 0.25,
688
+ ) -> np.ndarray | None:
689
+ """Plot image grid with labels, bounding boxes, masks, and keypoints.
690
+
691
+ Args:
692
+ labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
693
+ 'keypoints', 'batch_idx', 'img'.
694
+ images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
695
+ paths (Optional[list[str]]): List of file paths for each image in the batch.
696
+ fname (str): Output filename for the plotted image grid.
697
+ names (Optional[dict[int, str]]): Dictionary mapping class indices to class names.
698
+ on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.
699
+ max_size (int): Maximum size of the output image grid.
700
+ max_subplots (int): Maximum number of subplots in the image grid.
701
+ save (bool): Whether to save the plotted image grid to a file.
702
+ conf_thres (float): Confidence threshold for displaying detections.
703
+
704
+ Returns:
705
+ (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
706
+
707
+ Notes:
708
+ This function supports both tensor and numpy array inputs. It will automatically
709
+ convert tensor inputs to numpy arrays for processing.
710
+
711
+ Channel Support:
712
+ - 1 channel: Grayscale
713
+ - 2 channels: Third channel added as zeros
714
+ - 3 channels: Used as-is (standard RGB)
715
+ - 4+ channels: Cropped to first 3 channels
716
+ """
717
+ for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}:
718
+ if k not in labels:
719
+ continue
720
+ if k == "cls" and labels[k].ndim == 2:
721
+ labels[k] = labels[k].squeeze(1) # squeeze if shape is (n, 1)
722
+ if isinstance(labels[k], torch.Tensor):
723
+ labels[k] = labels[k].cpu().numpy()
724
+
725
+ cls = labels.get("cls", np.zeros(0, dtype=np.int64))
726
+ batch_idx = labels.get("batch_idx", np.zeros(cls.shape, dtype=np.int64))
727
+ bboxes = labels.get("bboxes", np.zeros(0, dtype=np.float32))
728
+ confs = labels.get("conf", None)
729
+ masks = labels.get("masks", np.zeros(0, dtype=np.uint8))
730
+ kpts = labels.get("keypoints", np.zeros(0, dtype=np.float32))
731
+ images = labels.get("img", images) # default to input images
732
+
733
+ if len(images) and isinstance(images, torch.Tensor):
734
+ images = images.cpu().float().numpy()
735
+
736
+ # Handle 2-ch and n-ch images
737
+ c = images.shape[1]
738
+ if c == 2:
739
+ zero = np.zeros_like(images[:, :1])
740
+ images = np.concatenate((images, zero), axis=1) # pad 2-ch with a black channel
741
+ elif c > 3:
742
+ images = images[:, :3] # crop multispectral images to first 3 channels
743
+
744
+ bs, _, h, w = images.shape # batch size, _, height, width
745
+ bs = min(bs, max_subplots) # limit plot images
746
+ ns = np.ceil(bs**0.5) # number of subplots (square)
747
+ if np.max(images[0]) <= 1:
748
+ images *= 255 # de-normalise (optional)
749
+
750
+ # Build Image
751
+ mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
752
+ for i in range(bs):
753
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
754
+ mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
755
+
756
+ # Resize (optional)
757
+ scale = max_size / ns / max(h, w)
758
+ if scale < 1:
759
+ h = math.ceil(scale * h)
760
+ w = math.ceil(scale * w)
761
+ mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
762
+
763
+ # Annotate
764
+ fs = int((h + w) * ns * 0.01) # font size
765
+ fs = max(fs, 18) # ensure that the font size is large enough to be easily readable.
766
+ annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=str(names))
767
+ for i in range(bs):
768
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
769
+ annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
770
+ if paths:
771
+ annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
772
+ if len(cls) > 0:
773
+ idx = batch_idx == i
774
+ classes = cls[idx].astype("int")
775
+ labels = confs is None
776
+ conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
777
+
778
+ if len(bboxes):
779
+ boxes = bboxes[idx]
780
+ if len(boxes):
781
+ if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
782
+ boxes[..., [0, 2]] *= w # scale to pixels
783
+ boxes[..., [1, 3]] *= h
784
+ elif scale < 1: # absolute coords need scale if image scales
785
+ boxes[..., :4] *= scale
786
+ boxes[..., 0] += x
787
+ boxes[..., 1] += y
788
+ is_obb = boxes.shape[-1] == 5 # xywhr
789
+ boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
790
+ for j, box in enumerate(boxes.astype(np.int64).tolist()):
791
+ c = classes[j]
792
+ color = colors(c)
793
+ c = names.get(c, c) if names else c
794
+ if labels or conf[j] > conf_thres:
795
+ label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
796
+ annotator.box_label(box, label, color=color)
797
+
798
+ elif len(classes):
799
+ for c in classes:
800
+ color = colors(c)
801
+ c = names.get(c, c) if names else c
802
+ label = f"{c}" if labels else f"{c} {conf[0]:.1f}"
803
+ annotator.text([x, y], label, txt_color=color, box_color=(64, 64, 64, 128))
804
+
805
+ # Plot keypoints
806
+ if len(kpts):
807
+ kpts_ = kpts[idx].copy()
808
+ if len(kpts_):
809
+ if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
810
+ kpts_[..., 0] *= w # scale to pixels
811
+ kpts_[..., 1] *= h
812
+ elif scale < 1: # absolute coords need scale if image scales
813
+ kpts_ *= scale
814
+ kpts_[..., 0] += x
815
+ kpts_[..., 1] += y
816
+ for j in range(len(kpts_)):
817
+ if labels or conf[j] > conf_thres:
818
+ annotator.kpts(kpts_[j], conf_thres=conf_thres)
819
+
820
+ # Plot masks
821
+ if len(masks):
822
+ if idx.shape[0] == masks.shape[0] and masks.max() <= 1: # overlap_mask=False
823
+ image_masks = masks[idx]
824
+ else: # overlap_mask=True
825
+ image_masks = masks[[i]] # (1, 640, 640)
826
+ nl = idx.sum()
827
+ index = np.arange(1, nl + 1).reshape((nl, 1, 1))
828
+ image_masks = (image_masks == index).astype(np.float32)
829
+
830
+ im = np.asarray(annotator.im).copy()
831
+ for j in range(len(image_masks)):
832
+ if labels or conf[j] > conf_thres:
833
+ color = colors(classes[j])
834
+ mh, mw = image_masks[j].shape
835
+ if mh != h or mw != w:
836
+ mask = image_masks[j].astype(np.uint8)
837
+ mask = cv2.resize(mask, (w, h))
838
+ mask = mask.astype(bool)
839
+ else:
840
+ mask = image_masks[j].astype(bool)
841
+ try:
842
+ im[y : y + h, x : x + w, :][mask] = (
843
+ im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
844
+ )
845
+ except Exception:
846
+ pass
847
+ annotator.fromarray(im)
848
+ if not save:
849
+ return np.asarray(annotator.im)
850
+ annotator.im.save(fname) # save
851
+ if on_plot:
852
+ on_plot(fname)
853
+
854
+
855
+ @plt_settings()
856
+ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
857
+ """Plot training results from a results CSV file. The function supports various types of data including
858
+ segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
859
+ CSV is located.
860
+
861
+ Args:
862
+ file (str, optional): Path to the CSV file containing the training results.
863
+ dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
864
+ on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
865
+
866
+ Examples:
867
+ >>> from ultralytics.utils.plotting import plot_results
868
+ >>> plot_results("path/to/results.csv", segment=True)
869
+ """
870
+ import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
871
+ import polars as pl
872
+ from scipy.ndimage import gaussian_filter1d
873
+
874
+ save_dir = Path(file).parent if file else Path(dir)
875
+ files = list(save_dir.glob("results*.csv"))
876
+ assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
877
+
878
+ loss_keys, metric_keys = [], []
879
+ for i, f in enumerate(files):
880
+ try:
881
+ data = pl.read_csv(f, infer_schema_length=None)
882
+ if i == 0:
883
+ for c in data.columns:
884
+ if "loss" in c:
885
+ loss_keys.append(c)
886
+ elif "metric" in c:
887
+ metric_keys.append(c)
888
+ loss_mid, metric_mid = len(loss_keys) // 2, len(metric_keys) // 2
889
+ columns = (
890
+ loss_keys[:loss_mid] + metric_keys[:metric_mid] + loss_keys[loss_mid:] + metric_keys[metric_mid:]
891
+ )
892
+ fig, ax = plt.subplots(2, len(columns) // 2, figsize=(len(columns) + 2, 6), tight_layout=True)
893
+ ax = ax.ravel()
894
+ x = data.select(data.columns[0]).to_numpy().flatten()
895
+ for i, j in enumerate(columns):
896
+ y = data.select(j).to_numpy().flatten().astype("float")
897
+ ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
898
+ ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
899
+ ax[i].set_title(j, fontsize=12)
900
+ except Exception as e:
901
+ LOGGER.error(f"Plotting error for {f}: {e}")
902
+ ax[1].legend()
903
+ fname = save_dir / "results.png"
904
+ fig.savefig(fname, dpi=200)
905
+ plt.close()
906
+ if on_plot:
907
+ on_plot(fname)
908
+
909
+
910
+ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
911
+ """Plot a scatter plot with points colored based on a 2D histogram.
912
+
913
+ Args:
914
+ v (array-like): Values for the x-axis.
915
+ f (array-like): Values for the y-axis.
916
+ bins (int, optional): Number of bins for the histogram.
917
+ cmap (str, optional): Colormap for the scatter plot.
918
+ alpha (float, optional): Alpha for the scatter plot.
919
+ edgecolors (str, optional): Edge colors for the scatter plot.
920
+
921
+ Examples:
922
+ >>> v = np.random.rand(100)
923
+ >>> f = np.random.rand(100)
924
+ >>> plt_color_scatter(v, f)
925
+ """
926
+ import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
927
+
928
+ # Calculate 2D histogram and corresponding colors
929
+ hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
930
+ colors = [
931
+ hist[
932
+ min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
933
+ min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
934
+ ]
935
+ for i in range(len(v))
936
+ ]
937
+
938
+ # Scatter plot
939
+ plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
940
+
941
+
942
+ @plt_settings()
943
+ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
944
+ """Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
945
+ key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
946
+ the plots.
947
+
948
+ Args:
949
+ csv_file (str, optional): Path to the CSV file containing the tuning results.
950
+ exclude_zero_fitness_points (bool, optional): Don't include points with zero fitness in tuning plots.
951
+
952
+ Examples:
953
+ >>> plot_tune_results("path/to/tune_results.csv")
954
+ """
955
+ import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
956
+ import polars as pl
957
+ from scipy.ndimage import gaussian_filter1d
958
+
959
+ def _save_one_file(file):
960
+ """Save one matplotlib plot to 'file'."""
961
+ plt.savefig(file, dpi=200)
962
+ plt.close()
963
+ LOGGER.info(f"Saved {file}")
964
+
965
+ # Scatter plots for each hyperparameter
966
+ csv_file = Path(csv_file)
967
+ data = pl.read_csv(csv_file, infer_schema_length=None)
968
+ num_metrics_columns = 1
969
+ keys = [x.strip() for x in data.columns][num_metrics_columns:]
970
+ x = data.to_numpy()
971
+ fitness = x[:, 0] # fitness
972
+ if exclude_zero_fitness_points:
973
+ mask = fitness > 0 # exclude zero-fitness points
974
+ x, fitness = x[mask], fitness[mask]
975
+ if len(fitness) == 0:
976
+ LOGGER.warning("No valid fitness values to plot (all iterations may have failed)")
977
+ return
978
+ # Iterative sigma rejection on lower bound only
979
+ for _ in range(3): # max 3 iterations
980
+ mean, std = fitness.mean(), fitness.std()
981
+ lower_bound = mean - 3 * std
982
+ mask = fitness >= lower_bound
983
+ if mask.all(): # no more outliers
984
+ break
985
+ x, fitness = x[mask], fitness[mask]
986
+ j = np.argmax(fitness) # max fitness index
987
+ n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
988
+ plt.figure(figsize=(10, 10), tight_layout=True)
989
+ for i, k in enumerate(keys):
990
+ v = x[:, i + num_metrics_columns]
991
+ mu = v[j] # best single result
992
+ plt.subplot(n, n, i + 1)
993
+ plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
994
+ plt.plot(mu, fitness.max(), "k+", markersize=15)
995
+ plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
996
+ plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
997
+ if i % n != 0:
998
+ plt.yticks([])
999
+ _save_one_file(csv_file.with_name("tune_scatter_plots.png"))
1000
+
1001
+ # Fitness vs iteration
1002
+ x = range(1, len(fitness) + 1)
1003
+ plt.figure(figsize=(10, 6), tight_layout=True)
1004
+ plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
1005
+ plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
1006
+ plt.title("Fitness vs Iteration")
1007
+ plt.xlabel("Iteration")
1008
+ plt.ylabel("Fitness")
1009
+ plt.grid(True)
1010
+ plt.legend()
1011
+ _save_one_file(csv_file.with_name("tune_fitness.png"))
1012
+
1013
+
1014
+ @plt_settings()
1015
+ def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
1016
+ """Visualize feature maps of a given model module during inference.
1017
+
1018
+ Args:
1019
+ x (torch.Tensor): Features to be visualized.
1020
+ module_type (str): Module type.
1021
+ stage (int): Module stage within the model.
1022
+ n (int, optional): Maximum number of feature maps to plot.
1023
+ save_dir (Path, optional): Directory to save results.
1024
+ """
1025
+ import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
1026
+
1027
+ for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads
1028
+ if m in module_type:
1029
+ return
1030
+ if isinstance(x, torch.Tensor):
1031
+ _, channels, height, width = x.shape # batch, channels, height, width
1032
+ if height > 1 and width > 1:
1033
+ f = save_dir / f"stage{stage}_{module_type.rsplit('.', 1)[-1]}_features.png" # filename
1034
+
1035
+ blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
1036
+ n = min(n, channels) # number of plots
1037
+ _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
1038
+ ax = ax.ravel()
1039
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
1040
+ for i in range(n):
1041
+ ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
1042
+ ax[i].axis("off")
1043
+
1044
+ LOGGER.info(f"Saving {f}... ({n}/{channels})")
1045
+ plt.savefig(f, dpi=300, bbox_inches="tight")
1046
+ plt.close()
1047
+ np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save