ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,315 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from scipy.optimize import linear_sum_assignment
11
+
12
+ from ultralytics.utils.metrics import bbox_iou
13
+ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
14
+
15
+
16
+ class HungarianMatcher(nn.Module):
17
+ """A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
18
+
19
+ HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
20
+ function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
21
+ used in end-to-end object detection models like DETR.
22
+
23
+ Attributes:
24
+ cost_gain (dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
25
+ components.
26
+ use_fl (bool): Whether to use Focal Loss for classification cost calculation.
27
+ with_mask (bool): Whether the model makes mask predictions.
28
+ num_sample_points (int): Number of sample points used in mask cost calculation.
29
+ alpha (float): Alpha factor in Focal Loss calculation.
30
+ gamma (float): Gamma factor in Focal Loss calculation.
31
+
32
+ Methods:
33
+ forward: Compute optimal assignment between predictions and ground truths for a batch.
34
+ _cost_mask: Compute mask cost and dice cost if masks are predicted.
35
+
36
+ Examples:
37
+ Initialize a HungarianMatcher with custom cost gains
38
+ >>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
39
+
40
+ Perform matching between predictions and ground truth
41
+ >>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
42
+ >>> pred_scores = torch.rand(2, 100, 80) # 80 classes
43
+ >>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
44
+ >>> gt_classes = torch.randint(0, 80, (10,))
45
+ >>> gt_groups = [5, 5] # 5 GT boxes per image
46
+ >>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ cost_gain: dict[str, float] | None = None,
52
+ use_fl: bool = True,
53
+ with_mask: bool = False,
54
+ num_sample_points: int = 12544,
55
+ alpha: float = 0.25,
56
+ gamma: float = 2.0,
57
+ ):
58
+ """Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
59
+
60
+ Args:
61
+ cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
62
+ components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
63
+ use_fl (bool): Whether to use Focal Loss for classification cost calculation.
64
+ with_mask (bool): Whether the model makes mask predictions.
65
+ num_sample_points (int): Number of sample points used in mask cost calculation.
66
+ alpha (float): Alpha factor in Focal Loss calculation.
67
+ gamma (float): Gamma factor in Focal Loss calculation.
68
+ """
69
+ super().__init__()
70
+ if cost_gain is None:
71
+ cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
72
+ self.cost_gain = cost_gain
73
+ self.use_fl = use_fl
74
+ self.with_mask = with_mask
75
+ self.num_sample_points = num_sample_points
76
+ self.alpha = alpha
77
+ self.gamma = gamma
78
+
79
+ def forward(
80
+ self,
81
+ pred_bboxes: torch.Tensor,
82
+ pred_scores: torch.Tensor,
83
+ gt_bboxes: torch.Tensor,
84
+ gt_cls: torch.Tensor,
85
+ gt_groups: list[int],
86
+ masks: torch.Tensor | None = None,
87
+ gt_mask: list[torch.Tensor] | None = None,
88
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
89
+ """Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
90
+
91
+ This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
92
+ mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
93
+
94
+ Args:
95
+ pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
96
+ pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
97
+ num_classes).
98
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
99
+ gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
100
+ gt_groups (list[int]): Number of ground truth boxes for each image in the batch.
101
+ masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
102
+ gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
103
+
104
+ Returns:
105
+ (list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
106
+ index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
107
+ the tensor of indices of the corresponding selected ground truth targets (in order).
108
+ For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
109
+ """
110
+ bs, nq, nc = pred_scores.shape
111
+
112
+ if sum(gt_groups) == 0:
113
+ return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
114
+
115
+ # Flatten to compute cost matrices in batch format
116
+ pred_scores = pred_scores.detach().view(-1, nc)
117
+ pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
118
+ pred_bboxes = pred_bboxes.detach().view(-1, 4)
119
+
120
+ # Compute classification cost
121
+ pred_scores = pred_scores[:, gt_cls]
122
+ if self.use_fl:
123
+ neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
124
+ pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
125
+ cost_class = pos_cost_class - neg_cost_class
126
+ else:
127
+ cost_class = -pred_scores
128
+
129
+ # Compute L1 cost between boxes
130
+ cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
131
+
132
+ # Compute GIoU cost between boxes, (bs*num_queries, num_gt)
133
+ cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
134
+
135
+ # Combine costs into final cost matrix
136
+ C = (
137
+ self.cost_gain["class"] * cost_class
138
+ + self.cost_gain["bbox"] * cost_bbox
139
+ + self.cost_gain["giou"] * cost_giou
140
+ )
141
+
142
+ # Add mask costs if available
143
+ if self.with_mask:
144
+ C += self._cost_mask(bs, gt_groups, masks, gt_mask)
145
+
146
+ # Set invalid values (NaNs and infinities) to 0
147
+ C[C.isnan() | C.isinf()] = 0.0
148
+
149
+ C = C.view(bs, nq, -1).cpu()
150
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
151
+ gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
152
+ return [
153
+ (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
154
+ for k, (i, j) in enumerate(indices)
155
+ ]
156
+
157
+ # This function is for future RT-DETR Segment models
158
+ # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
159
+ # assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
160
+ # # all masks share the same set of points for efficient matching
161
+ # sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
162
+ # sample_points = 2.0 * sample_points - 1.0
163
+ #
164
+ # out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
165
+ # out_mask = out_mask.flatten(0, 1)
166
+ #
167
+ # tgt_mask = torch.cat(gt_mask).unsqueeze(1)
168
+ # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
169
+ # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
170
+ #
171
+ # with torch.amp.autocast("cuda", enabled=False):
172
+ # # binary cross entropy cost
173
+ # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
174
+ # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
175
+ # cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
176
+ # cost_mask /= self.num_sample_points
177
+ #
178
+ # # dice cost
179
+ # out_mask = F.sigmoid(out_mask)
180
+ # numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
181
+ # denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
182
+ # cost_dice = 1 - (numerator + 1) / (denominator + 1)
183
+ #
184
+ # C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
185
+ # return C
186
+
187
+
188
+ def get_cdn_group(
189
+ batch: dict[str, Any],
190
+ num_classes: int,
191
+ num_queries: int,
192
+ class_embed: torch.Tensor,
193
+ num_dn: int = 100,
194
+ cls_noise_ratio: float = 0.5,
195
+ box_noise_scale: float = 1.0,
196
+ training: bool = False,
197
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
198
+ """Generate contrastive denoising training group with positive and negative samples from ground truths.
199
+
200
+ This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
201
+ boxes and class labels. It generates both positive and negative samples to improve model robustness.
202
+
203
+ Args:
204
+ batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
205
+ (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
206
+ per image.
207
+ num_classes (int): Total number of object classes.
208
+ num_queries (int): Number of object queries.
209
+ class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
210
+ num_dn (int): Number of denoising queries to generate.
211
+ cls_noise_ratio (float): Noise ratio for class labels.
212
+ box_noise_scale (float): Noise scale for bounding box coordinates.
213
+ training (bool): Whether model is in training mode.
214
+
215
+ Returns:
216
+ padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
217
+ padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
218
+ attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
219
+ dn_meta (dict[str, Any] | None): Meta information dictionary containing denoising parameters.
220
+
221
+ Examples:
222
+ Generate denoising group for training
223
+ >>> batch = {
224
+ ... "cls": torch.tensor([0, 1, 2]),
225
+ ... "bboxes": torch.rand(3, 4),
226
+ ... "batch_idx": torch.tensor([0, 0, 1]),
227
+ ... "gt_groups": [2, 1],
228
+ ... }
229
+ >>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
230
+ >>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
231
+ """
232
+ if (not training) or num_dn <= 0 or batch is None:
233
+ return None, None, None, None
234
+ gt_groups = batch["gt_groups"]
235
+ total_num = sum(gt_groups)
236
+ max_nums = max(gt_groups)
237
+ if max_nums == 0:
238
+ return None, None, None, None
239
+
240
+ num_group = num_dn // max_nums
241
+ num_group = 1 if num_group == 0 else num_group
242
+ # Pad gt to max_num of a batch
243
+ bs = len(gt_groups)
244
+ gt_cls = batch["cls"] # (bs*num, )
245
+ gt_bbox = batch["bboxes"] # bs*num, 4
246
+ b_idx = batch["batch_idx"]
247
+
248
+ # Each group has positive and negative queries
249
+ dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
250
+ dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
251
+ dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
252
+
253
+ # Positive and negative mask
254
+ # (bs*num*num_group, ), the second total_num*num_group part as negative samples
255
+ neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
256
+
257
+ if cls_noise_ratio > 0:
258
+ # Apply class label noise to half of the samples
259
+ mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
260
+ idx = torch.nonzero(mask).squeeze(-1)
261
+ # Randomly assign new class labels
262
+ new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
263
+ dn_cls[idx] = new_label
264
+
265
+ if box_noise_scale > 0:
266
+ known_bbox = xywh2xyxy(dn_bbox)
267
+
268
+ diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
269
+
270
+ rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
271
+ rand_part = torch.rand_like(dn_bbox)
272
+ rand_part[neg_idx] += 1.0
273
+ rand_part *= rand_sign
274
+ known_bbox += rand_part * diff
275
+ known_bbox.clip_(min=0.0, max=1.0)
276
+ dn_bbox = xyxy2xywh(known_bbox)
277
+ dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
278
+
279
+ num_dn = int(max_nums * 2 * num_group) # total denoising queries
280
+ dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
281
+ padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
282
+ padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
283
+
284
+ map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
285
+ pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
286
+
287
+ map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
288
+ padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
289
+ padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
290
+
291
+ tgt_size = num_dn + num_queries
292
+ attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
293
+ # Match query cannot see the reconstruct
294
+ attn_mask[num_dn:, :num_dn] = True
295
+ # Reconstruct cannot see each other
296
+ for i in range(num_group):
297
+ if i == 0:
298
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
299
+ if i == num_group - 1:
300
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
301
+ else:
302
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
303
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
304
+ dn_meta = {
305
+ "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
306
+ "dn_num_group": num_group,
307
+ "dn_num_split": [num_dn, num_queries],
308
+ }
309
+
310
+ return (
311
+ padding_cls.to(class_embed.device),
312
+ padding_bbox.to(class_embed.device),
313
+ attn_mask.to(class_embed.device),
314
+ dn_meta,
315
+ )
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
4
+
5
+ from .model import YOLO, YOLOE, YOLOWorld
6
+
7
+ __all__ = "YOLO", "YOLOE", "YOLOWorld", "classify", "detect", "obb", "pose", "segment", "world", "yoloe"
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.models.yolo.classify.predict import ClassificationPredictor
4
+ from ultralytics.models.yolo.classify.train import ClassificationTrainer
5
+ from ultralytics.models.yolo.classify.val import ClassificationValidator
6
+
7
+ __all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
@@ -0,0 +1,90 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import cv2
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from ultralytics.data.augment import classify_transforms
8
+ from ultralytics.engine.predictor import BasePredictor
9
+ from ultralytics.engine.results import Results
10
+ from ultralytics.utils import DEFAULT_CFG, ops
11
+
12
+
13
+ class ClassificationPredictor(BasePredictor):
14
+ """A class extending the BasePredictor class for prediction based on a classification model.
15
+
16
+ This predictor handles the specific requirements of classification models, including preprocessing images and
17
+ postprocessing predictions to generate classification results.
18
+
19
+ Attributes:
20
+ args (dict): Configuration arguments for the predictor.
21
+
22
+ Methods:
23
+ preprocess: Convert input images to model-compatible format.
24
+ postprocess: Process model predictions into Results objects.
25
+
26
+ Examples:
27
+ >>> from ultralytics.utils import ASSETS
28
+ >>> from ultralytics.models.yolo.classify import ClassificationPredictor
29
+ >>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
30
+ >>> predictor = ClassificationPredictor(overrides=args)
31
+ >>> predictor.predict_cli()
32
+
33
+ Notes:
34
+ - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
35
+ """
36
+
37
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
38
+ """Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
39
+
40
+ This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
41
+ tasks. It ensures the task is set to 'classify' regardless of input configuration.
42
+
43
+ Args:
44
+ cfg (dict): Default configuration dictionary containing prediction settings.
45
+ overrides (dict, optional): Configuration overrides that take precedence over cfg.
46
+ _callbacks (list, optional): List of callback functions to be executed during prediction.
47
+ """
48
+ super().__init__(cfg, overrides, _callbacks)
49
+ self.args.task = "classify"
50
+
51
+ def setup_source(self, source):
52
+ """Set up source and inference mode and classify transforms."""
53
+ super().setup_source(source)
54
+ updated = (
55
+ self.model.model.transforms.transforms[0].size != max(self.imgsz)
56
+ if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
57
+ else False
58
+ )
59
+ self.transforms = (
60
+ classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
61
+ )
62
+
63
+ def preprocess(self, img):
64
+ """Convert input images to model-compatible tensor format with appropriate normalization."""
65
+ if not isinstance(img, torch.Tensor):
66
+ img = torch.stack(
67
+ [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
68
+ )
69
+ img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
70
+ return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
71
+
72
+ def postprocess(self, preds, img, orig_imgs):
73
+ """Process predictions to return Results objects with classification probabilities.
74
+
75
+ Args:
76
+ preds (torch.Tensor): Raw predictions from the model.
77
+ img (torch.Tensor): Input images after preprocessing.
78
+ orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
79
+
80
+ Returns:
81
+ (list[Results]): List of Results objects containing classification results for each image.
82
+ """
83
+ if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
84
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
85
+
86
+ preds = preds[0] if isinstance(preds, (list, tuple)) else preds
87
+ return [
88
+ Results(orig_img, path=img_path, names=self.model.names, probs=pred)
89
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
90
+ ]
@@ -0,0 +1,202 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import copy
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from ultralytics.data import ClassificationDataset, build_dataloader
11
+ from ultralytics.engine.trainer import BaseTrainer
12
+ from ultralytics.models import yolo
13
+ from ultralytics.nn.tasks import ClassificationModel
14
+ from ultralytics.utils import DEFAULT_CFG, RANK
15
+ from ultralytics.utils.plotting import plot_images
16
+ from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_first
17
+
18
+
19
+ class ClassificationTrainer(BaseTrainer):
20
+ """A trainer class extending BaseTrainer for training image classification models.
21
+
22
+ This trainer handles the training process for image classification tasks, supporting both YOLO classification models
23
+ and torchvision models with comprehensive dataset handling and validation.
24
+
25
+ Attributes:
26
+ model (ClassificationModel): The classification model to be trained.
27
+ data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
28
+ loss_names (list[str]): Names of the loss functions used during training.
29
+ validator (ClassificationValidator): Validator instance for model evaluation.
30
+
31
+ Methods:
32
+ set_model_attributes: Set the model's class names from the loaded dataset.
33
+ get_model: Return a modified PyTorch model configured for training.
34
+ setup_model: Load, create or download model for classification.
35
+ build_dataset: Create a ClassificationDataset instance.
36
+ get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
37
+ preprocess_batch: Preprocess a batch of images and classes.
38
+ progress_string: Return a formatted string showing training progress.
39
+ get_validator: Return an instance of ClassificationValidator.
40
+ label_loss_items: Return a loss dict with labeled training loss items.
41
+ final_eval: Evaluate trained model and save validation results.
42
+ plot_training_samples: Plot training samples with their annotations.
43
+
44
+ Examples:
45
+ Initialize and train a classification model
46
+ >>> from ultralytics.models.yolo.classify import ClassificationTrainer
47
+ >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
48
+ >>> trainer = ClassificationTrainer(overrides=args)
49
+ >>> trainer.train()
50
+ """
51
+
52
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
53
+ """Initialize a ClassificationTrainer object.
54
+
55
+ Args:
56
+ cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
57
+ overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
58
+ _callbacks (list[Any], optional): List of callback functions to be executed during training.
59
+ """
60
+ if overrides is None:
61
+ overrides = {}
62
+ overrides["task"] = "classify"
63
+ if overrides.get("imgsz") is None:
64
+ overrides["imgsz"] = 224
65
+ super().__init__(cfg, overrides, _callbacks)
66
+
67
+ def set_model_attributes(self):
68
+ """Set the YOLO model's class names from the loaded dataset."""
69
+ self.model.names = self.data["names"]
70
+
71
+ def get_model(self, cfg=None, weights=None, verbose: bool = True):
72
+ """Return a modified PyTorch model configured for training YOLO classification.
73
+
74
+ Args:
75
+ cfg (Any, optional): Model configuration.
76
+ weights (Any, optional): Pre-trained model weights.
77
+ verbose (bool, optional): Whether to display model information.
78
+
79
+ Returns:
80
+ (ClassificationModel): Configured PyTorch model for classification.
81
+ """
82
+ model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
83
+ if weights:
84
+ model.load(weights)
85
+
86
+ for m in model.modules():
87
+ if not self.args.pretrained and hasattr(m, "reset_parameters"):
88
+ m.reset_parameters()
89
+ if isinstance(m, torch.nn.Dropout) and self.args.dropout:
90
+ m.p = self.args.dropout # set dropout
91
+ for p in model.parameters():
92
+ p.requires_grad = True # for training
93
+ return model
94
+
95
+ def setup_model(self):
96
+ """Load, create or download model for classification tasks.
97
+
98
+ Returns:
99
+ (Any): Model checkpoint if applicable, otherwise None.
100
+ """
101
+ import torchvision # scope for faster 'import ultralytics'
102
+
103
+ if str(self.model) in torchvision.models.__dict__:
104
+ self.model = torchvision.models.__dict__[self.model](
105
+ weights="IMAGENET1K_V1" if self.args.pretrained else None
106
+ )
107
+ ckpt = None
108
+ else:
109
+ ckpt = super().setup_model()
110
+ ClassificationModel.reshape_outputs(self.model, self.data["nc"])
111
+ return ckpt
112
+
113
+ def build_dataset(self, img_path: str, mode: str = "train", batch=None):
114
+ """Create a ClassificationDataset instance given an image path and mode.
115
+
116
+ Args:
117
+ img_path (str): Path to the dataset images.
118
+ mode (str, optional): Dataset mode ('train', 'val', or 'test').
119
+ batch (Any, optional): Batch information (unused in this implementation).
120
+
121
+ Returns:
122
+ (ClassificationDataset): Dataset for the specified mode.
123
+ """
124
+ return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
125
+
126
+ def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
127
+ """Return PyTorch DataLoader with transforms to preprocess images.
128
+
129
+ Args:
130
+ dataset_path (str): Path to the dataset.
131
+ batch_size (int, optional): Number of images per batch.
132
+ rank (int, optional): Process rank for distributed training.
133
+ mode (str, optional): 'train', 'val', or 'test' mode.
134
+
135
+ Returns:
136
+ (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
137
+ """
138
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
139
+ dataset = self.build_dataset(dataset_path, mode)
140
+
141
+ loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
142
+ # Attach inference transforms
143
+ if mode != "train":
144
+ if is_parallel(self.model):
145
+ self.model.module.transforms = loader.dataset.torch_transforms
146
+ else:
147
+ self.model.transforms = loader.dataset.torch_transforms
148
+ return loader
149
+
150
+ def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
151
+ """Preprocess a batch of images and classes."""
152
+ batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
153
+ batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
154
+ return batch
155
+
156
+ def progress_string(self) -> str:
157
+ """Return a formatted string showing training progress."""
158
+ return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
159
+ "Epoch",
160
+ "GPU_mem",
161
+ *self.loss_names,
162
+ "Instances",
163
+ "Size",
164
+ )
165
+
166
+ def get_validator(self):
167
+ """Return an instance of ClassificationValidator for validation."""
168
+ self.loss_names = ["loss"]
169
+ return yolo.classify.ClassificationValidator(
170
+ self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
171
+ )
172
+
173
+ def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
174
+ """Return a loss dict with labeled training loss items tensor.
175
+
176
+ Args:
177
+ loss_items (torch.Tensor, optional): Loss tensor items.
178
+ prefix (str, optional): Prefix to prepend to loss names.
179
+
180
+ Returns:
181
+ keys (list[str]): List of loss keys if loss_items is None.
182
+ loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
183
+ """
184
+ keys = [f"{prefix}/{x}" for x in self.loss_names]
185
+ if loss_items is None:
186
+ return keys
187
+ loss_items = [round(float(loss_items), 5)]
188
+ return dict(zip(keys, loss_items))
189
+
190
+ def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
191
+ """Plot training samples with their annotations.
192
+
193
+ Args:
194
+ batch (dict[str, torch.Tensor]): Batch containing images and class labels.
195
+ ni (int): Number of iterations.
196
+ """
197
+ batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
198
+ plot_images(
199
+ labels=batch,
200
+ fname=self.save_dir / f"train_batch{ni}.jpg",
201
+ on_plot=self.on_plot,
202
+ )