dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,20 @@ from pathlib import Path
6
6
  from typing import Any
7
7
 
8
8
  import torch
9
+ import torch.distributed as dist
9
10
 
10
11
  from ultralytics.data import ClassificationDataset, build_dataloader
11
12
  from ultralytics.engine.validator import BaseValidator
12
- from ultralytics.utils import LOGGER
13
+ from ultralytics.utils import LOGGER, RANK
13
14
  from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
14
15
  from ultralytics.utils.plotting import plot_images
15
16
 
16
17
 
17
18
  class ClassificationValidator(BaseValidator):
18
- """
19
- A class extending the BaseValidator class for validation based on a classification model.
19
+ """A class extending the BaseValidator class for validation based on a classification model.
20
20
 
21
- This validator handles the validation process for classification models, including metrics calculation,
22
- confusion matrix generation, and visualization of results.
21
+ This validator handles the validation process for classification models, including metrics calculation, confusion
22
+ matrix generation, and visualization of results.
23
23
 
24
24
  Attributes:
25
25
  targets (list[torch.Tensor]): Ground truth class labels.
@@ -54,20 +54,13 @@ class ClassificationValidator(BaseValidator):
54
54
  """
55
55
 
56
56
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
57
- """
58
- Initialize ClassificationValidator with dataloader, save directory, and other parameters.
57
+ """Initialize ClassificationValidator with dataloader, save directory, and other parameters.
59
58
 
60
59
  Args:
61
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
60
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
62
61
  save_dir (str | Path, optional): Directory to save results.
63
62
  args (dict, optional): Arguments containing model and validation configuration.
64
63
  _callbacks (list, optional): List of callback functions to be called during validation.
65
-
66
- Examples:
67
- >>> from ultralytics.models.yolo.classify import ClassificationValidator
68
- >>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
69
- >>> validator = ClassificationValidator(args=args)
70
- >>> validator()
71
64
  """
72
65
  super().__init__(dataloader, save_dir, args, _callbacks)
73
66
  self.targets = None
@@ -89,14 +82,13 @@ class ClassificationValidator(BaseValidator):
89
82
 
90
83
  def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
91
84
  """Preprocess input batch by moving data to device and converting to appropriate dtype."""
92
- batch["img"] = batch["img"].to(self.device, non_blocking=True)
85
+ batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
93
86
  batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
94
- batch["cls"] = batch["cls"].to(self.device, non_blocking=True)
87
+ batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
95
88
  return batch
96
89
 
97
90
  def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
98
- """
99
- Update running metrics with model predictions and batch targets.
91
+ """Update running metrics with model predictions and batch targets.
100
92
 
101
93
  Args:
102
94
  preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
@@ -111,12 +103,7 @@ class ClassificationValidator(BaseValidator):
111
103
  self.targets.append(batch["cls"].type(torch.int32).cpu())
112
104
 
113
105
  def finalize_metrics(self) -> None:
114
- """
115
- Finalize metrics including confusion matrix and processing speed.
116
-
117
- Notes:
118
- This method processes the accumulated predictions and targets to generate the confusion matrix,
119
- optionally plots it, and updates the metrics object with speed information.
106
+ """Finalize metrics including confusion matrix and processing speed.
120
107
 
121
108
  Examples:
122
109
  >>> validator = ClassificationValidator()
@@ -124,6 +111,10 @@ class ClassificationValidator(BaseValidator):
124
111
  >>> validator.targets = [torch.tensor([0])] # Ground truth class
125
112
  >>> validator.finalize_metrics()
126
113
  >>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
114
+
115
+ Notes:
116
+ This method processes the accumulated predictions and targets to generate the confusion matrix,
117
+ optionally plots it, and updates the metrics object with speed information.
127
118
  """
128
119
  self.confusion_matrix.process_cls_preds(self.pred, self.targets)
129
120
  if self.args.plots:
@@ -142,13 +133,25 @@ class ClassificationValidator(BaseValidator):
142
133
  self.metrics.process(self.targets, self.pred)
143
134
  return self.metrics.results_dict
144
135
 
136
+ def gather_stats(self) -> None:
137
+ """Gather stats from all GPUs."""
138
+ if RANK == 0:
139
+ gathered_preds = [None] * dist.get_world_size()
140
+ gathered_targets = [None] * dist.get_world_size()
141
+ dist.gather_object(self.pred, gathered_preds, dst=0)
142
+ dist.gather_object(self.targets, gathered_targets, dst=0)
143
+ self.pred = [pred for rank in gathered_preds for pred in rank]
144
+ self.targets = [targets for rank in gathered_targets for targets in rank]
145
+ elif RANK > 0:
146
+ dist.gather_object(self.pred, None, dst=0)
147
+ dist.gather_object(self.targets, None, dst=0)
148
+
145
149
  def build_dataset(self, img_path: str) -> ClassificationDataset:
146
150
  """Create a ClassificationDataset instance for validation."""
147
151
  return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
148
152
 
149
153
  def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
150
- """
151
- Build and return a data loader for classification validation.
154
+ """Build and return a data loader for classification validation.
152
155
 
153
156
  Args:
154
157
  dataset_path (str | Path): Path to the dataset directory.
@@ -166,8 +169,7 @@ class ClassificationValidator(BaseValidator):
166
169
  LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
167
170
 
168
171
  def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
169
- """
170
- Plot validation image samples with their ground truth labels.
172
+ """Plot validation image samples with their ground truth labels.
171
173
 
172
174
  Args:
173
175
  batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
@@ -178,7 +180,7 @@ class ClassificationValidator(BaseValidator):
178
180
  >>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
179
181
  >>> validator.plot_val_samples(batch, 0)
180
182
  """
