ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +527 -67
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +44 -37
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +84 -56
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.28.dist-info/METADATA +0 -373
  244. ultralytics-8.1.28.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- import contextlib
4
3
  import math
5
4
  import warnings
6
5
  from pathlib import Path
6
+ from typing import Callable, Dict, List, Optional, Union
7
7
 
8
8
  import cv2
9
9
  import matplotlib.pyplot as plt
@@ -12,14 +12,14 @@ import torch
12
12
  from PIL import Image, ImageDraw, ImageFont
13
13
  from PIL import __version__ as pil_version
14
14
 
15
- from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded
16
- from .checks import check_font, check_version, is_ascii
17
- from .files import increment_path
15
+ from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded
16
+ from ultralytics.utils.checks import check_font, check_version, is_ascii
17
+ from ultralytics.utils.files import increment_path
18
18
 
19
19
 
20
20
  class Colors:
21
21
  """
22
- Ultralytics default color palette https://ultralytics.com/.
22
+ Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
23
23
 
24
24
  This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
25
25
  RGB values.
@@ -28,31 +28,85 @@ class Colors:
28
28
  palette (list of tuple): List of RGB color values.
29
29
  n (int): The number of colors in the palette.
30
30
  pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
31
+
32
+ ## Ultralytics Color Palette
33
+
34
+ | Index | Color | HEX | RGB |
35
+ |-------|-------------------------------------------------------------------|-----------|-------------------|
36
+ | 0 | <i class="fa-solid fa-square fa-2xl" style="color: #042aff;"></i> | `#042aff` | (4, 42, 255) |
37
+ | 1 | <i class="fa-solid fa-square fa-2xl" style="color: #0bdbeb;"></i> | `#0bdbeb` | (11, 219, 235) |
38
+ | 2 | <i class="fa-solid fa-square fa-2xl" style="color: #f3f3f3;"></i> | `#f3f3f3` | (243, 243, 243) |
39
+ | 3 | <i class="fa-solid fa-square fa-2xl" style="color: #00dfb7;"></i> | `#00dfb7` | (0, 223, 183) |
40
+ | 4 | <i class="fa-solid fa-square fa-2xl" style="color: #111f68;"></i> | `#111f68` | (17, 31, 104) |
41
+ | 5 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6fdd;"></i> | `#ff6fdd` | (255, 111, 221) |
42
+ | 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff444f;"></i> | `#ff444f` | (255, 68, 79) |
43
+ | 7 | <i class="fa-solid fa-square fa-2xl" style="color: #cced00;"></i> | `#cced00` | (204, 237, 0) |
44
+ | 8 | <i class="fa-solid fa-square fa-2xl" style="color: #00f344;"></i> | `#00f344` | (0, 243, 68) |
45
+ | 9 | <i class="fa-solid fa-square fa-2xl" style="color: #bd00ff;"></i> | `#bd00ff` | (189, 0, 255) |
46
+ | 10 | <i class="fa-solid fa-square fa-2xl" style="color: #00b4ff;"></i> | `#00b4ff` | (0, 180, 255) |
47
+ | 11 | <i class="fa-solid fa-square fa-2xl" style="color: #dd00ba;"></i> | `#dd00ba` | (221, 0, 186) |
48
+ | 12 | <i class="fa-solid fa-square fa-2xl" style="color: #00ffff;"></i> | `#00ffff` | (0, 255, 255) |
49
+ | 13 | <i class="fa-solid fa-square fa-2xl" style="color: #26c000;"></i> | `#26c000` | (38, 192, 0) |
50
+ | 14 | <i class="fa-solid fa-square fa-2xl" style="color: #01ffb3;"></i> | `#01ffb3` | (1, 255, 179) |
51
+ | 15 | <i class="fa-solid fa-square fa-2xl" style="color: #7d24ff;"></i> | `#7d24ff` | (125, 36, 255) |
52
+ | 16 | <i class="fa-solid fa-square fa-2xl" style="color: #7b0068;"></i> | `#7b0068` | (123, 0, 104) |
53
+ | 17 | <i class="fa-solid fa-square fa-2xl" style="color: #ff1b6c;"></i> | `#ff1b6c` | (255, 27, 108) |
54
+ | 18 | <i class="fa-solid fa-square fa-2xl" style="color: #fc6d2f;"></i> | `#fc6d2f` | (252, 109, 47) |
55
+ | 19 | <i class="fa-solid fa-square fa-2xl" style="color: #a2ff0b;"></i> | `#a2ff0b` | (162, 255, 11) |
56
+
57
+ ## Pose Color Palette
58
+
59
+ | Index | Color | HEX | RGB |
60
+ |-------|-------------------------------------------------------------------|-----------|-------------------|
61
+ | 0 | <i class="fa-solid fa-square fa-2xl" style="color: #ff8000;"></i> | `#ff8000` | (255, 128, 0) |
62
+ | 1 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9933;"></i> | `#ff9933` | (255, 153, 51) |
63
+ | 2 | <i class="fa-solid fa-square fa-2xl" style="color: #ffb266;"></i> | `#ffb266` | (255, 178, 102) |
64
+ | 3 | <i class="fa-solid fa-square fa-2xl" style="color: #e6e600;"></i> | `#e6e600` | (230, 230, 0) |
65
+ | 4 | <i class="fa-solid fa-square fa-2xl" style="color: #ff99ff;"></i> | `#ff99ff` | (255, 153, 255) |
66
+ | 5 | <i class="fa-solid fa-square fa-2xl" style="color: #99ccff;"></i> | `#99ccff` | (153, 204, 255) |
67
+ | 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff66ff;"></i> | `#ff66ff` | (255, 102, 255) |
68
+ | 7 | <i class="fa-solid fa-square fa-2xl" style="color: #ff33ff;"></i> | `#ff33ff` | (255, 51, 255) |
69
+ | 8 | <i class="fa-solid fa-square fa-2xl" style="color: #66b2ff;"></i> | `#66b2ff` | (102, 178, 255) |
70
+ | 9 | <i class="fa-solid fa-square fa-2xl" style="color: #3399ff;"></i> | `#3399ff` | (51, 153, 255) |
71
+ | 10 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9999;"></i> | `#ff9999` | (255, 153, 153) |
72
+ | 11 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6666;"></i> | `#ff6666` | (255, 102, 102) |
73
+ | 12 | <i class="fa-solid fa-square fa-2xl" style="color: #ff3333;"></i> | `#ff3333` | (255, 51, 51) |
74
+ | 13 | <i class="fa-solid fa-square fa-2xl" style="color: #99ff99;"></i> | `#99ff99` | (153, 255, 153) |
75
+ | 14 | <i class="fa-solid fa-square fa-2xl" style="color: #66ff66;"></i> | `#66ff66` | (102, 255, 102) |
76
+ | 15 | <i class="fa-solid fa-square fa-2xl" style="color: #33ff33;"></i> | `#33ff33` | (51, 255, 51) |
77
+ | 16 | <i class="fa-solid fa-square fa-2xl" style="color: #00ff00;"></i> | `#00ff00` | (0, 255, 0) |
78
+ | 17 | <i class="fa-solid fa-square fa-2xl" style="color: #0000ff;"></i> | `#0000ff` | (0, 0, 255) |
79
+ | 18 | <i class="fa-solid fa-square fa-2xl" style="color: #ff0000;"></i> | `#ff0000` | (255, 0, 0) |
80
+ | 19 | <i class="fa-solid fa-square fa-2xl" style="color: #ffffff;"></i> | `#ffffff` | (255, 255, 255) |
81
+
82
+ !!! note "Ultralytics Brand Colors"
83
+
84
+ For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
31
85
  """
