dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,20 @@ from pathlib import Path
6
6
  from typing import Any
7
7
 
8
8
  import torch
9
+ import torch.distributed as dist
9
10
 
10
11
  from ultralytics.data import ClassificationDataset, build_dataloader
11
12
  from ultralytics.engine.validator import BaseValidator
12
- from ultralytics.utils import LOGGER
13
+ from ultralytics.utils import LOGGER, RANK
13
14
  from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
14
15
  from ultralytics.utils.plotting import plot_images
15
16
 
16
17
 
17
18
  class ClassificationValidator(BaseValidator):
18
- """
19
- A class extending the BaseValidator class for validation based on a classification model.
19
+ """A class extending the BaseValidator class for validation based on a classification model.
20
20
 
21
- This validator handles the validation process for classification models, including metrics calculation,
22
- confusion matrix generation, and visualization of results.
21
+ This validator handles the validation process for classification models, including metrics calculation, confusion
22
+ matrix generation, and visualization of results.
23
23
 
24
24
  Attributes:
25
25
  targets (list[torch.Tensor]): Ground truth class labels.
@@ -45,7 +45,7 @@ class ClassificationValidator(BaseValidator):
45
45
 
46
46
  Examples:
47
47
  >>> from ultralytics.models.yolo.classify import ClassificationValidator
48
- >>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
48
+ >>> args = dict(model="yolo26n-cls.pt", data="imagenet10")
49
49
  >>> validator = ClassificationValidator(args=args)
50
50
  >>> validator()
51
51
 
@@ -54,20 +54,13 @@ class ClassificationValidator(BaseValidator):
54
54
  """
55
55
 
56
56
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
57
- """
58
- Initialize ClassificationValidator with dataloader, save directory, and other parameters.
57
+ """Initialize ClassificationValidator with dataloader, save directory, and other parameters.
59
58
 
60
59
  Args:
61
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
60
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
62
61
  save_dir (str | Path, optional): Directory to save results.
63
62
  args (dict, optional): Arguments containing model and validation configuration.
64
63
  _callbacks (list, optional): List of callback functions to be called during validation.
65
-
66
- Examples:
67
- >>> from ultralytics.models.yolo.classify import ClassificationValidator
68
- >>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
69
- >>> validator = ClassificationValidator(args=args)
70
- >>> validator()
71
64
  """
72
65
  super().__init__(dataloader, save_dir, args, _callbacks)
73
66
  self.targets = None
@@ -95,8 +88,7 @@ class ClassificationValidator(BaseValidator):
95
88
  return batch
96
89
 
97
90
  def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
98
- """
99
- Update running metrics with model predictions and batch targets.
91
+ """Update running metrics with model predictions and batch targets.
100
92
 
101
93
  Args:
102
94
  preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
@@ -111,12 +103,7 @@ class ClassificationValidator(BaseValidator):
111
103
  self.targets.append(batch["cls"].type(torch.int32).cpu())
112
104
 
113
105
  def finalize_metrics(self) -> None:
114
- """
115
- Finalize metrics including confusion matrix and processing speed.
116
-
117
- Notes:
118
- This method processes the accumulated predictions and targets to generate the confusion matrix,
119
- optionally plots it, and updates the metrics object with speed information.
106
+ """Finalize metrics including confusion matrix and processing speed.
120
107
 
121
108
  Examples:
122
109
  >>> validator = ClassificationValidator()
@@ -124,6 +111,10 @@ class ClassificationValidator(BaseValidator):
124
111
  >>> validator.targets = [torch.tensor([0])] # Ground truth class
125
112
  >>> validator.finalize_metrics()
126
113
  >>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
