dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,9 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from copy import copy
6
+ from typing import Any
4
7
 
5
8
  import torch
6
9
 
@@ -8,22 +11,21 @@ from ultralytics.data import ClassificationDataset, build_dataloader
8
11
  from ultralytics.engine.trainer import BaseTrainer
9
12
  from ultralytics.models import yolo
10
13
  from ultralytics.nn.tasks import ClassificationModel
11
- from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
12
- from ultralytics.utils.plotting import plot_images, plot_results
13
- from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
14
+ from ultralytics.utils import DEFAULT_CFG, RANK
15
+ from ultralytics.utils.plotting import plot_images
16
+ from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_first
14
17
 
15
18
 
16
19
  class ClassificationTrainer(BaseTrainer):
17
- """
18
- A class extending the BaseTrainer class for training based on a classification model.
20
+ """A trainer class extending BaseTrainer for training image classification models.
19
21
 
20
22
  This trainer handles the training process for image classification tasks, supporting both YOLO classification models
21
- and torchvision models.
23
+ and torchvision models with comprehensive dataset handling and validation.
22
24
 
23
25
  Attributes:
24
26
  model (ClassificationModel): The classification model to be trained.
25
- data (dict): Dictionary containing dataset information including class names and number of classes.
26
- loss_names (List[str]): Names of the loss functions used during training.
27
+ data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
28
+ loss_names (list[str]): Names of the loss functions used during training.
27
29
  validator (ClassificationValidator): Validator instance for model evaluation.
28
30
 
29
31
  Methods:
@@ -35,35 +37,25 @@ class ClassificationTrainer(BaseTrainer):
35
37
  preprocess_batch: Preprocess a batch of images and classes.
36
38
  progress_string: Return a formatted string showing training progress.
37
39
  get_validator: Return an instance of ClassificationValidator.
38
- label_loss_items: Return a loss dict with labelled training loss items.
39
- plot_metrics: Plot metrics from a CSV file.
40
+ label_loss_items: Return a loss dict with labeled training loss items.
40
41
  final_eval: Evaluate trained model and save validation results.
41
42
  plot_training_samples: Plot training samples with their annotations.
42
43
 
43
44
  Examples:
45
+ Initialize and train a classification model
44
46
  >>> from ultralytics.models.yolo.classify import ClassificationTrainer
45
47
  >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
46
48
  >>> trainer = ClassificationTrainer(overrides=args)
47
49
  >>> trainer.train()
48
50
  """
49
51
 
50
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
51
- """
52
- Initialize a ClassificationTrainer object.
53
-
54
- This constructor sets up a trainer for image classification tasks, configuring the task type and default
55
- image size if not specified.
52
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
53
+ """Initialize a ClassificationTrainer object.
56
54
 
57
55
  Args:
58
- cfg (dict, optional): Default configuration dictionary containing training parameters.
59
- overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
60
- _callbacks (list, optional): List of callback functions to be executed during training.
61
-
62
- Examples:
63
- >>> from ultralytics.models.yolo.classify import ClassificationTrainer
64
- >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
65
- >>> trainer = ClassificationTrainer(overrides=args)
66
- >>> trainer.train()
56
+ cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
57
+ overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
58
+ _callbacks (list[Any], optional): List of callback functions to be executed during training.
67
59
  """
68
60
  if overrides is None:
69
61
  overrides = {}
@@ -76,14 +68,13 @@ class ClassificationTrainer(BaseTrainer):
76
68
  """Set the YOLO model's class names from the loaded dataset."""
77
69
  self.model.names = self.data["names"]
78
70
 
79
- def get_model(self, cfg=None, weights=None, verbose=True):
80
- """
81
- Return a modified PyTorch model configured for training YOLO.
71
+ def get_model(self, cfg=None, weights=None, verbose: bool = True):
72
+ """Return a modified PyTorch model configured for training YOLO classification.
82
73
 
83
74
  Args:
84
- cfg (Any): Model configuration.
85
- weights (Any): Pre-trained model weights.
86
- verbose (bool): Whether to display model information.
75
+ cfg (Any, optional): Model configuration.
76
+ weights (Any, optional): Pre-trained model weights.
77
+ verbose (bool, optional): Whether to display model information.
87
78
 