32
86
 
33
87
  def __init__(self):
34
88
  """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
35
89
  hexs = (
36
- "FF3838",
37
- "FF9D97",
38
- "FF701F",
39
- "FFB21D",
40
- "CFD231",
41
- "48F90A",
42
- "92CC17",
43
- "3DDB86",
44
- "1A9334",
45
- "00D4BB",
46
- "2C99A8",
47
- "00C2FF",
48
- "344593",
49
- "6473FF",
50
- "0018EC",
51
- "8438FF",
52
- "520085",
53
- "CB38FF",
54
- "FF95C8",
55
- "FF37C7",
90
+ "042AFF",
91
+ "0BDBEB",
92
+ "F3F3F3",
93
+ "00DFB7",
94
+ "111F68",
95
+ "FF6FDD",
96
+ "FF444F",
97
+ "CCED00",
98
+ "00F344",
99
+ "BD00FF",
100
+ "00B4FF",
101
+ "DD00BA",
102
+ "00FFFF",
103
+ "26C000",
104
+ "01FFB3",
105
+ "7D24FF",
106
+ "7B0068",
107
+ "FF1B6C",
108
+ "FC6D2F",
109
+ "A2FF0B",
56
110
  )
57
111
  self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
58
112
  self.n = len(self.palette)
@@ -158,22 +212,153 @@ class Annotator:
158
212
 
159
213
  self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
160
214
  self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
215
+ self.dark_colors = {
216
+ (235, 219, 11),
217
+ (243, 243, 243),
218
+ (183, 223, 0),
219
+ (221, 111, 255),
220
+ (0, 237, 204),
221
+ (68, 243, 0),
222
+ (255, 255, 0),
223
+ (179, 255, 1),
224
+ (11, 255, 162),
225
+ }
226
+ self.light_colors = {
227
+ (255, 42, 4),
228
+ (79, 68, 255),
229
+ (255, 0, 189),
230
+ (255, 180, 0),
231
+ (186, 0, 221),
232
+ (0, 192, 38),
233
+ (255, 36, 125),
234
+ (104, 0, 123),
235
+ (108, 27, 255),
236
+ (47, 109, 252),
237
+ (104, 31, 17),
238
+ }
239
+
240
+ def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
241
+ """
242
+ Assign text color based on background color.
243
+
244
+ Args:
245
+ color (tuple, optional): The background color of the rectangle for text (B, G, R).
246
+ txt_color (tuple, optional): The color of the text (R, G, B).
247
+
248
+ Returns:
249
+ txt_color (tuple): Text color for label
250
+ """
251
+ if color in self.dark_colors:
252
+ return 104, 31, 17
253
+ elif color in self.light_colors:
254
+ return 255, 255, 255
255
+ else:
256
+ return txt_color
257
+
258
+ def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):
259
+ """
260
+ Draws a label with a background circle centered within a given bounding box.
261
+
262
+ Args:
263
+ box (tuple): The bounding box coordinates (x1, y1, x2, y2).
264
+ label (str): The text label to be displayed.
265
+ color (tuple, optional): The background color of the rectangle (B, G, R).
266
+ txt_color (tuple, optional): The color of the text (R, G, B).
267
+ margin (int, optional): The margin between the text and the rectangle border.
268
+ """
269
+ # If label have more than 3 characters, skip other characters, due to circle size
270
+ if len(label) > 3:
271
+ print(
272
+ f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!"
273
+ )
274
+ label = label[:3]
275
+
276
+ # Calculate the center of the box
277
+ x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
278
+ # Get the text size
279
+ text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
280
+ # Calculate the required radius to fit the text with the margin
281
+ required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
282
+ # Draw the circle with the required radius
283
+ cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
284
+ # Calculate the position for the text
285
+ text_x = x_center - text_size[0] // 2
286
+ text_y = y_center + text_size[1] // 2
287
+ # Draw the text
288
+ cv2.putText(
289
+ self.im,
290
+ str(label),
291
+ (text_x, text_y),
292
+ cv2.FONT_HERSHEY_SIMPLEX,
293
+ self.sf - 0.15,
294
+ self.get_txt_color(color, txt_color),
295
+ self.tf,
296
+ lineType=cv2.LINE_AA,
297
+ )
298
+
299
+ def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):
300
+ """
301
+ Draws a label with a background rectangle centered within a given bounding box.
302
+
303
+ Args:
304
+ box (tuple): The bounding box coordinates (x1, y1, x2, y2).
305
+ label (str): The text label to be displayed.
306
+ color (tuple, optional): The background color of the rectangle (B, G, R).
307
+ txt_color (tuple, optional): The color of the text (R, G, B).
308
+ margin (int, optional): The margin between the text and the rectangle border.
309
+ """
310
+ # Calculate the center of the bounding box
311
+ x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
312
+ # Get the size of the text
313
+ text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
314
+ # Calculate the top-left corner of the text (to center it)
315
+ text_x = x_center - text_size[0] // 2
316
+ text_y = y_center + text_size[1] // 2
317
+ # Calculate the coordinates of the background rectangle
318
+ rect_x1 = text_x - margin
319
+ rect_y1 = text_y - text_size[1] - margin
320
+ rect_x2 = text_x + text_size[0] + margin
321
+ rect_y2 = text_y + margin
322
+ # Draw the background rectangle
323
+ cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
324
+ # Draw the text on top of the rectangle
325
+ cv2.putText(
326
+ self.im,
327
+ label,
328
+ (text_x, text_y),
329
+ cv2.FONT_HERSHEY_SIMPLEX,
330
+ self.sf - 0.1,
331
+ self.get_txt_color(color, txt_color),
332
+ self.tf,
333
+ lineType=cv2.LINE_AA,
334
+ )
161
335
 
162
336
  def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
163
- """Add one xyxy box to image with label."""
337
+ """
338
+ Draws a bounding box to image with label.
339
+
340
+ Args:
341
+ box (tuple): The bounding box coordinates (x1, y1, x2, y2).
342
+ label (str): The text label to be displayed.
343
+ color (tuple, optional): The background color of the rectangle (B, G, R).
344
+ txt_color (tuple, optional): The color of the text (R, G, B).
345
+ rotated (bool, optional): Variable used to check if task is OBB
346
+ """
347
+ txt_color = self.get_txt_color(color, txt_color)
164
348
  if isinstance(box, torch.Tensor):
165
349
  box = box.tolist()
166
350
  if self.pil or not is_ascii(label):
167
351
  if rotated:
168
352
  p1 = box[0]
169
- # NOTE: PIL-version polygon needs tuple type.
170
- self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color)
353
+ self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box
171
354
  else:
172
355
  p1 = (box[0], box[1])
173
356
  self.draw.rectangle(box, width=self.lw, outline=color) # box
174
357
  if label:
175
358
  w, h = self.font.getsize(label) # text width, height
176
- outside = p1[1] - h >= 0 # label fits outside box
359
+ outside = p1[1] >= h # label fits outside box
360
+ if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image
361
+ p1 = self.im.size[0] - w, p1[1]
177
362
  self.draw.rectangle(
178
363
  (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
179
364
  fill=color,
@@ -183,20 +368,22 @@ class Annotator:
183
368
  else: # cv2
184
369
  if rotated:
185
370
  p1 = [int(b) for b in box[0]]
186
- # NOTE: cv2-version polylines needs np.asarray type.
187
- cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw)
371
+ cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box
188
372
  else:
189
373
  p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
190
374
  cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
191
375
  if label:
192
376
  w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
193
- outside = p1[1] - h >= 3
194
- p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
377
+ h += 3 # add pixels to pad text
378
+ outside = p1[1] >= h # label fits outside box
379
+ if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image
380
+ p1 = self.im.shape[1] - w, p1[1]
381
+ p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h
195
382
  cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
196
383
  cv2.putText(
197
384
  self.im,
198
385
  label,
199
- (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
386
+ (p1[0], p1[1] - 2 if outside else p1[1] + h - 1),
200
387
  0,
201
388
  self.sf,
202
389
  txt_color,
@@ -240,20 +427,24 @@ class Annotator:
240
427
  # Convert im back to PIL and update draw
241
428
  self.fromarray(self.im)
242
429
 
243
- def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
430
+ def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
244
431
  """
