dgenerate-ultralytics-headless 8.3.194__py3-none-any.whl → 8.3.196__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/METADATA +1 -2
  2. {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/RECORD +107 -106
  3. tests/test_python.py +1 -1
  4. ultralytics/__init__.py +1 -1
  5. ultralytics/cfg/__init__.py +9 -8
  6. ultralytics/cfg/default.yaml +1 -0
  7. ultralytics/data/annotator.py +1 -1
  8. ultralytics/data/augment.py +76 -76
  9. ultralytics/data/base.py +12 -12
  10. ultralytics/data/build.py +5 -1
  11. ultralytics/data/converter.py +4 -4
  12. ultralytics/data/dataset.py +7 -7
  13. ultralytics/data/loaders.py +15 -15
  14. ultralytics/data/split_dota.py +10 -10
  15. ultralytics/data/utils.py +12 -12
  16. ultralytics/engine/exporter.py +19 -31
  17. ultralytics/engine/model.py +13 -13
  18. ultralytics/engine/predictor.py +16 -14
  19. ultralytics/engine/results.py +21 -21
  20. ultralytics/engine/trainer.py +15 -4
  21. ultralytics/engine/validator.py +6 -2
  22. ultralytics/hub/google/__init__.py +2 -2
  23. ultralytics/hub/session.py +7 -7
  24. ultralytics/models/fastsam/model.py +5 -5
  25. ultralytics/models/fastsam/predict.py +11 -11
  26. ultralytics/models/nas/model.py +1 -1
  27. ultralytics/models/rtdetr/predict.py +2 -2
  28. ultralytics/models/rtdetr/val.py +4 -4
  29. ultralytics/models/sam/amg.py +6 -6
  30. ultralytics/models/sam/build.py +9 -9
  31. ultralytics/models/sam/model.py +7 -7
  32. ultralytics/models/sam/modules/blocks.py +6 -6
  33. ultralytics/models/sam/modules/decoders.py +1 -1
  34. ultralytics/models/sam/modules/encoders.py +27 -27
  35. ultralytics/models/sam/modules/sam.py +4 -4
  36. ultralytics/models/sam/modules/tiny_encoder.py +18 -18
  37. ultralytics/models/sam/modules/utils.py +8 -8
  38. ultralytics/models/sam/predict.py +63 -63
  39. ultralytics/models/utils/loss.py +22 -22
  40. ultralytics/models/utils/ops.py +8 -8
  41. ultralytics/models/yolo/classify/predict.py +2 -2
  42. ultralytics/models/yolo/classify/train.py +9 -19
  43. ultralytics/models/yolo/classify/val.py +4 -4
  44. ultralytics/models/yolo/detect/predict.py +3 -3
  45. ultralytics/models/yolo/detect/train.py +38 -12
  46. ultralytics/models/yolo/detect/val.py +38 -37
  47. ultralytics/models/yolo/model.py +6 -6
  48. ultralytics/models/yolo/obb/train.py +1 -10
  49. ultralytics/models/yolo/obb/val.py +13 -13
  50. ultralytics/models/yolo/pose/train.py +1 -9
  51. ultralytics/models/yolo/pose/val.py +12 -12
  52. ultralytics/models/yolo/segment/predict.py +4 -4
  53. ultralytics/models/yolo/segment/train.py +2 -10
  54. ultralytics/models/yolo/segment/val.py +15 -15
  55. ultralytics/models/yolo/world/train.py +13 -13
  56. ultralytics/models/yolo/world/train_world.py +3 -3
  57. ultralytics/models/yolo/yoloe/predict.py +4 -4
  58. ultralytics/models/yolo/yoloe/train.py +7 -16
  59. ultralytics/models/yolo/yoloe/val.py +0 -7
  60. ultralytics/nn/autobackend.py +2 -2
  61. ultralytics/nn/modules/block.py +6 -6
  62. ultralytics/nn/modules/conv.py +2 -2
  63. ultralytics/nn/modules/head.py +6 -5
  64. ultralytics/nn/tasks.py +17 -15
  65. ultralytics/nn/text_model.py +3 -3
  66. ultralytics/solutions/ai_gym.py +2 -2
  67. ultralytics/solutions/analytics.py +3 -3
  68. ultralytics/solutions/config.py +5 -5
  69. ultralytics/solutions/distance_calculation.py +2 -2
  70. ultralytics/solutions/heatmap.py +1 -1
  71. ultralytics/solutions/instance_segmentation.py +4 -4
  72. ultralytics/solutions/object_counter.py +4 -4
  73. ultralytics/solutions/parking_management.py +7 -7
  74. ultralytics/solutions/queue_management.py +3 -3
  75. ultralytics/solutions/region_counter.py +4 -4
  76. ultralytics/solutions/similarity_search.py +2 -2
  77. ultralytics/solutions/solutions.py +48 -48
  78. ultralytics/solutions/streamlit_inference.py +1 -1
  79. ultralytics/solutions/trackzone.py +4 -4
  80. ultralytics/solutions/vision_eye.py +1 -1
  81. ultralytics/trackers/byte_tracker.py +11 -11
  82. ultralytics/trackers/utils/gmc.py +3 -3
  83. ultralytics/trackers/utils/matching.py +5 -5
  84. ultralytics/utils/__init__.py +30 -19
  85. ultralytics/utils/autodevice.py +2 -2
  86. ultralytics/utils/benchmarks.py +10 -10
  87. ultralytics/utils/callbacks/clearml.py +1 -1
  88. ultralytics/utils/callbacks/comet.py +5 -5
  89. ultralytics/utils/callbacks/tensorboard.py +2 -2
  90. ultralytics/utils/checks.py +7 -5
  91. ultralytics/utils/cpu.py +90 -0
  92. ultralytics/utils/dist.py +1 -1
  93. ultralytics/utils/downloads.py +2 -2
  94. ultralytics/utils/export.py +5 -5
  95. ultralytics/utils/instance.py +2 -2
  96. ultralytics/utils/loss.py +14 -8
  97. ultralytics/utils/metrics.py +35 -35
  98. ultralytics/utils/nms.py +4 -4
  99. ultralytics/utils/ops.py +1 -1
  100. ultralytics/utils/patches.py +2 -2
  101. ultralytics/utils/plotting.py +10 -9
  102. ultralytics/utils/torch_utils.py +113 -15
  103. ultralytics/utils/triton.py +5 -5
  104. {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/WHEEL +0 -0
  105. {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/entry_points.txt +0 -0
  106. {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/licenses/LICENSE +0 -0
  107. {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/top_level.txt +0 -0
@@ -8,16 +8,17 @@ from copy import copy
8
8
  from typing import Any
9
9
 
10
10
  import numpy as np
11
+ import torch
11
12
  import torch.nn as nn
12
13
 
13
14
  from ultralytics.data import build_dataloader, build_yolo_dataset
14
15
  from ultralytics.engine.trainer import BaseTrainer
15
16
  from ultralytics.models import yolo
16
17
  from ultralytics.nn.tasks import DetectionModel
17
- from ultralytics.utils import LOGGER, RANK
18
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
18
19
  from ultralytics.utils.patches import override_configs
19
20
  from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
20
- from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
21
+ from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
21
22
 
22
23
 
23
24
  class DetectionTrainer(BaseTrainer):
@@ -29,7 +30,7 @@ class DetectionTrainer(BaseTrainer):
29
30
 
30
31
  Attributes:
31
32
  model (DetectionModel): The YOLO detection model being trained.
32
- data (Dict): Dictionary containing dataset information including class names and number of classes.
33
+ data (dict): Dictionary containing dataset information including class names and number of classes.
33
34
  loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
34
35
 
35
36
  Methods:
@@ -53,6 +54,18 @@ class DetectionTrainer(BaseTrainer):
53
54
  >>> trainer.train()
54
55
  """
55
56
 
57
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
58
+ """
59
+ Initialize a DetectionTrainer object for training YOLO object detection model training.
60
+
61
+ Args:
62
+ cfg (dict, optional): Default configuration dictionary containing training parameters.
63
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
64
+ _callbacks (list, optional): List of callback functions to be executed during training.
65
+ """
66
+ super().__init__(cfg, overrides, _callbacks)
67
+ self.dynamic_tensors = ["batch_idx", "cls", "bboxes"]
68
+
56
69
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
57
70
  """
58
71
  Build YOLO Dataset for training or validation.
@@ -65,7 +78,7 @@ class DetectionTrainer(BaseTrainer):
65
78
  Returns:
66
79
  (Dataset): YOLO dataset object configured for the specified mode.
67
80
  """
68
- gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
81
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
69
82
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
70
83
 
71
84
  def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
@@ -88,20 +101,29 @@ class DetectionTrainer(BaseTrainer):
88
101
  if getattr(dataset, "rect", False) and shuffle:
89
102
  LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
90
103
  shuffle = False
91
- workers = self.args.workers if mode == "train" else self.args.workers * 2
92
- return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
104
+ return build_dataloader(
105
+ dataset,
106
+ batch=batch_size,
107
+ workers=self.args.workers if mode == "train" else self.args.workers * 2,
108
+ shuffle=shuffle,
109
+ rank=rank,
110
+ drop_last=self.args.compile and mode == "train",
111
+ )
93
112
 
94
113
  def preprocess_batch(self, batch: dict) -> dict:
95
114
  """
96
115
  Preprocess a batch of images by scaling and converting to float.
97
116
 
98
117
  Args:
99
- batch (Dict): Dictionary containing batch data with 'img' tensor.
118
+ batch (dict): Dictionary containing batch data with 'img' tensor.
100
119
 
101
120
  Returns:
102
- (Dict): Preprocessed batch with normalized images.
121
+ (dict): Preprocessed batch with normalized images.
103
122
  """
104
- batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
123
+ for k, v in batch.items():
124
+ if isinstance(v, torch.Tensor):
125
+ batch[k] = v.to(self.device, non_blocking=True)
126
+ batch["img"] = batch["img"].float() / 255
105
127
  if self.args.multi_scale:
106
128
  imgs = batch["img"]
107
129
  sz = (
@@ -116,6 +138,10 @@ class DetectionTrainer(BaseTrainer):
116
138
  ] # new shape (stretched to gs-multiple)
117
139
  imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
118
140
  batch["img"] = imgs
141
+
142
+ if self.args.compile:
143
+ for k in self.dynamic_tensors:
144
+ torch._dynamo.maybe_mark_dynamic(batch[k], 0)
119
145
  return batch
120
146
 
121
147
  def set_model_attributes(self):
@@ -158,11 +184,11 @@ class DetectionTrainer(BaseTrainer):
158
184
  Return a loss dict with labeled training loss items tensor.
159
185
 
160
186
  Args:
161
- loss_items (List[float], optional): List of loss values.
187
+ loss_items (list[float], optional): List of loss values.
162
188
  prefix (str): Prefix for keys in the returned dictionary.
163
189
 
164
190
  Returns:
165
- (Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
191
+ (dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
166
192
  """
167
193
  keys = [f"{prefix}/{x}" for x in self.loss_names]
168
194
  if loss_items is not None:
@@ -186,7 +212,7 @@ class DetectionTrainer(BaseTrainer):
186
212
  Plot training samples with their annotations.
187
213
 
188
214
  Args:
189
- batch (Dict[str, Any]): Dictionary containing batch data.
215
+ batch (dict[str, Any]): Dictionary containing batch data.
190
216
  ni (int): Number of iterations.
191
217
  """
192
218
  plot_images(
@@ -27,13 +27,13 @@ class DetectionValidator(BaseValidator):
27
27
  Attributes:
28
28
  is_coco (bool): Whether the dataset is COCO.
29
29
  is_lvis (bool): Whether the dataset is LVIS.
30
- class_map (List[int]): Mapping from model class indices to dataset class indices.
30
+ class_map (list[int]): Mapping from model class indices to dataset class indices.
31
31
  metrics (DetMetrics): Object detection metrics calculator.
32
32
  iouv (torch.Tensor): IoU thresholds for mAP calculation.
33
33
  niou (int): Number of IoU thresholds.
34
- lb (List[Any]): List for storing ground truth labels for hybrid saving.
35
- jdict (List[Dict[str, Any]]): List for storing JSON detection results.
36
- stats (Dict[str, List[torch.Tensor]]): Dictionary for storing statistics during validation.
34
+ lb (list[Any]): List for storing ground truth labels for hybrid saving.
35
+ jdict (list[dict[str, Any]]): List for storing JSON detection results.
36
+ stats (dict[str, list[torch.Tensor]]): Dictionary for storing statistics during validation.
37
37
 
38
38
  Examples:
39
39
  >>> from ultralytics.models.yolo.detect import DetectionValidator
@@ -49,8 +49,8 @@ class DetectionValidator(BaseValidator):
49
49
  Args:
50
50
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
51
51
  save_dir (Path, optional): Directory to save results.
52
- args (Dict[str, Any], optional): Arguments for the validator.
53
- _callbacks (List[Any], optional): List of callback functions.
52
+ args (dict[str, Any], optional): Arguments for the validator.
53
+ _callbacks (list[Any], optional): List of callback functions.
54
54
  """
55
55
  super().__init__(dataloader, save_dir, args, _callbacks)
56
56
  self.is_coco = False
@@ -66,16 +66,15 @@ class DetectionValidator(BaseValidator):
66
66
  Preprocess batch of images for YOLO validation.
67
67
 
68
68
  Args:
69
- batch (Dict[str, Any]): Batch containing images and annotations.
69
+ batch (dict[str, Any]): Batch containing images and annotations.
70
70
 
71
71
  Returns:
72
- (Dict[str, Any]): Preprocessed batch.
72
+ (dict[str, Any]): Preprocessed batch.
73
73
  """
74
- batch["img"] = batch["img"].to(self.device, non_blocking=True)
74
+ for k, v in batch.items():
75
+ if isinstance(v, torch.Tensor):
76
+ batch[k] = v.to(self.device, non_blocking=True)
75
77
  batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
76
- for k in {"batch_idx", "cls", "bboxes"}:
77
- batch[k] = batch[k].to(self.device, non_blocking=True)
78
-
79
78
  return batch
80
79
 
81
80
  def init_metrics(self, model: torch.nn.Module) -> None:
@@ -114,7 +113,7 @@ class DetectionValidator(BaseValidator):
114
113
  preds (torch.Tensor): Raw predictions from the model.
115
114
 
116
115
  Returns:
117
- (List[Dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
116
+ (list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
118
117
  'bboxes', 'conf', 'cls', and 'extra' tensors.
119
118
  """
120
119
  outputs = nms.non_max_suppression(
@@ -136,10 +135,10 @@ class DetectionValidator(BaseValidator):
136
135
 
137
136
  Args:
138
137
  si (int): Batch index.
139
- batch (Dict[str, Any]): Batch data containing images and annotations.
138
+ batch (dict[str, Any]): Batch data containing images and annotations.
140
139
 
141
140
  Returns:
142
- (Dict[str, Any]): Prepared batch with processed annotations.
141
+ (dict[str, Any]): Prepared batch with processed annotations.
143
142
  """
144
143
  idx = batch["batch_idx"] == si
145
144
  cls = batch["cls"][idx].squeeze(-1)
@@ -163,10 +162,10 @@ class DetectionValidator(BaseValidator):
163
162
  Prepare predictions for evaluation against ground truth.
164
163
 
165
164
  Args:
166
- pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
165
+ pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
167
166
 
168
167
  Returns:
169
- (Dict[str, torch.Tensor]): Prepared predictions in native space.
168
+ (dict[str, torch.Tensor]): Prepared predictions in native space.
170
169
  """
171
170
  if self.args.single_cls:
172
171
  pred["cls"] *= 0
@@ -177,8 +176,8 @@ class DetectionValidator(BaseValidator):
177
176
  Update metrics with new predictions and ground truth.
178
177
 
179
178
  Args:
180
- preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
181
- batch (Dict[str, Any]): Batch data containing ground truth.
179
+ preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
180
+ batch (dict[str, Any]): Batch data containing ground truth.
182
181
  """
183
182
  for si, pred in enumerate(preds):
184
183
  self.seen += 1
@@ -232,7 +231,7 @@ class DetectionValidator(BaseValidator):
232
231
  Calculate and return metrics statistics.
233
232
 
234
233
  Returns:
235
- (Dict[str, Any]): Dictionary containing metrics results.
234
+ (dict[str, Any]): Dictionary containing metrics results.
236
235
  """
237
236
  self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)
238
237
  self.metrics.clear_stats()
@@ -263,11 +262,11 @@ class DetectionValidator(BaseValidator):
263
262
  Return correct prediction matrix.
264
263
 
265
264
  Args:
266
- preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
267
- batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
265
+ preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
266
+ batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
268
267
 
269
268
  Returns:
270
- (Dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
269
+ (dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
271
270
  """
272
271
  if len(batch["cls"]) == 0 or len(preds["cls"]) == 0:
273
272
  return {"tp": np.zeros((len(preds["cls"]), self.niou), dtype=bool)}
@@ -300,14 +299,16 @@ class DetectionValidator(BaseValidator):
300
299
  (torch.utils.data.DataLoader): Dataloader for validation.
301
300
  """
302
301
  dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
303
- return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
302
+ return build_dataloader(
303
+ dataset, batch_size, self.args.workers, shuffle=False, rank=-1, drop_last=self.args.compile
304
+ )
304
305
 
305
306
  def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
306
307
  """
307
308
  Plot validation image samples.
308
309
 
309
310
  Args:
310
- batch (Dict[str, Any]): Batch containing images and annotations.
311
+ batch (dict[str, Any]): Batch containing images and annotations.
311
312
  ni (int): Batch index.
312
313
  """
313
314
  plot_images(
@@ -325,8 +326,8 @@ class DetectionValidator(BaseValidator):
325
326
  Plot predicted bounding boxes on input images and save the result.
326
327
 
327
328
  Args:
328
- batch (Dict[str, Any]): Batch containing images and annotations.
329
- preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
329
+ batch (dict[str, Any]): Batch containing images and annotations.
330
+ preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
330
331
  ni (int): Batch index.
331
332
  max_det (Optional[int]): Maximum number of detections to plot.
332
333
  """
@@ -352,9 +353,9 @@ class DetectionValidator(BaseValidator):
352
353
  Save YOLO detections to a txt file in normalized coordinates in a specific format.
353
354
 
354
355
  Args:
355
- predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
356
+ predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
356
357
  save_conf (bool): Whether to save confidence scores.
357
- shape (Tuple[int, int]): Shape of the original image (height, width).
358
+ shape (tuple[int, int]): Shape of the original image (height, width).
358
359
  file (Path): File path to save the detections.
359
360
  """
360
361
  from ultralytics.engine.results import Results
@@ -371,9 +372,9 @@ class DetectionValidator(BaseValidator):
371
372
  Serialize YOLO predictions to COCO json format.
372
373
 
373
374
  Args:
374
- predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
375
+ predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
375
376
  with bounding box coordinates, confidence scores, and class predictions.
376
- pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
377
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
377
378
 
378
379
  Examples:
379
380
  >>> result = {
@@ -417,10 +418,10 @@ class DetectionValidator(BaseValidator):
417
418
  Evaluate YOLO output in JSON format and return performance statistics.
418
419
 
419
420
  Args:
420
- stats (Dict[str, Any]): Current statistics dictionary.
421
+ stats (dict[str, Any]): Current statistics dictionary.
421
422
 
422
423
  Returns:
423
- (Dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
424
+ (dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
424
425
  """
425
426
  pred_json = self.save_dir / "predictions.json" # predictions
426
427
  anno_json = (
@@ -446,16 +447,16 @@ class DetectionValidator(BaseValidator):
446
447
  including mAP50, mAP50-95, and LVIS-specific metrics if applicable.
447
448
 
448
449
  Args:
449
- stats (Dict[str, Any]): Dictionary to store computed metrics and statistics.
450
+ stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
450
451
  pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
451
452
  anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
452
- iou_types (str | List[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
453
+ iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
453
454
  Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
454
- suffix (str | List[str]]): Suffix to append to metric names in stats dictionary. Should correspond
455
+ suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
455
456
  to iou_types if multiple types provided. Defaults to "Box".
456
457
 
457
458
  Returns:
458
- (Dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
459
+ (dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
459
460
  """
460
461
  if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
461
462
  LOGGER.info(f"\nEvaluating faster-coco-eval mAP using {pred_json} and {anno_json}...")
@@ -185,7 +185,7 @@ class YOLOWorld(Model):
185
185
  Set the model's class names for detection.
186
186
 
187
187
  Args:
188
- classes (List[str]): A list of categories i.e. ["person"].
188
+ classes (list[str]): A list of categories i.e. ["person"].
189
189
  """
190
190
  self.model.set_classes(classes)
191
191
  # Remove background if it's given
@@ -299,8 +299,8 @@ class YOLOE(Model):
299
299
  classification tasks. The model must be an instance of YOLOEModel.
300
300
 
301
301
  Args:
302
- vocab (List[str]): Vocabulary list containing tokens or words used by the model for text processing.
303
- names (List[str]): List of class names that the model can detect or classify.
302
+ vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
303
+ names (list[str]): List of class names that the model can detect or classify.
304
304
 
305
305
  Raises:
306
306
  AssertionError: If the model is not an instance of YOLOEModel.
@@ -322,7 +322,7 @@ class YOLOE(Model):
322
322
  Set the model's class names and embeddings for detection.
323
323
 
324
324
  Args:
325
- classes (List[str]): A list of categories i.e. ["person"].
325
+ classes (list[str]): A list of categories i.e. ["person"].
326
326
  embeddings (torch.Tensor): Embeddings corresponding to the classes.
327
327
  """
328
328
  assert isinstance(self.model, YOLOEModel)
@@ -381,7 +381,7 @@ class YOLOE(Model):
381
381
  directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
382
382
  stream (bool): Whether to stream the prediction results. If True, results are yielded as a
383
383
  generator as they are computed.
384
- visual_prompts (Dict[str, List]): Dictionary containing visual prompts for the model. Must include
384
+ visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
385
385
  'bboxes' and 'cls' keys when non-empty.
386
386
  refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
387
387
  predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
@@ -389,7 +389,7 @@ class YOLOE(Model):
389
389
  **kwargs (Any): Additional keyword arguments passed to the predictor.
390
390
 
391
391
  Returns:
392
- (List | generator): List of Results objects or generator of Results objects if stream=True.
392
+ (list | generator): List of Results objects or generator of Results objects if stream=True.
393
393
 
394
394
  Examples:
395
395
  >>> model = YOLOE("yoloe-11s-seg.pt")
@@ -37,21 +37,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
37
37
  """
38
38
  Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
39
39
 
40
- This trainer extends the DetectionTrainer class to specialize in training models that detect oriented
41
- bounding boxes. It automatically sets the task to 'obb' in the configuration.
42
-
43
40
  Args:
44
41
  cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
45
42
  model configuration.
46
43
  overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
47
44
  will take precedence over those in cfg.
48
- _callbacks (List[Any], optional): List of callback functions to be invoked during training.
49
-
50
- Examples:
51
- >>> from ultralytics.models.yolo.obb import OBBTrainer
52
- >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
53
- >>> trainer = OBBTrainer(overrides=args)
54
- >>> trainer.train()
45
+ _callbacks (list[Any], optional): List of callback functions to be invoked during training.
55
46
  """
56
47
  if overrides is None:
57
48
  overrides = {}
@@ -77,13 +77,13 @@ class OBBValidator(DetectionValidator):
77
77
  Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
78
78
 
79
79
  Args:
80
- preds (Dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
80
+ preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
81
81
  class labels and bounding boxes.
82
- batch (Dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
82
+ batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
83
83
  class labels and bounding boxes.
84
84
 
85
85
  Returns:
86
- (Dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
86
+ (dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
87
87
  array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
88
88
  of predictions compared to the ground truth.
89
89
 
@@ -104,7 +104,7 @@ class OBBValidator(DetectionValidator):
104
104
  preds (torch.Tensor): Raw predictions from the model.
105
105
 
106
106
  Returns:
107
- (List[Dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
107
+ (list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
108
108
  """
109
109
  preds = super().postprocess(preds)
110
110
  for pred in preds:
@@ -117,7 +117,7 @@ class OBBValidator(DetectionValidator):
117
117
 
118
118
  Args:
119
119
  si (int): Batch index to process.
120
- batch (Dict[str, Any]): Dictionary containing batch data with keys:
120
+ batch (dict[str, Any]): Dictionary containing batch data with keys:
121
121
  - batch_idx: Tensor of batch indices
122
122
  - cls: Tensor of class labels
123
123
  - bboxes: Tensor of bounding boxes
@@ -126,7 +126,7 @@ class OBBValidator(DetectionValidator):
126
126
  - ratio_pad: Ratio and padding information
127
127
 
128
128
  Returns:
129
- (Dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
129
+ (dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
130
130
  """
131
131
  idx = batch["batch_idx"] == si
132
132
  cls = batch["cls"][idx].squeeze(-1)
@@ -150,8 +150,8 @@ class OBBValidator(DetectionValidator):
150
150
  Plot predicted bounding boxes on input images and save the result.
151
151
 
152
152
  Args:
153
- batch (Dict[str, Any]): Batch data containing images, file paths, and other metadata.
154
- preds (List[torch.Tensor]): List of prediction tensors for each image in the batch.
153
+ batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
154
+ preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
155
155
  ni (int): Batch index used for naming the output file.
156
156
 
157
157
  Examples:
@@ -170,9 +170,9 @@ class OBBValidator(DetectionValidator):
170
170
  Convert YOLO predictions to COCO JSON format with rotated bounding box information.
171
171
 
172
172
  Args:
173
- predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
173
+ predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
174
174
  with bounding box coordinates, confidence scores, and class predictions.
175
- pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
175
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
176
176
 
177
177
  Notes:
178
178
  This method processes rotated bounding box predictions and converts them to both rbox format
@@ -204,7 +204,7 @@ class OBBValidator(DetectionValidator):
204
204
  predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
205
205
  class predictions, and angles in format (x, y, w, h, conf, cls, angle).
206
206
  save_conf (bool): Whether to save confidence scores in the text file.
207
- shape (Tuple[int, int]): Original image shape in format (height, width).
207
+ shape (tuple[int, int]): Original image shape in format (height, width).
208
208
  file (Path): Output file path to save detections.
209
209
 
210
210
  Examples:
@@ -237,10 +237,10 @@ class OBBValidator(DetectionValidator):
237
237
  Evaluate YOLO output in JSON format and save predictions in DOTA format.
238
238
 
239
239
  Args:
240
- stats (Dict[str, Any]): Performance statistics dictionary.
240
+ stats (dict[str, Any]): Performance statistics dictionary.
241
241
 
242
242
  Returns:
243
- (Dict[str, Any]): Updated performance statistics.
243
+ (dict[str, Any]): Updated performance statistics.
244
244
  """
245
245
  if self.args.save_json and self.is_dota and len(self.jdict):
246
246
  import json
@@ -44,9 +44,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
44
44
  """
45
45
  Initialize a PoseTrainer object for training YOLO pose estimation models.
46
46
 
47
- This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
48
- handling specific configurations needed for keypoint detection models.
49
-
50
47
  Args:
51
48
  cfg (dict, optional): Default configuration dictionary containing training parameters.
52
49
  overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
@@ -55,17 +52,12 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
55
52
  Notes:
56
53
  This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
57
54
  A warning is issued when using Apple MPS device due to known bugs with pose models.
58
-
59
- Examples:
60
- >>> from ultralytics.models.yolo.pose import PoseTrainer
61
- >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
62
- >>> trainer = PoseTrainer(overrides=args)
63
- >>> trainer.train()
64
55
  """
65
56
  if overrides is None:
66
57
  overrides = {}
67
58
  overrides["task"] = "pose"
68
59
  super().__init__(cfg, overrides, _callbacks)
60
+ self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "keypoints"]
69
61
 
70
62
  if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
71
63
  LOGGER.warning(
@@ -22,7 +22,7 @@ class PoseValidator(DetectionValidator):
22
22
 
23
23
  Attributes:
24
24
  sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
25
- kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
25
+ kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
26
26
  args (dict): Arguments for the validator including task set to "pose".
27
27
  metrics (PoseMetrics): Metrics object for pose evaluation.
28
28
 
@@ -86,7 +86,7 @@ class PoseValidator(DetectionValidator):
86
86
  def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
87
87
  """Preprocess batch by converting keypoints data to float and moving it to the device."""
88
88
  batch = super().preprocess(batch)
89
- batch["keypoints"] = batch["keypoints"].to(self.device, non_blocking=True).float()
89
+ batch["keypoints"] = batch["keypoints"].float()
90
90
  return batch
91
91
 
92
92
  def get_desc(self) -> str:
@@ -132,7 +132,7 @@ class PoseValidator(DetectionValidator):
132
132
  bounding boxes, confidence scores, class predictions, and keypoint data.
133
133
 
134
134
  Returns:
135
- (Dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
135
+ (dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
136
136
  - 'bboxes': Bounding box coordinates
137
137
  - 'conf': Confidence scores
138
138
  - 'cls': Class predictions
@@ -154,10 +154,10 @@ class PoseValidator(DetectionValidator):
154
154
 
155
155
  Args:
156
156
  si (int): Batch index.
157
- batch (Dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
157
+ batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
158
158
 
159
159
  Returns:
160
- (Dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
160
+ (dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
161
161
 
162
162
  Notes:
163
163
  This method extends the parent class's _prepare_batch method by adding keypoint processing.
@@ -177,13 +177,13 @@ class PoseValidator(DetectionValidator):
177
177
  Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
178
178
 
179
179
  Args:
180
- preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
180
+ preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
181
181
  and 'keypoints' for keypoint predictions.
182
- batch (Dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
182
+ batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
183
183
  'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
184
184
 
185
185
  Returns:
186
- (Dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
186
+ (dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
187
187
  true positives across 10 IoU levels.
188
188
 
189
189
  Notes:
@@ -207,9 +207,9 @@ class PoseValidator(DetectionValidator):
207
207
  Save YOLO pose detections to a text file in normalized coordinates.
208
208
 
209
209
  Args:
210
- predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
210
+ predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
211
211
  save_conf (bool): Whether to save confidence scores.
212
- shape (Tuple[int, int]): Shape of the original image (height, width).
212
+ shape (tuple[int, int]): Shape of the original image (height, width).
213
213
  file (Path): Output file path to save detections.
214
214
 
215
215
  Notes:
@@ -234,9 +234,9 @@ class PoseValidator(DetectionValidator):
234
234
  to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
235
235
 
236
236
  Args:
237
- predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
237
+ predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
238
238
  and 'keypoints' tensors.
239
- pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
239
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
240
240
 
241
241
  Notes:
242
242
  The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
@@ -71,13 +71,13 @@ class SegmentationPredictor(DetectionPredictor):
71
71
  Construct a list of result objects from the predictions.
72
72
 
73
73
  Args:
74
- preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
74
+ preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
75
75
  img (torch.Tensor): The image after preprocessing.
76
- orig_imgs (List[np.ndarray]): List of original images before preprocessing.
77
- protos (List[torch.Tensor]): List of prototype masks.
76
+ orig_imgs (list[np.ndarray]): List of original images before preprocessing.
77
+ protos (list[torch.Tensor]): List of prototype masks.
78
78
 
79
79
  Returns:
80
- (List[Results]): List of result objects containing the original images, image paths, class names,
80
+ (list[Results]): List of result objects containing the original images, image paths, class names,
81
81
  bounding boxes, and masks.
82
82
  """
83
83
  return [
@@ -19,7 +19,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
19
19
  functionality including model initialization, validation, and visualization.
20
20
 
21
21
  Attributes:
22
- loss_names (Tuple[str]): Names of the loss components used during training.
22
+ loss_names (tuple[str]): Names of the loss components used during training.
23
23
 
24
24
  Examples:
25
25
  >>> from ultralytics.models.yolo.segment import SegmentationTrainer
@@ -32,24 +32,16 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
32
32
  """
33
33
  Initialize a SegmentationTrainer object.
34
34
 
35
- This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific
36
- functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
37
-
38
35
  Args:
39
36
  cfg (dict): Configuration dictionary with default training settings.
40
37
  overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
41
38
  _callbacks (list, optional): List of callback functions to be executed during training.
42
-
43
- Examples:
44
- >>> from ultralytics.models.yolo.segment import SegmentationTrainer
45
- >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
46
- >>> trainer = SegmentationTrainer(overrides=args)
47
- >>> trainer.train()
48
39
  """
49
40
  if overrides is None:
50
41
  overrides = {}
51
42
  overrides["task"] = "segment"
52
43
  super().__init__(cfg, overrides, _callbacks)
44
+ self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "masks"]
53
45
 
54
46
  def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
55
47
  """