ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +37 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +111 -41
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +579 -244
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +191 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +226 -82
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +172 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +305 -112
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.63.dist-info/METADATA +370 -0
  235. ultralytics-8.3.63.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
@@ -1 +1 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
  import torch.nn as nn
@@ -6,6 +6,7 @@ import torch.nn.functional as F
6
6
 
7
7
  from ultralytics.utils.loss import FocalLoss, VarifocalLoss
8
8
  from ultralytics.utils.metrics import bbox_iou
9
+
9
10
  from .ops import HungarianMatcher
10
11
 
11
12
 
@@ -33,15 +34,19 @@ class DETRLoss(nn.Module):
33
34
  self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
34
35
  ):
35
36
  """
36
- DETR loss function.
37
+ Initialize DETR loss function with customizable components and gains.
38
+
39
+ Uses default loss_gain if not provided. Initializes HungarianMatcher with
40
+ preset cost gains. Supports auxiliary losses and various loss types.
37
41
 
38
42
  Args:
39
- nc (int): The number of classes.
40
- loss_gain (dict): The coefficient of loss.
41
- aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
42
- use_vfl (bool): Use VarifocalLoss or not.
43
- use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
44
- uni_match_ind (int): The fixed indices of a layer.
43
+ nc (int): Number of classes.
44
+ loss_gain (dict): Coefficients for different loss components.
45
+ aux_loss (bool): Use auxiliary losses from each decoder layer.
46
+ use_fl (bool): Use FocalLoss.
47
+ use_vfl (bool): Use VarifocalLoss.
48
+ use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
49
+ uni_match_ind (int): Index of fixed layer for uni_match.
45
50
  """
46
51
  super().__init__()
47
52
 
@@ -81,9 +86,7 @@ class DETRLoss(nn.Module):
81
86
  return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
82
87
 
83
88
  def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
84
- """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
85
- boxes.
86
- """
89
+ """Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
87
90
  # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
88
91
  name_bbox = f"loss_bbox{postfix}"
89
92
  name_giou = f"loss_giou{postfix}"
@@ -240,23 +243,32 @@ class DETRLoss(nn.Module):
240
243
  if len(gt_bboxes):
241
244
  gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
242
245
 
243
- loss = {}
244
- loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
245
- loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
246
- # if masks is not None and gt_mask is not None:
247
- # loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
248
- return loss
246
+ return {
247
+ **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
248
+ **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
249
+ # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
250
+ }
249
251
 
250
252
  def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
251
253
  """
254
+ Calculate loss for predicted bounding boxes and scores.
255
+
252
256
  Args:
253
- pred_bboxes (torch.Tensor): [l, b, query, 4]
254
- pred_scores (torch.Tensor): [l, b, query, num_classes]
255
- batch (dict): A dict includes:
256
- gt_cls (torch.Tensor) with shape [num_gts, ],
257
- gt_bboxes (torch.Tensor): [num_gts, 4],
258
- gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
259
- postfix (str): postfix of loss name.
257
+ pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
258
+ pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
259
+ batch (dict): Batch information containing:
260
+ cls (torch.Tensor): Ground truth classes, shape [num_gts].
261
+ bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
262
+ gt_groups (List[int]): Number of ground truths for each image in the batch.
263
+ postfix (str): Postfix for loss names.
264
+ **kwargs (Any): Additional arguments, may include 'match_indices'.
265
+
266
+ Returns:
267
+ (dict): Computed losses, including main and auxiliary (if enabled).
268
+
269
+ Note:
270
+ Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
271
+ self.aux_loss is True.
260
272
  """
261
273
  self.device = pred_bboxes.device
262
274
  match_indices = kwargs.get("match_indices", None)
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
  import torch.nn as nn
@@ -32,9 +32,7 @@ class HungarianMatcher(nn.Module):
32
32
  """
33
33
 
34
34
  def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
