dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,885 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import contextlib
4
+ import math
5
+ import re
6
+ import time
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from ultralytics.utils import LOGGER
14
+ from ultralytics.utils.metrics import batch_probiou
15
+
16
+
17
+ class Profile(contextlib.ContextDecorator):
18
+ """
19
+ YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
20
+
21
+ Attributes:
22
+ t (float): Accumulated time.
23
+ device (torch.device): Device used for model inference.
24
+ cuda (bool): Whether CUDA is being used.
25
+
26
+ Examples:
27
+ >>> from ultralytics.utils.ops import Profile
28
+ >>> with Profile(device=device) as dt:
29
+ ... pass # slow operation here
30
+ >>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
31
+ """
32
+
33
+ def __init__(self, t=0.0, device: torch.device = None):
34
+ """
35
+ Initialize the Profile class.
36
+
37
+ Args:
38
+ t (float): Initial time.
39
+ device (torch.device): Device used for model inference.
40
+ """
41
+ self.t = t
42
+ self.device = device
43
+ self.cuda = bool(device and str(device).startswith("cuda"))
44
+
45
+ def __enter__(self):
46
+ """Start timing."""
47
+ self.start = self.time()
48
+ return self
49
+
50
+ def __exit__(self, type, value, traceback): # noqa
51
+ """Stop timing."""
52
+ self.dt = self.time() - self.start # delta-time
53
+ self.t += self.dt # accumulate dt
54
+
55
+ def __str__(self):
56
+ """Returns a human-readable string representing the accumulated elapsed time in the profiler."""
57
+ return f"Elapsed time is {self.t} s"
58
+
59
+ def time(self):
60
+ """Get current time."""
61
+ if self.cuda:
62
+ torch.cuda.synchronize(self.device)
63
+ return time.perf_counter()
64
+
65
+
66
+ def segment2box(segment, width=640, height=640):
67
+ """
68
+ Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
69
+
70
+ Args:
71
+ segment (torch.Tensor): The segment label.
72
+ width (int): The width of the image.
73
+ height (int): The height of the image.
74
+
75
+ Returns:
76
+ (np.ndarray): The minimum and maximum x and y values of the segment.
77
+ """
78
+ x, y = segment.T # segment xy
79
+ # any 3 out of 4 sides are outside the image, clip coordinates first, https://github.com/ultralytics/ultralytics/pull/18294
80
+ if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
81
+ x = x.clip(0, width)
82
+ y = y.clip(0, height)
83
+ inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
84
+ x = x[inside]
85
+ y = y[inside]
86
+ return (
87
+ np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
88
+ if any(x)
89
+ else np.zeros(4, dtype=segment.dtype)
90
+ ) # xyxy
91
+
92
+
93
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
94
+ """
95
+ Rescale bounding boxes from img1_shape to img0_shape.
96
+
97
+ Args:
98
+ img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
99
+ boxes (torch.Tensor): The bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2).
100
+ img0_shape (tuple): The shape of the target image, in the format of (height, width).
101
+ ratio_pad (tuple): A tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
102
+ calculated based on the size difference between the two images.
103
+ padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
104
+ rescaling.
105
+ xywh (bool): The box format is xywh or not.
106
+
107
+ Returns:
108
+ (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2).
109
+ """
110
+ if ratio_pad is None: # calculate from img0_shape
111
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
112
+ pad = (
113
+ round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
114
+ round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
115
+ ) # wh padding
116
+ else:
117
+ gain = ratio_pad[0][0]
118
+ pad = ratio_pad[1]
119
+
120
+ if padding:
121
+ boxes[..., 0] -= pad[0] # x padding
122
+ boxes[..., 1] -= pad[1] # y padding
123
+ if not xywh:
124
+ boxes[..., 2] -= pad[0] # x padding
125
+ boxes[..., 3] -= pad[1] # y padding
126
+ boxes[..., :4] /= gain
127
+ return clip_boxes(boxes, img0_shape)
128
+
129
+
130
+ def make_divisible(x, divisor):
131
+ """
132
+ Returns the nearest number that is divisible by the given divisor.
133
+
134
+ Args:
135
+ x (int): The number to make divisible.
136
+ divisor (int | torch.Tensor): The divisor.
137
+
138
+ Returns:
139
+ (int): The nearest number divisible by the divisor.
140
+ """
141
+ if isinstance(divisor, torch.Tensor):
142
+ divisor = int(divisor.max()) # to int
143
+ return math.ceil(x / divisor) * divisor
144
+
145
+
146
+ def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
147
+ """
148
+ NMS for oriented bounding boxes using probiou and fast-nms.
149
+
150
+ Args:
151
+ boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
152
+ scores (torch.Tensor): Confidence scores, shape (N,).
153
+ threshold (float): IoU threshold.
154
+ use_triu (bool): Whether to use `torch.triu` operator. It'd be useful for disable it
155
+ when exporting obb models to some formats that do not support `torch.triu`.
156
+
157
+ Returns:
158
+ (torch.Tensor): Indices of boxes to keep after NMS.
159
+ """
160
+ sorted_idx = torch.argsort(scores, descending=True)
161
+ boxes = boxes[sorted_idx]
162
+ ious = batch_probiou(boxes, boxes)
163
+ if use_triu:
164
+ ious = ious.triu_(diagonal=1)
165
+ # pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
166
+ # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
167
+ pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
168
+ else:
169
+ n = boxes.shape[0]
170
+ row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
171
+ col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
172
+ upper_mask = row_idx < col_idx
173
+ ious = ious * upper_mask
174
+ # Zeroing these scores ensures the additional indices would not affect the final results
175
+ scores[~((ious >= threshold).sum(0) <= 0)] = 0
176
+ # NOTE: return indices with fixed length to avoid TFLite reshape error
177
+ pick = torch.topk(scores, scores.shape[0]).indices
178
+ return sorted_idx[pick]
179
+
180
+
181
+ def non_max_suppression(
182
+ prediction,
183
+ conf_thres=0.25,
184
+ iou_thres=0.45,
185
+ classes=None,
186
+ agnostic=False,
187
+ multi_label=False,
188
+ labels=(),
189
+ max_det=300,
190
+ nc=0, # number of classes (optional)
191
+ max_time_img=0.05,
192
+ max_nms=30000,
193
+ max_wh=7680,
194
+ in_place=True,
195
+ rotated=False,
196
+ end2end=False,
197
+ return_idxs=False,
198
+ ):
199
+ """
200
+ Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
201
+
202
+ Args:
203
+ prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
204
+ containing the predicted boxes, classes, and masks. The tensor should be in the format
205
+ output by a model, such as YOLO.
206
+ conf_thres (float): The confidence threshold below which boxes will be filtered out.
207
+ Valid values are between 0.0 and 1.0.
208
+ iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
209
+ Valid values are between 0.0 and 1.0.
210
+ classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
211
+ agnostic (bool): If True, the model is agnostic to the number of classes, and all
212
+ classes will be considered as one.
213
+ multi_label (bool): If True, each box may have multiple labels.
214
+ labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
215
+ list contains the apriori labels for a given image. The list should be in the format
216
+ output by a dataloader, with each label being a tuple of (class_index, x, y, w, h).
217
+ max_det (int): The maximum number of boxes to keep after NMS.
218
+ nc (int): The number of classes output by the model. Any indices after this will be considered masks.
219
+ max_time_img (float): The maximum time (seconds) for processing one image.
220
+ max_nms (int): The maximum number of boxes into torchvision.ops.nms().
221
+ max_wh (int): The maximum box width and height in pixels.
222
+ in_place (bool): If True, the input prediction tensor will be modified in place.
223
+ rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
224
+ end2end (bool): If the model doesn't require NMS.
225
+ return_idxs (bool): Return the indices of the detections that were kept.
226
+
227
+ Returns:
228
+ (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
229
+ shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
230
+ (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
231
+ """
232
+ import torchvision # scope for faster 'import ultralytics'
233
+
234
+ # Checks
235
+ assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
236
+ assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
237
+ if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
238
+ prediction = prediction[0] # select only inference output
239
+ if classes is not None:
240
+ classes = torch.tensor(classes, device=prediction.device)
241
+
242
+ if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
243
+ output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
244
+ if classes is not None:
245
+ output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
246
+ return output
247
+
248
+ bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
249
+ nc = nc or (prediction.shape[1] - 4) # number of classes
250
+ nm = prediction.shape[1] - nc - 4 # number of masks
251
+ mi = 4 + nc # mask start index
252
+ xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
253
+ xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None] # to track idxs
254
+
255
+ # Settings
256
+ # min_wh = 2 # (pixels) minimum box width and height
257
+ time_limit = 2.0 + max_time_img * bs # seconds to quit after
258
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
259
+
260
+ prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
261
+ if not rotated:
262
+ if in_place:
263
+ prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
264
+ else:
265
+ prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
266
+
267
+ t = time.time()
268
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
269
+ keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
270
+ for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
271
+ # Apply constraints
272
+ # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
273
+ filt = xc[xi] # confidence
274
+ x, xk = x[filt], xk[filt]
275
+
276
+ # Cat apriori labels if autolabelling
277
+ if labels and len(labels[xi]) and not rotated:
278
+ lb = labels[xi]
279
+ v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
280
+ v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
281
+ v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
282
+ x = torch.cat((x, v), 0)
283
+
284
+ # If none remain process next image
285
+ if not x.shape[0]:
286
+ continue
287
+
288
+ # Detections matrix nx6 (xyxy, conf, cls)
289
+ box, cls, mask = x.split((4, nc, nm), 1)
290
+
291
+ if multi_label:
292
+ i, j = torch.where(cls > conf_thres)
293
+ x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
294
+ xk = xk[i]
295
+ else: # best class only
296
+ conf, j = cls.max(1, keepdim=True)
297
+ filt = conf.view(-1) > conf_thres
298
+ x = torch.cat((box, conf, j.float(), mask), 1)[filt]
299
+ xk = xk[filt]
300
+
301
+ # Filter by class
302
+ if classes is not None:
303
+ filt = (x[:, 5:6] == classes).any(1)
304
+ x, xk = x[filt], xk[filt]
305
+
306
+ # Check shape
307
+ n = x.shape[0] # number of boxes
308
+ if not n: # no boxes
309
+ continue
310
+ if n > max_nms: # excess boxes
311
+ filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
312
+ x, xk = x[filt], xk[filt]
313
+
314
+ # Batched NMS
315
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
316
+ scores = x[:, 4] # scores
317
+ if rotated:
318
+ boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
319
+ i = nms_rotated(boxes, scores, iou_thres)
320
+ else:
321
+ boxes = x[:, :4] + c # boxes (offset by class)
322
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
323
+ i = i[:max_det] # limit detections
324
+
325
+ # # Experimental
326
+ # merge = False # use merge-NMS
327
+ # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
328
+ # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
329
+ # from .metrics import box_iou
330
+ # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
331
+ # weights = iou * scores[None] # box weights
332
+ # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
333
+ # redundant = True # require redundant detections
334
+ # if redundant:
335
+ # i = i[iou.sum(1) > 1] # require redundancy
336
+
337
+ output[xi], keepi[xi] = x[i], xk[i].reshape(-1)
338
+ if (time.time() - t) > time_limit:
339
+ LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
340
+ break # time limit exceeded
341
+
342
+ return (output, keepi) if return_idxs else output
343
+
344
+
345
+ def clip_boxes(boxes, shape):
346
+ """
347
+ Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
348
+
349
+ Args:
350
+ boxes (torch.Tensor | numpy.ndarray): The bounding boxes to clip.
351
+ shape (tuple): The shape of the image.
352
+
353
+ Returns:
354
+ (torch.Tensor | numpy.ndarray): The clipped boxes.
355
+ """
356
+ if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
357
+ boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
358
+ boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
359
+ boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
360
+ boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
361
+ else: # np.array (faster grouped)
362
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
363
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
364
+ return boxes
365
+
366
+
367
+ def clip_coords(coords, shape):
368
+ """
369
+ Clip line coordinates to the image boundaries.
370
+
371
+ Args:
372
+ coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
373
+ shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
374
+
375
+ Returns:
376
+ (torch.Tensor | numpy.ndarray): Clipped coordinates.
377
+ """
378
+ if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
379
+ coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
380
+ coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
381
+ else: # np.array (faster grouped)
382
+ coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
383
+ coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
384
+ return coords
385
+
386
+
387
+ def scale_image(masks, im0_shape, ratio_pad=None):
388
+ """
389
+ Takes a mask, and resizes it to the original image size.
390
+
391
+ Args:
392
+ masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3].
393
+ im0_shape (tuple): The original image shape.
394
+ ratio_pad (tuple): The ratio of the padding to the original image.
395
+
396
+ Returns:
397
+ masks (np.ndarray): The masks that are being returned with shape [h, w, num].
398
+ """
399
+ # Rescale coordinates (xyxy) from im1_shape to im0_shape
400
+ im1_shape = masks.shape
401
+ if im1_shape[:2] == im0_shape[:2]:
402
+ return masks
403
+ if ratio_pad is None: # calculate from im0_shape
404
+ gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
405
+ pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
406
+ else:
407
+ # gain = ratio_pad[0][0]
408
+ pad = ratio_pad[1]
409
+ top, left = int(pad[1]), int(pad[0]) # y, x
410
+ bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
411
+
412
+ if len(masks.shape) < 2:
413
+ raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
414
+ masks = masks[top:bottom, left:right]
415
+ masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
416
+ if len(masks.shape) == 2:
417
+ masks = masks[:, :, None]
418
+
419
+ return masks
420
+
421
+
422
+ def xyxy2xywh(x):
423
+ """
424
+ Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
425
+ top-left corner and (x2, y2) is the bottom-right corner.
426
+
427
+ Args:
428
+ x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
429
+
430
+ Returns:
431
+ y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
432
+ """
433
+ assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
434
+ y = empty_like(x) # faster than clone/copy
435
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
436
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
437
+ y[..., 2] = x[..., 2] - x[..., 0] # width
438
+ y[..., 3] = x[..., 3] - x[..., 1] # height
439
+ return y
440
+
441
+
442
+ def xywh2xyxy(x):
443
+ """
444
+ Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
445
+ top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
446
+
447
+ Args:
448
+ x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
449
+
450
+ Returns:
451
+ y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
452
+ """
453
+ assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
454
+ y = empty_like(x) # faster than clone/copy
455
+ xy = x[..., :2] # centers
456
+ wh = x[..., 2:] / 2 # half width-height
457
+ y[..., :2] = xy - wh # top left xy
458
+ y[..., 2:] = xy + wh # bottom right xy
459
+ return y
460
+
461
+
462
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
463
+ """
464
+ Convert normalized bounding box coordinates to pixel coordinates.
465
+
466
+ Args:
467
+ x (np.ndarray | torch.Tensor): The bounding box coordinates.
468
+ w (int): Width of the image.
469
+ h (int): Height of the image.
470
+ padw (int): Padding width.
471
+ padh (int): Padding height.
472
+
473
+ Returns:
474
+ y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
475
+ x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
476
+ """
477
+ assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
478
+ y = empty_like(x) # faster than clone/copy
479
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
480
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
481
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
482
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
483
+ return y
484
+
485
+
486
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
487
+ """
488
+ Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
489
+ width and height are normalized to image dimensions.
490
+
491
+ Args:
492
+ x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
493
+ w (int): The width of the image.
494
+ h (int): The height of the image.
495
+ clip (bool): If True, the boxes will be clipped to the image boundaries.
496
+ eps (float): The minimum value of the box's width and height.
497
+
498
+ Returns:
499
+ y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
500
+ """
501
+ if clip:
502
+ x = clip_boxes(x, (h - eps, w - eps))
503
+ assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
504
+ y = empty_like(x) # faster than clone/copy
505
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
506
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
507
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
508
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
509
+ return y
510
+
511
+
512
+ def xywh2ltwh(x):
513
+ """
514
+ Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
515
+
516
+ Args:
517
+ x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
518
+
519
+ Returns:
520
+ y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
521
+ """
522
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
523
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
524
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
525
+ return y
526
+
527
+
528
+ def xyxy2ltwh(x):
529
+ """
530
+ Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
531
+
532
+ Args:
533
+ x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
534
+
535
+ Returns:
536
+ y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
537
+ """
538
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
539
+ y[..., 2] = x[..., 2] - x[..., 0] # width
540
+ y[..., 3] = x[..., 3] - x[..., 1] # height
541
+ return y
542
+
543
+
544
+ def ltwh2xywh(x):
545
+ """
546
+ Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
547
+
548
+ Args:
549
+ x (torch.Tensor): the input tensor
550
+
551
+ Returns:
552
+ y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
553
+ """
554
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
555
+ y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
556
+ y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
557
+ return y
558
+
559
+
560
+ def xyxyxyxy2xywhr(x):
561
+ """
562
+ Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
563
+ returned in radians from 0 to pi/2.
564
+
565
+ Args:
566
+ x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8).
567
+
568
+ Returns:
569
+ (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
570
+ """
571
+ is_torch = isinstance(x, torch.Tensor)
572
+ points = x.cpu().numpy() if is_torch else x
573
+ points = points.reshape(len(x), -1, 2)
574
+ rboxes = []
575
+ for pts in points:
576
+ # NOTE: Use cv2.minAreaRect to get accurate xywhr,
577
+ # especially some objects are cut off by augmentations in dataloader.
578
+ (cx, cy), (w, h), angle = cv2.minAreaRect(pts)
579
+ rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
580
+ return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
581
+
582
+
583
+ def xywhr2xyxyxyxy(x):
584
+ """
585
+ Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
586
+ be in radians from 0 to pi/2.
587
+
588
+ Args:
589
+ x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
590
+
591
+ Returns:
592
+ (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
593
+ """
594
+ cos, sin, cat, stack = (
595
+ (torch.cos, torch.sin, torch.cat, torch.stack)
596
+ if isinstance(x, torch.Tensor)
597
+ else (np.cos, np.sin, np.concatenate, np.stack)
598
+ )
599
+
600
+ ctr = x[..., :2]
601
+ w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
602
+ cos_value, sin_value = cos(angle), sin(angle)
603
+ vec1 = [w / 2 * cos_value, w / 2 * sin_value]
604
+ vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
605
+ vec1 = cat(vec1, -1)
606
+ vec2 = cat(vec2, -1)
607
+ pt1 = ctr + vec1 + vec2
608
+ pt2 = ctr + vec1 - vec2
609
+ pt3 = ctr - vec1 - vec2
610
+ pt4 = ctr - vec1 + vec2
611
+ return stack([pt1, pt2, pt3, pt4], -2)
612
+
613
+
614
+ def ltwh2xyxy(x):
615
+ """
616
+ Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
617
+
618
+ Args:
619
+ x (np.ndarray | torch.Tensor): The input image.
620
+
621
+ Returns:
622
+ (np.ndarray | torch.Tensor): The xyxy coordinates of the bounding boxes.
623
+ """
624
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
625
+ y[..., 2] = x[..., 2] + x[..., 0] # width
626
+ y[..., 3] = x[..., 3] + x[..., 1] # height
627
+ return y
628
+
629
+
630
+ def segments2boxes(segments):
631
+ """
632
+ Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
633
+
634
+ Args:
635
+ segments (list): List of segments, each segment is a list of points, each point is a list of x, y coordinates.
636
+
637
+ Returns:
638
+ (np.ndarray): The xywh coordinates of the bounding boxes.
639
+ """
640
+ boxes = []
641
+ for s in segments:
642
+ x, y = s.T # segment xy
643
+ boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
644
+ return xyxy2xywh(np.array(boxes)) # cls, xywh
645
+
646
+
647
+ def resample_segments(segments, n=1000):
648
+ """
649
+ Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
650
+
651
+ Args:
652
+ segments (list): A list of (n,2) arrays, where n is the number of points in the segment.
653
+ n (int): Number of points to resample the segment to.
654
+
655
+ Returns:
656
+ segments (list): The resampled segments.
657
+ """
658
+ for i, s in enumerate(segments):
659
+ if len(s) == n:
660
+ continue
661
+ s = np.concatenate((s, s[0:1, :]), axis=0)
662
+ x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)
663
+ xp = np.arange(len(s))
664
+ x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x
665
+ segments[i] = (
666
+ np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
667
+ ) # segment xy
668
+ return segments
669
+
670
+
671
+ def crop_mask(masks, boxes):
672
+ """
673
+ Crop masks to bounding boxes.
674
+
675
+ Args:
676
+ masks (torch.Tensor): [n, h, w] tensor of masks.
677
+ boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form.
678
+
679
+ Returns:
680
+ (torch.Tensor): Cropped masks.
681
+ """
682
+ _, h, w = masks.shape
683
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
684
+ r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
685
+ c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
686
+
687
+ return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
688
+
689
+
690
+ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
691
+ """
692
+ Apply masks to bounding boxes using the output of the mask head.
693
+
694
+ Args:
695
+ protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
696
+ masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
697
+ bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
698
+ shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
699
+ upsample (bool): A flag to indicate whether to upsample the mask to the original image size.
700
+
701
+ Returns:
702
+ (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
703
+ are the height and width of the input image. The mask is applied to the bounding boxes.
704
+ """
705
+ c, mh, mw = protos.shape # CHW
706
+ ih, iw = shape
707
+ masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
708
+ width_ratio = mw / iw
709
+ height_ratio = mh / ih
710
+
711
+ downsampled_bboxes = bboxes.clone()
712
+ downsampled_bboxes[:, 0] *= width_ratio
713
+ downsampled_bboxes[:, 2] *= width_ratio
714
+ downsampled_bboxes[:, 3] *= height_ratio
715
+ downsampled_bboxes[:, 1] *= height_ratio
716
+
717
+ masks = crop_mask(masks, downsampled_bboxes) # CHW
718
+ if upsample:
719
+ masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
720
+ return masks.gt_(0.0)
721
+
722
+
723
+ def process_mask_native(protos, masks_in, bboxes, shape):
724
+ """
725
+ Apply masks to bounding boxes using the output of the mask head with native upsampling.
726
+
727
+ Args:
728
+ protos (torch.Tensor): [mask_dim, mask_h, mask_w].
729
+ masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
730
+ bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
731
+ shape (tuple): The size of the input image (h,w).
732
+
733
+ Returns:
734
+ (torch.Tensor): The returned masks with dimensions [h, w, n].
735
+ """
736
+ c, mh, mw = protos.shape # CHW
737
+ masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
738
+ masks = scale_masks(masks[None], shape)[0] # CHW
739
+ masks = crop_mask(masks, bboxes) # CHW
740
+ return masks.gt_(0.0)
741
+
742
+
743
+ def scale_masks(masks, shape, padding=True):
744
+ """
745
+ Rescale segment masks to shape.
746
+
747
+ Args:
748
+ masks (torch.Tensor): (N, C, H, W).
749
+ shape (tuple): Height and width.
750
+ padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
751
+ rescaling.
752
+
753
+ Returns:
754
+ (torch.Tensor): Rescaled masks.
755
+ """
756
+ mh, mw = masks.shape[2:]
757
+ gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
758
+ pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
759
+ if padding:
760
+ pad[0] /= 2
761
+ pad[1] /= 2
762
+ top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
763
+ bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
764
+ masks = masks[..., top:bottom, left:right]
765
+
766
+ masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
767
+ return masks
768
+
769
+
770
+ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
771
+ """
772
+ Rescale segment coordinates (xy) from img1_shape to img0_shape.
773
+
774
+ Args:
775
+ img1_shape (tuple): The shape of the image that the coords are from.
776
+ coords (torch.Tensor): The coords to be scaled of shape n,2.
777
+ img0_shape (tuple): The shape of the image that the segmentation is being applied to.
778
+ ratio_pad (tuple): The ratio of the image size to the padded image size.
779
+ normalize (bool): If True, the coordinates will be normalized to the range [0, 1].
780
+ padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
781
+ rescaling.
782
+
783
+ Returns:
784
+ coords (torch.Tensor): The scaled coordinates.
785
+ """
786
+ if ratio_pad is None: # calculate from img0_shape
787
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
788
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
789
+ else:
790
+ gain = ratio_pad[0][0]
791
+ pad = ratio_pad[1]
792
+
793
+ if padding:
794
+ coords[..., 0] -= pad[0] # x padding
795
+ coords[..., 1] -= pad[1] # y padding
796
+ coords[..., 0] /= gain
797
+ coords[..., 1] /= gain
798
+ coords = clip_coords(coords, img0_shape)
799
+ if normalize:
800
+ coords[..., 0] /= img0_shape[1] # width
801
+ coords[..., 1] /= img0_shape[0] # height
802
+ return coords
803
+
804
+
805
+ def regularize_rboxes(rboxes):
806
+ """
807
+ Regularize rotated boxes in range [0, pi/2].
808
+
809
+ Args:
810
+ rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format.
811
+
812
+ Returns:
813
+ (torch.Tensor): The regularized boxes.
814
+ """
815
+ x, y, w, h, t = rboxes.unbind(dim=-1)
816
+ # Swap edge if t >= pi/2 while not being symmetrically opposite
817
+ swap = t % math.pi >= math.pi / 2
818
+ w_ = torch.where(swap, h, w)
819
+ h_ = torch.where(swap, w, h)
820
+ t = t % (math.pi / 2)
821
+ return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
822
+
823
+
824
+ def masks2segments(masks, strategy="all"):
825
+ """
826
+ Convert masks to segments.
827
+
828
+ Args:
829
+ masks (torch.Tensor): The output of the model, which is a tensor of shape (batch_size, 160, 160).
830
+ strategy (str): 'all' or 'largest'.
831
+
832
+ Returns:
833
+ (list): List of segment masks.
834
+ """
835
+ from ultralytics.data.converter import merge_multi_segment
836
+
837
+ segments = []
838
+ for x in masks.int().cpu().numpy().astype("uint8"):
839
+ c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
840
+ if c:
841
+ if strategy == "all": # merge and concatenate all segments
842
+ c = (
843
+ np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))
844
+ if len(c) > 1
845
+ else c[0].reshape(-1, 2)
846
+ )
847
+ elif strategy == "largest": # select largest segment
848
+ c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
849
+ else:
850
+ c = np.zeros((0, 2)) # no segments found
851
+ segments.append(c.astype("float32"))
852
+ return segments
853
+
854
+
855
+ def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
856
+ """
857
+ Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
858
+
859
+ Args:
860
+ batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
861
+
862
+ Returns:
863
+ (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
864
+ """
865
+ return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
866
+
867
+
868
+ def clean_str(s):
869
+ """
870
+ Cleans a string by replacing special characters with '_' character.
871
+
872
+ Args:
873
+ s (str): A string needing special characters replaced.
874
+
875
+ Returns:
876
+ (str): A string with special characters replaced by an underscore _.
877
+ """
878
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
879
+
880
+
881
+ def empty_like(x):
882
+ """Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
883
+ return (
884
+ torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
885
+ )