dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +52 -51
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +191 -161
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +4 -6
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +2 -2
  95. ultralytics/solutions/instance_segmentation.py +7 -4
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -11
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +189 -79
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +45 -29
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@
2
2
 
3
3
  import itertools
4
4
  from copy import copy, deepcopy
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Union
5
7
 
6
8
  import torch
7
9
 
@@ -17,9 +19,22 @@ from .val import YOLOEDetectValidator
17
19
 
18
20
 
19
21
  class YOLOETrainer(DetectionTrainer):
20
- """A base trainer for YOLOE training."""
22
+ """
23
+ A trainer class for YOLOE object detection models.
21
24
 
22
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
25
+ This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,
26
+ including custom model initialization, validation, and dataset building with multi-modal support.
27
+
28
+ Attributes:
29
+ loss_names (tuple): Names of loss components used during training.
30
+
31
+ Methods:
32
+ get_model: Initialize and return a YOLOEModel with specified configuration.
33
+ get_validator: Return a YOLOEDetectValidator for model validation.
34
+ build_dataset: Build YOLO dataset with multi-modal support for training.
35
+ """
36
+
37
+ def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict] = None, _callbacks=None):
23
38
  """
24
39
  Initialize the YOLOE Trainer with specified configurations.
25
40
 
@@ -36,14 +51,14 @@ class YOLOETrainer(DetectionTrainer):
36
51
  overrides["overlap_mask"] = False
37
52
  super().__init__(cfg, overrides, _callbacks)
38
53
 
39
- def get_model(self, cfg=None, weights=None, verbose=True):
54
+ def get_model(self, cfg=None, weights=None, verbose: bool = True):
40
55
  """
41
56
  Return a YOLOEModel initialized with the specified configuration and weights.
42
57
 
43
58
  Args:
44
- cfg (dict | str | None): Model configuration. Can be a dictionary containing a 'yaml_file' key,
59
+ cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
45
60
  a direct path to a YAML file, or None to use default configuration.
46
- weights (str | Path | None): Path to pretrained weights file to load into the model.
61
+ weights (str | Path, optional): Path to pretrained weights file to load into the model.
47
62
  verbose (bool): Whether to display model information during initialization.
48
63
 
49
64
  Returns:
@@ -68,20 +83,20 @@ class YOLOETrainer(DetectionTrainer):
68
83
  return model
69
84
 
70
85
  def get_validator(self):
71
- """Returns a DetectionValidator for YOLO model validation."""
86
+ """Return a YOLOEDetectValidator for YOLOE model validation."""
72
87
  self.loss_names = "box", "cls", "dfl"
73
88
  return YOLOEDetectValidator(
74
89
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
75
90
  )
76
91
 
77
- def build_dataset(self, img_path, mode="train", batch=None):
92
+ def build_dataset(self, img_path: str, mode: str = "train", batch: Optional[int] = None):
78
93
  """
79
94
  Build YOLO Dataset.
80
95
 
81
96
  Args:
82
97
  img_path (str): Path to the folder containing images.
83
- mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
84
- batch (int, optional): Size of batches, this is for `rect`.
98
+ mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
99
+ batch (int, optional): Size of batches, this is for rectangular training.
85
100
 
86
101
  Returns:
87
102
  (Dataset): YOLO dataset configured for training or validation.
@@ -93,9 +108,17 @@ class YOLOETrainer(DetectionTrainer):
93
108
 
94
109
 
95
110
  class YOLOEPETrainer(DetectionTrainer):
96
- """Fine-tune YOLOE model in linear probing way."""
111
+ """
112
+ Fine-tune YOLOE model using linear probing approach.
97
113
 
98
- def get_model(self, cfg=None, weights=None, verbose=True):
114
+ This trainer freezes most model layers and only trains specific projection layers for efficient
115
+ fine-tuning on new datasets while preserving pretrained features.
116
+
117
+ Methods:
118
+ get_model: Initialize YOLOEModel with frozen layers except projection layers.
119
+ """
120
+
121
+ def get_model(self, cfg=None, weights=None, verbose: bool = True):
99
122
  """
100
123
  Return YOLOEModel initialized with specified config and weights.
101
124
 
@@ -139,9 +162,19 @@ class YOLOEPETrainer(DetectionTrainer):
139
162
 