245
432
  Plot keypoints on the image.
246
433
 
247
434
  Args:
248
- kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
249
- shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
250
- radius (int, optional): Radius of the drawn keypoints. Default is 5.
251
- kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
252
- for human pose. Default is True.
435
+ kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
436
+ shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).
437
+ radius (int, optional): Keypoint radius. Defaults to 5.
438
+ kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.
439
+ conf_thres (float, optional): Confidence threshold. Defaults to 0.25.
440
+ kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.
253
441
 
254
442
  Note:
255
- `kpt_line=True` currently only supports human pose plotting.
443
+ - `kpt_line=True` currently only supports human pose plotting.
444
+ - Modifies self.im in-place.
445
+ - If self.pil is True, converts image to numpy array and back to PIL.
256
446
  """
447
+ radius = radius if radius is not None else self.lw
257
448
  if self.pil:
258
449
  # Convert to numpy first
259
450
  self.im = np.asarray(self.im).copy()
@@ -261,12 +452,12 @@ class Annotator:
261
452
  is_pose = nkpt == 17 and ndim in {2, 3}
262
453
  kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
263
454
  for i, k in enumerate(kpts):
264
- color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i)
455
+ color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
265
456
  x_coord, y_coord = k[0], k[1]
266
457
  if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
267
458
  if len(k) == 3:
268
459
  conf = k[2]
269
- if conf < 0.5:
460
+ if conf < conf_thres:
270
461
  continue
271
462
  cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
272
463
 
@@ -278,13 +469,20 @@ class Annotator:
278
469
  if ndim == 3:
279
470
  conf1 = kpts[(sk[0] - 1), 2]
280
471
  conf2 = kpts[(sk[1] - 1), 2]
281
- if conf1 < 0.5 or conf2 < 0.5:
472
+ if conf1 < conf_thres or conf2 < conf_thres:
282
473
  continue
283
474
  if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
284
475
  continue
285
476
  if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
286
477
  continue
287
- cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA)
478
+ cv2.line(
479
+ self.im,
480
+ pos1,
481
+ pos2,
482
+ kpt_color or self.limb_color[i].tolist(),
483
+ thickness=int(np.ceil(self.lw / 2)),
484
+ lineType=cv2.LINE_AA,
485
+ )
288
486
  if self.pil:
289
487
  # Convert im back to PIL and update draw
290
488
  self.fromarray(self.im)
@@ -315,8 +513,9 @@ class Annotator:
315
513
  else:
316
514
  if box_style:
317
515
  w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
318
- outside = xy[1] - h >= 3
319
- p2 = xy[0] + w, xy[1] - h - 3 if outside else xy[1] + h + 3
516
+ h += 3 # add pixels to pad text
517
+ outside = xy[1] >= h # label fits outside box
518
+ p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
320
519
  cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled
321
520
  # Using `txt_color` for background and draw fg with white color
322
521
  txt_color = (255, 255, 255)
@@ -333,12 +532,37 @@ class Annotator:
333
532
 
334
533
  def show(self, title=None):
335
534
  """Show the annotated image."""
336
- Image.fromarray(np.asarray(self.im)[..., ::-1]).show(title)
535
+ im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
536
+ if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
537
+ try:
538
+ display(im) # noqa - display() function only available in ipython environments
539
+ except ImportError as e:
540
+ LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}")
541
+ else:
542
+ im.show(title=title)
337
543
 
338
544
  def save(self, filename="image.jpg"):
339
545
  """Save the annotated image to 'filename'."""
340
546
  cv2.imwrite(filename, np.asarray(self.im))
341
547
 
548
+ @staticmethod
549
+ def get_bbox_dimension(bbox=None):
550
+ """
551
+ Calculate the area of a bounding box.
552
+
553
+ Args:
554
+ bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
555
+
556
+ Returns:
557
+ width (float): Width of the bounding box.
558
+ height (float): Height of the bounding box.
559
+ area (float): Area enclosed by the bounding box.
560
+ """
561
+ x_min, y_min, x_max, y_max = bbox
562
+ width = x_max - x_min
563
+ height = y_max - y_min
564
+ return width, height, width * height
565
+
342
566
  def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
343
567
  """
