ultralytics 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +11 -11
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +15 -7
  96. ultralytics/solutions/object_cropper.py +3 -2
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +184 -75
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.143.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -44,7 +44,7 @@ class ClassificationPredictor(BasePredictor):
44
44
  tasks. It ensures the task is set to 'classify' regardless of input configuration.
45
45
 
46
46
  Args:
47
- cfg (dict): Default configuration dictionary containing prediction settings. Defaults to DEFAULT_CFG.
47
+ cfg (dict): Default configuration dictionary containing prediction settings.
48
48
  overrides (dict, optional): Configuration overrides that take precedence over cfg.
49
49
  _callbacks (list, optional): List of callback functions to be executed during prediction.
50
50
  """
@@ -53,7 +53,7 @@ class ClassificationPredictor(BasePredictor):
53
53
  self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
54
54
 
55
55
  def setup_source(self, source):
56
- """Sets up source and inference mode and classify transforms."""
56
+ """Set up source and inference mode and classify transforms."""
57
57
  super().setup_source(source)
58
58
  updated = (
59
59
  self.model.model.transforms.transforms[0].size != max(self.imgsz)
@@ -68,14 +68,14 @@ class ClassificationPredictor(BasePredictor):
68
68
  is_legacy_transform = any(
69
69
  self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
70
70
  )
71
- if is_legacy_transform: # to handle legacy transforms
71
+ if is_legacy_transform: # Handle legacy transforms
72
72
  img = torch.stack([self.transforms(im) for im in img], dim=0)
73
73
  else:
74
74
  img = torch.stack(
75
75
  [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
76
76
  )
77
77
  img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
78
- return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
78
+ return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
79
79
 
80
80
  def postprocess(self, preds, img, orig_imgs):
81
81
  """
@@ -89,7 +89,7 @@ class ClassificationPredictor(BasePredictor):
89
89
  Returns:
90
90
  (List[Results]): List of Results objects containing classification results for each image.
91
91
  """
92
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
92
+ if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
93
93
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
94
94
 
95
95
  preds = preds[0] if isinstance(preds, (list, tuple)) else preds
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
+ from typing import Any, Dict, Optional
4
5
 
5
6
  import torch
6
7
 
@@ -15,14 +16,14 @@ from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_di
15
16
 
16
17
  class ClassificationTrainer(BaseTrainer):
17
18
  """
18
- A class extending the BaseTrainer class for training based on a classification model.
19
+ A trainer class extending BaseTrainer for training image classification models.
19
20
 
20
21
  This trainer handles the training process for image classification tasks, supporting both YOLO classification models
21
- and torchvision models.
22
+ and torchvision models with comprehensive dataset handling and validation.
22
23
 
23
24
  Attributes:
24
25
  model (ClassificationModel): The classification model to be trained.
