ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,337 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import sys
4
+ import time
5
+
6
+ import torch
7
+
8
+ from ultralytics.utils import LOGGER
9
+ from ultralytics.utils.metrics import batch_probiou, box_iou
10
+ from ultralytics.utils.ops import xywh2xyxy
11
+
12
+
13
+ def non_max_suppression(
14
+ prediction,
15
+ conf_thres: float = 0.25,
16
+ iou_thres: float = 0.45,
17
+ classes=None,
18
+ agnostic: bool = False,
19
+ multi_label: bool = False,
20
+ labels=(),
21
+ max_det: int = 300,
22
+ nc: int = 0, # number of classes (optional)
23
+ max_time_img: float = 0.05,
24
+ max_nms: int = 30000,
25
+ max_wh: int = 7680,
26
+ rotated: bool = False,
27
+ end2end: bool = False,
28
+ return_idxs: bool = False,
29
+ ):
30
+ """Perform non-maximum suppression (NMS) on prediction results.
31
+
32
+ Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple detection
33
+ formats including standard boxes, rotated boxes, and masks.
34
+
35
+ Args:
36
+ prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
37
+ containing boxes, classes, and optional masks.
38
+ conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
39
+ iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
40
+ classes (list[int], optional): List of class indices to consider. If None, all classes are considered.
41
+ agnostic (bool): Whether to perform class-agnostic NMS.
42
+ multi_label (bool): Whether each box can have multiple labels.
43
+ labels (list[list[Union[int, float, torch.Tensor]]]): A priori labels for each image.
44
+ max_det (int): Maximum number of detections to keep per image.
45
+ nc (int): Number of classes. Indices after this are considered masks.
46
+ max_time_img (float): Maximum time in seconds for processing one image.
47
+ max_nms (int): Maximum number of boxes for NMS.
48
+ max_wh (int): Maximum box width and height in pixels.
49
+ rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
50
+ end2end (bool): Whether the model is end-to-end and doesn't require NMS.
51
+ return_idxs (bool): Whether to return the indices of kept detections.
52
+
53
+ Returns:
54
+ output (list[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks) containing (x1,
55
+ y1, x2, y2, confidence, class, mask1, mask2, ...).
56
+ keepi (list[torch.Tensor]): Indices of kept detections if return_idxs=True.
57
+ """
58
+ # Checks
59
+ assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
60
+ assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
61
+ if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
62
+ prediction = prediction[0] # select only inference output
63
+ if classes is not None:
64
+ classes = torch.tensor(classes, device=prediction.device)
65
+
66
+ if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
67
+ output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
68
+ if classes is not None:
69
+ output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
70
+ return output
71
+
72
+ bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
73
+ nc = nc or (prediction.shape[1] - 4) # number of classes
74
+ extra = prediction.shape[1] - nc - 4 # number of extra info
75
+ mi = 4 + nc # mask start index
76
+ xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
77
+ xinds = torch.arange(prediction.shape[-1], device=prediction.device).expand(bs, -1)[..., None] # to track idxs
78
+
79
+ # Settings
80
+ # min_wh = 2 # (pixels) minimum box width and height
81
+ time_limit = 2.0 + max_time_img * bs # seconds to quit after
82
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
83
+
84
+ prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
85
+ if not rotated:
86
+ prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
87
+
88
+ t = time.time()
89
+ output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
90
+ keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
91
+ for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
92
+ # Apply constraints
93
+ # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
94
+ filt = xc[xi] # confidence
95
+ x = x[filt]
96
+ if return_idxs:
97
+ xk = xk[filt]
98
+
99
+ # Cat apriori labels if autolabelling
100
+ if labels and len(labels[xi]) and not rotated:
101
+ lb = labels[xi]
102
+ v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
103
+ v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
104
+ v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
105
+ x = torch.cat((x, v), 0)
106
+
107
+ # If none remain process next image
108
+ if not x.shape[0]:
109
+ continue
110
+
111
+ # Detections matrix nx6 (xyxy, conf, cls)
112
+ box, cls, mask = x.split((4, nc, extra), 1)
113
+
114
+ if multi_label:
115
+ i, j = torch.where(cls > conf_thres)
116
+ x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
117
+ if return_idxs:
118
+ xk = xk[i]
119
+ else: # best class only
120
+ conf, j = cls.max(1, keepdim=True)
121
+ filt = conf.view(-1) > conf_thres
122
+ x = torch.cat((box, conf, j.float(), mask), 1)[filt]
123
+ if return_idxs:
124
+ xk = xk[filt]
125
+
126
+ # Filter by class
127
+ if classes is not None:
128
+ filt = (x[:, 5:6] == classes).any(1)
129
+ x = x[filt]
130
+ if return_idxs:
131
+ xk = xk[filt]
132
+
133
+ # Check shape
134
+ n = x.shape[0] # number of boxes
135
+ if not n: # no boxes
136
+ continue
137
+ if n > max_nms: # excess boxes
138
+ filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
139
+ x = x[filt]
140
+ if return_idxs:
141
+ xk = xk[filt]
142
+
143
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
144
+ scores = x[:, 4] # scores
145
+ if rotated:
146
+ boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
147
+ i = TorchNMS.fast_nms(boxes, scores, iou_thres, iou_func=batch_probiou)
148
+ else:
149
+ boxes = x[:, :4] + c # boxes (offset by class)
150
+ # Speed strategy: torchvision for val or already loaded (faster), TorchNMS for predict (lower latency)
151
+ if "torchvision" in sys.modules:
152
+ import torchvision # scope as slow import
153
+
154
+ i = torchvision.ops.nms(boxes, scores, iou_thres)
155
+ else:
156
+ i = TorchNMS.nms(boxes, scores, iou_thres)
157
+ i = i[:max_det] # limit detections
158
+
159
+ output[xi] = x[i]
160
+ if return_idxs:
161
+ keepi[xi] = xk[i].view(-1)
162
+ if (time.time() - t) > time_limit:
163
+ LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
164
+ break # time limit exceeded
165
+
166
+ return (output, keepi) if return_idxs else output
167
+
168
+
169
+ class TorchNMS:
170
+ """Ultralytics custom NMS implementation optimized for YOLO.
171
+
172
+ This class provides static methods for performing non-maximum suppression (NMS) operations on bounding boxes,
173
+ including both standard NMS and batched NMS for multi-class scenarios.
174
+
175
+ Methods:
176
+ nms: Optimized NMS with early termination that matches torchvision behavior exactly.
177
+ batched_nms: Batched NMS for class-aware suppression.
178
+
179
+ Examples:
180
+ Perform standard NMS on boxes and scores
181
+ >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
182
+ >>> scores = torch.tensor([0.9, 0.8])
183
+ >>> keep = TorchNMS.nms(boxes, scores, 0.5)
184
+ """
185
+
186
+ @staticmethod
187
+ def fast_nms(
188
+ boxes: torch.Tensor,
189
+ scores: torch.Tensor,
190
+ iou_threshold: float,
191
+ use_triu: bool = True,
192
+ iou_func=box_iou,
193
+ exit_early: bool = True,
194
+ ) -> torch.Tensor:
195
+ """Fast-NMS implementation from https://arxiv.org/pdf/1904.02689 using upper triangular matrix operations.
196
+
197
+ Args:
198
+ boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
199
+ scores (torch.Tensor): Confidence scores with shape (N,).
200
+ iou_threshold (float): IoU threshold for suppression.
201
+ use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
202
+ iou_func (callable): Function to compute IoU between boxes.
203
+ exit_early (bool): Whether to exit early if there are no boxes.
204
+
205
+ Returns:
206
+ (torch.Tensor): Indices of boxes to keep after NMS.
207
+
208
+ Examples:
209
+ Apply NMS to a set of boxes
210
+ >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
211
+ >>> scores = torch.tensor([0.9, 0.8])
212
+ >>> keep = TorchNMS.nms(boxes, scores, 0.5)
213
+ """
214
+ if boxes.numel() == 0 and exit_early:
215
+ return torch.empty((0,), dtype=torch.int64, device=boxes.device)
216
+
217
+ sorted_idx = torch.argsort(scores, descending=True)
218
+ boxes = boxes[sorted_idx]
219
+ ious = iou_func(boxes, boxes)
220
+ if use_triu:
221
+ ious = ious.triu_(diagonal=1)
222
+ # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
223
+ pick = torch.nonzero((ious >= iou_threshold).sum(0) <= 0).squeeze_(-1)
224
+ else:
225
+ n = boxes.shape[0]
226
+ row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
227
+ col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
228
+ upper_mask = row_idx < col_idx
229
+ ious = ious * upper_mask
230
+ # Zeroing these scores ensures the additional indices would not affect the final results
231
+ scores_ = scores[sorted_idx]
232
+ scores_[~((ious >= iou_threshold).sum(0) <= 0)] = 0
233
+ scores[sorted_idx] = scores_ # update original tensor for NMSModel
234
+ # NOTE: return indices with fixed length to avoid TFLite reshape error
235
+ pick = torch.topk(scores_, scores_.shape[0]).indices
236
+ return sorted_idx[pick]
237
+
238
+ @staticmethod
239
+ def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
240
+ """Optimized NMS with early termination that matches torchvision behavior exactly.
241
+
242
+ Args:
243
+ boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
244
+ scores (torch.Tensor): Confidence scores with shape (N,).
245
+ iou_threshold (float): IoU threshold for suppression.
246
+
247
+ Returns:
248
+ (torch.Tensor): Indices of boxes to keep after NMS.
249
+
250
+ Examples:
251
+ Apply NMS to a set of boxes
252
+ >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
253
+ >>> scores = torch.tensor([0.9, 0.8])
254
+ >>> keep = TorchNMS.nms(boxes, scores, 0.5)
255
+ """
256
+ if boxes.numel() == 0:
257
+ return torch.empty((0,), dtype=torch.int64, device=boxes.device)
258
+
259
+ # Pre-allocate and extract coordinates once
260
+ x1, y1, x2, y2 = boxes.unbind(1)
261
+ areas = (x2 - x1) * (y2 - y1)
262
+
263
+ # Sort by scores descending
264
+ order = scores.argsort(0, descending=True)
265
+
266
+ # Pre-allocate keep list with maximum possible size
267
+ keep = torch.zeros(order.numel(), dtype=torch.int64, device=boxes.device)
268
+ keep_idx = 0
269
+ while order.numel() > 0:
270
+ i = order[0]
271
+ keep[keep_idx] = i
272
+ keep_idx += 1
273
+
274
+ if order.numel() == 1:
275
+ break
276
+ # Vectorized IoU calculation for remaining boxes
277
+ rest = order[1:]
278
+ xx1 = torch.maximum(x1[i], x1[rest])
279
+ yy1 = torch.maximum(y1[i], y1[rest])
280
+ xx2 = torch.minimum(x2[i], x2[rest])
281
+ yy2 = torch.minimum(y2[i], y2[rest])
282
+
283
+ # Fast intersection and IoU
284
+ w = (xx2 - xx1).clamp_(min=0)
285
+ h = (yy2 - yy1).clamp_(min=0)
286
+ inter = w * h
287
+ # Early exit: skip IoU calculation if no intersection
288
+ if inter.sum() == 0:
289
+ # No overlaps with current box, keep all remaining boxes
290
+ order = rest
291
+ continue
292
+ iou = inter / (areas[i] + areas[rest] - inter)
293
+ # Keep boxes with IoU <= threshold
294
+ order = rest[iou <= iou_threshold]
295
+
296
+ return keep[:keep_idx]
297
+
298
+ @staticmethod
299
+ def batched_nms(
300
+ boxes: torch.Tensor,
301
+ scores: torch.Tensor,
302
+ idxs: torch.Tensor,
303
+ iou_threshold: float,
304
+ use_fast_nms: bool = False,
305
+ ) -> torch.Tensor:
306
+ """Batched NMS for class-aware suppression.
307
+
308
+ Args:
309
+ boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
310
+ scores (torch.Tensor): Confidence scores with shape (N,).
311
+ idxs (torch.Tensor): Class indices with shape (N,).
312
+ iou_threshold (float): IoU threshold for suppression.
313
+ use_fast_nms (bool): Whether to use the Fast-NMS implementation.
314
+
315
+ Returns:
316
+ (torch.Tensor): Indices of boxes to keep after NMS.
317
+
318
+ Examples:
319
+ Apply batched NMS across multiple classes
320
+ >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
321
+ >>> scores = torch.tensor([0.9, 0.8])
322
+ >>> idxs = torch.tensor([0, 1])
323
+ >>> keep = TorchNMS.batched_nms(boxes, scores, idxs, 0.5)
324
+ """
325
+ if boxes.numel() == 0:
326
+ return torch.empty((0,), dtype=torch.int64, device=boxes.device)
327
+
328
+ # Strategy: offset boxes by class index to prevent cross-class suppression
329
+ max_coordinate = boxes.max()
330
+ offsets = idxs.to(boxes) * (max_coordinate + 1)
331
+ boxes_for_nms = boxes + offsets[:, None]
332
+
333
+ return (
334
+ TorchNMS.fast_nms(boxes_for_nms, scores, iou_threshold)
335
+ if use_fast_nms
336
+ else TorchNMS.nms(boxes_for_nms, scores, iou_threshold)
337
+ )