344
568
  Draw region line.
@@ -350,6 +574,10 @@ class Annotator:
350
574
  """
351
575
  cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
352
576
 
577
+ # Draw small circles at the corner points
578
+ for point in reg_pts:
579
+ cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle
580
+
353
581
  def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
354
582
  """
355
583
  Draw centroid point and track trails.
@@ -363,36 +591,99 @@ class Annotator:
363
591
  cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
364
592
  cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
365
593
 
366
- def count_labels(self, counts=0, count_txt_size=2, color=(255, 255, 255), txt_color=(0, 0, 0)):
594
+ def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)):
367
595
  """
368
- Plot counts for object counter.
596
+ Displays queue counts on an image centered at the points with customizable font size and colors.
369
597
 
370
598
  Args:
371
- counts (int): objects counts value
372
- count_txt_size (int): text size for counts display
373
- color (tuple): background color of counts display
374
- txt_color (tuple): text color of counts display
599
+ label (str): Queue counts label.
600
+ points (tuple): Region points for center point calculation to display text.
601
+ region_color (tuple): RGB queue region color.
602
+ txt_color (tuple): RGB text display color.
375
603
  """
376
- self.tf = count_txt_size
377
- tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
378
- tf = max(tl - 1, 1)
604
+ x_values = [point[0] for point in points]
605
+ y_values = [point[1] for point in points]
606
+ center_x = sum(x_values) // len(points)
607
+ center_y = sum(y_values) // len(points)
379
608
 
380
- # Get text size for in_count and out_count
381
- t_size_in = cv2.getTextSize(str(counts), 0, fontScale=tl / 2, thickness=tf)[0]
609
+ text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]
610
+ text_width = text_size[0]
611
+ text_height = text_size[1]
382
612
 
383
- # Calculate positions for counts label
384
- text_width = t_size_in[0]
385
- text_x = (self.im.shape[1] - text_width) // 2 # Center x-coordinate
386
- text_y = t_size_in[1]
613
+ rect_width = text_width + 20
614
+ rect_height = text_height + 20
615
+ rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)
616
+ rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)
617
+ cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)
387
618
 
388
- # Create a rounded rectangle for in_count
389
- cv2.rectangle(
390
- self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1
391
- )
619
+ text_x = center_x - text_width // 2
620
+ text_y = center_y + text_height // 2
621
+
622
+ # Draw text
392
623
  cv2.putText(
393
- self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA
624
+ self.im,
625
+ label,
626
+ (text_x, text_y),
627
+ 0,
628
+ fontScale=self.sf,
629
+ color=txt_color,
630
+ thickness=self.tf,
631
+ lineType=cv2.LINE_AA,
394
632
  )
395
633
 
634
+ def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin):
635
+ """
636
+ Display the bounding boxes labels in parking management app.
637
+
638
+ Args:
639
+ im0 (ndarray): Inference image.
640
+ text (str): Object/class name.
641
+ txt_color (tuple): Display color for text foreground.
642
+ bg_color (tuple): Display color for text background.
643
+ x_center (float): The x position center point for bounding box.
644
+ y_center (float): The y position center point for bounding box.
645
+ margin (int): The gap between text and rectangle for better display.
646
+ """
647
+ text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
648
+ text_x = x_center - text_size[0] // 2
649
+ text_y = y_center + text_size[1] // 2
650
+
651
+ rect_x1 = text_x - margin
652
+ rect_y1 = text_y - text_size[1] - margin
653
+ rect_x2 = text_x + text_size[0] + margin
654
+ rect_y2 = text_y + margin
655
+ cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
656
+ cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
657
+
658
+ def display_analytics(self, im0, text, txt_color, bg_color, margin):
659
+ """
660
+ Display the overall statistics for parking lots.
661
+
662
+ Args:
663
+ im0 (ndarray): Inference image.
664
+ text (dict): Labels dictionary.
665
+ txt_color (tuple): Display color for text foreground.
666
+ bg_color (tuple): Display color for text background.
667
+ margin (int): Gap between text and rectangle for better display.
668
+ """
669
+ horizontal_gap = int(im0.shape[1] * 0.02)
670
+ vertical_gap = int(im0.shape[0] * 0.01)
671
+ text_y_offset = 0
672
+ for label, value in text.items():
673
+ txt = f"{label}: {value}"
674
+ text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]
675
+ if text_size[0] < 5 or text_size[1] < 5:
676
+ text_size = (5, 5)
677
+ text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap
678
+ text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap
679
+ rect_x1 = text_x - margin * 2
680
+ rect_y1 = text_y - text_size[1] - margin * 2
681
+ rect_x2 = text_x + text_size[0] + margin * 2
682
+ rect_y2 = text_y + margin * 2
683
+ cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
684
+ cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
685
+ text_y_offset = rect_y2
686
+
396
687
  @staticmethod