88
79
  Returns:
89
80
  (ClassificationModel): Configured PyTorch model for classification.
@@ -102,8 +93,7 @@ class ClassificationTrainer(BaseTrainer):
102
93
  return model
103
94
 
104
95
  def setup_model(self):
105
- """
106
- Load, create or download model for classification tasks.
96
+ """Load, create or download model for classification tasks.
107
97
 
108
98
  Returns:
109
99
  (Any): Model checkpoint if applicable, otherwise None.
@@ -120,29 +110,27 @@ class ClassificationTrainer(BaseTrainer):
120
110
  ClassificationModel.reshape_outputs(self.model, self.data["nc"])
121
111
  return ckpt
122
112
 
123
- def build_dataset(self, img_path, mode="train", batch=None):
124
- """
125
- Create a ClassificationDataset instance given an image path and mode.
113
+ def build_dataset(self, img_path: str, mode: str = "train", batch=None):
114
+ """Create a ClassificationDataset instance given an image path and mode.
126
115
 
127
116
  Args:
128
117
  img_path (str): Path to the dataset images.
129
- mode (str): Dataset mode ('train', 'val', or 'test').
130
- batch (Any): Batch information (unused in this implementation).
118
+ mode (str, optional): Dataset mode ('train', 'val', or 'test').
119
+ batch (Any, optional): Batch information (unused in this implementation).
131
120
 
132
121
  Returns:
133
122
  (ClassificationDataset): Dataset for the specified mode.
134
123
  """
135
124
  return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
136
125
 
137
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
138
- """
139
- Return PyTorch DataLoader with transforms to preprocess images.
126
+ def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
127
+ """Return PyTorch DataLoader with transforms to preprocess images.
140
128
 
141
129
  Args:
142
130
  dataset_path (str): Path to the dataset.
143
- batch_size (int): Number of images per batch.
144
- rank (int): Process rank for distributed training.
145
- mode (str): 'train', 'val', or 'test' mode.
131
+ batch_size (int, optional): Number of images per batch.
132
+ rank (int, optional): Process rank for distributed training.
133
+ mode (str, optional): 'train', 'val', or 'test' mode.
146
134
 
147
135
  Returns:
148
136
  (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
@@ -150,7 +138,7 @@ class ClassificationTrainer(BaseTrainer):
150
138
  with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
151
139
  dataset = self.build_dataset(dataset_path, mode)
152
140
 
153
- loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
141
+ loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
154
142
  # Attach inference transforms
155
143
  if mode != "train":
156
144
  if is_parallel(self.model):
@@ -159,14 +147,14 @@ class ClassificationTrainer(BaseTrainer):
159
147
  self.model.transforms = loader.dataset.torch_transforms
160
148
  return loader
161
149
 
162
- def preprocess_batch(self, batch):
163
- """Preprocesses a batch of images and classes."""
164
- batch["img"] = batch["img"].to(self.device)
165
- batch["cls"] = batch["cls"].to(self.device)
150
+ def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
151
+ """Preprocess a batch of images and classes."""
152
+ batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
153
+ batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
166
154
  return batch
167
155
 
168
- def progress_string(self):
169
- """Returns a formatted string showing training progress."""
156
+ def progress_string(self) -> str:
157
+ """Return a formatted string showing training progress."""
170
158
  return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
171
159
  "Epoch",
172
160
  "GPU_mem",
@@ -176,22 +164,22 @@ class ClassificationTrainer(BaseTrainer):
176
164
  )
177
165
 
178
166
  def get_validator(self):
179
- """Returns an instance of ClassificationValidator for validation."""
167
+ """Return an instance of ClassificationValidator for validation."""
180
168
  self.loss_names = ["loss"]
181
169
  return yolo.classify.ClassificationValidator(
182
170
  self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
183
171
  )
184
172
 
185
- def label_loss_items(self, loss_items=None, prefix="train"):
186
- """
187
- Return a loss dict with labelled training loss items tensor.
173
+ def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
174
+ """Return a loss dict with labeled training loss items tensor.
188
175
 
189
176
  Args:
190
177
  loss_items (torch.Tensor, optional): Loss tensor items.
191
- prefix (str): Prefix to prepend to loss names.
178
+ prefix (str, optional): Prefix to prepend to loss names.
192
179
 
193
180
  Returns:
194
- (Dict[str, float] | List[str]): Dictionary of loss items or list of loss keys if loss_items is None.
181
+ keys (list[str]): List of loss keys if loss_items is None.
182
+ loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
195
183
  """
