ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,216 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+
11
+ from ultralytics.data import ClassificationDataset, build_dataloader
12
+ from ultralytics.engine.validator import BaseValidator
13
+ from ultralytics.utils import LOGGER, RANK
14
+ from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
15
+ from ultralytics.utils.plotting import plot_images
16
+
17
+
18
+ class ClassificationValidator(BaseValidator):
19
+ """A class extending the BaseValidator class for validation based on a classification model.
20
+
21
+ This validator handles the validation process for classification models, including metrics calculation, confusion
22
+ matrix generation, and visualization of results.
23
+
24
+ Attributes:
25
+ targets (list[torch.Tensor]): Ground truth class labels.
26
+ pred (list[torch.Tensor]): Model predictions.
27
+ metrics (ClassifyMetrics): Object to calculate and store classification metrics.
28
+ names (dict): Mapping of class indices to class names.
29
+ nc (int): Number of classes.
30
+ confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
31
+
32
+ Methods:
33
+ get_desc: Return a formatted string summarizing classification metrics.
34
+ init_metrics: Initialize confusion matrix, class names, and tracking containers.
35
+ preprocess: Preprocess input batch by moving data to device.
36
+ update_metrics: Update running metrics with model predictions and batch targets.
37
+ finalize_metrics: Finalize metrics including confusion matrix and processing speed.
38
+ postprocess: Extract the primary prediction from model output.
39
+ get_stats: Calculate and return a dictionary of metrics.
40
+ build_dataset: Create a ClassificationDataset instance for validation.
41
+ get_dataloader: Build and return a data loader for classification validation.
42
+ print_results: Print evaluation metrics for the classification model.
43
+ plot_val_samples: Plot validation image samples with their ground truth labels.
44
+ plot_predictions: Plot images with their predicted class labels.
45
+
46
+ Examples:
47
+ >>> from ultralytics.models.yolo.classify import ClassificationValidator
48
+ >>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
49
+ >>> validator = ClassificationValidator(args=args)
50
+ >>> validator()
51
+
52
+ Notes:
53
+ Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
54
+ """
55
+
56
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
57
+ """Initialize ClassificationValidator with dataloader, save directory, and other parameters.
58
+
59
+ Args:
60
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
61
+ save_dir (str | Path, optional): Directory to save results.
62
+ args (dict, optional): Arguments containing model and validation configuration.
63
+ _callbacks (list, optional): List of callback functions to be called during validation.
64
+ """
65
+ super().__init__(dataloader, save_dir, args, _callbacks)
66
+ self.targets = None
67
+ self.pred = None
68
+ self.args.task = "classify"
69
+ self.metrics = ClassifyMetrics()
70
+
71
+ def get_desc(self) -> str:
72
+ """Return a formatted string summarizing classification metrics."""
73
+ return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
74
+
75
+ def init_metrics(self, model: torch.nn.Module) -> None:
76
+ """Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
77
+ self.names = model.names
78
+ self.nc = len(model.names)
79
+ self.pred = []
80
+ self.targets = []
81
+ self.confusion_matrix = ConfusionMatrix(names=model.names)
82
+
83
+ def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
84
+ """Preprocess input batch by moving data to device and converting to appropriate dtype."""
85
+ batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
86
+ batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
87
+ batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
88
+ return batch
89
+
90
+ def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
91
+ """Update running metrics with model predictions and batch targets.
92
+
93
+ Args:
94
+ preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
95
+ batch (dict): Batch data containing images and class labels.
96
+
97
+ Notes:
98
+ This method appends the top-N predictions (sorted by confidence in descending order) to the
99
+ prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
100
+ """
101
+ n5 = min(len(self.names), 5)
102
+ self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
103
+ self.targets.append(batch["cls"].type(torch.int32).cpu())
104
+
105
+ def finalize_metrics(self) -> None:
106
+ """Finalize metrics including confusion matrix and processing speed.
107
+
108
+ Examples:
109
+ >>> validator = ClassificationValidator()
110
+ >>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
111
+ >>> validator.targets = [torch.tensor([0])] # Ground truth class
112
+ >>> validator.finalize_metrics()
113
+ >>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
114
+
115
+ Notes:
116
+ This method processes the accumulated predictions and targets to generate the confusion matrix,
117
+ optionally plots it, and updates the metrics object with speed information.
118
+ """
119
+ self.confusion_matrix.process_cls_preds(self.pred, self.targets)
120
+ if self.args.plots:
121
+ for normalize in True, False:
122
+ self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
123
+ self.metrics.speed = self.speed
124
+ self.metrics.save_dir = self.save_dir
125
+ self.metrics.confusion_matrix = self.confusion_matrix
126
+
127
+ def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
128
+ """Extract the primary prediction from model output if it's in a list or tuple format."""
129
+ return preds[0] if isinstance(preds, (list, tuple)) else preds
130
+
131
+ def get_stats(self) -> dict[str, float]:
132
+ """Calculate and return a dictionary of metrics by processing targets and predictions."""
133
+ self.metrics.process(self.targets, self.pred)
134
+ return self.metrics.results_dict
135
+
136
+ def gather_stats(self) -> None:
137
+ """Gather stats from all GPUs."""
138
+ if RANK == 0:
139
+ gathered_preds = [None] * dist.get_world_size()
140
+ gathered_targets = [None] * dist.get_world_size()
141
+ dist.gather_object(self.pred, gathered_preds, dst=0)
142
+ dist.gather_object(self.targets, gathered_targets, dst=0)
143
+ self.pred = [pred for rank in gathered_preds for pred in rank]
144
+ self.targets = [targets for rank in gathered_targets for targets in rank]
145
+ elif RANK > 0:
146
+ dist.gather_object(self.pred, None, dst=0)
147
+ dist.gather_object(self.targets, None, dst=0)
148
+
149
+ def build_dataset(self, img_path: str) -> ClassificationDataset:
150
+ """Create a ClassificationDataset instance for validation."""
151
+ return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
152
+
153
+ def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
154
+ """Build and return a data loader for classification validation.
155
+
156
+ Args:
157
+ dataset_path (str | Path): Path to the dataset directory.
158
+ batch_size (int): Number of samples per batch.
159
+
160
+ Returns:
161
+ (torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
162
+ """
163
+ dataset = self.build_dataset(dataset_path)
164
+ return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
165
+
166
+ def print_results(self) -> None:
167
+ """Print evaluation metrics for the classification model."""
168
+ pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
169
+ LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
170
+
171
+ def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
172
+ """Plot validation image samples with their ground truth labels.
173
+
174
+ Args:
175
+ batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
176
+ ni (int): Batch index used for naming the output file.
177
+
178
+ Examples:
179
+ >>> validator = ClassificationValidator()
180
+ >>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
181
+ >>> validator.plot_val_samples(batch, 0)
182
+ """
183
+ batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
184
+ plot_images(
185
+ labels=batch,
186
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
187
+ names=self.names,
188
+ on_plot=self.on_plot,
189
+ )
190
+
191
+ def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
192
+ """Plot images with their predicted class labels and save the visualization.
193
+
194
+ Args:
195
+ batch (dict[str, Any]): Batch data containing images and other information.
196
+ preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
197
+ ni (int): Batch index used for naming the output file.
198
+
199
+ Examples:
200
+ >>> validator = ClassificationValidator()
201
+ >>> batch = {"img": torch.rand(16, 3, 224, 224)}
202
+ >>> preds = torch.rand(16, 10) # 16 images, 10 classes
203
+ >>> validator.plot_predictions(batch, preds, 0)
204
+ """
205
+ batched_preds = dict(
206
+ img=batch["img"],
207
+ batch_idx=torch.arange(batch["img"].shape[0]),
208
+ cls=torch.argmax(preds, dim=1),
209
+ conf=torch.amax(preds, dim=1),
210
+ )
211
+ plot_images(
212
+ batched_preds,
213
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
214
+ names=self.names,
215
+ on_plot=self.on_plot,
216
+ ) # pred
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .predict import DetectionPredictor
4
+ from .train import DetectionTrainer
5
+ from .val import DetectionValidator
6
+
7
+ __all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
@@ -0,0 +1,122 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.engine.predictor import BasePredictor
4
+ from ultralytics.engine.results import Results
5
+ from ultralytics.utils import nms, ops
6
+
7
+
8
+ class DetectionPredictor(BasePredictor):
9
+ """A class extending the BasePredictor class for prediction based on a detection model.
10
+
11
+ This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
12
+ with bounding boxes and class predictions.
13
+
14
+ Attributes:
15
+ args (namespace): Configuration arguments for the predictor.
16
+ model (nn.Module): The detection model used for inference.
17
+ batch (list): Batch of images and metadata for processing.
18
+
19
+ Methods:
20
+ postprocess: Process raw model predictions into detection results.
21
+ construct_results: Build Results objects from processed predictions.
22
+ construct_result: Create a single Result object from a prediction.
23
+ get_obj_feats: Extract object features from the feature maps.
24
+
25
+ Examples:
26
+ >>> from ultralytics.utils import ASSETS
27
+ >>> from ultralytics.models.yolo.detect import DetectionPredictor
28
+ >>> args = dict(model="yolo11n.pt", source=ASSETS)
29
+ >>> predictor = DetectionPredictor(overrides=args)
30
+ >>> predictor.predict_cli()
31
+ """
32
+
33
+ def postprocess(self, preds, img, orig_imgs, **kwargs):
34
+ """Post-process predictions and return a list of Results objects.
35
+
36
+ This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
37
+ further analysis.
38
+
39
+ Args:
40
+ preds (torch.Tensor): Raw predictions from the model.
41
+ img (torch.Tensor): Processed input image tensor in model input format.
42
+ orig_imgs (torch.Tensor | list): Original input images before preprocessing.
43
+ **kwargs (Any): Additional keyword arguments.
44
+
45
+ Returns:
46
+ (list): List of Results objects containing the post-processed predictions.
47
+
48
+ Examples:
49
+ >>> predictor = DetectionPredictor(overrides=dict(model="yolo11n.pt"))
50
+ >>> results = predictor.predict("path/to/image.jpg")
51
+ >>> processed_results = predictor.postprocess(preds, img, orig_imgs)
52
+ """
53
+ save_feats = getattr(self, "_feats", None) is not None
54
+ preds = nms.non_max_suppression(
55
+ preds,
56
+ self.args.conf,
57
+ self.args.iou,
58
+ self.args.classes,
59
+ self.args.agnostic_nms,
60
+ max_det=self.args.max_det,
61
+ nc=0 if self.args.task == "detect" else len(self.model.names),
62
+ end2end=getattr(self.model, "end2end", False),
63
+ rotated=self.args.task == "obb",
64
+ return_idxs=save_feats,
65
+ )
66
+
67
+ if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
68
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
69
+
70
+ if save_feats:
71
+ obj_feats = self.get_obj_feats(self._feats, preds[1])
72
+ preds = preds[0]
73
+
74
+ results = self.construct_results(preds, img, orig_imgs, **kwargs)
75
+
76
+ if save_feats:
77
+ for r, f in zip(results, obj_feats):
78
+ r.feats = f # add object features to results
79
+
80
+ return results
81
+
82
+ @staticmethod
83
+ def get_obj_feats(feat_maps, idxs):
84
+ """Extract object features from the feature maps."""
85
+ import torch
86
+
87
+ s = min(x.shape[1] for x in feat_maps) # find shortest vector length
88
+ obj_feats = torch.cat(
89
+ [x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
90
+ ) # mean reduce all vectors to same length
91
+ return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
92
+
93
+ def construct_results(self, preds, img, orig_imgs):
94
+ """Construct a list of Results objects from model predictions.
95
+
96
+ Args:
97
+ preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
98
+ img (torch.Tensor): Batch of preprocessed images used for inference.
99
+ orig_imgs (list[np.ndarray]): List of original images before preprocessing.
100
+
101
+ Returns:
102
+ (list[Results]): List of Results objects containing detection information for each image.
103
+ """
104
+ return [
105
+ self.construct_result(pred, img, orig_img, img_path)
106
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
107
+ ]
108
+
109
+ def construct_result(self, pred, img, orig_img, img_path):
110
+ """Construct a single Results object from one image prediction.
111
+
112
+ Args:
113
+ pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
114
+ img (torch.Tensor): Preprocessed image tensor used for inference.
115
+ orig_img (np.ndarray): Original image before preprocessing.
116
+ img_path (str): Path to the original image file.
117
+
118
+ Returns:
119
+ (Results): Results object containing the original image, image path, class names, and scaled bounding boxes.
120
+ """
121
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
122
+ return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
@@ -0,0 +1,227 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import random
7
+ from copy import copy
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from ultralytics.data import build_dataloader, build_yolo_dataset
15
+ from ultralytics.engine.trainer import BaseTrainer
16
+ from ultralytics.models import yolo
17
+ from ultralytics.nn.tasks import DetectionModel
18
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
19
+ from ultralytics.utils.patches import override_configs
20
+ from ultralytics.utils.plotting import plot_images, plot_labels
21
+ from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
22
+
23
+
24
+ class DetectionTrainer(BaseTrainer):
25
+ """A class extending the BaseTrainer class for training based on a detection model.
26
+
27
+ This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
28
+ object detection including dataset building, data loading, preprocessing, and model configuration.
29
+
30
+ Attributes:
31
+ model (DetectionModel): The YOLO detection model being trained.
32
+ data (dict): Dictionary containing dataset information including class names and number of classes.
33
+ loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
34
+
35
+ Methods:
36
+ build_dataset: Build YOLO dataset for training or validation.
37
+ get_dataloader: Construct and return dataloader for the specified mode.
38
+ preprocess_batch: Preprocess a batch of images by scaling and converting to float.
39
+ set_model_attributes: Set model attributes based on dataset information.
40
+ get_model: Return a YOLO detection model.
41
+ get_validator: Return a validator for model evaluation.
42
+ label_loss_items: Return a loss dictionary with labeled training loss items.
43
+ progress_string: Return a formatted string of training progress.
44
+ plot_training_samples: Plot training samples with their annotations.
45
+ plot_training_labels: Create a labeled training plot of the YOLO model.
46
+ auto_batch: Calculate optimal batch size based on model memory requirements.
47
+
48
+ Examples:
49
+ >>> from ultralytics.models.yolo.detect import DetectionTrainer
50
+ >>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
51
+ >>> trainer = DetectionTrainer(overrides=args)
52
+ >>> trainer.train()
53
+ """
54
+
55
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
56
+ """Initialize a DetectionTrainer object for training YOLO object detection models.
57
+
58
+ Args:
59
+ cfg (dict, optional): Default configuration dictionary containing training parameters.
60
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
61
+ _callbacks (list, optional): List of callback functions to be executed during training.
62
+ """
63
+ super().__init__(cfg, overrides, _callbacks)
64
+
65
+ def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
66
+ """Build YOLO Dataset for training or validation.
67
+
68
+ Args:
69
+ img_path (str): Path to the folder containing images.
70
+ mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
71
+ batch (int, optional): Size of batches, this is for 'rect' mode.
72
+
73
+ Returns:
74
+ (Dataset): YOLO dataset object configured for the specified mode.
75
+ """
76
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
77
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
78
+
79
+ def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
80
+ """Construct and return dataloader for the specified mode.
81
+
82
+ Args:
83
+ dataset_path (str): Path to the dataset.
84
+ batch_size (int): Number of images per batch.
85
+ rank (int): Process rank for distributed training.
86
+ mode (str): 'train' for training dataloader, 'val' for validation dataloader.
87
+
88
+ Returns:
89
+ (DataLoader): PyTorch dataloader object.
90
+ """
91
+ assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
92
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
93
+ dataset = self.build_dataset(dataset_path, mode, batch_size)
94
+ shuffle = mode == "train"
95
+ if getattr(dataset, "rect", False) and shuffle:
96
+ LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
97
+ shuffle = False
98
+ return build_dataloader(
99
+ dataset,
100
+ batch=batch_size,
101
+ workers=self.args.workers if mode == "train" else self.args.workers * 2,
102
+ shuffle=shuffle,
103
+ rank=rank,
104
+ drop_last=self.args.compile and mode == "train",
105
+ )
106
+
107
+ def preprocess_batch(self, batch: dict) -> dict:
108
+ """Preprocess a batch of images by scaling and converting to float.
109
+
110
+ Args:
111
+ batch (dict): Dictionary containing batch data with 'img' tensor.
112
+
113
+ Returns:
114
+ (dict): Preprocessed batch with normalized images.
115
+ """
116
+ for k, v in batch.items():
117
+ if isinstance(v, torch.Tensor):
118
+ batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
119
+ batch["img"] = batch["img"].float() / 255
120
+ if self.args.multi_scale:
121
+ imgs = batch["img"]
122
+ sz = (
123
+ random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
124
+ // self.stride
125
+ * self.stride
126
+ ) # size
127
+ sf = sz / max(imgs.shape[2:]) # scale factor
128
+ if sf != 1:
129
+ ns = [
130
+ math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
131
+ ] # new shape (stretched to gs-multiple)
132
+ imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
133
+ batch["img"] = imgs
134
+ return batch
135
+
136
+ def set_model_attributes(self):
137
+ """Set model attributes based on dataset information."""
138
+ # Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
139
+ # self.args.box *= 3 / nl # scale to layers
140
+ # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
141
+ # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
142
+ self.model.nc = self.data["nc"] # attach number of classes to model
143
+ self.model.names = self.data["names"] # attach class names to model
144
+ self.model.args = self.args # attach hyperparameters to model
145
+ # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
146
+
147
+ def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
148
+ """Return a YOLO detection model.
149
+
150
+ Args:
151
+ cfg (str, optional): Path to model configuration file.
152
+ weights (str, optional): Path to model weights.
153
+ verbose (bool): Whether to display model information.
154
+
155
+ Returns:
156
+ (DetectionModel): YOLO detection model.
157
+ """
158
+ model = DetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
159
+ if weights:
160
+ model.load(weights)
161
+ return model
162
+
163
+ def get_validator(self):
164
+ """Return a DetectionValidator for YOLO model validation."""
165
+ self.loss_names = "box_loss", "cls_loss", "dfl_loss"
166
+ return yolo.detect.DetectionValidator(
167
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
168
+ )
169
+
170
+ def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
171
+ """Return a loss dict with labeled training loss items tensor.
172
+
173
+ Args:
174
+ loss_items (list[float], optional): List of loss values.
175
+ prefix (str): Prefix for keys in the returned dictionary.
176
+
177
+ Returns:
178
+ (dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
179
+ """
180
+ keys = [f"{prefix}/{x}" for x in self.loss_names]
181
+ if loss_items is not None:
182
+ loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
183
+ return dict(zip(keys, loss_items))
184
+ else:
185
+ return keys
186
+
187
+ def progress_string(self):
188
+ """Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
189
+ return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
190
+ "Epoch",
191
+ "GPU_mem",
192
+ *self.loss_names,
193
+ "Instances",
194
+ "Size",
195
+ )
196
+
197
+ def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
198
+ """Plot training samples with their annotations.
199
+
200
+ Args:
201
+ batch (dict[str, Any]): Dictionary containing batch data.
202
+ ni (int): Number of iterations.
203
+ """
204
+ plot_images(
205
+ labels=batch,
206
+ paths=batch["im_file"],
207
+ fname=self.save_dir / f"train_batch{ni}.jpg",
208
+ on_plot=self.on_plot,
209
+ )
210
+
211
+ def plot_training_labels(self):
212
+ """Create a labeled training plot of the YOLO model."""
213
+ boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
214
+ cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
215
+ plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
216
+
217
+ def auto_batch(self):
218
+ """Get optimal batch size by calculating memory occupation of model.
219
+
220
+ Returns:
221
+ (int): Optimal batch size.
222
+ """
223
+ with override_configs(self.args, overrides={"cache": False}) as self.args:
224
+ train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
225
+ max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
226
+ del train_dataset # free memory
227
+ return super().auto_batch(max_num_obj)