397
688
  def estimate_pose_angle(a, b, c):
398
689
  """
@@ -413,162 +704,180 @@ class Annotator:
413
704
  angle = 360 - angle
414
705
  return angle
415
706
 
416
- def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2):
707
+ def draw_specific_points(self, keypoints, indices=None, radius=2, conf_thres=0.25):
417
708
  """
418
709
  Draw specific keypoints for gym steps counting.
419
710
 
420
711
  Args:
421
- keypoints (list): list of keypoints data to be plotted
422
- indices (list): keypoints ids list to be plotted
423
- shape (tuple): imgsz for model inference
424
- radius (int): Keypoint radius value
712
+ keypoints (list): Keypoints data to be plotted.
713
+ indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7].
714
+ radius (int, optional): Keypoint radius. Defaults to 2.
715
+ conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25.
716
+
717
+ Returns:
718
+ (numpy.ndarray): Image with drawn keypoints.
719
+
720
+ Note:
721
+ Keypoint format: [x, y] or [x, y, confidence].
722
+ Modifies self.im in-place.
425
723
  """
426
- for i, k in enumerate(keypoints):
427
- if i in indices:
428
- x_coord, y_coord = k[0], k[1]
429
- if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
430
- if len(k) == 3:
431
- conf = k[2]
432
- if conf < 0.5:
433
- continue
434
- cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA)
724
+ indices = indices or [2, 5, 7]
725
+ points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thres]
726
+
727
+ # Draw lines between consecutive points
728
+ for start, end in zip(points[:-1], points[1:]):
729
+ cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA)
730
+
731
+ # Draw circles for keypoints
732
+ for pt in points:
733
+ cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA)
734
+
435
735
  return self.im
436
736
 
437
- def plot_angle_and_count_and_stage(self, angle_text, count_text, stage_text, center_kpt, line_thickness=2):
737
+ def plot_workout_information(self, display_text, position, color=(104, 31, 17), txt_color=(255, 255, 255)):
438
738
  """
439
- Plot the pose angle, count value and step stage.
739
+ Draw text with a background on the image.
440
740
 
441
741
  Args:
442
- angle_text (str): angle value for workout monitoring
443
- count_text (str): counts value for workout monitoring
444
- stage_text (str): stage decision for workout monitoring
445
- center_kpt (int): centroid pose index for workout monitoring
446
- line_thickness (int): thickness for text display
742
+ display_text (str): The text to be displayed.
743
+ position (tuple): Coordinates (x, y) on the image where the text will be placed.
744
+ color (tuple, optional): Text background color
745
+ txt_color (tuple, optional): Text foreground color
447
746
  """
448
- angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}")
449
- font_scale = 0.6 + (line_thickness / 10.0)
450
-
451
- # Draw angle
452
- (angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, 0, font_scale, line_thickness)
453
- angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
454
- angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
455
- angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2))
456
- cv2.rectangle(
457
- self.im,
458
- angle_background_position,
459
- (
460
- angle_background_position[0] + angle_background_size[0],
461
- angle_background_position[1] + angle_background_size[1],
462
- ),
463
- (255, 255, 255),
464
- -1,
465
- )
466
- cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness)
467
-
468
- # Draw Counts
469
- (count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness)
470
- count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
471
- count_background_position = (
472
- angle_background_position[0],
473
- angle_background_position[1] + angle_background_size[1] + 5,
474
- )
475
- count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2))
747
+ (text_width, text_height), _ = cv2.getTextSize(display_text, 0, self.sf, self.tf)
476
748
 
749
+ # Draw background rectangle
477
750
  cv2.rectangle(
478
751
  self.im,
479
- count_background_position,
480
- (
481
- count_background_position[0] + count_background_size[0],
482
- count_background_position[1] + count_background_size[1],
483
- ),
484
- (255, 255, 255),
752
+ (position[0], position[1] - text_height - 5),
753
+ (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf),
754
+ color,
485
755
  -1,
486
756
  )
487
- cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness)
757
+ # Draw text
758
+ cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf)
488
759
 
489
- # Draw Stage
490
- (stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, 0, font_scale, line_thickness)
491
- stage_text_position = (int(center_kpt[0]), int(center_kpt[1]) + angle_text_height + count_text_height + 40)
492
- stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
493
- stage_background_size = (stage_text_width + 10, stage_text_height + 10)
760
+ return text_height
494
761
 
495
- cv2.rectangle(
496
- self.im,
497
- stage_background_position,
498
- (
499
- stage_background_position[0] + stage_background_size[0],
500
- stage_background_position[1] + stage_background_size[1],
501
- ),
502
- (255, 255, 255),
503
- -1,
762
+ def plot_angle_and_count_and_stage(
763
+ self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255)
764
+ ):
765
+ """
766
+ Plot the pose angle, count value, and step stage.
767
+
768
+ Args:
769
+ angle_text (str): Angle value for workout monitoring
770
+ count_text (str): Counts value for workout monitoring
771
+ stage_text (str): Stage decision for workout monitoring
772
+ center_kpt (list): Centroid pose index for workout monitoring
773
+ color (tuple, optional): Text background color
774
+ txt_color (tuple, optional): Text foreground color
775
+ """
776
+ # Format text
777
+ angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}"
778
+
779
+ # Draw angle, count and stage text
780
+ angle_height = self.plot_workout_information(
781
+ angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color
782
+ )
783
+ count_height = self.plot_workout_information(
784
+ count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color
785
+ )
786
+ self.plot_workout_information(
787
+ stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color
504
788
  )
505
- cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
506
789
 
507
- def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
790
+ def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)):
508
791
  """