140
163
 
141
164
  class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
142
- """Train YOLOE models from scratch."""
165
+ """
166
+ Train YOLOE models from scratch with text embedding support.
167
+
168
+ This trainer combines YOLOE training capabilities with world training features, enabling
169
+ training from scratch with text embeddings and grounding datasets.
143
170
 
144
- def build_dataset(self, img_path, mode="train", batch=None):
171
+ Methods:
172
+ build_dataset: Build datasets for training with grounding support.
173
+ preprocess_batch: Process batches with text features.
174
+ generate_text_embeddings: Generate and cache text embeddings for training.
175
+ """
176
+
177
+ def build_dataset(self, img_path: Union[List[str], str], mode: str = "train", batch: Optional[int] = None):
145
178
  """
146
179
  Build YOLO Dataset for training or validation.
147
180
 
@@ -168,7 +201,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
168
201
  batch["txt_feats"] = txt_feats
169
202
  return batch
170
203
 
171
- def generate_text_embeddings(self, texts, batch, cache_dir):
204
+ def generate_text_embeddings(self, texts: List[str], batch: int, cache_dir: Path):
172
205
  """
173
206
  Generate text embeddings for a list of text samples.
174
207
 
@@ -196,21 +229,31 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
196
229
 
197
230
 
198
231
  class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
199
- """Train prompt-free YOLOE model."""
232
+ """
233
+ Train prompt-free YOLOE model.
234
+
235
+ This trainer combines linear probing capabilities with from-scratch training for prompt-free
236
+ YOLOE models that don't require text prompts during inference.
237
+
238
+ Methods:
239
+ get_validator: Return standard DetectionValidator for validation.
240
+ preprocess_batch: Preprocess batches without text features.
241
+ set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).
242
+ """
200
243
 
201
244
  def get_validator(self):
202
- """Returns a DetectionValidator for YOLO model validation."""
245
+ """Return a DetectionValidator for YOLO model validation."""
203
246
  self.loss_names = "box", "cls", "dfl"
204
247
  return DetectionValidator(
205
248
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
206
249
  )
207
250
 
208
251
  def preprocess_batch(self, batch):
209
- """Preprocesses a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
252
+ """Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
210
253
  batch = DetectionTrainer.preprocess_batch(self, batch)
211
254
  return batch
212
255
 
213
- def set_text_embeddings(self, datasets, batch):
256
+ def set_text_embeddings(self, datasets, batch: int):
214
257
  """
215
258
  Set text embeddings for datasets to accelerate training by caching category names.
216
259
 
@@ -231,9 +274,18 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
231
274
 
232
275
 
233
276
  class YOLOEVPTrainer(YOLOETrainerFromScratch):
234
- """Train YOLOE model with visual prompts."""
277
+ """
278
+ Train YOLOE model with visual prompts.
279
+
280
+ This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,
281
+ where visual cues are provided alongside images to guide the detection process.
282
+
283
+ Methods:
284
+ build_dataset: Build dataset with visual prompt loading transforms.
285
+ preprocess_batch: Preprocess batches with visual prompts.
286
+ """
235
287
 
236
- def build_dataset(self, img_path, mode="train", batch=None):
288
+ def build_dataset(self, img_path: Union[List[str], str], mode: str = "train", batch: Optional[int] = None):
237
289
  """
238
290
  Build YOLO Dataset for training or validation with visual prompts.
239
291
 
@@ -263,7 +315,7 @@ class YOLOEVPTrainer(YOLOETrainerFromScratch):
263
315
  self.train_loader.dataset.transforms.append(LoadVisualPrompt())
264
316
 
265
317
  def preprocess_batch(self, batch):
266
- """Preprocesses a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
318
+ """Preprocess a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
267
319
  batch = super().preprocess_batch(batch)
268
320
  batch["visuals"] = batch["visuals"].to(self.device)
269
321
  return batch