35
- """Initializes HungarianMatcher with cost coefficients, Focal Loss, mask prediction, sample points, and alpha
36
- gamma factors.
37
- """
35
+ """Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
38
36
  super().__init__()
39
37
  if cost_gain is None:
40
38
  cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
@@ -70,7 +68,6 @@ class HungarianMatcher(nn.Module):
70
68
  For each batch element, it holds:
71
69
  len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
72
70
  """
73
-
74
71
  bs, nq, nc = pred_scores.shape
75
72
 
76
73
  if sum(gt_groups) == 0:
@@ -133,7 +130,7 @@ class HungarianMatcher(nn.Module):
133
130
  # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
134
131
  # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
135
132
  #
136
- # with torch.cuda.amp.autocast(False):
133
+ # with torch.amp.autocast("cuda", enabled=False):
137
134
  # # binary cross entropy cost
138
135
  # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
139
136
  # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
@@ -175,7 +172,6 @@ def get_cdn_group(
175
172
  bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
176
173
  is less than or equal to 0, the function returns None for all elements in the tuple.
177
174
  """
178
-
179
175
  if (not training) or num_dn <= 0:
180
176
  return None, None, None, None
181
177
  gt_groups = batch["gt_groups"]
@@ -1,7 +1,7 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from ultralytics.models.yolo import classify, detect, obb, pose, segment
3
+ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
4
4
 
5
5
  from .model import YOLO, YOLOWorld
6
6
 
7
- __all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
7
+ __all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from ultralytics.models.yolo.classify.predict import ClassificationPredictor
4
4
  from ultralytics.models.yolo.classify.train import ClassificationTrainer
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import cv2
4
4
  import torch
@@ -21,7 +21,7 @@ class ClassificationPredictor(BasePredictor):
21
21
  from ultralytics.utils import ASSETS
22
22
  from ultralytics.models.yolo.classify import ClassificationPredictor
23
23
 
24
- args = dict(model='yolov8n-cls.pt', source=ASSETS)
24
+ args = dict(model="yolov8n-cls.pt", source=ASSETS)
25
25
  predictor = ClassificationPredictor(overrides=args)
26
26
  predictor.predict_cli()
27
27
  ```
@@ -53,9 +53,8 @@ class ClassificationPredictor(BasePredictor):
53
53
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
54
54
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
55
55
 
56
- results = []
57
- for i, pred in enumerate(preds):
58
- orig_img = orig_imgs[i]
59
- img_path = self.batch[0][i]
60
- results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
61
- return results
56
+ preds = preds[0] if isinstance(preds, (list, tuple)) else preds
57
+ return [
58
+ Results(orig_img, path=img_path, names=self.model.names, probs=pred)
59
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
60
+ ]
@@ -1,13 +1,14 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from copy import copy
2
4
 
3
5
  import torch
4
- import torchvision
5
6
 
6
7
  from ultralytics.data import ClassificationDataset, build_dataloader
7
8
  from ultralytics.engine.trainer import BaseTrainer
8
9
  from ultralytics.models import yolo
9
- from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
10
- from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
10
+ from ultralytics.nn.tasks import ClassificationModel
11
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
11
12
  from ultralytics.utils.plotting import plot_images, plot_results
12
13
  from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
13
14
 
@@ -23,7 +24,7 @@ class ClassificationTrainer(BaseTrainer):
23
24
  ```python
24
25
  from ultralytics.models.yolo.classify import ClassificationTrainer
25
26
 
26
- args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
27
+ args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
27
28
  trainer = ClassificationTrainer(overrides=args)
28
29
  trainer.train()
29
30
  ```
@@ -59,23 +60,16 @@ class ClassificationTrainer(BaseTrainer):
59
60
 
60
61
  def setup_model(self):
61
62
  """Load, create or download model for any task."""
62
- if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
63
- return
64
-
65
- model, ckpt = str(self.model), None
66
- # Load a YOLO model locally, from torchvision, or from Ultralytics assets
67
- if model.endswith(".pt"):
68
- self.model, ckpt = attempt_load_one_weight(model, device="cpu")
69
- for p in self.model.parameters():
70
- p.requires_grad = True # for training
71
- elif model.split(".")[-1] in ("yaml", "yml"):
72
- self.model = self.get_model(cfg=model)
73
- elif model in torchvision.models.__dict__:
74
- self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
63
+ import torchvision # scope for faster 'import ultralytics'
64
+
65
+ if str(self.model) in torchvision.models.__dict__:
66
+ self.model = torchvision.models.__dict__[self.model](
67
+ weights="IMAGENET1K_V1" if self.args.pretrained else None
68
+ )
69
+ ckpt = None
75
70
  else:
76
- FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
71
+ ckpt = super().setup_model()
77
72
  ClassificationModel.reshape_outputs(self.model, self.data["nc"])
78
-
79
73
  return ckpt
80
74
 
81
75
  def build_dataset(self, img_path, mode="train", batch=None):
@@ -115,7 +109,9 @@ class ClassificationTrainer(BaseTrainer):
115
109
  def get_validator(self):
116
110
  """Returns an instance of ClassificationValidator for validation."""
117
111
  self.loss_names = ["loss"]
118
- return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
112
+ return yolo.classify.ClassificationValidator(
113
+ self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
114
+ )
119
115
 
120
116
  def label_loss_items(self, loss_items=None, prefix="train"):
121
117
  """
@@ -145,7 +141,6 @@ class ClassificationTrainer(BaseTrainer):
145
141
  self.metrics = self.validator(model=f)
146
142
  self.metrics.pop("fitness", None)
147
143
  self.run_callbacks("on_fit_epoch_end")
148
- LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
149
144
 
150
145
  def plot_training_samples(self, batch, ni):
151
146
  """Plots training samples with their annotations."""
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
 
@@ -20,7 +20,7 @@ class ClassificationValidator(BaseValidator):
20
20
  ```python
21
21
  from ultralytics.models.yolo.classify import ClassificationValidator
22
22
 
23
- args = dict(model='yolov8n-cls.pt', data='imagenet10')
23
+ args = dict(model="yolov8n-cls.pt", data="imagenet10")
24
24
  validator = ClassificationValidator(args=args)
25
25
  validator()
26
26
  ```
@@ -56,8 +56,8 @@ class ClassificationValidator(BaseValidator):
56
56
  def update_metrics(self, preds, batch):
57
57
  """Updates running metrics with model predictions and batch targets."""
58
58
  n5 = min(len(self.names), 5)
59
- self.pred.append(preds.argsort(1, descending=True)[:, :n5])
60
- self.targets.append(batch["cls"])
59
+ self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
60
+ self.targets.append(batch["cls"].type(torch.int32).cpu())
61
61
 
62
62
  def finalize_metrics(self, *args, **kwargs):
63
63
  """Finalizes metrics of the model such as confusion_matrix and speed."""
@@ -71,6 +71,10 @@ class ClassificationValidator(BaseValidator):
71
71
  self.metrics.confusion_matrix = self.confusion_matrix
72
72
  self.metrics.save_dir = self.save_dir
73
73
 
74
+ def postprocess(self, preds):
75
+ """Preprocesses the classification predictions."""
76
+ return preds[0] if isinstance(preds, (list, tuple)) else preds
77
+
74
78
  def get_stats(self):
75
79
  """Returns a dictionary of metrics obtained by processing targets and predictions."""
76
80
  self.metrics.process(self.targets, self.pred)
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .predict import DetectionPredictor
4
4
  from .train import DetectionTrainer
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from ultralytics.engine.predictor import BasePredictor
4
4
  from ultralytics.engine.results import Results
@@ -14,7 +14,7 @@ class DetectionPredictor(BasePredictor):
14
14
  from ultralytics.utils import ASSETS
15
15
  from ultralytics.models.yolo.detect import DetectionPredictor
16
16
 
17
- args = dict(model='yolov8n.pt', source=ASSETS)
17
+ args = dict(model="yolo11n.pt", source=ASSETS)
18
18
  predictor = DetectionPredictor(overrides=args)
19
19
  predictor.predict_cli()
20
20
  ```
@@ -35,9 +35,7 @@ class DetectionPredictor(BasePredictor):
35
35
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
36
36
 
37
37
  results = []
38
- for i, pred in enumerate(preds):
39
- orig_img = orig_imgs[i]
38
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
40
39
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
41
- img_path = self.batch[0][i]
42
40
  results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
43
41
  return results
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import math
4
4
  import random
@@ -24,7 +24,7 @@ class DetectionTrainer(BaseTrainer):
24
24
  ```python
25
25
  from ultralytics.models.yolo.detect import DetectionTrainer
26
26
 
27
- args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
27
+ args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
28
28
  trainer = DetectionTrainer(overrides=args)
29
29
  trainer.train()
30
30
  ```
@@ -44,7 +44,7 @@ class DetectionTrainer(BaseTrainer):
44
44
 
45
45
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
46
46
  """Construct and return dataloader."""
47
- assert mode in ["train", "val"]
47
+ assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
48
48
  with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
49
49
  dataset = self.build_dataset(dataset_path, mode, batch_size)
50
50
  shuffle = mode == "train"
@@ -60,7 +60,7 @@ class DetectionTrainer(BaseTrainer):
60
60
  if self.args.multi_scale:
61
61
  imgs = batch["img"]
62
62
  sz = (
63
- random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
63
+ random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
64
64
  // self.stride
65
65
  * self.stride
66
66
  ) # size
@@ -141,3 +141,10 @@ class DetectionTrainer(BaseTrainer):
141
141
  boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
142
142
  cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
143
143
  plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
144
+
145
+ def auto_batch(self):
146
+ """Get batch size by calculating memory occupation of model."""
147
+ train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
148
+ # 4 for mosaic augmentation
149
+ max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4
150
+ return super().auto_batch(max_num_obj)
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import os
4
4
  from pathlib import Path
@@ -22,7 +22,7 @@ class DetectionValidator(BaseValidator):
22
22
  ```python
23
23
  from ultralytics.models.yolo.detect import DetectionValidator
24
24
 
25
- args = dict(model='yolov8n.pt', data='coco8.yaml')
25
+ args = dict(model="yolo11n.pt", data="coco8.yaml")
26
26
  validator = DetectionValidator(args=args)
27
27
  validator()
28
28
  ```
@@ -32,13 +32,20 @@ class DetectionValidator(BaseValidator):
32
32
  """Initialize detection model with necessary variables and settings."""
33
33
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
34
34
  self.nt_per_class = None
35
+ self.nt_per_image = None
35
36
  self.is_coco = False
37
+ self.is_lvis = False
36
38
  self.class_map = None
37
39
  self.args.task = "detect"
38
40
  self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
39
41
  self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
40
42
  self.niou = self.iouv.numel()
41
43
  self.lb = [] # for autolabelling
44
+ if self.args.save_hybrid:
45
+ LOGGER.warning(
46
+ "WARNING ⚠️ 'save_hybrid=True' will append ground truth to predictions for autolabelling.\n"
47
+ "WARNING ⚠️ 'save_hybrid=True' will cause incorrect mAP.\n"
48
+ )
42
49
 
43
50
  def preprocess(self, batch):
44
51
  """Preprocesses batch of images for YOLO training."""
@@ -51,23 +58,24 @@ class DetectionValidator(BaseValidator):
51
58
  height, width = batch["img"].shape[2:]
52
59
  nb = len(batch["img"])
53
60
  bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
54
- self.lb = (
55
- [
56
- torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
57
- for i in range(nb)
58
- ]
59
- if self.args.save_hybrid
60
- else []
61
- ) # for autolabelling
61
+ self.lb = [
62
+ torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
63
+ for i in range(nb)
64
+ ]
62
65
 
63
66
  return batch
64
67
 
65
68
  def init_metrics(self, model):
66
69
  """Initialize evaluation metrics for YOLO."""
67
70
  val = self.data.get(self.args.split, "") # validation path
68
- self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
69
- self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
70
- self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
71
+ self.is_coco = (
72
+ isinstance(val, str)
73
+ and "coco" in val
74
+ and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
75
+ ) # is COCO
76
+ self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
77
+ self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
78
+ self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
71
79
  self.names = model.names
72
80
  self.nc = len(model.names)
73
81
  self.metrics.names = self.names
@@ -75,7 +83,7 @@ class DetectionValidator(BaseValidator):
75
83
  self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
76
84
  self.seen = 0
77
85
  self.jdict = []
78
- self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
86
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
79
87
 
80
88
  def get_desc(self):
81
89
  """Return a formatted string summarizing class metrics of YOLO model."""
@@ -89,7 +97,7 @@ class DetectionValidator(BaseValidator):
89
97
  self.args.iou,
90
98
  labels=self.lb,
91
99
  multi_label=True,
92
- agnostic=self.args.single_cls,
100
+ agnostic=self.args.single_cls or self.args.agnostic_nms,
93
101
  max_det=self.args.max_det,
94
102
  )
95
103
 
@@ -104,7 +112,7 @@ class DetectionValidator(BaseValidator):
104
112
  if len(cls):
105
113
  bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
106
114
  ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
107
- return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
115
+ return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
108
116
 
109
117
  def _prepare_pred(self, pred, pbatch):
110
118
  """Prepares a batch of images and annotations for validation."""
@@ -128,6 +136,7 @@ class DetectionValidator(BaseValidator):
128
136
  cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
129
137
  nl = len(cls)
130
138
  stat["target_cls"] = cls
139
+ stat["target_img"] = cls.unique()
131
140
  if npr == 0:
132
141
  if nl:
133
142
  for k in self.stats.keys():
@@ -146,8 +155,8 @@ class DetectionValidator(BaseValidator):
146
155
  # Evaluate
147
156
  if nl:
148
157
  stat["tp"] = self._process_batch(predn, bbox, cls)
149
- if self.args.plots:
150
- self.confusion_matrix.process_batch(predn, bbox, cls)
158
+ if self.args.plots:
159
+ self.confusion_matrix.process_batch(predn, bbox, cls)
151
160
  for k in self.stats.keys():
152
161
  self.stats[k].append(stat[k])
153
162
 
@@ -155,8 +164,12 @@ class DetectionValidator(BaseValidator):
155
164
  if self.args.save_json:
156
165
  self.pred_to_json(predn, batch["im_file"][si])
157
166
  if self.args.save_txt:
158
- file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
159
- self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
167
+ self.save_one_txt(
168
+ predn,
169
+ self.args.save_conf,
170
+ pbatch["ori_shape"],
171
+ self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
172
+ )
160
173
 
161
174
  def finalize_metrics(self, *args, **kwargs):
162
175
  """Set final values for metrics speed and confusion matrix."""
@@ -166,11 +179,11 @@ class DetectionValidator(BaseValidator):
166
179
  def get_stats(self):
167
180
  """Returns metrics statistics and results dictionary."""
168
181
  stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
182
+ self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
183
+ self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
184
+ stats.pop("target_img", None)
169
185
  if len(stats) and stats["tp"].any():
170
186
  self.metrics.process(**stats)
171
- self.nt_per_class = np.bincount(
172
- stats["target_cls"].astype(int), minlength=self.nc
173
- ) # number of targets per class
174
187
  return self.metrics.results_dict
175
188
 
176
189
  def print_results(self):
@@ -183,7 +196,9 @@ class DetectionValidator(BaseValidator):
183
196
  # Print results per class
184
197
  if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
185
198
  for i, c in enumerate(self.metrics.ap_class_index):
186
- LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
199
+ LOGGER.info(
200
+ pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
201
+ )
187
202
 
188
203
  if self.args.plots:
189
204
  for normalize in True, False:
@@ -196,13 +211,18 @@ class DetectionValidator(BaseValidator):
196
211
  Return correct prediction matrix.
197
212
 
198
213
  Args:
199
- detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
200
- Each detection is of the format: x1, y1, x2, y2, conf, class.
201
- labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
202
- Each label is of the format: class, x1, y1, x2, y2.
214
+ detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
215
+ (x1, y1, x2, y2, conf, class).
216
+ gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
217
+ bounding box is of the format: (x1, y1, x2, y2).
218
+ gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.
203
219
 
204
220
  Returns:
205
- (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
221
+ (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
222
+
223
+ Note:
224
+ The function does not return any value directly usable for metrics calculation. Instead, it provides an
225
+ intermediate representation used for evaluating predictions against ground truth.
206
226
  """
207
227
  iou = box_iou(gt_bboxes, detections[:, :4])
208
228
  return self.match_predictions(detections[:, 5], gt_cls, iou)
@@ -249,12 +269,14 @@ class DetectionValidator(BaseValidator):
249
269
 
250
270
  def save_one_txt(self, predn, save_conf, shape, file):
251
271
  """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
252
- gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
253
- for *xyxy, conf, cls in predn.tolist():
254
- xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
255
- line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
256
- with open(file, "a") as f:
257
- f.write(("%g " * len(line)).rstrip() % line + "\n")
272
+ from ultralytics.engine.results import Results
273
+
274
+ Results(
275
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
276
+ path=None,
277
+ names=self.names,
278
+ boxes=predn[:, :6],
279
+ ).save_txt(file, save_conf=save_conf)
258
280
 
259
281
  def pred_to_json(self, predn, filename):
260
282
  """Serialize YOLO predictions to COCO json format."""
@@ -274,26 +296,42 @@ class DetectionValidator(BaseValidator):
274
296
 
275
297
  def eval_json(self, stats):
276
298
  """Evaluates YOLO output in JSON format and returns performance statistics."""
277
- if self.args.save_json and self.is_coco and len(self.jdict):
278
- anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
299
+ if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
279
300
  pred_json = self.save_dir / "predictions.json" # predictions
280
- LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
301
+ anno_json = (
302
+ self.data["path"]
303
+ / "annotations"
304
+ / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
305
+ ) # annotations
306
+ pkg = "pycocotools" if self.is_coco else "lvis"
307
+ LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
281
308
  try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
282
- check_requirements("pycocotools>=2.0.6")
283
- from pycocotools.coco import COCO # noqa
284
- from pycocotools.cocoeval import COCOeval # noqa
285
-
286
- for x in anno_json, pred_json:
309
+ for x in pred_json, anno_json:
287
310
  assert x.is_file(), f"{x} file not found"
288
- anno = COCO(str(anno_json)) # init annotations api
289
- pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
290
- eval = COCOeval(anno, pred, "bbox")
311
+ check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
291
312
  if self.is_coco:
292
- eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
293
- eval.evaluate()
294
- eval.accumulate()
295
- eval.summarize()
296
- stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
313
+ from pycocotools.coco import COCO # noqa
314
+ from pycocotools.cocoeval import COCOeval # noqa
315
+
316
+ anno = COCO(str(anno_json)) # init annotations api
317
+ pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
318
+ val = COCOeval(anno, pred, "bbox")
319
+ else:
320
+ from lvis import LVIS, LVISEval
321
+
322
+ anno = LVIS(str(anno_json)) # init annotations api
323
+ pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
324
+ val = LVISEval(anno, pred, "bbox")
325
+ val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
326
+ val.evaluate()
327
+ val.accumulate()
328
+ val.summarize()
329
+ if self.is_lvis:
330
+ val.print_results() # explicitly call print_results
331
+ # update mAP50-95 and mAP50
332
+ stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
333
+ val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
334
+ )
297
335
  except Exception as e:
298
- LOGGER.warning(f"pycocotools unable to run: {e}")
336
+ LOGGER.warning(f"{pkg} unable to run: {e}")
299
337
  return stats