509
792
  Function for drawing segmented object in bounding box shape.
510
793
 
511
794
  Args:
512
- mask (list): masks data list for instance segmentation area plotting
513
- mask_color (tuple): mask foreground color
514
- det_label (str): Detection label text
515
- track_label (str): Tracking label text
795
+ mask (np.ndarray): A 2D array of shape (N, 2) containing the contour points of the segmented object.
796
+ mask_color (tuple): RGB color for the contour and label background.
797
+ label (str, optional): Text label for the object. If None, no label is drawn.
798
+ txt_color (tuple): RGB color for the label text.
516
799
  """
517
- cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
518
-
519
- label = f"Track ID: {track_label}" if track_label else det_label
520
- text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
521
-
522
- cv2.rectangle(
523
- self.im,
524
- (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
525
- (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)),
526
- mask_color,
527
- -1,
528
- )
800
+ if mask.size == 0: # no masks to plot
801
+ return
529
802
 
530
- cv2.putText(
531
- self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2
532
- )
803
+ cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
804
+ text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
805
+
806
+ if label:
807
+ cv2.rectangle(
808
+ self.im,
809
+ (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
810
+ (int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),
811
+ mask_color,
812
+ -1,
813
+ )
814
+ cv2.putText(
815
+ self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
816
+ )
817
+
818
+ def sweep_annotator(self, line_x=0, line_y=0, label=None, color=(221, 0, 186), txt_color=(255, 255, 255)):
819
+ """
820
+ Function for drawing a sweep annotation line and an optional label.
533
821
 
534
- def plot_distance_and_line(self, distance_m, distance_mm, centroids, line_color, centroid_color):
822
+ Args:
823
+ line_x (int): The x-coordinate of the sweep line.
824
+ line_y (int): The y-coordinate limit of the sweep line.
825
+ label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn.
826
+ color (tuple): RGB color for the line and label background.
827
+ txt_color (tuple): RGB color for the label text.
828
+ """
829
+ # Draw the sweep line
830
+ cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2)
831
+
832
+ # Draw label, if provided
833
+ if label:
834
+ (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf)
835
+ cv2.rectangle(
836
+ self.im,
837
+ (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10),
838
+ (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10),
839
+ color,
840
+ -1,
841
+ )
842
+ cv2.putText(
843
+ self.im,
844
+ label,
845
+ (line_x - text_width // 2, line_y // 2 + text_height // 2),
846
+ cv2.FONT_HERSHEY_SIMPLEX,
847
+ self.sf,
848
+ txt_color,
849
+ self.tf,
850
+ )
851
+
852
+ def plot_distance_and_line(
853
+ self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255)
854
+ ):
535
855
  """
536
856
  Plot the distance and line on frame.
537
857
 
538
858
  Args:
539
- distance_m (float): Distance between two bbox centroids in meters.
540
- distance_mm (float): Distance between two bbox centroids in millimeters.
859
+ pixels_distance (float): Pixels distance between two bbox centroids.
541
860
  centroids (list): Bounding box centroids data.
542
- line_color (RGB): Distance line color.
543
- centroid_color (RGB): Bounding box centroid color.
861
+ line_color (tuple, optional): Distance line color.
862
+ centroid_color (tuple, optional): Bounding box centroid color.
544
863
  """
545
- (text_width_m, text_height_m), _ = cv2.getTextSize(
546
- f"Distance M: {distance_m:.2f}m", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
547
- )
548
- cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), (255, 255, 255), -1)
549
- cv2.putText(
550
- self.im,
551
- f"Distance M: {distance_m:.2f}m",
552
- (20, 50),
553
- cv2.FONT_HERSHEY_SIMPLEX,
554
- 0.8,
555
- (0, 0, 0),
556
- 2,
557
- cv2.LINE_AA,
558
- )
864
+ # Get the text size
865
+ text = f"Pixels Distance: {pixels_distance:.2f}"
866
+ (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
559
867
 
560
- (text_width_mm, text_height_mm), _ = cv2.getTextSize(
561
- f"Distance MM: {distance_mm:.2f}mm", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
562
- )
563
- cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), (255, 255, 255), -1)
868
+ # Define corners with 10-pixel margin and draw rectangle
869
+ cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
870
+
871
+ # Calculate the position for the text with a 10-pixel margin and draw text
872
+ text_position = (25, 25 + text_height_m + 10)
564
873
  cv2.putText(
565
874
  self.im,
566
- f"Distance MM: {distance_mm:.2f}mm",
567
- (20, 100),
568
- cv2.FONT_HERSHEY_SIMPLEX,
569
- 0.8,
570
- (0, 0, 0),
571
- 2,
875
+ text,
876
+ text_position,
877
+ 0,
878
+ self.sf,
879
+ (255, 255, 255),
880
+ self.tf,
572
881
  cv2.LINE_AA,
573
882
  )
574
883
 
@@ -576,7 +885,7 @@ class Annotator:
576
885
  cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
577
886
  cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
578
887
 
579
- def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
888
+ def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)):
580
889
  """
581
890
  Function for pinpoint human-vision eye mapping and plotting.
582
891
 
@@ -585,21 +894,19 @@ class Annotator:
585
894
  center_point (tuple): center point for vision eye view
586
895
  color (tuple): object centroid and line color value
587
896
  pin_color (tuple): visioneye point color value