25
- data (dict): Dictionary containing dataset information including class names and number of classes.
26
+ data (Dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
26
27
  loss_names (List[str]): Names of the loss functions used during training.
27
28
  validator (ClassificationValidator): Validator instance for model evaluation.
28
29
 
@@ -41,13 +42,14 @@ class ClassificationTrainer(BaseTrainer):
41
42
  plot_training_samples: Plot training samples with their annotations.
42
43
 
43
44
  Examples:
45
+ Initialize and train a classification model
44
46
  >>> from ultralytics.models.yolo.classify import ClassificationTrainer
45
47
  >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
46
48
  >>> trainer = ClassificationTrainer(overrides=args)
47
49
  >>> trainer.train()
48
50
  """
49
51
 
50
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
52
+ def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):
51
53
  """
52
54
  Initialize a ClassificationTrainer object.
53
55
 
@@ -55,11 +57,12 @@ class ClassificationTrainer(BaseTrainer):
55
57
  image size if not specified.
56
58
 
57
59
  Args:
58
- cfg (dict, optional): Default configuration dictionary containing training parameters.
59
- overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
60
- _callbacks (list, optional): List of callback functions to be executed during training.
60
+ cfg (Dict[str, Any], optional): Default configuration dictionary containing training parameters.
61
+ overrides (Dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
62
+ _callbacks (List[Any], optional): List of callback functions to be executed during training.
61
63
 
62
64
  Examples:
65
+ Create a trainer with custom configuration
63
66
  >>> from ultralytics.models.yolo.classify import ClassificationTrainer
64
67
  >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
65
68
  >>> trainer = ClassificationTrainer(overrides=args)
@@ -76,14 +79,14 @@ class ClassificationTrainer(BaseTrainer):
76
79
  """Set the YOLO model's class names from the loaded dataset."""
77
80
  self.model.names = self.data["names"]
78
81
 
79
- def get_model(self, cfg=None, weights=None, verbose=True):
82
+ def get_model(self, cfg=None, weights=None, verbose: bool = True):
80
83
  """
81
- Return a modified PyTorch model configured for training YOLO.
84
+ Return a modified PyTorch model configured for training YOLO classification.
82
85
 
83
86
  Args:
84
- cfg (Any): Model configuration.
85
- weights (Any): Pre-trained model weights.
86
- verbose (bool): Whether to display model information.
87
+ cfg (Any, optional): Model configuration.
88
+ weights (Any, optional): Pre-trained model weights.
89
+ verbose (bool, optional): Whether to display model information.
87
90
 
88
91
  Returns:
89
92
  (ClassificationModel): Configured PyTorch model for classification.
@@ -120,29 +123,29 @@ class ClassificationTrainer(BaseTrainer):
120
123
  ClassificationModel.reshape_outputs(self.model, self.data["nc"])
121
124
  return ckpt
122
125
 
123
- def build_dataset(self, img_path, mode="train", batch=None):
126
+ def build_dataset(self, img_path: str, mode: str = "train", batch=None):
124
127
  """
125
128
  Create a ClassificationDataset instance given an image path and mode.
126
129
 
127
130
  Args:
128
131
  img_path (str): Path to the dataset images.
129
- mode (str): Dataset mode ('train', 'val', or 'test').
130
- batch (Any): Batch information (unused in this implementation).
132
+ mode (str, optional): Dataset mode ('train', 'val', or 'test').
133
+ batch (Any, optional): Batch information (unused in this implementation).
131
134
 
132
135
  Returns:
133
136
  (ClassificationDataset): Dataset for the specified mode.
134
137
  """
135
138
  return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
136
139
 
137
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
140
+ def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
138
141
  """
139
142
  Return PyTorch DataLoader with transforms to preprocess images.
140
143
 
141
144
  Args:
142
145
  dataset_path (str): Path to the dataset.
143
- batch_size (int): Number of images per batch.
144
- rank (int): Process rank for distributed training.
145
- mode (str): 'train', 'val', or 'test' mode.
146
+ batch_size (int, optional): Number of images per batch.
147
+ rank (int, optional): Process rank for distributed training.
148
+ mode (str, optional): 'train', 'val', or 'test' mode.
146
149
 
147
150
  Returns:
148
151
  (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
@@ -159,14 +162,14 @@ class ClassificationTrainer(BaseTrainer):
159
162
  self.model.transforms = loader.dataset.torch_transforms
160
163
  return loader
161
164
 
162
- def preprocess_batch(self, batch):
163
- """Preprocesses a batch of images and classes."""
165
+ def preprocess_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
166
+ """Preprocess a batch of images and classes."""
164
167
  batch["img"] = batch["img"].to(self.device)
165
168
  batch["cls"] = batch["cls"].to(self.device)
166
169
  return batch
167
170
 
168
- def progress_string(self):
169
- """Returns a formatted string showing training progress."""
171
+ def progress_string(self) -> str:
172
+ """Return a formatted string showing training progress."""
170
173
  return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
171
174
  "Epoch",
172
175
  "GPU_mem",
@@ -176,22 +179,23 @@ class ClassificationTrainer(BaseTrainer):
176
179
  )
177
180
 
178
181
  def get_validator(self):
179
- """Returns an instance of ClassificationValidator for validation."""
182
+ """Return an instance of ClassificationValidator for validation."""
180
183
  self.loss_names = ["loss"]
181
184
  return yolo.classify.ClassificationValidator(
182
185
  self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
183
186
  )
184
187
 
185
- def label_loss_items(self, loss_items=None, prefix="train"):
188
+ def label_loss_items(self, loss_items: Optional[torch.Tensor] = None, prefix: str = "train"):
186
189
  """
187
190
  Return a loss dict with labelled training loss items tensor.
188
191
 
189
192
  Args:
190
193
  loss_items (torch.Tensor, optional): Loss tensor items.
191
- prefix (str): Prefix to prepend to loss names.
194
+ prefix (str, optional): Prefix to prepend to loss names.
192
195
 
193
196
  Returns:
194
- (Dict[str, float] | List[str]): Dictionary of loss items or list of loss keys if loss_items is None.
197
+ keys (List[str]): List of loss keys if loss_items is None.
198
+ loss_dict (Dict[str, float]): Dictionary of loss items if loss_items is provided.
195
199
  """
196
200
  keys = [f"{prefix}/{x}" for x in self.loss_names]
197
201
  if loss_items is None:
@@ -216,7 +220,7 @@ class ClassificationTrainer(BaseTrainer):
216
220
  self.metrics.pop("fitness", None)
217
221
  self.run_callbacks("on_fit_epoch_end")
218
222
 
219
- def plot_training_samples(self, batch, ni):
223
+ def plot_training_samples(self, batch: Dict[str, torch.Tensor], ni: int):
220
224
  """
221
225
  Plot training samples with their annotations.
222
226
 
@@ -52,9 +52,6 @@ class ClassificationValidator(BaseValidator):
52
52
  """
53
53
  Initialize ClassificationValidator with dataloader, save directory, and other parameters.
54
54
 
55
- This validator handles the validation process for classification models, including metrics calculation,
56
- confusion matrix generation, and visualization of results.
57
-
58
55
  Args:
59
56
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
60
57
  save_dir (str | Path, optional): Directory to save results.
@@ -101,8 +98,9 @@ class ClassificationValidator(BaseValidator):
101
98
  preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
102
99
  batch (dict): Batch data containing images and class labels.
103
100
 
104
- This method appends the top-N predictions (sorted by confidence in descending order) to the
105
- prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
101
+ Notes:
102
+ This method appends the top-N predictions (sorted by confidence in descending order) to the
103
+ prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
106
104
  """
107
105
  n5 = min(len(self.names), 5)
108
106
  self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
@@ -112,13 +110,14 @@ class ClassificationValidator(BaseValidator):
112
110
  """
113
111
  Finalize metrics including confusion matrix and processing speed.
114
112
 
115
- This method processes the accumulated predictions and targets to generate the confusion matrix,
116
- optionally plots it, and updates the metrics object with speed information.
117
-
118
113
  Args:
119
114
  *args (Any): Variable length argument list.
120
115
  **kwargs (Any): Arbitrary keyword arguments.
121
116
 
117
+ Notes:
118
+ This method processes the accumulated predictions and targets to generate the confusion matrix,
119
+ optionally plots it, and updates the metrics object with speed information.
120
+
122
121
  Examples:
123
122
  >>> validator = ClassificationValidator()
124
123
  >>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
@@ -21,6 +21,7 @@ class DetectionPredictor(BasePredictor):
21
21
  postprocess: Process raw model predictions into detection results.
22
22
  construct_results: Build Results objects from processed predictions.
23
23
  construct_result: Create a single Result object from a prediction.
24
+ get_obj_feats: Extract object features from the feature maps.
24
25
 
25
26
  Examples:
26
27
  >>> from ultralytics.utils import ASSETS
@@ -3,6 +3,7 @@
3
3
  import math
4
4
  import random
5
5
  from copy import copy
6
+ from typing import Dict, List, Optional
6
7
 
7
8
  import numpy as np
8
9
  import torch.nn as nn
@@ -21,12 +22,12 @@ class DetectionTrainer(BaseTrainer):
21
22
  A class extending the BaseTrainer class for training based on a detection model.
22
23
 
23
24
  This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
24
- for object detection.
25
+ for object detection including dataset building, data loading, preprocessing, and model configuration.
25
26
 
26
27
  Attributes:
27
28
  model (DetectionModel): The YOLO detection model being trained.
28
- data (dict): Dictionary containing dataset information including class names and number of classes.
29
- loss_names (Tuple[str]): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
29
+ data (Dict): Dictionary containing dataset information including class names and number of classes.
30
+ loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
30
31
 
31
32
  Methods:
32
33
  build_dataset: Build YOLO dataset for training or validation.
@@ -49,14 +50,14 @@ class DetectionTrainer(BaseTrainer):
49
50
  >>> trainer.train()
50
51
  """
51
52
 
52
- def build_dataset(self, img_path, mode="train", batch=None):
53
+ def build_dataset(self, img_path: str, mode: str = "train", batch: Optional[int] = None):
53
54
  """
54
55
  Build YOLO Dataset for training or validation.
55
56
 
56
57
  Args:
57
58
  img_path (str): Path to the folder containing images.
58
- mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
59
- batch (int, optional): Size of batches, this is for `rect`.
59
+ mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
60
+ batch (int, optional): Size of batches, this is for 'rect' mode.
60
61
 
61
62
  Returns:
62
63
  (Dataset): YOLO dataset object configured for the specified mode.
@@ -64,7 +65,7 @@ class DetectionTrainer(BaseTrainer):
64
65
  gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
65
66
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
66
67
 
67
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
68
+ def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
68
69
  """
69
70
  Construct and return dataloader for the specified mode.
70
71
 
@@ -87,15 +88,15 @@ class DetectionTrainer(BaseTrainer):
87
88
  workers = self.args.workers if mode == "train" else self.args.workers * 2
88
89
  return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
89
90
 
90
- def preprocess_batch(self, batch):
91
+ def preprocess_batch(self, batch: Dict) -> Dict:
91
92
  """
92
93
  Preprocess a batch of images by scaling and converting to float.
93
94
 
94
95
  Args:
95
- batch (dict): Dictionary containing batch data with 'img' tensor.
96
+ batch (Dict): Dictionary containing batch data with 'img' tensor.
96
97
 
97
98
  Returns:
98
- (dict): Preprocessed batch with normalized images.
99
+ (Dict): Preprocessed batch with normalized images.
99
100
  """
100
101
  batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
101
102
  if self.args.multi_scale:
@@ -125,7 +126,7 @@ class DetectionTrainer(BaseTrainer):
125
126
  self.model.args = self.args # attach hyperparameters to model
126
127
  # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
127
128
 
128
- def get_model(self, cfg=None, weights=None, verbose=True):
129
+ def get_model(self, cfg: Optional[str] = None, weights: Optional[str] = None, verbose: bool = True):
129
130
  """
130
131
  Return a YOLO detection model.
131
132
 
@@ -149,7 +150,7 @@ class DetectionTrainer(BaseTrainer):
149
150
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
150
151
  )
151
152
 
152
- def label_loss_items(self, loss_items=None, prefix="train"):
153
+ def label_loss_items(self, loss_items: Optional[List[float]] = None, prefix: str = "train"):
153
154
  """
154
155
  Return a loss dict with labeled training loss items tensor.
155
156
 
@@ -177,12 +178,12 @@ class DetectionTrainer(BaseTrainer):
177
178
  "Size",
178
179
  )
179
180
 
180
- def plot_training_samples(self, batch, ni):
181
+ def plot_training_samples(self, batch: Dict, ni: int):
181
182
  """
182
183
  Plot training samples with their annotations.
183
184
 
184
185
  Args:
185
- batch (dict): Dictionary containing batch data.
186
+ batch (Dict): Dictionary containing batch data.
186
187
  ni (int): Number of iterations.
187
188
  """
188
189
  plot_images(
@@ -2,6 +2,7 @@
2
2
 
3
3
  import os
4
4
  from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Tuple
5
6
 
6
7
  import numpy as np
7
8
  import torch
@@ -26,13 +27,13 @@ class DetectionValidator(BaseValidator):
26
27
  nt_per_image (np.ndarray): Number of targets per image.
27
28
  is_coco (bool): Whether the dataset is COCO.
28
29
  is_lvis (bool): Whether the dataset is LVIS.
29
- class_map (list): Mapping from model class indices to dataset class indices.
30
+ class_map (List[int]): Mapping from model class indices to dataset class indices.
30
31
  metrics (DetMetrics): Object detection metrics calculator.
31
32
  iouv (torch.Tensor): IoU thresholds for mAP calculation.
32
33
  niou (int): Number of IoU thresholds.
33
- lb (list): List for storing ground truth labels for hybrid saving.
34
- jdict (list): List for storing JSON detection results.
35
- stats (dict): Dictionary for storing statistics during validation.
34
+ lb (List[Any]): List for storing ground truth labels for hybrid saving.
35
+ jdict (List[Dict[str, Any]]): List for storing JSON detection results.
36
+ stats (Dict[str, List[torch.Tensor]]): Dictionary for storing statistics during validation.
36
37
 
37
38
  Examples:
38
39
  >>> from ultralytics.models.yolo.detect import DetectionValidator
@@ -49,8 +50,8 @@ class DetectionValidator(BaseValidator):
49
50
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
50
51
  save_dir (Path, optional): Directory to save results.
51
52
  pbar (Any, optional): Progress bar for displaying progress.
52
- args (dict, optional): Arguments for the validator.
53
- _callbacks (list, optional): List of callback functions.
53
+ args (Dict[str, Any], optional): Arguments for the validator.
54
+ _callbacks (List[Any], optional): List of callback functions.
54
55
  """
55
56
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
56
57
  self.nt_per_class = None
@@ -63,15 +64,15 @@ class DetectionValidator(BaseValidator):
63
64
  self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
64
65
  self.niou = self.iouv.numel()
65
66
 
66
- def preprocess(self, batch):
67
+ def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
67
68
  """
68
69
  Preprocess batch of images for YOLO validation.
69
70
 
70
71
  Args:
71
- batch (dict): Batch containing images and annotations.
72
+ batch (Dict[str, Any]): Batch containing images and annotations.
72
73
 
73
74
  Returns:
74
- (dict): Preprocessed batch.
75
+ (Dict[str, Any]): Preprocessed batch.
75
76
  """
76
77
  batch["img"] = batch["img"].to(self.device, non_blocking=True)
77
78
  batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
@@ -80,7 +81,7 @@ class DetectionValidator(BaseValidator):
80
81
 
81
82
  return batch
82
83
 
83
- def init_metrics(self, model):
84
+ def init_metrics(self, model: torch.nn.Module) -> None:
84
85
  """
85
86
  Initialize evaluation metrics for YOLO detection validation.
86
87
 
@@ -106,11 +107,11 @@ class DetectionValidator(BaseValidator):
106
107
  self.jdict = []
107
108
  self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
108
109
 
109
- def get_desc(self):
110
+ def get_desc(self) -> str:
110
111
  """Return a formatted string summarizing class metrics of YOLO model."""
111
112
  return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
112
113
 
113
- def postprocess(self, preds):
114
+ def postprocess(self, preds: torch.Tensor) -> List[torch.Tensor]:
114
115
  """
115
116
  Apply Non-maximum suppression to prediction outputs.
116
117
 
@@ -132,16 +133,16 @@ class DetectionValidator(BaseValidator):
132
133
  rotated=self.args.task == "obb",
133
134
  )