181
- batch["batch_idx"] = torch.arange(len(batch["img"])) # add batch index for plotting
183
+ batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
182
184
  plot_images(
183
185
  labels=batch,
184
186
  fname=self.save_dir / f"val_batch{ni}_labels.jpg",
@@ -187,8 +189,7 @@ class ClassificationValidator(BaseValidator):
187
189
  )
188
190
 
189
191
  def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
190
- """
191
- Plot images with their predicted class labels and save the visualization.
192
+ """Plot images with their predicted class labels and save the visualization.
192
193
 
193
194
  Args:
194
195
  batch (dict[str, Any]): Batch data containing images and other information.
@@ -203,8 +204,9 @@ class ClassificationValidator(BaseValidator):
203
204
  """
204
205
  batched_preds = dict(
205
206
  img=batch["img"],
206
- batch_idx=torch.arange(len(batch["img"])),
207
+ batch_idx=torch.arange(batch["img"].shape[0]),
207
208
  cls=torch.argmax(preds, dim=1),
209
+ conf=torch.amax(preds, dim=1),
208
210
  )
209
211
  plot_images(
210
212
  batched_preds,
@@ -6,8 +6,7 @@ from ultralytics.utils import nms, ops
6
6
 
7
7
 
8
8
  class DetectionPredictor(BasePredictor):
9
- """
10
- A class extending the BasePredictor class for prediction based on a detection model.
9
+ """A class extending the BasePredictor class for prediction based on a detection model.
11
10
 
12
11
  This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
13
12
  with bounding boxes and class predictions.
@@ -32,8 +31,7 @@ class DetectionPredictor(BasePredictor):
32
31
  """
33
32
 
34
33
  def postprocess(self, preds, img, orig_imgs, **kwargs):
35
- """
36
- Post-process predictions and return a list of Results objects.
34
+ """Post-process predictions and return a list of Results objects.
37
35
 
38
36
  This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
39
37
  further analysis.
@@ -67,7 +65,7 @@ class DetectionPredictor(BasePredictor):
67
65
  )
68
66
 
69
67
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
70
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
68
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
71
69
 
72
70
  if save_feats:
73
71
  obj_feats = self.get_obj_feats(self._feats, preds[1])
@@ -81,7 +79,8 @@ class DetectionPredictor(BasePredictor):
81
79
 
82
80
  return results
83
81
 
84
- def get_obj_feats(self, feat_maps, idxs):
82
+ @staticmethod
83
+ def get_obj_feats(feat_maps, idxs):
85
84
  """Extract object features from the feature maps."""
86
85
  import torch
87
86
 