588
- thickness (int): int value for line thickness
589
- pins_radius (int): visioneye point radius value
590
897
  """
591
898
  center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
592
- cv2.circle(self.im, center_point, pins_radius, pin_color, -1)
593
- cv2.circle(self.im, center_bbox, pins_radius, color, -1)
594
- cv2.line(self.im, center_point, center_bbox, color, thickness)
899
+ cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
900
+ cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
901
+ cv2.line(self.im, center_point, center_bbox, color, self.tf)
595
902
 
596
903
 
597
904
  @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
598
905
  @plt_settings()
599
906
  def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
600
907
  """Plot training labels including class histograms and box statistics."""
601
- import pandas as pd
602
- import seaborn as sn
908
+ import pandas # scope for faster 'import ultralytics'
909
+ import seaborn # scope for faster 'import ultralytics'
603
910
 
604
911
  # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
605
912
  warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
@@ -609,10 +916,10 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
609
916
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
610
917
  nc = int(cls.max() + 1) # number of classes
611
918
  boxes = boxes[:1000000] # limit to 1M boxes
612
- x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
919
+ x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
613
920
 
614
921
  # Seaborn correlogram
615
- sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
922
+ seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
616
923
  plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
617
924
  plt.close()
618
925
 
@@ -627,8 +934,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
627
934
  ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
628
935
  else:
629
936
  ax[0].set_xlabel("classes")
630
- sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
631
- sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
937
+ seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
938
+ seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
632
939
 
633
940
  # Rectangles
634
941
  boxes[:, 0:2] = 0.5 # center
@@ -676,11 +983,10 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
676
983
  from ultralytics.utils.plotting import save_one_box
677
984
 
678
985
  xyxy = [50, 50, 150, 150]
679
- im = cv2.imread('image.jpg')
680
- cropped_im = save_one_box(xyxy, im, file='cropped.jpg', square=True)
986
+ im = cv2.imread("image.jpg")
987
+ cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
681
988
  ```
682
989
  """
683
-
684
990
  if not isinstance(xyxy, torch.Tensor): # may be list
685
991
  xyxy = torch.stack(xyxy)
686
992
  b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
@@ -700,22 +1006,49 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
700
1006
 
701
1007
  @threaded
702
1008
  def plot_images(
703
- images,
704
- batch_idx,
705
- cls,
706
- bboxes=np.zeros(0, dtype=np.float32),
707
- confs=None,
708
- masks=np.zeros(0, dtype=np.uint8),
709
- kpts=np.zeros((0, 51), dtype=np.float32),
710
- paths=None,
711
- fname="images.jpg",
712
- names=None,
713
- on_plot=None,
714
- max_subplots=16,
715
- save=True,
716
- conf_thres=0.25,
717
- ):
718
- """Plot image grid with labels."""
1009
+ images: Union[torch.Tensor, np.ndarray],
1010
+ batch_idx: Union[torch.Tensor, np.ndarray],
1011
+ cls: Union[torch.Tensor, np.ndarray],
1012
+ bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
1013
+ confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
1014
+ masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
1015
+ kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
1016
+ paths: Optional[List[str]] = None,
1017
+ fname: str = "images.jpg",
1018
+ names: Optional[Dict[int, str]] = None,
1019
+ on_plot: Optional[Callable] = None,
1020
+ max_size: int = 1920,
1021
+ max_subplots: int = 16,
1022
+ save: bool = True,
1023
+ conf_thres: float = 0.25,
1024
+ ) -> Optional[np.ndarray]:
1025
+ """
1026
+ Plot image grid with labels, bounding boxes, masks, and keypoints.
1027
+
1028
+ Args:
1029
+ images: Batch of images to plot. Shape: (batch_size, channels, height, width).
1030
+ batch_idx: Batch indices for each detection. Shape: (num_detections,).
1031
+ cls: Class labels for each detection. Shape: (num_detections,).
1032
+ bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
1033
+ confs: Confidence scores for each detection. Shape: (num_detections,).
1034
+ masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
1035
+ kpts: Keypoints for each detection. Shape: (num_detections, 51).
1036
+ paths: List of file paths for each image in the batch.
1037
+ fname: Output filename for the plotted image grid.
1038
+ names: Dictionary mapping class indices to class names.
1039
+ on_plot: Optional callback function to be called after saving the plot.
1040
+ max_size: Maximum size of the output image grid.
1041
+ max_subplots: Maximum number of subplots in the image grid.
1042
+ save: Whether to save the plotted image grid to a file.
1043
+ conf_thres: Confidence threshold for displaying detections.
1044
+
1045
+ Returns:
1046
+ np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.
1047
+
1048
+ Note:
1049
+ This function supports both tensor and numpy array inputs. It will automatically
1050
+ convert tensor inputs to numpy arrays for processing.
1051
+ """
719
1052
  if isinstance(images, torch.Tensor):
720
1053
  images = images.cpu().float().numpy()
721
1054
  if isinstance(cls, torch.Tensor):