196
184
  keys = [f"{prefix}/{x}" for x in self.loss_names]
197
185
  if loss_items is None:
@@ -199,35 +187,16 @@ class ClassificationTrainer(BaseTrainer):
199
187
  loss_items = [round(float(loss_items), 5)]
200
188
  return dict(zip(keys, loss_items))
201
189
 
202
- def plot_metrics(self):
203
- """Plot metrics from a CSV file."""
204
- plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
205
-
206
- def final_eval(self):
207
- """Evaluate trained model and save validation results."""
208
- for f in self.last, self.best:
209
- if f.exists():
210
- strip_optimizer(f) # strip optimizers
211
- if f is self.best:
212
- LOGGER.info(f"\nValidating {f}...")
213
- self.validator.args.data = self.args.data
214
- self.validator.args.plots = self.args.plots
215
- self.metrics = self.validator(model=f)
216
- self.metrics.pop("fitness", None)
217
- self.run_callbacks("on_fit_epoch_end")
218
-
219
- def plot_training_samples(self, batch, ni):
220
- """
221
- Plot training samples with their annotations.
190
+ def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
191
+ """Plot training samples with their annotations.
222
192
 
223
193
  Args:
224
- batch (Dict[str, torch.Tensor]): Batch containing images and class labels.
194
+ batch (dict[str, torch.Tensor]): Batch containing images and class labels.
225
195
  ni (int): Number of iterations.
226
196
  """
197
+ batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
227
198
  plot_images(
228
- images=batch["img"],
229
- batch_idx=torch.arange(len(batch["img"])),
230
- cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
199
+ labels=batch,
231
200
  fname=self.save_dir / f"train_batch{ni}.jpg",
232
201
  on_plot=self.on_plot,
233
202
  )
@@ -1,24 +1,29 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
3
8
  import torch
9
+ import torch.distributed as dist
4
10
 
5
11
  from ultralytics.data import ClassificationDataset, build_dataloader
6
12
  from ultralytics.engine.validator import BaseValidator
7
- from ultralytics.utils import LOGGER
13
+ from ultralytics.utils import LOGGER, RANK
8
14
  from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
9
15
  from ultralytics.utils.plotting import plot_images
10
16
 
11
17
 
12
18
  class ClassificationValidator(BaseValidator):
13
- """
14
- A class extending the BaseValidator class for validation based on a classification model.
19
+ """A class extending the BaseValidator class for validation based on a classification model.
15
20
 
16
- This validator handles the validation process for classification models, including metrics calculation,
17
- confusion matrix generation, and visualization of results.
21
+ This validator handles the validation process for classification models, including metrics calculation, confusion
22
+ matrix generation, and visualization of results.
18
23
 
19
24
  Attributes:
20
- targets (List[torch.Tensor]): Ground truth class labels.
21
- pred (List[torch.Tensor]): Model predictions.
25
+ targets (list[torch.Tensor]): Ground truth class labels.
26
+ pred (list[torch.Tensor]): Model predictions.
22
27
  metrics (ClassifyMetrics): Object to calculate and store classification metrics.
23
28
  names (dict): Mapping of class indices to class names.
24
29
  nc (int): Number of classes.
@@ -48,17 +53,12 @@ class ClassificationValidator(BaseValidator):
48
53
  Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
49
54
  """
50
55
 