134
135
 
135
- def _prepare_batch(self, si, batch):
136
+ def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
136
137
  """
137
138
  Prepare a batch of images and annotations for validation.
138
139
 
139
140
  Args:
140
141
  si (int): Batch index.
141
- batch (dict): Batch data containing images and annotations.
142
+ batch (Dict[str, Any]): Batch data containing images and annotations.
142
143
 
143
144
  Returns:
144
- (dict): Prepared batch with processed annotations.
145
+ (Dict[str, Any]): Prepared batch with processed annotations.
145
146
  """
146
147
  idx = batch["batch_idx"] == si
147
148
  cls = batch["cls"][idx].squeeze(-1)
@@ -154,13 +155,13 @@ class DetectionValidator(BaseValidator):
154
155
  ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
155
156
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
156
157
 
157
- def _prepare_pred(self, pred, pbatch):
158
+ def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
158
159
  """
159
160
  Prepare predictions for evaluation against ground truth.
160
161
 
161
162
  Args:
162
163
  pred (torch.Tensor): Model predictions.
163
- pbatch (dict): Prepared batch information.
164
+ pbatch (Dict[str, Any]): Prepared batch information.
164
165
 
165
166
  Returns:
166
167
  (torch.Tensor): Prepared predictions in native space.
@@ -171,13 +172,13 @@ class DetectionValidator(BaseValidator):
171
172
  ) # native-space pred
172
173
  return predn
173
174
 
174
- def update_metrics(self, preds, batch):
175
+ def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
175
176
  """