114
+
115
+ Notes:
116
+ This method processes the accumulated predictions and targets to generate the confusion matrix,
117
+ optionally plots it, and updates the metrics object with speed information.
127
118
  """
128
119
  self.confusion_matrix.process_cls_preds(self.pred, self.targets)
129
120
  if self.args.plots:
@@ -142,13 +133,25 @@ class ClassificationValidator(BaseValidator):
142
133
  self.metrics.process(self.targets, self.pred)
143
134
  return self.metrics.results_dict
144
135
 
136
+ def gather_stats(self) -> None:
137
+ """Gather stats from all GPUs."""
138
+ if RANK == 0:
139
+ gathered_preds = [None] * dist.get_world_size()
140
+ gathered_targets = [None] * dist.get_world_size()
141
+ dist.gather_object(self.pred, gathered_preds, dst=0)
142
+ dist.gather_object(self.targets, gathered_targets, dst=0)
143
+ self.pred = [pred for rank in gathered_preds for pred in rank]
144
+ self.targets = [targets for rank in gathered_targets for targets in rank]
145
+ elif RANK > 0:
146
+ dist.gather_object(self.pred, None, dst=0)
147
+ dist.gather_object(self.targets, None, dst=0)
148
+
145
149
  def build_dataset(self, img_path: str) -> ClassificationDataset:
146
150
  """Create a ClassificationDataset instance for validation."""
147
151
  return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
148
152
 
149
153
  def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
150
- """
151
- Build and return a data loader for classification validation.
154
+ """Build and return a data loader for classification validation.
152
155
 
153
156
  Args:
154
157
  dataset_path (str | Path): Path to the dataset directory.
@@ -166,8 +169,7 @@ class ClassificationValidator(BaseValidator):
166
169
  LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
167
170
 
168
171
  def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
169
- """
170
- Plot validation image samples with their ground truth labels.
172
+ """Plot validation image samples with their ground truth labels.
171
173
 
172
174
  Args:
173
175
  batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
@@ -187,8 +189,7 @@ class ClassificationValidator(BaseValidator):
187
189
  )
188
190
 
189
191
  def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
190
- """
191
- Plot images with their predicted class labels and save the visualization.
192
+ """Plot images with their predicted class labels and save the visualization.
192
193
 
193
194
  Args:
194
195
  batch (dict[str, Any]): Batch data containing images and other information.
@@ -6,8 +6,7 @@ from ultralytics.utils import nms, ops
6
6
 
7
7
 
8
8
  class DetectionPredictor(BasePredictor):
9
- """
10
- A class extending the BasePredictor class for prediction based on a detection model.
9
+ """A class extending the BasePredictor class for prediction based on a detection model.
11
10
 
12
11
  This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
13
12
  with bounding boxes and class predictions.
@@ -26,14 +25,13 @@ class DetectionPredictor(BasePredictor):
26
25
  Examples:
27
26
  >>> from ultralytics.utils import ASSETS
28
27
  >>> from ultralytics.models.yolo.detect import DetectionPredictor
29
- >>> args = dict(model="yolo11n.pt", source=ASSETS)
28
+ >>> args = dict(model="yolo26n.pt", source=ASSETS)
30
29
  >>> predictor = DetectionPredictor(overrides=args)
31
30
  >>> predictor.predict_cli()
32
31
  """
33
32
 
34
33
  def postprocess(self, preds, img, orig_imgs, **kwargs):
35
- """
36
- Post-process predictions and return a list of Results objects.
34
+ """Post-process predictions and return a list of Results objects.
37
35
 
38
36
  This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
39
37
  further analysis.
@@ -48,7 +46,7 @@ class DetectionPredictor(BasePredictor):
48
46
  (list): List of Results objects containing the post-processed predictions.
49
47
 
50
48
  Examples:
51
- >>> predictor = DetectionPredictor(overrides=dict(model="yolo11n.pt"))
49
+ >>> predictor = DetectionPredictor(overrides=dict(model="yolo26n.pt"))
52
50
  >>> results = predictor.predict("path/to/image.jpg")
53
51
  >>> processed_results = predictor.postprocess(preds, img, orig_imgs)