51
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
52
- """
53
- Initialize ClassificationValidator with dataloader, save directory, and other parameters.
54
-
55
- This validator handles the validation process for classification models, including metrics calculation,
56
- confusion matrix generation, and visualization of results.
56
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
57
+ """Initialize ClassificationValidator with dataloader, save directory, and other parameters.
57
58
 
58
59
  Args:
59
60
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
60
61
  save_dir (str | Path, optional): Directory to save results.
61
- pbar (bool, optional): Display a progress bar.
62
62
  args (dict, optional): Arguments containing model and validation configuration.
63
63
  _callbacks (list, optional): List of callback functions to be called during validation.
64
64
 
@@ -68,56 +68,48 @@ class ClassificationValidator(BaseValidator):
68
68
  >>> validator = ClassificationValidator(args=args)
69
69
  >>> validator()
70
70
  """
71
- super().__init__(dataloader, save_dir, pbar, args, _callbacks)
71
+ super().__init__(dataloader, save_dir, args, _callbacks)
72
72
  self.targets = None
73
73
  self.pred = None
74
74
  self.args.task = "classify"
75
75
  self.metrics = ClassifyMetrics()
76
76
 
77
- def get_desc(self):
77
+ def get_desc(self) -> str:
78
78
  """Return a formatted string summarizing classification metrics."""
79
79
  return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
80
80
 
81
- def init_metrics(self, model):
81
+ def init_metrics(self, model: torch.nn.Module) -> None:
82
82
  """Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
83
83
  self.names = model.names
84
84
  self.nc = len(model.names)
85
- self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
86
85
  self.pred = []
87
86
  self.targets = []
87
+ self.confusion_matrix = ConfusionMatrix(names=model.names)
88
88
 
89
- def preprocess(self, batch):
89
+ def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
90
90
  """Preprocess input batch by moving data to device and converting to appropriate dtype."""
91
- batch["img"] = batch["img"].to(self.device, non_blocking=True)
91
+ batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
92
92
  batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
93
- batch["cls"] = batch["cls"].to(self.device)
93
+ batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
94
94
  return batch
95
95
 
96
- def update_metrics(self, preds, batch):
97
- """
98
- Update running metrics with model predictions and batch targets.
96
+ def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
97
+ """Update running metrics with model predictions and batch targets.
99
98
 
100
99
  Args:
101
100
  preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
102
101
  batch (dict): Batch data containing images and class labels.
103
102
 
104
- This method appends the top-N predictions (sorted by confidence in descending order) to the
105
- prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
103
+ Notes:
104
+ This method appends the top-N predictions (sorted by confidence in descending order) to the
105
+ prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
106
106
  """
107
107
  n5 = min(len(self.names), 5)
108
108
  self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
109
109
  self.targets.append(batch["cls"].type(torch.int32).cpu())
110
110
 
111
- def finalize_metrics(self, *args, **kwargs):
112
- """
113
- Finalize metrics including confusion matrix and processing speed.
114
-
115
- This method processes the accumulated predictions and targets to generate the confusion matrix,
116
- optionally plots it, and updates the metrics object with speed information.
117
-
118
- Args:
119
- *args (Any): Variable length argument list.
120
- **kwargs (Any): Arbitrary keyword arguments.
111
+ def finalize_metrics(self) -> None:
112
+ """Finalize metrics including confusion matrix and processing speed.
121
113
 
122
114
  Examples:
123
115
  >>> validator = ClassificationValidator()
@@ -125,33 +117,47 @@ class ClassificationValidator(BaseValidator):
125
117
  >>> validator.targets = [torch.tensor([0])] # Ground truth class
126
118
  >>> validator.finalize_metrics()
127
119
  >>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
120
+
121
+ Notes:
122
+ This method processes the accumulated predictions and targets to generate the confusion matrix,
123
+ optionally plots it, and updates the metrics object with speed information.
128
124
  """
129
125
  self.confusion_matrix.process_cls_preds(self.pred, self.targets)
130
126
  if self.args.plots:
131
127
  for normalize in True, False:
132
- self.confusion_matrix.plot(
133
- save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
134
- )
128
+ self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
135
129
  self.metrics.speed = self.speed