@@ -89,11 +88,10 @@ class DetectionPredictor(BasePredictor):
89
88
  obj_feats = torch.cat(
90
89
  [x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
91
90
  ) # mean reduce all vectors to same length
92
- return [feats[idx] if len(idx) else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
91
+ return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
93
92
 
94
93
  def construct_results(self, preds, img, orig_imgs):
95
- """
96
- Construct a list of Results objects from model predictions.
94
+ """Construct a list of Results objects from model predictions.
97
95
 
98
96
  Args:
99
97
  preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
@@ -109,8 +107,7 @@ class DetectionPredictor(BasePredictor):
109
107
  ]
110
108
 
111
109
  def construct_result(self, pred, img, orig_img, img_path):
112
- """
113
- Construct a single Results object from one image prediction.
110
+ """Construct a single Results object from one image prediction.
114
111
 
115
112
  Args:
116
113
  pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
@@ -17,16 +17,15 @@ from ultralytics.models import yolo
17
17
  from ultralytics.nn.tasks import DetectionModel
18
18
  from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
19
19
  from ultralytics.utils.patches import override_configs
20
- from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
20
+ from ultralytics.utils.plotting import plot_images, plot_labels
21
21
  from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
22
22
 
23
23
 
24
24
  class DetectionTrainer(BaseTrainer):
25
- """
26
- A class extending the BaseTrainer class for training based on a detection model.
25
+ """A class extending the BaseTrainer class for training based on a detection model.
27
26
 
28
- This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
29
- for object detection including dataset building, data loading, preprocessing, and model configuration.
27
+ This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
28
+ object detection including dataset building, data loading, preprocessing, and model configuration.
30
29
 
31
30
  Attributes:
32
31
  model (DetectionModel): The YOLO detection model being trained.
@@ -43,7 +42,6 @@ class DetectionTrainer(BaseTrainer):
43
42
  label_loss_items: Return a loss dictionary with labeled training loss items.
44
43
  progress_string: Return a formatted string of training progress.
45
44
  plot_training_samples: Plot training samples with their annotations.
46
- plot_metrics: Plot metrics from a CSV file.
47
45
  plot_training_labels: Create a labeled training plot of the YOLO model.
48
46
  auto_batch: Calculate optimal batch size based on model memory requirements.
49
47
 
@@ -55,8 +53,7 @@ class DetectionTrainer(BaseTrainer):
55
53
  """
56
54
 
57
55
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
58
- """
59
- Initialize a DetectionTrainer object for training YOLO object detection model training.
56
+ """Initialize a DetectionTrainer object for training YOLO object detection models.
60
57
 
61
58
  Args:
62
59
  cfg (dict, optional): Default configuration dictionary containing training parameters.
@@ -64,11 +61,9 @@ class DetectionTrainer(BaseTrainer):
64
61
  _callbacks (list, optional): List of callback functions to be executed during training.
65
62
  """
66
63
  super().__init__(cfg, overrides, _callbacks)
67
- self.dynamic_tensors = ["batch_idx", "cls", "bboxes"]
68
64
 
69
65
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
70
- """
71
- Build YOLO Dataset for training or validation.
66
+ """Build YOLO Dataset for training or validation.
72
67
 
73
68
  Args:
74
69
  img_path (str): Path to the folder containing images.
@@ -82,8 +77,7 @@ class DetectionTrainer(BaseTrainer):
82
77
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
83
78
 
84
79
  def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
85
- """
86
- Construct and return dataloader for the specified mode.
80
+ """Construct and return dataloader for the specified mode.
87
81
 
88
82
  Args:
89
83
  dataset_path (str): Path to the dataset.
@@ -111,8 +105,7 @@ class DetectionTrainer(BaseTrainer):
111
105
  )
112
106
 
113
107
  def preprocess_batch(self, batch: dict) -> dict:
114
- """
115
- Preprocess a batch of images by scaling and converting to float.
108
+ """Preprocess a batch of images by scaling and converting to float.
116
109
 
117
110
  Args:
118
111
  batch (dict): Dictionary containing batch data with 'img' tensor.
@@ -122,7 +115,7 @@ class DetectionTrainer(BaseTrainer):
122
115
  """
123
116
  for k, v in batch.items():
124
117
  if isinstance(v, torch.Tensor):
125
- batch[k] = v.to(self.device, non_blocking=True)
118
+ batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
126
119
  batch["img"] = batch["img"].float() / 255
127
120
  if self.args.multi_scale:
128
121
  imgs = batch["img"]
@@ -138,10 +131,6 @@ class DetectionTrainer(BaseTrainer):
138
131
  ] # new shape (stretched to gs-multiple)
139
132
  imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
140
133
  batch["img"] = imgs
141
-
142
- if self.args.compile:
143
- for k in self.dynamic_tensors:
144
- torch._dynamo.maybe_mark_dynamic(batch[k], 0)
145
134
  return batch
146
135
 
147
136
  def set_model_attributes(self):
@@ -156,8 +145,7 @@ class DetectionTrainer(BaseTrainer):
156
145
  # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
157
146
 
158
147
  def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
159
- """
160
- Return a YOLO detection model.
148
+ """Return a YOLO detection model.
161
149
 
162
150
  Args:
163
151
  cfg (str, optional): Path to model configuration file.
@@ -180,8 +168,7 @@ class DetectionTrainer(BaseTrainer):
180
168
  )
181
169
 
182
170
  def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
183
- """
184
- Return a loss dict with labeled training loss items tensor.
171
+ """Return a loss dict with labeled training loss items tensor.
185
172
 
186
173
  Args:
187
174
  loss_items (list[float], optional): List of loss values.
@@ -208,8 +195,7 @@ class DetectionTrainer(BaseTrainer):
208
195
  )
209
196
 
210
197
  def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
211
- """
212
- Plot training samples with their annotations.
198
+ """Plot training samples with their annotations.
213
199
 
214
200
  Args:
215
201
  batch (dict[str, Any]): Dictionary containing batch data.
@@ -222,10 +208,6 @@ class DetectionTrainer(BaseTrainer):
222
208
  on_plot=self.on_plot,
223
209
  )
224
210
 
225
- def plot_metrics(self):
226
- """Plot metrics from a CSV file."""
227
- plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
228
-
229
211
  def plot_training_labels(self):
230
212
  """Create a labeled training plot of the YOLO model."""
231
213
  boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
@@ -233,8 +215,7 @@ class DetectionTrainer(BaseTrainer):
233
215
  plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
234
216
 
235
217
  def auto_batch(self):
236
- """
237
- Get optimal batch size by calculating memory occupation of model.
218
+ """Get optimal batch size by calculating memory occupation of model.
238
219
 
239
220
  Returns:
240
221
  (int): Optimal batch size.