dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,268 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from scipy.optimize import linear_sum_assignment
7
+
8
+ from ultralytics.utils.metrics import bbox_iou
9
+ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
10
+
11
+
12
+ class HungarianMatcher(nn.Module):
13
+ """
14
+ A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
15
+ end-to-end fashion.
16
+
17
+ HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
18
+ function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
19
+
20
+ Attributes:
21
+ cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
22
+ use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
23
+ with_mask (bool): Indicates whether the model makes mask predictions.
24
+ num_sample_points (int): The number of sample points used in mask cost calculation.
25
+ alpha (float): The alpha factor in Focal Loss calculation.
26
+ gamma (float): The gamma factor in Focal Loss calculation.
27
+
28
+ Methods:
29
+ forward: Computes the assignment between predictions and ground truths for a batch.
30
+ _cost_mask: Computes the mask cost and dice cost if masks are predicted.
31
+ """
32
+
33
+ def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
34
+ """
35
+ Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes.
36
+
37
+ The HungarianMatcher uses a cost function that considers classification scores, bounding box coordinates,
38
+ and optionally mask predictions to perform optimal bipartite matching between predictions and ground truths.
39
+
40
+ Args:
41
+ cost_gain (dict, optional): Dictionary of cost coefficients for different components of the matching cost.
42
+ Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
43
+ use_fl (bool, optional): Whether to use Focal Loss for the classification cost calculation.
44
+ with_mask (bool, optional): Whether the model makes mask predictions.
45
+ num_sample_points (int, optional): Number of sample points used in mask cost calculation.
46
+ alpha (float, optional): Alpha factor in Focal Loss calculation.
47
+ gamma (float, optional): Gamma factor in Focal Loss calculation.
48
+ """
49
+ super().__init__()
50
+ if cost_gain is None:
51
+ cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
52
+ self.cost_gain = cost_gain
53
+ self.use_fl = use_fl
54
+ self.with_mask = with_mask
55
+ self.num_sample_points = num_sample_points
56
+ self.alpha = alpha
57
+ self.gamma = gamma
58
+
59
+ def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
60
+ """
61
+ Forward pass for HungarianMatcher. Computes costs based on prediction and ground truth and finds the optimal
62
+ matching between predictions and ground truth based on these costs.
63
+
64
+ Args:
65
+ pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
66
+ pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries, num_classes).
67
+ gt_cls (torch.Tensor): Ground truth classes with shape (num_gts, ).
68
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
69
+ gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
70
+ each image.
71
+ masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
72
+ gt_mask (List[torch.Tensor], optional): List of ground truth masks, each with shape (num_masks, Height, Width).
73
+
74
+ Returns:
75
+ (List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
76
+ - index_i is the tensor of indices of the selected predictions (in order)
77
+ - index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
78
+ For each batch element, it holds:
79
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
80
+ """
81
+ bs, nq, nc = pred_scores.shape
82
+
83
+ if sum(gt_groups) == 0:
84
+ return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
85
+
86
+ # We flatten to compute the cost matrices in a batch
87
+ # (batch_size * num_queries, num_classes)
88
+ pred_scores = pred_scores.detach().view(-1, nc)
89
+ pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
90
+ # (batch_size * num_queries, 4)
91
+ pred_bboxes = pred_bboxes.detach().view(-1, 4)
92
+
93
+ # Compute the classification cost
94
+ pred_scores = pred_scores[:, gt_cls]
95
+ if self.use_fl:
96
+ neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
97
+ pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
98
+ cost_class = pos_cost_class - neg_cost_class
99
+ else:
100
+ cost_class = -pred_scores
101
+
102
+ # Compute the L1 cost between boxes
103
+ cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
104
+
105
+ # Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
106
+ cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
107
+
108
+ # Final cost matrix
109
+ C = (
110
+ self.cost_gain["class"] * cost_class
111
+ + self.cost_gain["bbox"] * cost_bbox
112
+ + self.cost_gain["giou"] * cost_giou
113
+ )
114
+ # Compute the mask cost and dice cost
115
+ if self.with_mask:
116
+ C += self._cost_mask(bs, gt_groups, masks, gt_mask)
117
+
118
+ # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries)
119
+ C[C.isnan() | C.isinf()] = 0.0
120
+
121
+ C = C.view(bs, nq, -1).cpu()
122
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
123
+ gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
124
+ return [
125
+ (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
126
+ for k, (i, j) in enumerate(indices)
127
+ ]
128
+
129
+ # This function is for future RT-DETR Segment models
130
+ # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
131
+ # assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
132
+ # # all masks share the same set of points for efficient matching
133
+ # sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
134
+ # sample_points = 2.0 * sample_points - 1.0
135
+ #
136
+ # out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
137
+ # out_mask = out_mask.flatten(0, 1)
138
+ #
139
+ # tgt_mask = torch.cat(gt_mask).unsqueeze(1)
140
+ # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
141
+ # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
142
+ #
143
+ # with torch.amp.autocast("cuda", enabled=False):
144
+ # # binary cross entropy cost
145
+ # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
146
+ # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
147
+ # cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
148
+ # cost_mask /= self.num_sample_points
149
+ #
150
+ # # dice cost
151
+ # out_mask = F.sigmoid(out_mask)
152
+ # numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
153
+ # denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
154
+ # cost_dice = 1 - (numerator + 1) / (denominator + 1)
155
+ #
156
+ # C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
157
+ # return C
158
+
159
+
160
+ def get_cdn_group(
161
+ batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
162
+ ):
163
+ """
164
+ Get contrastive denoising training group with positive and negative samples from ground truths.
165
+
166
+ Args:
167
+ batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape (num_gts, )), 'gt_bboxes'
168
+ (torch.Tensor with shape (num_gts, 4)), 'gt_groups' (List[int]) which is a list of batch size length
169
+ indicating the number of gts of each image.
170
+ num_classes (int): Number of classes.
171
+ num_queries (int): Number of queries.
172
+ class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
173
+ num_dn (int, optional): Number of denoising queries.
174
+ cls_noise_ratio (float, optional): Noise ratio for class labels.
175
+ box_noise_scale (float, optional): Noise scale for bounding box coordinates.
176
+ training (bool, optional): If it's in training mode.
177
+
178
+ Returns:
179
+ padding_cls (Optional[torch.Tensor]): The modified class embeddings for denoising.
180
+ padding_bbox (Optional[torch.Tensor]): The modified bounding boxes for denoising.
181
+ attn_mask (Optional[torch.Tensor]): The attention mask for denoising.
182
+ dn_meta (Optional[Dict]): Meta information for denoising.
183
+ """
184
+ if (not training) or num_dn <= 0 or batch is None:
185
+ return None, None, None, None
186
+ gt_groups = batch["gt_groups"]
187
+ total_num = sum(gt_groups)
188
+ max_nums = max(gt_groups)
189
+ if max_nums == 0:
190
+ return None, None, None, None
191
+
192
+ num_group = num_dn // max_nums
193
+ num_group = 1 if num_group == 0 else num_group
194
+ # Pad gt to max_num of a batch
195
+ bs = len(gt_groups)
196
+ gt_cls = batch["cls"] # (bs*num, )
197
+ gt_bbox = batch["bboxes"] # bs*num, 4
198
+ b_idx = batch["batch_idx"]
199
+
200
+ # Each group has positive and negative queries.
201
+ dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
202
+ dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
203
+ dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
204
+
205
+ # Positive and negative mask
206
+ # (bs*num*num_group, ), the second total_num*num_group part as negative samples
207
+ neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
208
+
209
+ if cls_noise_ratio > 0:
210
+ # Half of bbox prob
211
+ mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
212
+ idx = torch.nonzero(mask).squeeze(-1)
213
+ # Randomly put a new one here
214
+ new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
215
+ dn_cls[idx] = new_label
216
+
217
+ if box_noise_scale > 0:
218
+ known_bbox = xywh2xyxy(dn_bbox)
219
+
220
+ diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
221
+
222
+ rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
223
+ rand_part = torch.rand_like(dn_bbox)
224
+ rand_part[neg_idx] += 1.0
225
+ rand_part *= rand_sign
226
+ known_bbox += rand_part * diff
227
+ known_bbox.clip_(min=0.0, max=1.0)
228
+ dn_bbox = xyxy2xywh(known_bbox)
229
+ dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
230
+
231
+ num_dn = int(max_nums * 2 * num_group) # total denoising queries
232
+ # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
233
+ dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
234
+ padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
235
+ padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
236
+
237
+ map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
238
+ pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
239
+
240
+ map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
241
+ padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
242
+ padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
243
+
244
+ tgt_size = num_dn + num_queries
245
+ attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
246
+ # Match query cannot see the reconstruct
247
+ attn_mask[num_dn:, :num_dn] = True
248
+ # Reconstruct cannot see each other
249
+ for i in range(num_group):
250
+ if i == 0:
251
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
252
+ if i == num_group - 1:
253
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
254
+ else:
255
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
256
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
257
+ dn_meta = {
258
+ "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
259
+ "dn_num_group": num_group,
260
+ "dn_num_split": [num_dn, num_queries],
261
+ }
262
+
263
+ return (
264
+ padding_cls.to(class_embed.device),
265
+ padding_bbox.to(class_embed.device),
266
+ attn_mask.to(class_embed.device),
267
+ dn_meta,
268
+ )
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
4
+
5
+ from .model import YOLO, YOLOE, YOLOWorld
6
+
7
+ __all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.models.yolo.classify.predict import ClassificationPredictor
4
+ from ultralytics.models.yolo.classify.train import ClassificationTrainer
5
+ from ultralytics.models.yolo.classify.val import ClassificationValidator
6
+
7
+ __all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
@@ -0,0 +1,88 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import cv2
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from ultralytics.engine.predictor import BasePredictor
8
+ from ultralytics.engine.results import Results
9
+ from ultralytics.utils import DEFAULT_CFG, ops
10
+
11
+
12
+ class ClassificationPredictor(BasePredictor):
13
+ """
14
+ A class extending the BasePredictor class for prediction based on a classification model.
15
+
16
+ This predictor handles the specific requirements of classification models, including preprocessing images
17
+ and postprocessing predictions to generate classification results.
18
+
19
+ Attributes:
20
+ args (dict): Configuration arguments for the predictor.
21
+ _legacy_transform_name (str): Name of the legacy transform class for backward compatibility.
22
+
23
+ Methods:
24
+ preprocess: Convert input images to model-compatible format.
25
+ postprocess: Process model predictions into Results objects.
26
+
27
+ Notes:
28
+ - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
29
+
30
+ Examples:
31
+ >>> from ultralytics.utils import ASSETS
32
+ >>> from ultralytics.models.yolo.classify import ClassificationPredictor
33
+ >>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
34
+ >>> predictor = ClassificationPredictor(overrides=args)
35
+ >>> predictor.predict_cli()
36
+ """
37
+
38
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
+ """
40
+ Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
41
+
42
+ This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
43
+ tasks. It ensures the task is set to 'classify' regardless of input configuration.
44
+
45
+ Args:
46
+ cfg (dict): Default configuration dictionary containing prediction settings. Defaults to DEFAULT_CFG.
47
+ overrides (dict, optional): Configuration overrides that take precedence over cfg.
48
+ _callbacks (list, optional): List of callback functions to be executed during prediction.
49
+ """
50
+ super().__init__(cfg, overrides, _callbacks)
51
+ self.args.task = "classify"
52
+ self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
53
+
54
+ def preprocess(self, img):
55
+ """Convert input images to model-compatible tensor format with appropriate normalization."""
56
+ if not isinstance(img, torch.Tensor):
57
+ is_legacy_transform = any(
58
+ self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
59
+ )
60
+ if is_legacy_transform: # to handle legacy transforms
61
+ img = torch.stack([self.transforms(im) for im in img], dim=0)
62
+ else:
63
+ img = torch.stack(
64
+ [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
65
+ )
66
+ img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
67
+ return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
68
+
69
+ def postprocess(self, preds, img, orig_imgs):
70
+ """
71
+ Process predictions to return Results objects with classification probabilities.
72
+
73
+ Args:
74
+ preds (torch.Tensor): Raw predictions from the model.
75
+ img (torch.Tensor): Input images after preprocessing.
76
+ orig_imgs (List[np.ndarray] | torch.Tensor): Original images before preprocessing.
77
+
78
+ Returns:
79
+ (List[Results]): List of Results objects containing classification results for each image.
80
+ """
81
+ if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
82
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
83
+
84
+ preds = preds[0] if isinstance(preds, (list, tuple)) else preds
85
+ return [
86
+ Results(orig_img, path=img_path, names=self.model.names, probs=pred)
87
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
88
+ ]
@@ -0,0 +1,233 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from copy import copy
4
+
5
+ import torch
6
+
7
+ from ultralytics.data import ClassificationDataset, build_dataloader
8
+ from ultralytics.engine.trainer import BaseTrainer
9
+ from ultralytics.models import yolo
10
+ from ultralytics.nn.tasks import ClassificationModel
11
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
12
+ from ultralytics.utils.plotting import plot_images, plot_results
13
+ from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
14
+
15
+
16
+ class ClassificationTrainer(BaseTrainer):
17
+ """
18
+ A class extending the BaseTrainer class for training based on a classification model.
19
+
20
+ This trainer handles the training process for image classification tasks, supporting both YOLO classification models
21
+ and torchvision models.
22
+
23
+ Attributes:
24
+ model (ClassificationModel): The classification model to be trained.
25
+ data (dict): Dictionary containing dataset information including class names and number of classes.
26
+ loss_names (List[str]): Names of the loss functions used during training.
27
+ validator (ClassificationValidator): Validator instance for model evaluation.
28
+
29
+ Methods:
30
+ set_model_attributes: Set the model's class names from the loaded dataset.
31
+ get_model: Return a modified PyTorch model configured for training.
32
+ setup_model: Load, create or download model for classification.
33
+ build_dataset: Create a ClassificationDataset instance.
34
+ get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
35
+ preprocess_batch: Preprocess a batch of images and classes.
36
+ progress_string: Return a formatted string showing training progress.
37
+ get_validator: Return an instance of ClassificationValidator.
38
+ label_loss_items: Return a loss dict with labelled training loss items.
39
+ plot_metrics: Plot metrics from a CSV file.
40
+ final_eval: Evaluate trained model and save validation results.
41
+ plot_training_samples: Plot training samples with their annotations.
42
+
43
+ Examples:
44
+ >>> from ultralytics.models.yolo.classify import ClassificationTrainer
45
+ >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
46
+ >>> trainer = ClassificationTrainer(overrides=args)
47
+ >>> trainer.train()
48
+ """
49
+
50
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
51
+ """
52
+ Initialize a ClassificationTrainer object.
53
+
54
+ This constructor sets up a trainer for image classification tasks, configuring the task type and default
55
+ image size if not specified.
56
+
57
+ Args:
58
+ cfg (dict, optional): Default configuration dictionary containing training parameters.
59
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
60
+ _callbacks (list, optional): List of callback functions to be executed during training.
61
+
62
+ Examples:
63
+ >>> from ultralytics.models.yolo.classify import ClassificationTrainer
64
+ >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
65
+ >>> trainer = ClassificationTrainer(overrides=args)
66
+ >>> trainer.train()
67
+ """
68
+ if overrides is None:
69
+ overrides = {}
70
+ overrides["task"] = "classify"
71
+ if overrides.get("imgsz") is None:
72
+ overrides["imgsz"] = 224
73
+ super().__init__(cfg, overrides, _callbacks)
74
+
75
+ def set_model_attributes(self):
76
+ """Set the YOLO model's class names from the loaded dataset."""
77
+ self.model.names = self.data["names"]
78
+
79
+ def get_model(self, cfg=None, weights=None, verbose=True):
80
+ """
81
+ Return a modified PyTorch model configured for training YOLO.
82
+
83
+ Args:
84
+ cfg (Any): Model configuration.
85
+ weights (Any): Pre-trained model weights.
86
+ verbose (bool): Whether to display model information.
87
+
88
+ Returns:
89
+ (ClassificationModel): Configured PyTorch model for classification.
90
+ """
91
+ model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
92
+ if weights:
93
+ model.load(weights)
94
+
95
+ for m in model.modules():
96
+ if not self.args.pretrained and hasattr(m, "reset_parameters"):
97
+ m.reset_parameters()
98
+ if isinstance(m, torch.nn.Dropout) and self.args.dropout:
99
+ m.p = self.args.dropout # set dropout
100
+ for p in model.parameters():
101
+ p.requires_grad = True # for training
102
+ return model
103
+
104
+ def setup_model(self):
105
+ """
106
+ Load, create or download model for classification tasks.
107
+
108
+ Returns:
109
+ (Any): Model checkpoint if applicable, otherwise None.
110
+ """
111
+ import torchvision # scope for faster 'import ultralytics'
112
+
113
+ if str(self.model) in torchvision.models.__dict__:
114
+ self.model = torchvision.models.__dict__[self.model](
115
+ weights="IMAGENET1K_V1" if self.args.pretrained else None
116
+ )
117
+ ckpt = None
118
+ else:
119
+ ckpt = super().setup_model()
120
+ ClassificationModel.reshape_outputs(self.model, self.data["nc"])
121
+ return ckpt
122
+
123
+ def build_dataset(self, img_path, mode="train", batch=None):
124
+ """
125
+ Create a ClassificationDataset instance given an image path and mode.
126
+
127
+ Args:
128
+ img_path (str): Path to the dataset images.
129
+ mode (str): Dataset mode ('train', 'val', or 'test').
130
+ batch (Any): Batch information (unused in this implementation).
131
+
132
+ Returns:
133
+ (ClassificationDataset): Dataset for the specified mode.
134
+ """
135
+ return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
136
+
137
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
138
+ """
139
+ Return PyTorch DataLoader with transforms to preprocess images.
140
+
141
+ Args:
142
+ dataset_path (str): Path to the dataset.
143
+ batch_size (int): Number of images per batch.
144
+ rank (int): Process rank for distributed training.
145
+ mode (str): 'train', 'val', or 'test' mode.
146
+
147
+ Returns:
148
+ (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
149
+ """
150
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
151
+ dataset = self.build_dataset(dataset_path, mode)
152
+
153
+ loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
154
+ # Attach inference transforms
155
+ if mode != "train":
156
+ if is_parallel(self.model):
157
+ self.model.module.transforms = loader.dataset.torch_transforms
158
+ else:
159
+ self.model.transforms = loader.dataset.torch_transforms
160
+ return loader
161
+
162
+ def preprocess_batch(self, batch):
163
+ """Preprocesses a batch of images and classes."""
164
+ batch["img"] = batch["img"].to(self.device)
165
+ batch["cls"] = batch["cls"].to(self.device)
166
+ return batch
167
+
168
+ def progress_string(self):
169
+ """Returns a formatted string showing training progress."""
170
+ return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
171
+ "Epoch",
172
+ "GPU_mem",
173
+ *self.loss_names,
174
+ "Instances",
175
+ "Size",
176
+ )
177
+
178
+ def get_validator(self):
179
+ """Returns an instance of ClassificationValidator for validation."""
180
+ self.loss_names = ["loss"]
181
+ return yolo.classify.ClassificationValidator(
182
+ self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
183
+ )
184
+
185
+ def label_loss_items(self, loss_items=None, prefix="train"):
186
+ """
187
+ Return a loss dict with labelled training loss items tensor.
188
+
189
+ Args:
190
+ loss_items (torch.Tensor, optional): Loss tensor items.
191
+ prefix (str): Prefix to prepend to loss names.
192
+
193
+ Returns:
194
+ (Dict[str, float] | List[str]): Dictionary of loss items or list of loss keys if loss_items is None.
195
+ """
196
+ keys = [f"{prefix}/{x}" for x in self.loss_names]
197
+ if loss_items is None:
198
+ return keys
199
+ loss_items = [round(float(loss_items), 5)]
200
+ return dict(zip(keys, loss_items))
201
+
202
+ def plot_metrics(self):
203
+ """Plot metrics from a CSV file."""
204
+ plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
205
+
206
+ def final_eval(self):
207
+ """Evaluate trained model and save validation results."""
208
+ for f in self.last, self.best:
209
+ if f.exists():
210
+ strip_optimizer(f) # strip optimizers
211
+ if f is self.best:
212
+ LOGGER.info(f"\nValidating {f}...")
213
+ self.validator.args.data = self.args.data
214
+ self.validator.args.plots = self.args.plots
215
+ self.metrics = self.validator(model=f)
216
+ self.metrics.pop("fitness", None)
217
+ self.run_callbacks("on_fit_epoch_end")
218
+
219
+ def plot_training_samples(self, batch, ni):
220
+ """
221
+ Plot training samples with their annotations.
222
+
223
+ Args:
224
+ batch (Dict[str, torch.Tensor]): Batch containing images and class labels.
225
+ ni (int): Number of iterations.
226
+ """
227
+ plot_images(
228
+ images=batch["img"],
229
+ batch_idx=torch.arange(len(batch["img"])),
230
+ cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
231
+ fname=self.save_dir / f"train_batch{ni}.jpg",
232
+ on_plot=self.on_plot,
233
+ )