@@ -14,8 +14,8 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
14
14
  """
15
15
  Trainer class for YOLOE segmentation models.
16
16
 
17
- This class combines YOLOETrainer and SegmentationTrainer to provide training functionality
18
- specifically for YOLOE segmentation models.
17
+ This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
18
+ segmentation models, enabling both object detection and instance segmentation capabilities.
19
19
 
20
20
  Attributes:
21
21
  cfg (dict): Configuration dictionary with training parameters.
@@ -28,7 +28,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
28
28
  Return YOLOESegModel initialized with specified config and weights.
29
29
 
30
30
  Args:
31
- cfg (dict | str): Model configuration dictionary or YAML file path.
31
+ cfg (dict | str, optional): Model configuration dictionary or YAML file path.
32
32
  weights (str, optional): Path to pretrained weights file.
33
33
  verbose (bool): Whether to display model information.
34
34
 
@@ -66,7 +66,10 @@ class YOLOEPESegTrainer(SegmentationTrainer):
66
66
  Fine-tune YOLOESeg model in linear probing way.
67
67
 
68
68
  This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
69
- most of the model and only training specific layers.
69
+ most of the model and only training specific layers for efficient adaptation to new tasks.
70
+
71
+ Attributes:
72
+ data (dict): Dataset configuration containing channels, class names, and number of classes.
70
73
  """
71
74
 
72
75
  def get_model(self, cfg=None, weights=None, verbose=True):
@@ -74,7 +77,7 @@ class YOLOEPESegTrainer(SegmentationTrainer):
74
77
  Return YOLOESegModel initialized with specified config and weights for linear probing.
75
78
 
76
79
  Args:
77
- cfg (dict | str): Model configuration dictionary or YAML file path.
80
+ cfg (dict | str, optional): Model configuration dictionary or YAML file path.
78
81
  weights (str, optional): Path to pretrained weights file.
79
82
  verbose (bool): Whether to display model information.
80
83
 
@@ -113,12 +116,12 @@ class YOLOEPESegTrainer(SegmentationTrainer):
113
116
 
114
117
 
115
118
  class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
116
- """Trainer for YOLOE segmentation from scratch."""
119
+ """Trainer for YOLOE segmentation models trained from scratch without pretrained weights."""
117
120
 
118
121
  pass
119
122
 
120
123
 
121
124
  class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
122
- """Trainer for YOLOE segmentation with VP."""
125
+ """Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities."""
123
126
 
124
127
  pass
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import deepcopy
4
+ from typing import Any, Dict, Optional, Union
4
5
 
5
6
  import torch
6
7
  from torch.nn import functional as F
@@ -18,26 +19,40 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
18
19
 
19
20
  class YOLOEDetectValidator(DetectionValidator):
20
21
  """
21
- A mixin class for YOLOE model validation that handles both text and visual prompt embeddings.
22
+ A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
22
23
 
23
- This mixin provides functionality to validate YOLOE models using either text or visual prompt embeddings.
24
- It includes methods for extracting visual prompt embeddings from samples, preprocessing batches, and
25
- running validation with different prompt types.
24
+ This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.
25
+ It supports validation using either text prompts or visual prompt embeddings extracted from training samples,
26
+ enabling flexible evaluation strategies for prompt-based object detection.
26
27
 
27
28
  Attributes:
28
29
  device (torch.device): The device on which validation is performed.
29
30
  args (namespace): Configuration arguments for validation.
30
31
  dataloader (DataLoader): DataLoader for validation data.
32
+
33
+ Methods:
34
+ get_visual_pe: Extract visual prompt embeddings from training samples.
35
+ preprocess: Preprocess batch data ensuring visuals are on the same device as images.
36
+ get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.
37
+ __call__: Run validation using either text or visual prompt embeddings.
38
+
39
+ Examples:
40
+ Validate with text prompts
41
+ >>> validator = YOLOEDetectValidator()
42
+ >>> stats = validator(model=model, load_vp=False)
43
+
44
+ Validate with visual prompts
45
+ >>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
31
46
  """
32
47
 
33
48
  @smart_inference_mode()