@@ -729,7 +1062,6 @@ def plot_images(
729
1062
  if isinstance(batch_idx, torch.Tensor):
730
1063
  batch_idx = batch_idx.cpu().numpy()
731
1064
 
732
- max_size = 1920 # max image size
733
1065
  bs, _, h, w = images.shape # batch size, _, height, width
734
1066
  bs = min(bs, max_subplots) # limit plot images
735
1067
  ns = np.ceil(bs**0.5) # number of subplots (square)
@@ -765,16 +1097,16 @@ def plot_images(
765
1097
  if len(bboxes):
766
1098
  boxes = bboxes[idx]
767
1099
  conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
768
- is_obb = boxes.shape[-1] == 5 # xywhr
769
- boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
770
1100
  if len(boxes):
771
1101
  if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
772
- boxes[..., 0::2] *= w # scale to pixels
773
- boxes[..., 1::2] *= h
1102
+ boxes[..., [0, 2]] *= w # scale to pixels
1103
+ boxes[..., [1, 3]] *= h
774
1104
  elif scale < 1: # absolute coords need scale if image scales
775
1105
  boxes[..., :4] *= scale
776
- boxes[..., 0::2] += x
777
- boxes[..., 1::2] += y
1106
+ boxes[..., 0] += x
1107
+ boxes[..., 1] += y
1108
+ is_obb = boxes.shape[-1] == 5 # xywhr
1109
+ boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
778
1110
  for j, box in enumerate(boxes.astype(np.int64).tolist()):
779
1111
  c = classes[j]
780
1112
  color = colors(c)
@@ -802,7 +1134,7 @@ def plot_images(
802
1134
  kpts_[..., 1] += y
803
1135
  for j in range(len(kpts_)):
804
1136
  if labels or conf[j] > conf_thres:
805
- annotator.kpts(kpts_[j])
1137
+ annotator.kpts(kpts_[j], conf_thres=conf_thres)
806
1138
 
807
1139
  # Plot masks
808
1140
  if len(masks):
@@ -826,10 +1158,12 @@ def plot_images(
826
1158
  mask = mask.astype(bool)
827
1159
  else:
828
1160
  mask = image_masks[j].astype(bool)
829
- with contextlib.suppress(Exception):
1161
+ try:
830
1162
  im[y : y + h, x : x + w, :][mask] = (
831
1163
  im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
832
1164
  )
1165
+ except Exception:
1166
+ pass
833
1167
  annotator.fromarray(im)
834
1168
  if not save:
835
1169
  return np.asarray(annotator.im)
@@ -857,25 +1191,25 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
857
1191
  ```python
858
1192
  from ultralytics.utils.plotting import plot_results
859
1193
 
860
- plot_results('path/to/results.csv', segment=True)
1194
+ plot_results("path/to/results.csv", segment=True)
861
1195
  ```
862
1196
  """
863
- import pandas as pd
1197
+ import pandas as pd # scope for faster 'import ultralytics'
864
1198
  from scipy.ndimage import gaussian_filter1d
865
1199
 
866
1200
  save_dir = Path(file).parent if file else Path(dir)
867
1201
  if classify:
868
1202
  fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
869
- index = [1, 4, 2, 3]
1203
+ index = [2, 5, 3, 4]
870
1204
  elif segment:
871
1205
  fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
872
- index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]
1206
+ index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
873
1207
  elif pose:
874
1208
  fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
875
- index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]
1209
+ index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
876
1210
  else:
877
1211
  fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
878
- index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
1212
+ index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
879
1213
  ax = ax.ravel()
880
1214
  files = list(save_dir.glob("results*.csv"))
881
1215
  assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
@@ -890,7 +1224,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
890
1224
  ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
891
1225
  ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
892
1226
  ax[i].set_title(s[j], fontsize=12)
893
- # if j in [8, 9, 10]: # share train and val loss y axes
1227
+ # if j in {8, 9, 10}: # share train and val loss y axes
894
1228
  # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
895
1229
  except Exception as e:
896
1230
  LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
@@ -919,7 +1253,6 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
919
1253
  >>> f = np.random.rand(100)
920
1254
  >>> plt_color_scatter(v, f)
921
1255
  """
922
-
923
1256
  # Calculate 2D histogram and corresponding colors
924
1257
  hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
925
1258
  colors = [
@@ -936,19 +1269,24 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
936
1269
 
937
1270
  def plot_tune_results(csv_file="tune_results.csv"):
938
1271
  """
939
- Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
1272
+ Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
940
1273
  in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
941
1274
 
942
1275
  Args:
943
1276
  csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
944
1277
 
945
1278
  Examples:
946
- >>> plot_tune_results('path/to/tune_results.csv')
1279
+ >>> plot_tune_results("path/to/tune_results.csv")
947
1280
  """
948
-
949
- import pandas as pd
1281
+ import pandas as pd # scope for faster 'import ultralytics'
950
1282
  from scipy.ndimage import gaussian_filter1d
951
1283
 
1284
+ def _save_one_file(file):
1285
+ """Save one matplotlib plot to 'file'."""
1286
+ plt.savefig(file, dpi=200)
1287
+ plt.close()
1288
+ LOGGER.info(f"Saved {file}")
1289
+
952
1290
  # Scatter plots for each hyperparameter
953
1291
  csv_file = Path(csv_file)
954
1292
  data = pd.read_csv(csv_file)
@@ -969,11 +1307,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
969
1307
  plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
970
1308
  if i % n != 0:
971
1309
  plt.yticks([])
972
-
973
- file = csv_file.with_name("tune_scatter_plots.png") # filename
974
- plt.savefig(file, dpi=200)
975
- plt.close()
976
- LOGGER.info(f"Saved {file}")
1310
+ _save_one_file(csv_file.with_name("tune_scatter_plots.png"))
977
1311
 
978
1312
  # Fitness vs iteration
979
1313
  x = range(1, len(fitness) + 1)
@@ -985,11 +1319,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
985
1319
  plt.ylabel("Fitness")
986
1320
  plt.grid(True)
987
1321
  plt.legend()
988
-
989
- file = csv_file.with_name("tune_fitness.png") # filename
990
- plt.savefig(file, dpi=200)
991
- plt.close()
992
- LOGGER.info(f"Saved {file}")
1322
+ _save_one_file(csv_file.with_name("tune_fitness.png"))
993
1323
 
994
1324
 
995
1325
  def output_to_target(output, max_det=300):
@@ -1025,23 +1355,24 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detec
1025
1355
  n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
1026
1356
  save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
1027
1357
  """
1028
- for m in ["Detect", "Pose", "Segment"]:
1358
+ for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads
1029
1359
  if m in module_type:
1030
1360
  return
1031
- _, channels, height, width = x.shape # batch, channels, height, width
1032
- if height > 1 and width > 1:
1033
- f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
1034
-
1035
- blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
1036
- n = min(n, channels) # number of plots
1037
- _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
1038
- ax = ax.ravel()
1039
- plt.subplots_adjust(wspace=0.05, hspace=0.05)
1040
- for i in range(n):
1041
- ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
1042
- ax[i].axis("off")
1043
-
1044
- LOGGER.info(f"Saving {f}... ({n}/{channels})")
1045
- plt.savefig(f, dpi=300, bbox_inches="tight")
1046
- plt.close()
1047
- np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save
1361
+ if isinstance(x, torch.Tensor):
1362
+ _, channels, height, width = x.shape # batch, channels, height, width
1363
+ if height > 1 and width > 1:
1364
+ f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
1365
+
1366
+ blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
1367
+ n = min(n, channels) # number of plots
1368
+ _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
1369
+ ax = ax.ravel()
1370
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
1371
+ for i in range(n):
1372
+ ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
1373
+ ax[i].axis("off")
1374
+
1375
+ LOGGER.info(f"Saving {f}... ({n}/{channels})")
1376
+ plt.savefig(f, dpi=300, bbox_inches="tight")
1377
+ plt.close()
1378
+ np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save