54
52
  """
@@ -67,7 +65,7 @@ class DetectionPredictor(BasePredictor):
67
65
  )
68
66
 
69
67
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
70
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
68
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
71
69
 
72
70
  if save_feats:
73
71
  obj_feats = self.get_obj_feats(self._feats, preds[1])
@@ -81,7 +79,8 @@ class DetectionPredictor(BasePredictor):
81
79
 
82
80
  return results
83
81
 
84
- def get_obj_feats(self, feat_maps, idxs):
82
+ @staticmethod
83
+ def get_obj_feats(feat_maps, idxs):
85
84
  """Extract object features from the feature maps."""
86
85
  import torch
87
86
 
@@ -92,8 +91,7 @@ class DetectionPredictor(BasePredictor):
92
91
  return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
93
92
 
94
93
  def construct_results(self, preds, img, orig_imgs):
95
- """
96
- Construct a list of Results objects from model predictions.
94
+ """Construct a list of Results objects from model predictions.
97
95
 
98
96
  Args:
99
97
  preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
@@ -109,8 +107,7 @@ class DetectionPredictor(BasePredictor):
109
107
  ]
110
108
 
111
109
  def construct_result(self, pred, img, orig_img, img_path):
112
- """
113
- Construct a single Results object from one image prediction.
110
+ """Construct a single Results object from one image prediction.
114
111
 
115
112
  Args:
116
113
  pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
@@ -22,11 +22,10 @@ from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_m
22
22
 
23
23
 
24
24
  class DetectionTrainer(BaseTrainer):
25
- """
26
- A class extending the BaseTrainer class for training based on a detection model.
25
+ """A class extending the BaseTrainer class for training based on a detection model.
27
26
 
28
- This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
29
- for object detection including dataset building, data loading, preprocessing, and model configuration.
27
+ This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
28
+ object detection including dataset building, data loading, preprocessing, and model configuration.
30
29
 
31
30
  Attributes:
32
31
  model (DetectionModel): The YOLO detection model being trained.
@@ -48,14 +47,13 @@ class DetectionTrainer(BaseTrainer):
48
47
 
49
48
  Examples:
50
49
  >>> from ultralytics.models.yolo.detect import DetectionTrainer
51
- >>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
50
+ >>> args = dict(model="yolo26n.pt", data="coco8.yaml", epochs=3)
52
51
  >>> trainer = DetectionTrainer(overrides=args)
53
52
  >>> trainer.train()
54
53
  """
55
54
 
56
55
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
57
- """
58
- Initialize a DetectionTrainer object for training YOLO object detection model training.
56
+ """Initialize a DetectionTrainer object for training YOLO object detection models.
59
57
 
60
58
  Args:
61
59
  cfg (dict, optional): Default configuration dictionary containing training parameters.
@@ -65,8 +63,7 @@ class DetectionTrainer(BaseTrainer):
65
63
  super().__init__(cfg, overrides, _callbacks)
66
64
 
67
65
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
68
- """
69
- Build YOLO Dataset for training or validation.
66
+ """Build YOLO Dataset for training or validation.
70
67
 
71
68
  Args:
72
69
  img_path (str): Path to the folder containing images.
@@ -80,8 +77,7 @@ class DetectionTrainer(BaseTrainer):
80
77
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
81
78
 
82
79
  def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
83
- """
84
- Construct and return dataloader for the specified mode.
80
+ """Construct and return dataloader for the specified mode.
85
81
 
86
82
  Args:
87
83
  dataset_path (str): Path to the dataset.
@@ -109,8 +105,7 @@ class DetectionTrainer(BaseTrainer):
109
105
  )
110
106
 
111
107
  def preprocess_batch(self, batch: dict) -> dict:
112
- """
113
- Preprocess a batch of images by scaling and converting to float.
108
+ """Preprocess a batch of images by scaling and converting to float.
114
109
 
115
110
  Args:
116
111
  batch (dict): Dictionary containing batch data with 'img' tensor.