34
- def get_visual_pe(self, dataloader, model):
49
+ def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
35
50
  """
36
51
  Extract visual prompt embeddings from training samples.
37
52
 
38
- This function processes a dataloader to compute visual prompt embeddings for each class
39
- using a YOLOE model. It normalizes the embeddings and handles cases where no samples
40
- exist for a class.
53
+ This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
54
+ It normalizes the embeddings and handles cases where no samples exist for a class by setting their
55
+ embeddings to zero.
41
56
 
42
57
  Args:
43
58
  dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
@@ -53,6 +68,7 @@ class YOLOEDetectValidator(DetectionValidator):
53
68
 
54
69
  desc = "Get visual prompt embeddings from samples"
55
70
 
71
+ # Count samples per class
56
72
  for batch in dataloader:
57
73
  cls = batch["cls"].squeeze(-1).to(torch.int).unique()
58
74
  count = torch.bincount(cls, minlength=len(names))
@@ -60,6 +76,7 @@ class YOLOEDetectValidator(DetectionValidator):
60
76
 
61
77
  cls_visual_num = cls_visual_num.to(self.device)
62
78
 
79
+ # Extract visual prompt embeddings
63
80
  pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
64
81
  for batch in pbar:
65
82
  batch = self.preprocess(batch)
@@ -73,30 +90,31 @@ class YOLOEDetectValidator(DetectionValidator):
73
90
  for c in cls:
74
91
  visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
75
92
 
93
+ # Normalize embeddings for classes with samples, set others to zero
76
94
  visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
77
95
  visual_pe[cls_visual_num == 0] = 0
78
96
  return visual_pe.unsqueeze(0)
79
97
 
80
- def preprocess(self, batch):
98
+ def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
81
99
  """Preprocess batch data, ensuring visuals are on the same device as images."""
82
100
  batch = super().preprocess(batch)
83
101
  if "visuals" in batch:
84
102
  batch["visuals"] = batch["visuals"].to(batch["img"].device)
85
103
  return batch
86
104
 
87
- def get_vpe_dataloader(self, data):
105
+ def get_vpe_dataloader(self, data: Dict[str, Any]) -> torch.utils.data.DataLoader:
88
106
  """
89
107
  Create a dataloader for LVIS training visual prompt samples.
90
108
 
91
- This function prepares a dataloader for visual prompt embeddings (VPE) using the LVIS dataset.
92
- It applies necessary transformations and configurations to the dataset and returns a dataloader
109
+ This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
110
+ It applies necessary transformations including LoadVisualPrompt and configurations to the dataset
93
111
  for validation purposes.
94
112
 
95
113
  Args:
96
114
  data (dict): Dataset configuration dictionary containing paths and settings.
97
115
 
98
116
  Returns:
99
- (torch.utils.data.DataLoader): The dataLoader for visual prompt samples.
117
+ (torch.utils.data.DataLoader): The dataloader for visual prompt samples.
100
118
  """
101
119
  dataset = build_yolo_dataset(
102
120
  self.args,
@@ -120,17 +138,23 @@ class YOLOEDetectValidator(DetectionValidator):
120
138
  )
121
139
 
122
140
  @smart_inference_mode()
123
- def __call__(self, trainer=None, model=None, refer_data=None, load_vp=False):
141
+ def __call__(
142
+ self,
143
+ trainer: Optional[Any] = None,
144
+ model: Optional[Union[YOLOEModel, str]] = None,
145
+ refer_data: Optional[str] = None,
146
+ load_vp: bool = False,
147
+ ) -> Dict[str, Any]:
124
148
  """
125
149
  Run validation on the model using either text or visual prompt embeddings.
126
150
 
127
- This method validates the model using either text prompts or visual prompts, depending
128
- on the `load_vp` flag. It supports validation during training (using a trainer object)
129
- or standalone validation with a provided model.
151
+ This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
152
+ It supports validation during training (using a trainer object) or standalone validation with a provided
153
+ model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.
130
154
 
131
155
  Args:
132
156
  trainer (object, optional): Trainer object containing the model and device.
133
- model (YOLOEModel, optional): Model to validate. Required if `trainer` is not provided.
157
+ model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
134
158
  refer_data (str, optional): Path to reference data for visual prompts.
135
159
  load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
136
160
 
@@ -6,7 +6,7 @@ import platform
6
6
  import zipfile
7
7
  from collections import OrderedDict, namedtuple
8
8
  from pathlib import Path
9
- from typing import List, Optional, Union
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
10
 
11
11
  import cv2
12
12
  import numpy as np
@@ -19,8 +19,19 @@ from ultralytics.utils.checks import check_requirements, check_suffix, check_ver
19
19
  from ultralytics.utils.downloads import attempt_download_asset, is_url