176
177
  Update metrics with new predictions and ground truth.
177
178
 
178
179
  Args:
179
180
  preds (List[torch.Tensor]): List of predictions from the model.
180
- batch (dict): Batch data containing ground truth.
181
+ batch (Dict[str, Any]): Batch data containing ground truth.
181
182
  """
182
183
  for si, pred in enumerate(preds):
183
184
  self.seen += 1
@@ -226,7 +227,7 @@ class DetectionValidator(BaseValidator):
226
227
  self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
227
228
  )
228
229
 
229
- def finalize_metrics(self, *args, **kwargs):
230
+ def finalize_metrics(self, *args: Any, **kwargs: Any) -> None:
230
231
  """
231
232
  Set final values for metrics speed and confusion matrix.
232
233
 
@@ -237,12 +238,12 @@ class DetectionValidator(BaseValidator):
237
238
  self.metrics.speed = self.speed
238
239
  self.metrics.confusion_matrix = self.confusion_matrix
239
240
 
240
- def get_stats(self):
241
+ def get_stats(self) -> Dict[str, Any]:
241
242
  """
242
243
  Calculate and return metrics statistics.
243
244
 
244
245
  Returns:
245
- (dict): Dictionary containing metrics results.
246
+ (Dict[str, Any]): Dictionary containing metrics results.
246
247
  """
