dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,451 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from ultralytics.data import build_dataloader, build_yolo_dataset, converter
10
+ from ultralytics.engine.validator import BaseValidator
11
+ from ultralytics.utils import LOGGER, ops
12
+ from ultralytics.utils.checks import check_requirements
13
+ from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
14
+ from ultralytics.utils.plotting import output_to_target, plot_images
15
+
16
+
17
+ class DetectionValidator(BaseValidator):
18
+ """
19
+ A class extending the BaseValidator class for validation based on a detection model.
20
+
21
+ This class implements validation functionality specific to object detection tasks, including metrics calculation,
22
+ prediction processing, and visualization of results.
23
+
24
+ Attributes:
25
+ nt_per_class (np.ndarray): Number of targets per class.
26
+ nt_per_image (np.ndarray): Number of targets per image.
27
+ is_coco (bool): Whether the dataset is COCO.
28
+ is_lvis (bool): Whether the dataset is LVIS.
29
+ class_map (list): Mapping from model class indices to dataset class indices.
30
+ metrics (DetMetrics): Object detection metrics calculator.
31
+ iouv (torch.Tensor): IoU thresholds for mAP calculation.
32
+ niou (int): Number of IoU thresholds.
33
+ lb (list): List for storing ground truth labels for hybrid saving.
34
+ jdict (list): List for storing JSON detection results.
35
+ stats (dict): Dictionary for storing statistics during validation.
36
+
37
+ Examples:
38
+ >>> from ultralytics.models.yolo.detect import DetectionValidator
39
+ >>> args = dict(model="yolo11n.pt", data="coco8.yaml")
40
+ >>> validator = DetectionValidator(args=args)
41
+ >>> validator()
42
+ """
43
+
44
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
45
+ """
46
+ Initialize detection validator with necessary variables and settings.
47
+
48
+ Args:
49
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
50
+ save_dir (Path, optional): Directory to save results.
51
+ pbar (Any, optional): Progress bar for displaying progress.
52
+ args (dict, optional): Arguments for the validator.
53
+ _callbacks (list, optional): List of callback functions.
54
+ """
55
+ super().__init__(dataloader, save_dir, pbar, args, _callbacks)
56
+ self.nt_per_class = None
57
+ self.nt_per_image = None
58
+ self.is_coco = False
59
+ self.is_lvis = False
60
+ self.class_map = None
61
+ self.args.task = "detect"
62
+ self.metrics = DetMetrics(save_dir=self.save_dir)
63
+ self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
64
+ self.niou = self.iouv.numel()
65
+
66
+ def preprocess(self, batch):
67
+ """
68
+ Preprocess batch of images for YOLO validation.
69
+
70
+ Args:
71
+ batch (dict): Batch containing images and annotations.
72
+
73
+ Returns:
74
+ (dict): Preprocessed batch.
75
+ """
76
+ batch["img"] = batch["img"].to(self.device, non_blocking=True)
77
+ batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
78
+ for k in ["batch_idx", "cls", "bboxes"]:
79
+ batch[k] = batch[k].to(self.device)
80
+
81
+ return batch
82
+
83
+ def init_metrics(self, model):
84
+ """
85
+ Initialize evaluation metrics for YOLO detection validation.
86
+
87
+ Args:
88
+ model (torch.nn.Module): Model to validate.
89
+ """
90
+ val = self.data.get(self.args.split, "") # validation path
91
+ self.is_coco = (
92
+ isinstance(val, str)
93
+ and "coco" in val
94
+ and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
95
+ ) # is COCO
96
+ self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
97
+ self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
98
+ self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
99
+ self.names = model.names
100
+ self.nc = len(model.names)
101
+ self.end2end = getattr(model, "end2end", False)
102
+ self.metrics.names = self.names
103
+ self.metrics.plot = self.args.plots
104
+ self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
105
+ self.seen = 0
106
+ self.jdict = []
107
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
108
+
109
+ def get_desc(self):
110
+ """Return a formatted string summarizing class metrics of YOLO model."""
111
+ return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
112
+
113
+ def postprocess(self, preds):
114
+ """
115
+ Apply Non-maximum suppression to prediction outputs.
116
+
117
+ Args:
118
+ preds (torch.Tensor): Raw predictions from the model.
119
+
120
+ Returns:
121
+ (List[torch.Tensor]): Processed predictions after NMS.
122
+ """
123
+ return ops.non_max_suppression(
124
+ preds,
125
+ self.args.conf,
126
+ self.args.iou,
127
+ nc=0 if self.args.task == "detect" else self.nc,
128
+ multi_label=True,
129
+ agnostic=self.args.single_cls or self.args.agnostic_nms,
130
+ max_det=self.args.max_det,
131
+ end2end=self.end2end,
132
+ rotated=self.args.task == "obb",
133
+ )
134
+
135
+ def _prepare_batch(self, si, batch):
136
+ """
137
+ Prepare a batch of images and annotations for validation.
138
+
139
+ Args:
140
+ si (int): Batch index.
141
+ batch (dict): Batch data containing images and annotations.
142
+
143
+ Returns:
144
+ (dict): Prepared batch with processed annotations.
145
+ """
146
+ idx = batch["batch_idx"] == si
147
+ cls = batch["cls"][idx].squeeze(-1)
148
+ bbox = batch["bboxes"][idx]
149
+ ori_shape = batch["ori_shape"][si]
150
+ imgsz = batch["img"].shape[2:]
151
+ ratio_pad = batch["ratio_pad"][si]
152
+ if len(cls):
153
+ bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
154
+ ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
155
+ return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
156
+
157
+ def _prepare_pred(self, pred, pbatch):
158
+ """
159
+ Prepare predictions for evaluation against ground truth.
160
+
161
+ Args:
162
+ pred (torch.Tensor): Model predictions.
163
+ pbatch (dict): Prepared batch information.
164
+
165
+ Returns:
166
+ (torch.Tensor): Prepared predictions in native space.
167
+ """
168
+ predn = pred.clone()
169
+ ops.scale_boxes(
170
+ pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
171
+ ) # native-space pred
172
+ return predn
173
+
174
+ def update_metrics(self, preds, batch):
175
+ """
176
+ Update metrics with new predictions and ground truth.
177
+
178
+ Args:
179
+ preds (List[torch.Tensor]): List of predictions from the model.
180
+ batch (dict): Batch data containing ground truth.
181
+ """
182
+ for si, pred in enumerate(preds):
183
+ self.seen += 1
184
+ npr = len(pred)
185
+ stat = dict(
186
+ conf=torch.zeros(0, device=self.device),
187
+ pred_cls=torch.zeros(0, device=self.device),
188
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
189
+ )
190
+ pbatch = self._prepare_batch(si, batch)
191
+ cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
192
+ nl = len(cls)
193
+ stat["target_cls"] = cls
194
+ stat["target_img"] = cls.unique()
195
+ if npr == 0:
196
+ if nl:
197
+ for k in self.stats.keys():
198
+ self.stats[k].append(stat[k])
199
+ if self.args.plots:
200
+ self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
201
+ continue
202
+
203
+ # Predictions
204
+ if self.args.single_cls:
205
+ pred[:, 5] = 0
206
+ predn = self._prepare_pred(pred, pbatch)
207
+ stat["conf"] = predn[:, 4]
208
+ stat["pred_cls"] = predn[:, 5]
209
+
210
+ # Evaluate
211
+ if nl:
212
+ stat["tp"] = self._process_batch(predn, bbox, cls)
213
+ if self.args.plots:
214
+ self.confusion_matrix.process_batch(predn, bbox, cls)
215
+ for k in self.stats.keys():
216
+ self.stats[k].append(stat[k])
217
+
218
+ # Save
219
+ if self.args.save_json:
220
+ self.pred_to_json(predn, batch["im_file"][si])
221
+ if self.args.save_txt:
222
+ self.save_one_txt(
223
+ predn,
224
+ self.args.save_conf,
225
+ pbatch["ori_shape"],
226
+ self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
227
+ )
228
+
229
+ def finalize_metrics(self, *args, **kwargs):
230
+ """
231
+ Set final values for metrics speed and confusion matrix.
232
+
233
+ Args:
234
+ *args (Any): Variable length argument list.
235
+ **kwargs (Any): Arbitrary keyword arguments.
236
+ """
237
+ self.metrics.speed = self.speed
238
+ self.metrics.confusion_matrix = self.confusion_matrix
239
+
240
+ def get_stats(self):
241
+ """
242
+ Calculate and return metrics statistics.
243
+
244
+ Returns:
245
+ (dict): Dictionary containing metrics results.
246
+ """
247
+ stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
248
+ self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
249
+ self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
250
+ stats.pop("target_img", None)
251
+ if len(stats):
252
+ self.metrics.process(**stats, on_plot=self.on_plot)
253
+ return self.metrics.results_dict
254
+
255
+ def print_results(self):
256
+ """Print training/validation set metrics per class."""
257
+ pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
258
+ LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
259
+ if self.nt_per_class.sum() == 0:
260
+ LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
261
+
262
+ # Print results per class
263
+ if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
264
+ for i, c in enumerate(self.metrics.ap_class_index):
265
+ LOGGER.info(
266
+ pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
267
+ )
268
+
269
+ if self.args.plots:
270
+ for normalize in True, False:
271
+ self.confusion_matrix.plot(
272
+ save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
273
+ )
274
+
275
+ def _process_batch(self, detections, gt_bboxes, gt_cls):
276
+ """
277
+ Return correct prediction matrix.
278
+
279
+ Args:
280
+ detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
281
+ (x1, y1, x2, y2, conf, class).
282
+ gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
283
+ bounding box is of the format: (x1, y1, x2, y2).
284
+ gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.
285
+
286
+ Returns:
287
+ (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
288
+ """
289
+ iou = box_iou(gt_bboxes, detections[:, :4])
290
+ return self.match_predictions(detections[:, 5], gt_cls, iou)
291
+
292
+ def build_dataset(self, img_path, mode="val", batch=None):
293
+ """
294
+ Build YOLO Dataset.
295
+
296
+ Args:
297
+ img_path (str): Path to the folder containing images.
298
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
299
+ batch (int, optional): Size of batches, this is for `rect`.
300
+
301
+ Returns:
302
+ (Dataset): YOLO dataset.
303
+ """
304
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
305
+
306
+ def get_dataloader(self, dataset_path, batch_size):
307
+ """
308
+ Construct and return dataloader.
309
+
310
+ Args:
311
+ dataset_path (str): Path to the dataset.
312
+ batch_size (int): Size of each batch.
313
+
314
+ Returns:
315
+ (torch.utils.data.DataLoader): Dataloader for validation.
316
+ """
317
+ dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
318
+ return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
319
+
320
+ def plot_val_samples(self, batch, ni):
321
+ """
322
+ Plot validation image samples.
323
+
324
+ Args:
325
+ batch (dict): Batch containing images and annotations.
326
+ ni (int): Batch index.
327
+ """
328
+ plot_images(
329
+ batch["img"],
330
+ batch["batch_idx"],
331
+ batch["cls"].squeeze(-1),
332
+ batch["bboxes"],
333
+ paths=batch["im_file"],
334
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
335
+ names=self.names,
336
+ on_plot=self.on_plot,
337
+ )
338
+
339
+ def plot_predictions(self, batch, preds, ni):
340
+ """
341
+ Plot predicted bounding boxes on input images and save the result.
342
+
343
+ Args:
344
+ batch (dict): Batch containing images and annotations.
345
+ preds (List[torch.Tensor]): List of predictions from the model.
346
+ ni (int): Batch index.
347
+ """
348
+ plot_images(
349
+ batch["img"],
350
+ *output_to_target(preds, max_det=self.args.max_det),
351
+ paths=batch["im_file"],
352
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
353
+ names=self.names,
354
+ on_plot=self.on_plot,
355
+ ) # pred
356
+
357
+ def save_one_txt(self, predn, save_conf, shape, file):
358
+ """
359
+ Save YOLO detections to a txt file in normalized coordinates in a specific format.
360
+
361
+ Args:
362
+ predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
363
+ save_conf (bool): Whether to save confidence scores.
364
+ shape (tuple): Shape of the original image.
365
+ file (Path): File path to save the detections.
366
+ """
367
+ from ultralytics.engine.results import Results
368
+
369
+ Results(
370
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
371
+ path=None,
372
+ names=self.names,
373
+ boxes=predn[:, :6],
374
+ ).save_txt(file, save_conf=save_conf)
375
+
376
+ def pred_to_json(self, predn, filename):
377
+ """
378
+ Serialize YOLO predictions to COCO json format.
379
+
380
+ Args:
381
+ predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
382
+ filename (str): Image filename.
383
+ """
384
+ stem = Path(filename).stem
385
+ image_id = int(stem) if stem.isnumeric() else stem
386
+ box = ops.xyxy2xywh(predn[:, :4]) # xywh
387
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
388
+ for p, b in zip(predn.tolist(), box.tolist()):
389
+ self.jdict.append(
390
+ {
391
+ "image_id": image_id,
392
+ "category_id": self.class_map[int(p[5])],
393
+ "bbox": [round(x, 3) for x in b],
394
+ "score": round(p[4], 5),
395
+ }
396
+ )
397
+
398
+ def eval_json(self, stats):
399
+ """
400
+ Evaluate YOLO output in JSON format and return performance statistics.
401
+
402
+ Args:
403
+ stats (dict): Current statistics dictionary.
404
+
405
+ Returns:
406
+ (dict): Updated statistics dictionary with COCO/LVIS evaluation results.
407
+ """
408
+ if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
409
+ pred_json = self.save_dir / "predictions.json" # predictions
410
+ anno_json = (
411
+ self.data["path"]
412
+ / "annotations"
413
+ / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
414
+ ) # annotations
415
+ pkg = "pycocotools" if self.is_coco else "lvis"
416
+ LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
417
+ try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
418
+ for x in pred_json, anno_json:
419
+ assert x.is_file(), f"{x} file not found"
420
+ check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
421
+ if self.is_coco:
422
+ from pycocotools.coco import COCO # noqa
423
+ from pycocotools.cocoeval import COCOeval # noqa
424
+
425
+ anno = COCO(str(anno_json)) # init annotations api
426
+ pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
427
+ val = COCOeval(anno, pred, "bbox")
428
+ else:
429
+ from lvis import LVIS, LVISEval
430
+
431
+ anno = LVIS(str(anno_json)) # init annotations api
432
+ pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
433
+ val = LVISEval(anno, pred, "bbox")
434
+ val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
435
+ val.evaluate()
436
+ val.accumulate()
437
+ val.summarize()
438
+ if self.is_lvis:
439
+ val.print_results() # explicitly call print_results
440
+ # update mAP50-95 and mAP50
441
+ stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
442
+ val.stats[:2] if self.is_coco else [val.results["AP"], val.results["AP50"]]
443
+ )
444
+ if self.is_lvis:
445
+ stats["metrics/APr(B)"] = val.results["APr"]
446
+ stats["metrics/APc(B)"] = val.results["APc"]
447
+ stats["metrics/APf(B)"] = val.results["APf"]
448
+ stats["fitness"] = val.results["AP"]
449
+ except Exception as e:
450
+ LOGGER.warning(f"{pkg} unable to run: {e}")
451
+ return stats