20
20
 
21
21
 
22
- def check_class_names(names):
23
- """Check class names and convert to dict format if needed."""
22
+ def check_class_names(names: Union[List, Dict]) -> Dict[int, str]:
23
+ """
24
+ Check class names and convert to dict format if needed.
25
+
26
+ Args:
27
+ names (list | dict): Class names as list or dict format.
28
+
29
+ Returns:
30
+ (dict): Class names in dict format with integer keys and string values.
31
+
32
+ Raises:
33
+ KeyError: If class indices are invalid for the dataset size.
34
+ """
24
35
  if isinstance(names, list): # names is a list
25
36
  names = dict(enumerate(names)) # convert to dict
26
37
  if isinstance(names, dict):
@@ -38,8 +49,16 @@ def check_class_names(names):
38
49
  return names
39
50
 
40
51
 
41
- def default_class_names(data=None):
42
- """Applies default class names to an input YAML file or returns numerical class names."""
52
+ def default_class_names(data: Optional[Union[str, Path]] = None) -> Dict[int, str]:
53
+ """
54
+ Apply default class names to an input YAML file or return numerical class names.
55
+
56
+ Args:
57
+ data (str | Path, optional): Path to YAML file containing class names.
58
+
59
+ Returns:
60
+ (dict): Dictionary mapping class indices to class names.
61
+ """
43
62
  if data:
44
63
  try:
45
64
  return YAML.load(check_yaml(data))["names"]
@@ -50,7 +69,7 @@ def default_class_names(data=None):
50
69
 
51
70
  class AutoBackend(nn.Module):
52
71
  """
53
- Handles dynamic backend selection for running inference using Ultralytics YOLO models.
72
+ Handle dynamic backend selection for running inference using Ultralytics YOLO models.
54
73
 
55
74
  The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
56
75
  range of formats, each with specific naming conventions as outlined below:
@@ -82,6 +101,24 @@ class AutoBackend(nn.Module):
82
101
  names (dict): A dictionary of class names that the model can detect.
83
102
  stride (int): The model stride, typically 32 for YOLO models.
84
103
  fp16 (bool): Whether the model uses half-precision (FP16) inference.
104
+ nhwc (bool): Whether the model expects NHWC input format instead of NCHW.
105
+ pt (bool): Whether the model is a PyTorch model.
106
+ jit (bool): Whether the model is a TorchScript model.
107
+ onnx (bool): Whether the model is an ONNX model.
108
+ xml (bool): Whether the model is an OpenVINO model.
109
+ engine (bool): Whether the model is a TensorRT engine.
110
+ coreml (bool): Whether the model is a CoreML model.
111
+ saved_model (bool): Whether the model is a TensorFlow SavedModel.
112
+ pb (bool): Whether the model is a TensorFlow GraphDef.
113
+ tflite (bool): Whether the model is a TensorFlow Lite model.
114
+ edgetpu (bool): Whether the model is a TensorFlow Edge TPU model.
115
+ tfjs (bool): Whether the model is a TensorFlow.js model.
116
+ paddle (bool): Whether the model is a PaddlePaddle model.
117
+ mnn (bool): Whether the model is an MNN model.
118
+ ncnn (bool): Whether the model is an NCNN model.
119
+ imx (bool): Whether the model is an IMX model.
120
+ rknn (bool): Whether the model is an RKNN model.
121
+ triton (bool): Whether the model is a Triton Inference Server model.
85
122
 
86
123
  Methods:
87
124
  forward: Run inference on an input image.
@@ -113,7 +150,7 @@ class AutoBackend(nn.Module):
113
150
  weights (str | List[str] | torch.nn.Module): Path to the model weights file or a module instance.
114
151
  device (torch.device): Device to run the model on.
115
152
  dnn (bool): Use OpenCV DNN module for ONNX inference.
116
- data (str | Path | optional): Path to the additional data.yaml file containing class names.
153
+ data (str | Path, optional): Path to the additional data.yaml file containing class names.
117
154
  fp16 (bool): Enable half-precision inference. Supported only on specific backends.
118
155
  batch (int): Batch-size to assume for inference.
119
156
  fuse (bool): Fuse Conv2D + BatchNorm layers for optimization.
@@ -567,15 +604,22 @@ class AutoBackend(nn.Module):
567
604
 
568
605
  self.__dict__.update(locals()) # assign all variables to self
569
606
 
570
- def forward(self, im, augment=False, visualize=False, embed=None, **kwargs):
607
+ def forward(
608
+ self,
609
+ im: torch.Tensor,
610
+ augment: bool = False,
611
+ visualize: bool = False,
612
+ embed: Optional[List] = None,
613
+ **kwargs: Any,
614
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
571
615
  """