@@ -122,10 +117,13 @@ class DetectionTrainer(BaseTrainer):
122
117
  if isinstance(v, torch.Tensor):
123
118
  batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
124
119
  batch["img"] = batch["img"].float() / 255
125
- if self.args.multi_scale:
120
+ if self.args.multi_scale > 0.0:
126
121
  imgs = batch["img"]
127
122
  sz = (
128
- random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
123
+ random.randrange(
124
+ int(self.args.imgsz * (1.0 - self.args.multi_scale)),
125
+ int(self.args.imgsz * (1.0 + self.args.multi_scale) + self.stride),
126
+ )
129
127
  // self.stride
130
128
  * self.stride
131
129
  ) # size
@@ -150,8 +148,7 @@ class DetectionTrainer(BaseTrainer):
150
148
  # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
151
149
 
152
150
  def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
153
- """
154
- Return a YOLO detection model.
151
+ """Return a YOLO detection model.
155
152
 
156
153
  Args:
157
154
  cfg (str, optional): Path to model configuration file.
@@ -174,8 +171,7 @@ class DetectionTrainer(BaseTrainer):
174
171
  )
175
172
 
176
173
  def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
177
- """
178
- Return a loss dict with labeled training loss items tensor.
174
+ """Return a loss dict with labeled training loss items tensor.
179
175
 
180
176
  Args:
181
177
  loss_items (list[float], optional): List of loss values.
@@ -202,8 +198,7 @@ class DetectionTrainer(BaseTrainer):
202
198
  )
203
199
 
204
200
  def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
205
- """
206
- Plot training samples with their annotations.
201
+ """Plot training samples with their annotations.
207
202
 
208
203
  Args:
209
204
  batch (dict[str, Any]): Dictionary containing batch data.
@@ -223,8 +218,7 @@ class DetectionTrainer(BaseTrainer):
223
218
  plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
224
219
 
225
220
  def auto_batch(self):
226
- """
227
- Get optimal batch size by calculating memory occupation of model.
221
+ """Get optimal batch size by calculating memory occupation of model.
228
222
 
229
223
  Returns:
230
224
  (int): Optimal batch size.
@@ -8,18 +8,18 @@ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  import torch
11
+ import torch.distributed as dist
11
12
 
12
13
  from ultralytics.data import build_dataloader, build_yolo_dataset, converter
13
14
  from ultralytics.engine.validator import BaseValidator
14
- from ultralytics.utils import LOGGER, nms, ops
15
+ from ultralytics.utils import LOGGER, RANK, nms, ops
15
16
  from ultralytics.utils.checks import check_requirements
16
17
  from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
17
18
  from ultralytics.utils.plotting import plot_images
18
19
 
19
20
 
20
21
  class DetectionValidator(BaseValidator):
21
- """
22
- A class extending the BaseValidator class for validation based on a detection model.
22
+ """A class extending the BaseValidator class for validation based on a detection model.
23
23
 
24
24
  This class implements validation functionality specific to object detection tasks, including metrics calculation,
25
25
  prediction processing, and visualization of results.
@@ -37,17 +37,16 @@ class DetectionValidator(BaseValidator):
37
37
 
38
38
  Examples:
39
39
  >>> from ultralytics.models.yolo.detect import DetectionValidator
40
- >>> args = dict(model="yolo11n.pt", data="coco8.yaml")
40
+ >>> args = dict(model="yolo26n.pt", data="coco8.yaml")
41
41
  >>> validator = DetectionValidator(args=args)
42
42
  >>> validator()