136
- self.metrics.confusion_matrix = self.confusion_matrix
137
130
  self.metrics.save_dir = self.save_dir
131
+ self.metrics.confusion_matrix = self.confusion_matrix
138
132
 
139
- def postprocess(self, preds):
133
+ def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
140
134
  """Extract the primary prediction from model output if it's in a list or tuple format."""
141
135
  return preds[0] if isinstance(preds, (list, tuple)) else preds
142
136
 
143
- def get_stats(self):
137
+ def get_stats(self) -> dict[str, float]:
144
138
  """Calculate and return a dictionary of metrics by processing targets and predictions."""
145
139
  self.metrics.process(self.targets, self.pred)
146
140
  return self.metrics.results_dict
147
141
 
148
- def build_dataset(self, img_path):
142
+ def gather_stats(self) -> None:
143
+ """Gather stats from all GPUs."""
144
+ if RANK == 0:
145
+ gathered_preds = [None] * dist.get_world_size()
146
+ gathered_targets = [None] * dist.get_world_size()
147
+ dist.gather_object(self.pred, gathered_preds, dst=0)
148
+ dist.gather_object(self.targets, gathered_targets, dst=0)
149
+ self.pred = [pred for rank in gathered_preds for pred in rank]
150
+ self.targets = [targets for rank in gathered_targets for targets in rank]
151
+ elif RANK > 0:
152
+ dist.gather_object(self.pred, None, dst=0)
153
+ dist.gather_object(self.targets, None, dst=0)
154
+
155
+ def build_dataset(self, img_path: str) -> ClassificationDataset:
149
156
  """Create a ClassificationDataset instance for validation."""
150
157
  return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
151
158
 
152
- def get_dataloader(self, dataset_path, batch_size):
153
- """
154
- Build and return a data loader for classification validation.
159
+ def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
160
+ """Build and return a data loader for classification validation.
155
161
 
156
162
  Args:
157
163
  dataset_path (str | Path): Path to the dataset directory.
@@ -163,17 +169,16 @@ class ClassificationValidator(BaseValidator):
163
169
  dataset = self.build_dataset(dataset_path)
164
170
  return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
165
171
 
166
- def print_results(self):
172
+ def print_results(self) -> None:
167
173
  """Print evaluation metrics for the classification model."""
168
174
  pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
169
175
  LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
170
176
 
171
- def plot_val_samples(self, batch, ni):
172
- """
173
- Plot validation image samples with their ground truth labels.
177
+ def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
178
+ """Plot validation image samples with their ground truth labels.
174
179
 
175
180
  Args:
176
- batch (dict): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
181
+ batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
177
182
  ni (int): Batch index used for naming the output file.
178
183
 
179
184
  Examples:
@@ -181,21 +186,19 @@ class ClassificationValidator(BaseValidator):
181
186
  >>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
182
187
  >>> validator.plot_val_samples(batch, 0)
183
188
  """
189
+ batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
184
190
  plot_images(
185
- images=batch["img"],
186
- batch_idx=torch.arange(len(batch["img"])),
187
- cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
191
+ labels=batch,
188
192
  fname=self.save_dir / f"val_batch{ni}_labels.jpg",
189
193
  names=self.names,
190
194
  on_plot=self.on_plot,
191
195
  )
192
196
 
193
- def plot_predictions(self, batch, preds, ni):
194
- """
195
- Plot images with their predicted class labels and save the visualization.
197
+ def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
198
+ """Plot images with their predicted class labels and save the visualization.
196
199
 
197
200
  Args:
198
- batch (dict): Batch data containing images and other information.
201
+ batch (dict[str, Any]): Batch data containing images and other information.
199
202
  preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
200
203
  ni (int): Batch index used for naming the output file.
201
204
 
@@ -205,10 +208,14 @@ class ClassificationValidator(BaseValidator):
205
208
  >>> preds = torch.rand(16, 10) # 16 images, 10 classes
206
209
  >>> validator.plot_predictions(batch, preds, 0)
207
210
  """