572
- Runs inference on the YOLOv8 MultiBackend model.
616
+ Run inference on an AutoBackend model.
573
617
 
574
618
  Args:
575
619
  im (torch.Tensor): The image tensor to perform inference on.
576
620
  augment (bool): Whether to perform data augmentation during inference.
577
621
  visualize (bool): Whether to visualize the output predictions.
578
- embed (list | None): A list of feature vectors/embeddings to return.
622
+ embed (list, optional): A list of feature vectors/embeddings to return.
579
623
  **kwargs (Any): Additional keyword arguments for model configuration.
580
624
 
581
625
  Returns:
@@ -632,7 +676,7 @@ class AutoBackend(nn.Module):
632
676
  results = [None] * n # preallocate list with None to match the number of images
633
677
 
634
678
  def callback(request, userdata):
635
- """Places result in preallocated list using userdata index."""
679
+ """Place result in preallocated list using userdata index."""
636
680
  results[userdata] = request.results
637
681
 
638
682
  # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image
@@ -780,7 +824,7 @@ class AutoBackend(nn.Module):
780
824
  else:
781
825
  return self.from_numpy(y)
782
826
 
783
- def from_numpy(self, x):
827
+ def from_numpy(self, x: np.ndarray) -> torch.Tensor:
784
828
  """
785
829
  Convert a numpy array to a tensor.
786
830
 
@@ -792,7 +836,7 @@ class AutoBackend(nn.Module):
792
836
  """
793
837
  return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
794
838
 
795
- def warmup(self, imgsz=(1, 3, 640, 640)):
839
+ def warmup(self, imgsz: Tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:
796
840
  """
797
841
  Warm up the model by running one forward pass with a dummy input.
798
842
 
@@ -808,9 +852,9 @@ class AutoBackend(nn.Module):
808
852
  self.forward(im) # warmup
809
853
 
810
854
  @staticmethod
811
- def _model_type(p="path/to/model.pt"):
855
+ def _model_type(p: str = "path/to/model.pt") -> List[bool]:
812
856
  """
813
- Takes a path to a model file and returns the model type.
857
+ Take a path to a model file and return the model type.
814
858
 
815
859
  Args:
816
860
  p (str): Path to the model file.
@@ -1,12 +1,12 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """
3
- Ultralytics modules.
3
+ Ultralytics neural network modules.
4
4
 
5
- This module provides access to various neural network components used in Ultralytics models, including convolution blocks,
6
- attention mechanisms, transformer components, and detection/segmentation heads.
5
+ This module provides access to various neural network components used in Ultralytics models, including convolution
6
+ blocks, attention mechanisms, transformer components, and detection/segmentation heads.
7
7
 
8
8
  Examples:
9
- Visualize a module with Netron.
9
+ Visualize a module with Netron
10
10
  >>> from ultralytics.nn.modules import *
11
11
  >>> import torch
12
12
  >>> import os
@@ -10,7 +10,7 @@ class AGLU(nn.Module):
10
10
  Unified activation function module from AGLU.
11
11
 
12
12
  This class implements a parameterized activation function with learnable parameters lambda and kappa, based on the
13
- AGLU (Adaptive Gated Linear Unit) approach (https://github.com/kostas1515/AGLU).
13
+ AGLU (Adaptive Gated Linear Unit) approach.
14
14
 
15
15
  Attributes:
16
16
  act (nn.Softplus): Softplus activation function with negative beta.
@@ -27,6 +27,9 @@ class AGLU(nn.Module):
27
27
  >>> output = m(input)
28
28
  >>> print(output.shape)
29
29
  torch.Size([2])
30
+
31
+ References:
32
+ https://github.com/kostas1515/AGLU
30
33
  """
31
34
 
32
35
  def __init__(self, device=None, dtype=None) -> None: