dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,89 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import copy
6
+
7
+ from ultralytics.models.yolo.detect import DetectionTrainer
8
+ from ultralytics.nn.tasks import RTDETRDetectionModel
9
+ from ultralytics.utils import RANK, colorstr
10
+
11
+ from .val import RTDETRDataset, RTDETRValidator
12
+
13
+
14
+ class RTDETRTrainer(DetectionTrainer):
15
+ """Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
16
+
17
+ This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
18
+ RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
19
+ inference speed.
20
+
21
+ Attributes:
22
+ loss_names (tuple): Names of the loss components used for training.
23
+ data (dict): Dataset configuration containing class count and other parameters.
24
+ args (dict): Training arguments and hyperparameters.
25
+ save_dir (Path): Directory to save training results.
26
+ test_loader (DataLoader): DataLoader for validation/testing data.
27
+
28
+ Methods:
29
+ get_model: Initialize and return an RT-DETR model for object detection tasks.
30
+ build_dataset: Build and return an RT-DETR dataset for training or validation.
31
+ get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
32
+
33
+ Examples:
34
+ >>> from ultralytics.models.rtdetr.train import RTDETRTrainer
35
+ >>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
36
+ >>> trainer = RTDETRTrainer(overrides=args)
37
+ >>> trainer.train()
38
+
39
+ Notes:
40
+ - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
41
+ - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
42
+ """
43
+
44
+ def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
45
+ """Initialize and return an RT-DETR model for object detection tasks.
46
+
47
+ Args:
48
+ cfg (dict, optional): Model configuration.
49
+ weights (str, optional): Path to pre-trained model weights.
50
+ verbose (bool): Verbose logging if True.
51
+
52
+ Returns:
53
+ (RTDETRDetectionModel): Initialized model.
54
+ """
55
+ model = RTDETRDetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
56
+ if weights:
57
+ model.load(weights)
58
+ return model
59
+
60
+ def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
61
+ """Build and return an RT-DETR dataset for training or validation.
62
+
63
+ Args:
64
+ img_path (str): Path to the folder containing images.
65
+ mode (str): Dataset mode, either 'train' or 'val'.
66
+ batch (int, optional): Batch size for rectangle training.
67
+
68
+ Returns:
69
+ (RTDETRDataset): Dataset object for the specific mode.
70
+ """
71
+ return RTDETRDataset(
72
+ img_path=img_path,
73
+ imgsz=self.args.imgsz,
74
+ batch_size=batch,
75
+ augment=mode == "train",
76
+ hyp=self.args,
77
+ rect=False,
78
+ cache=self.args.cache or None,
79
+ single_cls=self.args.single_cls or False,
80
+ prefix=colorstr(f"{mode}: "),
81
+ classes=self.args.classes,
82
+ data=self.data,
83
+ fraction=self.args.fraction if mode == "train" else 1.0,
84
+ )
85
+
86
+ def get_validator(self):
87
+ """Return a DetectionValidator suitable for RT-DETR model validation."""
88
+ self.loss_names = "giou_loss", "cls_loss", "l1_loss"
89
+ return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
@@ -0,0 +1,216 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from ultralytics.data import YOLODataset
11
+ from ultralytics.data.augment import Compose, Format, v8_transforms
12
+ from ultralytics.models.yolo.detect import DetectionValidator
13
+ from ultralytics.utils import colorstr, ops
14
+
15
+ __all__ = ("RTDETRValidator",) # tuple or list
16
+
17
+
18
+ class RTDETRDataset(YOLODataset):
19
+ """Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
20
+
21
+ This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
22
+ real-time detection and tracking tasks.
23
+
24
+ Attributes:
25
+ augment (bool): Whether to apply data augmentation.
26
+ rect (bool): Whether to use rectangular training.
27
+ use_segments (bool): Whether to use segmentation masks.
28
+ use_keypoints (bool): Whether to use keypoint annotations.
29
+ imgsz (int): Target image size for training.
30
+
31
+ Methods:
32
+ load_image: Load one image from dataset index.
33
+ build_transforms: Build transformation pipeline for the dataset.
34
+
35
+ Examples:
36
+ Initialize an RT-DETR dataset
37
+ >>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
38
+ >>> image, hw0, hw = dataset.load_image(0)
39
+ """
40
+
41
+ def __init__(self, *args, data=None, **kwargs):
42
+ """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
43
+
44
+ This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
45
+ model, building upon the base YOLODataset functionality.
46
+
47
+ Args:
48
+ *args (Any): Variable length argument list passed to the parent YOLODataset class.
49
+ data (dict | None): Dictionary containing dataset information. If None, default values will be used.
50
+ **kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
51
+ """
52
+ super().__init__(*args, data=data, **kwargs)
53
+
54
+ def load_image(self, i, rect_mode=False):
55
+ """Load one image from dataset index 'i'.
56
+
57
+ Args:
58
+ i (int): Index of the image to load.
59
+ rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
60
+
61
+ Returns:
62
+ im (np.ndarray): Loaded image as a NumPy array.
63
+ hw_original (tuple[int, int]): Original image dimensions in (height, width) format.
64
+ hw_resized (tuple[int, int]): Resized image dimensions in (height, width) format.
65
+
66
+ Examples:
67
+ Load an image from the dataset
68
+ >>> dataset = RTDETRDataset(img_path="path/to/images")
69
+ >>> image, hw0, hw = dataset.load_image(0)
70
+ """
71
+ return super().load_image(i=i, rect_mode=rect_mode)
72
+
73
+ def build_transforms(self, hyp=None):
74
+ """Build transformation pipeline for the dataset.
75
+
76
+ Args:
77
+ hyp (dict, optional): Hyperparameters for transformations.
78
+
79
+ Returns:
80
+ (Compose): Composition of transformation functions.
81
+ """
82
+ if self.augment:
83
+ hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
84
+ hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
85
+ hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0
86
+ transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
87
+ else:
88
+ # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
89
+ transforms = Compose([])
90
+ transforms.append(
91
+ Format(
92
+ bbox_format="xywh",
93
+ normalize=True,
94
+ return_mask=self.use_segments,
95
+ return_keypoint=self.use_keypoints,
96
+ batch_idx=True,
97
+ mask_ratio=hyp.mask_ratio,
98
+ mask_overlap=hyp.overlap_mask,
99
+ )
100
+ )
101
+ return transforms
102
+
103
+
104
+ class RTDETRValidator(DetectionValidator):
105
+ """RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
106
+ the RT-DETR (Real-Time DETR) object detection model.
107
+
108
+ The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
109
+ post-processing, and updates evaluation metrics accordingly.
110
+
111
+ Attributes:
112
+ args (Namespace): Configuration arguments for validation.
113
+ data (dict): Dataset configuration dictionary.
114
+
115
+ Methods:
116
+ build_dataset: Build an RTDETR Dataset for validation.
117
+ postprocess: Apply Non-maximum suppression to prediction outputs.
118
+
119
+ Examples:
120
+ Initialize and run RT-DETR validation
121
+ >>> from ultralytics.models.rtdetr import RTDETRValidator
122
+ >>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
123
+ >>> validator = RTDETRValidator(args=args)
124
+ >>> validator()
125
+
126
+ Notes:
127
+ For further details on the attributes and methods, refer to the parent DetectionValidator class.
128
+ """
129
+
130
+ def build_dataset(self, img_path, mode="val", batch=None):
131
+ """Build an RTDETR Dataset.
132
+
133
+ Args:
134
+ img_path (str): Path to the folder containing images.
135
+ mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for
136
+ each mode.
137
+ batch (int, optional): Size of batches, this is for `rect`.
138
+
139
+ Returns:
140
+ (RTDETRDataset): Dataset configured for RT-DETR validation.
141
+ """
142
+ return RTDETRDataset(
143
+ img_path=img_path,
144
+ imgsz=self.args.imgsz,
145
+ batch_size=batch,
146
+ augment=False, # no augmentation
147
+ hyp=self.args,
148
+ rect=False, # no rect
149
+ cache=self.args.cache or None,
150
+ prefix=colorstr(f"{mode}: "),
151
+ data=self.data,
152
+ )
153
+
154
+ def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
155
+ """Scales predictions to the original image size."""
156
+ return predn
157
+
158
+ def postprocess(
159
+ self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
160
+ ) -> list[dict[str, torch.Tensor]]:
161
+ """Apply Non-maximum suppression to prediction outputs.
162
+
163
+ Args:
164
+ preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
165
+ (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
166
+ class scores.
167
+
168
+ Returns:
169
+ (list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
170
+ - 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
171
+ - 'conf': Tensor of shape (N,) with confidence scores
172
+ - 'cls': Tensor of shape (N,) with class indices
173
+ """
174
+ if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
175
+ preds = [preds, None]
176
+
177
+ bs, _, nd = preds[0].shape
178
+ bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
179
+ bboxes *= self.args.imgsz
180
+ outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
181
+ for i, bbox in enumerate(bboxes): # (300, 4)
182
+ bbox = ops.xywh2xyxy(bbox)
183
+ score, cls = scores[i].max(-1) # (300, )
184
+ pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
185
+ # Sort by confidence to correctly get internal metrics
186
+ pred = pred[score.argsort(descending=True)]
187
+ outputs[i] = pred[score > self.args.conf]
188
+
189
+ return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
190
+
191
+ def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
192
+ """Serialize YOLO predictions to COCO json format.
193
+
194
+ Args:
195
+ predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
196
+ bounding box coordinates, confidence scores, and class predictions.
197
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
198
+ """
199
+ path = Path(pbatch["im_file"])
200
+ stem = path.stem
201
+ image_id = int(stem) if stem.isnumeric() else stem
202
+ box = predn["bboxes"].clone()
203
+ box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
204
+ box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
205
+ box = ops.xyxy2xywh(box) # xywh
206
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
207
+ for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
208
+ self.jdict.append(
209
+ {
210
+ "image_id": image_id,
211
+ "file_name": path.name,
212
+ "category_id": self.class_map[int(c)],
213
+ "bbox": [round(x, 3) for x in b],
214
+ "score": round(s, 5),
215
+ }
216
+ )
@@ -0,0 +1,25 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .model import SAM
4
+ from .predict import (
5
+ Predictor,
6
+ SAM2DynamicInteractivePredictor,
7
+ SAM2Predictor,
8
+ SAM2VideoPredictor,
9
+ SAM3Predictor,
10
+ SAM3SemanticPredictor,
11
+ SAM3VideoPredictor,
12
+ SAM3VideoSemanticPredictor,
13
+ )
14
+
15
+ __all__ = (
16
+ "SAM",
17
+ "Predictor",
18
+ "SAM2DynamicInteractivePredictor",
19
+ "SAM2Predictor",
20
+ "SAM2VideoPredictor",
21
+ "SAM3Predictor",
22
+ "SAM3SemanticPredictor",
23
+ "SAM3VideoPredictor",
24
+ "SAM3VideoSemanticPredictor",
25
+ ) # tuple or list of exportable items
@@ -0,0 +1,275 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from collections.abc import Generator
7
+ from itertools import product
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+
14
+ def is_box_near_crop_edge(
15
+ boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
16
+ ) -> torch.Tensor:
17
+ """Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
18
+
19
+ Args:
20
+ boxes (torch.Tensor): Bounding boxes in XYXY format.
21
+ crop_box (list[int]): Crop box coordinates in [x0, y0, x1, y1] format.
22
+ orig_box (list[int]): Original image box coordinates in [x0, y0, x1, y1] format.
23
+ atol (float, optional): Absolute tolerance for edge proximity detection.
24
+
25
+ Returns:
26
+ (torch.Tensor): Boolean tensor indicating which boxes are near crop edges.
27
+
28
+ Examples:
29
+ >>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
30
+ >>> crop_box = [0, 0, 200, 200]
31
+ >>> orig_box = [0, 0, 300, 300]
32
+ >>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
33
+ """
34
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
35
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
36
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
37
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
38
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
39
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
40
+ return torch.any(near_crop_edge, dim=1)
41
+
42
+
43
+ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
44
+ """Yield batches of data from input arguments with specified batch size for efficient processing.
45
+
46
+ This function takes a batch size and any number of iterables, then yields batches of elements from those
47
+ iterables. All input iterables must have the same length.
48
+
49
+ Args:
50
+ batch_size (int): Size of each batch to yield.
51
+ *args (Any): Variable length input iterables to batch. All iterables must have the same length.
52
+
53
+ Yields:
54
+ (list[Any]): A list of batched elements from each input iterable.
55
+
56
+ Examples:
57
+ >>> data = [1, 2, 3, 4, 5]
58
+ >>> labels = ["a", "b", "c", "d", "e"]
59
+ >>> for batch in batch_iterator(2, data, labels):
60
+ ... print(batch)
61
+ [[1, 2], ['a', 'b']]
62
+ [[3, 4], ['c', 'd']]
63
+ [[5], ['e']]
64
+ """
65
+ assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
66
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
67
+ for b in range(n_batches):
68
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
69
+
70
+
71
+ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
72
+ """Compute the stability score for a batch of masks.
73
+
74
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
75
+ low values.
76
+
77
+ Args:
78
+ masks (torch.Tensor): Batch of predicted mask logits.
79
+ mask_threshold (float): Threshold value for creating binary masks.
80
+ threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
81
+
82
+ Returns:
83
+ (torch.Tensor): Stability scores for each mask in the batch.
84
+
85
+ Examples:
86
+ >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
87
+ >>> mask_threshold = 0.5
88
+ >>> threshold_offset = 0.1
89
+ >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
90
+
91
+ Notes:
92
+ - One mask is always contained inside the other.
93
+ - Memory is saved by preventing unnecessary cast to torch.int64.
94
+ """
95
+ intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
96
+ unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
97
+ return intersections / unions
98
+
99
+
100
+ def build_point_grid(n_per_side: int) -> np.ndarray:
101
+ """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
102
+ offset = 1 / (2 * n_per_side)
103
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
104
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
105
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
106
+ return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
107
+
108
+
109
+ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> list[np.ndarray]:
110
+ """Generate point grids for multiple crop layers with varying scales and densities."""
111
+ return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
112
+
113
+
114
+ def generate_crop_boxes(
115
+ im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
116
+ ) -> tuple[list[list[int]], list[int]]:
117
+ """Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
118
+
119
+ Args:
120
+ im_size (tuple[int, ...]): Height and width of the input image.
121
+ n_layers (int): Number of layers to generate crop boxes for.
122
+ overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
123
+
124
+ Returns:
125
+ crop_boxes (list[list[int]]): List of crop boxes in [x0, y0, x1, y1] format.
126
+ layer_idxs (list[int]): List of layer indices corresponding to each crop box.
127
+
128
+ Examples:
129
+ >>> im_size = (800, 1200) # Height, width
130
+ >>> n_layers = 3
131
+ >>> overlap_ratio = 0.25
132
+ >>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
133
+ """
134
+ crop_boxes, layer_idxs = [], []
135
+ im_h, im_w = im_size
136
+ short_side = min(im_h, im_w)
137
+
138
+ # Original image
139
+ crop_boxes.append([0, 0, im_w, im_h])
140
+ layer_idxs.append(0)
141
+
142
+ def crop_len(orig_len, n_crops, overlap):
143
+ """Calculate the length of each crop given the original length, number of crops, and overlap."""
144
+ return math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)
145
+
146
+ for i_layer in range(n_layers):
147
+ n_crops_per_side = 2 ** (i_layer + 1)
148
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
149
+
150
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
151
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
152
+
153
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
154
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
155
+
156
+ # Crops in XYWH format
157
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
158
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
159
+ crop_boxes.append(box)
160
+ layer_idxs.append(i_layer + 1)
161
+
162
+ return crop_boxes, layer_idxs
163
+
164
+
165
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
166
+ """Uncrop bounding boxes by adding the crop box offset to their coordinates."""
167
+ x0, y0, _, _ = crop_box
168
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
169
+ # Check if boxes has a channel dimension
170
+ if len(boxes.shape) == 3:
171
+ offset = offset.unsqueeze(1)
172
+ return boxes + offset
173
+
174
+
175
+ def uncrop_points(points: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
176
+ """Uncrop points by adding the crop box offset to their coordinates."""
177
+ x0, y0, _, _ = crop_box
178
+ offset = torch.tensor([[x0, y0]], device=points.device)
179
+ # Check if points has a channel dimension
180
+ if len(points.shape) == 3:
181
+ offset = offset.unsqueeze(1)
182
+ return points + offset
183
+
184
+
185
+ def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w: int) -> torch.Tensor:
186
+ """Uncrop masks by padding them to the original image size, handling coordinate transformations."""
187
+ x0, y0, x1, y1 = crop_box
188
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
189
+ return masks
190
+ # Coordinate transform masks
191
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
192
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
193
+ return torch.nn.functional.pad(masks, pad, value=0)
194
+
195
+
196
+ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
197
+ """Remove small disconnected regions or holes in a mask based on area threshold and mode.
198
+
199
+ Args:
200
+ mask (np.ndarray): Binary mask to process.
201
+ area_thresh (float): Area threshold below which regions will be removed.
202
+ mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
203
+ regions.
204
+
205
+ Returns:
206
+ processed_mask (np.ndarray): Processed binary mask with small regions removed.
207
+ modified (bool): Whether any regions were modified.
208
+
209
+ Examples:
210
+ >>> mask = np.zeros((100, 100), dtype=np.bool_)
211
+ >>> mask[40:60, 40:60] = True # Create a square
212
+ >>> mask[45:55, 45:55] = False # Create a hole
213
+ >>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
214
+ """
215
+ import cv2 # type: ignore
216
+
217
+ assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
218
+ correct_holes = mode == "holes"
219
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
220
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
221
+ sizes = stats[:, -1][1:] # Row 0 is background label
222
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
223
+ if not small_regions:
224
+ return mask, False
225
+ fill_labels = [0, *small_regions]
226
+ if not correct_holes:
227
+ # If every region is below threshold, keep largest
228
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
229
+ mask = np.isin(regions, fill_labels)
230
+ return mask, True
231
+
232
+
233
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
234
+ """Calculate bounding boxes in XYXY format around binary masks.
235
+
236
+ Args:
237
+ masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
238
+
239
+ Returns:
240
+ (torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).
241
+
242
+ Notes:
243
+ - Handles empty masks by returning zero boxes.
244
+ - Preserves input tensor dimensions in the output.
245
+ """
246
+ # torch.max below raises an error on empty inputs, just skip in this case
247
+ if torch.numel(masks) == 0:
248
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
249
+
250
+ # Normalize shape to CxHxW
251
+ shape = masks.shape
252
+ h, w = shape[-2:]
253
+ masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
254
+ # Get top and bottom edges
255
+ in_height, _ = torch.max(masks, dim=-1)
256
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
257
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
258
+ in_height_coords = in_height_coords + h * (~in_height)
259
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
260
+
261
+ # Get left and right edges
262
+ in_width, _ = torch.max(masks, dim=-2)
263
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
264
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
265
+ in_width_coords = in_width_coords + w * (~in_width)
266
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
267
+
268
+ # If the mask is empty the right edge will be to the left of the left edge.
269
+ # Replace these boxes with [0, 0, 0, 0]
270
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
271
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
272
+ out = out * (~empty_filter).unsqueeze(-1)
273
+
274
+ # Return to original shape
275
+ return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]