247
248
  stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
248
249
  self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
@@ -252,7 +253,7 @@ class DetectionValidator(BaseValidator):
252
253
  self.metrics.process(**stats, on_plot=self.on_plot)
253
254
  return self.metrics.results_dict
254
255
 
255
- def print_results(self):
256
+ def print_results(self) -> None:
256
257
  """Print training/validation set metrics per class."""
257
258
  pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
258
259
  LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
@@ -272,7 +273,7 @@ class DetectionValidator(BaseValidator):
272
273
  save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
273
274
  )
274
275
 
275
- def _process_batch(self, detections, gt_bboxes, gt_cls):
276
+ def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
276
277
  """
277
278
  Return correct prediction matrix.
278
279
 
@@ -289,7 +290,7 @@ class DetectionValidator(BaseValidator):
289
290
  iou = box_iou(gt_bboxes, detections[:, :4])
290
291
  return self.match_predictions(detections[:, 5], gt_cls, iou)
291
292
 
292
- def build_dataset(self, img_path, mode="val", batch=None):
293
+ def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None):
293
294
  """
294
295
  Build YOLO Dataset.
295
296
 
@@ -303,7 +304,7 @@ class DetectionValidator(BaseValidator):
303
304
  """
304
305
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
305
306
 
306
- def get_dataloader(self, dataset_path, batch_size):
307
+ def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
307
308
  """