43
43
  """
44
44
 
45
45
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
46
- """
47
- Initialize detection validator with necessary variables and settings.
46
+ """Initialize detection validator with necessary variables and settings.
48
47
 
49
48
  Args:
50
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
49
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
51
50
  save_dir (Path, optional): Directory to save results.
52
51
  args (dict[str, Any], optional): Arguments for the validator.
53
52
  _callbacks (list[Any], optional): List of callback functions.
@@ -62,8 +61,7 @@ class DetectionValidator(BaseValidator):
62
61
  self.metrics = DetMetrics()
63
62
 
64
63
  def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
65
- """
66
- Preprocess batch of images for YOLO validation.
64
+ """Preprocess batch of images for YOLO validation.
67
65
 
68
66
  Args:
69
67
  batch (dict[str, Any]): Batch containing images and annotations.
@@ -78,8 +76,7 @@ class DetectionValidator(BaseValidator):
78
76
  return batch
79
77
 
80
78
  def init_metrics(self, model: torch.nn.Module) -> None:
81
- """
82
- Initialize evaluation metrics for YOLO detection validation.
79
+ """Initialize evaluation metrics for YOLO detection validation.
83
80
 
84
81
  Args:
85
82
  model (torch.nn.Module): Model to validate.
@@ -106,15 +103,14 @@ class DetectionValidator(BaseValidator):
106
103
  return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
107
104
 
108
105
  def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
109
- """
110
- Apply Non-maximum suppression to prediction outputs.
106
+ """Apply Non-maximum suppression to prediction outputs.
111
107
 
112
108
  Args:
113
109
  preds (torch.Tensor): Raw predictions from the model.
114
110
 
115
111
  Returns:
116
- (list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
117
- 'bboxes', 'conf', 'cls', and 'extra' tensors.
112
+ (list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains 'bboxes', 'conf',
113
+ 'cls', and 'extra' tensors.
118
114
  """
119
115
  outputs = nms.non_max_suppression(
120
116
  preds,
@@ -130,8 +126,7 @@ class DetectionValidator(BaseValidator):
130
126
  return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
131
127
 
132
128
  def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
133
- """
134
- Prepare a batch of images and annotations for validation.
129
+ """Prepare a batch of images and annotations for validation.
135
130
 
136
131
  Args:
137
132
  si (int): Batch index.
@@ -158,8 +153,7 @@ class DetectionValidator(BaseValidator):
158
153
  }
159
154
 
160
155
  def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
161
- """
162
- Prepare predictions for evaluation against ground truth.
156
+ """Prepare predictions for evaluation against ground truth.
163
157
 
164
158
  Args:
165
159
  pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
@@ -172,8 +166,7 @@ class DetectionValidator(BaseValidator):
172
166
  return pred
173
167
 
174
168
  def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
175
- """
176
- Update metrics with new predictions and ground truth.
169
+ """Update metrics with new predictions and ground truth.
177
170
 
178
171
  Args:
179
172
  preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
@@ -226,9 +219,30 @@ class DetectionValidator(BaseValidator):
226
219
  self.metrics.confusion_matrix = self.confusion_matrix
227
220
  self.metrics.save_dir = self.save_dir
228
221
 
222
+ def gather_stats(self) -> None:
223
+ """Gather stats from all GPUs."""
224
+ if RANK == 0:
225
+ gathered_stats = [None] * dist.get_world_size()
226
+ dist.gather_object(self.metrics.stats, gathered_stats, dst=0)
227
+ merged_stats = {key: [] for key in self.metrics.stats.keys()}
228
+ for stats_dict in gathered_stats:
229
+ for key in merged_stats:
230
+ merged_stats[key].extend(stats_dict[key])
231
+ gathered_jdict = [None] * dist.get_world_size()
232
+ dist.gather_object(self.jdict, gathered_jdict, dst=0)
233
+ self.jdict = []
234
+ for jdict in gathered_jdict:
235
+ self.jdict.extend(jdict)
236
+ self.metrics.stats = merged_stats
237
+ self.seen = len(self.dataloader.dataset) # total image count from dataset
238
+ elif RANK > 0:
239
+ dist.gather_object(self.metrics.stats, None, dst=0)
240
+ dist.gather_object(self.jdict, None, dst=0)
241
+ self.jdict = []
242
+ self.metrics.clear_stats()
243
+
229
244
  def get_stats(self) -> dict[str, Any]:
230
- """
231
- Calculate and return metrics statistics.
245
+ """Calculate and return metrics statistics.
232
246
 
233
247
  Returns:
234
248
  (dict[str, Any]): Dictionary containing metrics results.
@@ -242,7 +256,7 @@ class DetectionValidator(BaseValidator):
242
256
  pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
243
257
  LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
244
258
  if self.metrics.nt_per_class.sum() == 0:
245
- LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
259
+ LOGGER.warning(f"no labels found in {self.args.task} set, cannot compute metrics without labels")
246
260
 
247
261
  # Print results per class
248
262
  if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
@@ -258,15 +272,15 @@ class DetectionValidator(BaseValidator):
258
272
  )
259
273
 
260
274
  def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
261
- """
262
- Return correct prediction matrix.
275
+ """Return correct prediction matrix.
263
276
 
264
277
  Args:
265
278
  preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
266
279
  batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
267
280
 
268
281
  Returns:
269
- (dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
282
+ (dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for
283
+ 10 IoU levels.
270
284
  """
271
285
  if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
272
286
  return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
@@ -274,8 +288,7 @@ class DetectionValidator(BaseValidator):
274
288
  return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
275
289
 
276
290
  def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
277
- """
278
- Build YOLO Dataset.
291
+ """Build YOLO Dataset.
279
292
 
280
293
  Args:
281
294
  img_path (str): Path to the folder containing images.
@@ -288,24 +301,28 @@ class DetectionValidator(BaseValidator):
288
301
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
289
302
 
290
303
  def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
291
- """
292
- Construct and return dataloader.
304
+ """Construct and return dataloader.
293
305
 
294
306
  Args:
295
307
  dataset_path (str): Path to the dataset.
296
308
  batch_size (int): Size of each batch.
297
309
 
298
310
  Returns:
299
- (torch.utils.data.DataLoader): Dataloader for validation.
311
+ (torch.utils.data.DataLoader): DataLoader for validation.
300
312
  """
301
313
  dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
302
314
  return build_dataloader(
303
- dataset, batch_size, self.args.workers, shuffle=False, rank=-1, drop_last=self.args.compile
315
+ dataset,
316
+ batch_size,
317
+ self.args.workers,
318
+ shuffle=False,
319
+ rank=-1,
320
+ drop_last=self.args.compile,
321
+ pin_memory=self.training,
304
322
  )
305
323
 
306
324
  def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
307
- """
308
- Plot validation image samples.
325
+ """Plot validation image samples.
309
326
 
310
327
  Args:
311
328
  batch (dict[str, Any]): Batch containing images and annotations.
@@ -322,8 +339,7 @@ class DetectionValidator(BaseValidator):
322
339
  def plot_predictions(
323
340
  self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
324
341
  ) -> None:
325
- """
326
- Plot predicted bounding boxes on input images and save the result.
342
+ """Plot predicted bounding boxes on input images and save the result.
327
343
 
328
344
  Args:
329
345
  batch (dict[str, Any]): Batch containing images and annotations.
@@ -331,14 +347,14 @@ class DetectionValidator(BaseValidator):
331
347
  ni (int): Batch index.
332
348
  max_det (Optional[int]): Maximum number of detections to plot.
333
349
  """
334
- # TODO: optimize this
350
+ if not preds:
351
+ return
335
352
  for i, pred in enumerate(preds):
336
353
  pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions
337
354
  keys = preds[0].keys()
338
355
  max_det = max_det or self.args.max_det
339
356
  batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
340
- # TODO: fix this
341
- batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format
357
+ batched_preds["bboxes"] = ops.xyxy2xywh(batched_preds["bboxes"]) # convert to xywh format
342
358
  plot_images(
343
359
  images=batch["img"],
344
360
  labels=batched_preds,
@@ -349,8 +365,7 @@ class DetectionValidator(BaseValidator):
349
365
  ) # pred
350
366
 
351
367
  def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
352
- """
353
- Save YOLO detections to a txt file in normalized coordinates in a specific format.
368
+ """Save YOLO detections to a txt file in normalized coordinates in a specific format.
354
369
 
355
370
  Args:
356
371
  predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
@@ -368,12 +383,11 @@ class DetectionValidator(BaseValidator):
368
383
  ).save_txt(file, save_conf=save_conf)
369
384
 
370
385
  def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
371
- """
372
- Serialize YOLO predictions to COCO json format.
386
+ """Serialize YOLO predictions to COCO json format.
373
387
 
374
388
  Args:
375
- predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
376
- with bounding box coordinates, confidence scores, and class predictions.
389
+ predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
390
+ bounding box coordinates, confidence scores, and class predictions.
377
391
  pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
378
392
 
379
393
  Examples:
@@ -414,8 +428,7 @@ class DetectionValidator(BaseValidator):
414
428
  }
415
429
 
416
430
  def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
417
- """
418
- Evaluate YOLO output in JSON format and return performance statistics.
431
+ """Evaluate YOLO output in JSON format and return performance statistics.
419
432
 
420
433
  Args:
421
434
  stats (dict[str, Any]): Current statistics dictionary.
@@ -439,21 +452,20 @@ class DetectionValidator(BaseValidator):
439
452
  iou_types: str | list[str] = "bbox",
440
453
  suffix: str | list[str] = "Box",
441
454
  ) -> dict[str, Any]:
442
- """
443
- Evaluate COCO/LVIS metrics using faster-coco-eval library.
455
+ """Evaluate COCO/LVIS metrics using faster-coco-eval library.
444
456
 
445
- Performs evaluation using the faster-coco-eval library to compute mAP metrics
446
- for object detection. Updates the provided stats dictionary with computed metrics
447
- including mAP50, mAP50-95, and LVIS-specific metrics if applicable.
457
+ Performs evaluation using the faster-coco-eval library to compute mAP metrics for object detection. Updates the
458
+ provided stats dictionary with computed metrics including mAP50, mAP50-95, and LVIS-specific metrics if
459
+ applicable.
448
460
 
449
461
  Args:
450
462
  stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
451
- pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
452
- anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
453
- iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
454
- Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
455
- suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
456
- to iou_types if multiple types provided. Defaults to "Box".
463
+ pred_json (str | Path): Path to JSON file containing predictions in COCO format.
464
+ anno_json (str | Path): Path to JSON file containing ground truth annotations in COCO format.
465
+ iou_types (str | list[str]): IoU type(s) for evaluation. Can be single string or list of strings. Common
466
+ values include "bbox", "segm", "keypoints". Defaults to "bbox".
467
+ suffix (str | list[str]): Suffix to append to metric names in stats dictionary. Should correspond to
468
+ iou_types if multiple types provided. Defaults to "Box".
457
469
 
458
470
  Returns:
459
471
  (dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
@@ -482,6 +494,12 @@ class DetectionValidator(BaseValidator):
482
494
  # update mAP50-95 and mAP50
483
495
  stats[f"metrics/mAP50({suffix[i][0]})"] = val.stats_as_dict["AP_50"]
484
496
  stats[f"metrics/mAP50-95({suffix[i][0]})"] = val.stats_as_dict["AP_all"]
497
+ # record mAP for small, medium, large objects as well
498
+ stats["metrics/mAP_small(B)"] = val.stats_as_dict["AP_small"]
499
+ stats["metrics/mAP_medium(B)"] = val.stats_as_dict["AP_medium"]
500
+ stats["metrics/mAP_large(B)"] = val.stats_as_dict["AP_large"]
501
+ # update fitness
502
+ stats["fitness"] = 0.9 * val.stats_as_dict["AP_all"] + 0.1 * val.stats_as_dict["AP_50"]
485
503
 
486
504
  if self.is_lvis:
487
505
  stats[f"metrics/APr({suffix[i][0]})"] = val.stats_as_dict["APr"]