208
- plot_images(
209
- batch["img"],
210
- batch_idx=torch.arange(len(batch["img"])),
211
+ batched_preds = dict(
212
+ img=batch["img"],
213
+ batch_idx=torch.arange(batch["img"].shape[0]),
211
214
  cls=torch.argmax(preds, dim=1),
215
+ conf=torch.amax(preds, dim=1),
216
+ )
217
+ plot_images(
218
+ batched_preds,
212
219
  fname=self.save_dir / f"val_batch{ni}_pred.jpg",
213
220
  names=self.names,
214
221
  on_plot=self.on_plot,
@@ -2,12 +2,11 @@
2
2
 
3
3
  from ultralytics.engine.predictor import BasePredictor
4
4
  from ultralytics.engine.results import Results
5
- from ultralytics.utils import ops
5
+ from ultralytics.utils import nms, ops
6
6
 
7
7
 
8
8
  class DetectionPredictor(BasePredictor):
9
- """
10
- A class extending the BasePredictor class for prediction based on a detection model.
9
+ """A class extending the BasePredictor class for prediction based on a detection model.
11
10
 
12
11
  This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
13
12
  with bounding boxes and class predictions.
@@ -21,6 +20,7 @@ class DetectionPredictor(BasePredictor):
21
20
  postprocess: Process raw model predictions into detection results.
22
21
  construct_results: Build Results objects from processed predictions.
23
22
  construct_result: Create a single Result object from a prediction.
23
+ get_obj_feats: Extract object features from the feature maps.
24
24
 
25
25
  Examples:
26
26
  >>> from ultralytics.utils import ASSETS
@@ -31,8 +31,7 @@ class DetectionPredictor(BasePredictor):
31
31
  """
32
32
 
33
33
  def postprocess(self, preds, img, orig_imgs, **kwargs):
34
- """
35
- Post-process predictions and return a list of Results objects.
34
+ """Post-process predictions and return a list of Results objects.
36
35
 
37
36
  This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
38
37
  further analysis.
@@ -52,7 +51,7 @@ class DetectionPredictor(BasePredictor):
52
51
  >>> processed_results = predictor.postprocess(preds, img, orig_imgs)
53
52
  """
54
53
  save_feats = getattr(self, "_feats", None) is not None
55
- preds = ops.non_max_suppression(
54
+ preds = nms.non_max_suppression(
56
55
  preds,
57
56
  self.args.conf,
58
57
  self.args.iou,
@@ -84,23 +83,22 @@ class DetectionPredictor(BasePredictor):
84
83
  """Extract object features from the feature maps."""
85
84
  import torch
86
85
 
87
- s = min([x.shape[1] for x in feat_maps]) # find smallest vector length
86
+ s = min(x.shape[1] for x in feat_maps) # find shortest vector length
88
87
  obj_feats = torch.cat(
89
88
  [x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
90
89
  ) # mean reduce all vectors to same length
91
- return [feats[idx] if len(idx) else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
90
+ return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
92
91
 
93
92
  def construct_results(self, preds, img, orig_imgs):
94
- """
95
- Construct a list of Results objects from model predictions.
93
+ """Construct a list of Results objects from model predictions.
96
94
 
97
95
  Args:
98
- preds (List[torch.Tensor]): List of predicted bounding boxes and scores for each image.
96
+ preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
99
97
  img (torch.Tensor): Batch of preprocessed images used for inference.
100
- orig_imgs (List[np.ndarray]): List of original images before preprocessing.
98
+ orig_imgs (list[np.ndarray]): List of original images before preprocessing.
101
99
 
102
100
  Returns:
103
- (List[Results]): List of Results objects containing detection information for each image.
101
+ (list[Results]): List of Results objects containing detection information for each image.
104
102
  """
105
103
  return [
106
104
  self.construct_result(pred, img, orig_img, img_path)
@@ -108,8 +106,7 @@ class DetectionPredictor(BasePredictor):
108
106
  ]
109
107
 
110
108
  def construct_result(self, pred, img, orig_img, img_path):
111
- """
112
- Construct a single Results object from one image prediction.
109
+ """Construct a single Results object from one image prediction.
113
110
 
114
111
  Args:
115
112
  pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.