308
309
  Construct and return dataloader.
309
310
 
@@ -317,12 +318,12 @@ class DetectionValidator(BaseValidator):
317
318
  dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
318
319
  return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
319
320
 
320
- def plot_val_samples(self, batch, ni):
321
+ def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
321
322
  """
322
323
  Plot validation image samples.
323
324
 
324
325
  Args:
325
- batch (dict): Batch containing images and annotations.
326
+ batch (Dict[str, Any]): Batch containing images and annotations.
326
327
  ni (int): Batch index.
327
328
  """
328
329
  plot_images(
@@ -336,12 +337,12 @@ class DetectionValidator(BaseValidator):
336
337
  on_plot=self.on_plot,
337
338
  )
338
339
 
339
- def plot_predictions(self, batch, preds, ni):
340
+ def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
340
341
  """
341
342
  Plot predicted bounding boxes on input images and save the result.
342
343
 
343
344
  Args:
344
- batch (dict): Batch containing images and annotations.
345
+ batch (Dict[str, Any]): Batch containing images and annotations.
345
346
  preds (List[torch.Tensor]): List of predictions from the model.
346
347
  ni (int): Batch index.
347
348
  """
@@ -354,14 +355,14 @@ class DetectionValidator(BaseValidator):
354
355
  on_plot=self.on_plot,
355
356
  ) # pred
356
357
 
357
- def save_one_txt(self, predn, save_conf, shape, file):
358
+ def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
358
359
  """
359
360
  Save YOLO detections to a txt file in normalized coordinates in a specific format.
360
361
 
361
362
  Args:
362
363
  predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
363
364
  save_conf (bool): Whether to save confidence scores.
364
- shape (tuple): Shape of the original image.
365
+ shape (Tuple[int, int]): Shape of the original image.
365
366
  file (Path): File path to save the detections.
366
367
  """
367
368
  from ultralytics.engine.results import Results
@@ -373,7 +374,7 @@ class DetectionValidator(BaseValidator):
373
374
  boxes=predn[:, :6],
374
375
  ).save_txt(file, save_conf=save_conf)
375
376
 
376
- def pred_to_json(self, predn, filename):
377
+ def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
377
378
  """
378
379
  Serialize YOLO predictions to COCO json format.
379
380
 
@@ -395,15 +396,15 @@ class DetectionValidator(BaseValidator):
395
396
  }
396
397
  )
397
398
 
398
- def eval_json(self, stats):
399
+ def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
399
400
  """
400
401
  Evaluate YOLO output in JSON format and return performance statistics.
401
402
 
402
403
  Args:
403
- stats (dict): Current statistics dictionary.
404
+ stats (Dict[str, Any]): Current statistics dictionary.
404
405
 
405
406
  Returns:
406
- (dict): Updated statistics dictionary with COCO/LVIS evaluation results.
407
+ (Dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
407
408
  """
408
409
  if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
409
410
  pred_json = self.save_dir / "predictions.json" # predictions