dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- import itertools
3
+ from __future__ import annotations
4
+
4
5
  from copy import copy, deepcopy
6
+ from pathlib import Path
5
7
 
6
8
  import torch
7
9
 
@@ -10,21 +12,29 @@ from ultralytics.data.augment import LoadVisualPrompt
10
12
  from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
11
13
  from ultralytics.nn.tasks import YOLOEModel
12
14
  from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
13
- from ultralytics.utils.torch_utils import de_parallel
15
+ from ultralytics.utils.torch_utils import unwrap_model
14
16
 
15
17
  from ..world.train_world import WorldTrainerFromScratch
16
18
  from .val import YOLOEDetectValidator
17
19
 
18
20
 
19
21
  class YOLOETrainer(DetectionTrainer):
20
- """A base trainer for YOLOE training."""
22
+ """A trainer class for YOLOE object detection models.
21
23
 
22
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
23
- """
24
- Initialize the YOLOE Trainer with specified configurations.
24
+ This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
25
+ model initialization, validation, and dataset building with multi-modal support.
26
+
27
+ Attributes:
28
+ loss_names (tuple): Names of loss components used during training.
25
29
 
26
- This method sets up the YOLOE trainer with the provided configuration and overrides, initializing
27
- the training environment, model, and callbacks for YOLOE object detection training.
30
+ Methods:
31
+ get_model: Initialize and return a YOLOEModel with specified configuration.
32
+ get_validator: Return a YOLOEDetectValidator for model validation.
33
+ build_dataset: Build YOLO dataset with multi-modal support for training.
34
+ """
35
+
36
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
37
+ """Initialize the YOLOE Trainer with specified configurations.
28
38
 
29
39
  Args:
30
40
  cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
@@ -33,17 +43,17 @@ class YOLOETrainer(DetectionTrainer):
33
43
  """
34
44
  if overrides is None:
35
45
  overrides = {}
46
+ assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
36
47
  overrides["overlap_mask"] = False
37
48
  super().__init__(cfg, overrides, _callbacks)
38
49
 
39
- def get_model(self, cfg=None, weights=None, verbose=True):
40
- """
41
- Return a YOLOEModel initialized with the specified configuration and weights.
50
+ def get_model(self, cfg=None, weights=None, verbose: bool = True):
51
+ """Return a YOLOEModel initialized with the specified configuration and weights.
42
52
 
43
53
  Args:
44
- cfg (dict | str | None): Model configuration. Can be a dictionary containing a 'yaml_file' key,
45
- a direct path to a YAML file, or None to use default configuration.
46
- weights (str | Path | None): Path to pretrained weights file to load into the model.
54
+ cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
55
+ path to a YAML file, or None to use default configuration.
56
+ weights (str | Path, optional): Path to pretrained weights file to load into the model.
47
57
  verbose (bool): Whether to display model information during initialization.
48
58
 
49
59
  Returns:
@@ -68,36 +78,41 @@ class YOLOETrainer(DetectionTrainer):
68
78
  return model
69
79
 
70
80
  def get_validator(self):
71
- """Returns a DetectionValidator for YOLO model validation."""
81
+ """Return a YOLOEDetectValidator for YOLOE model validation."""
72
82
  self.loss_names = "box", "cls", "dfl"
73
83
  return YOLOEDetectValidator(
74
84
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
75
85
  )
76
86
 
77
- def build_dataset(self, img_path, mode="train", batch=None):
78
- """
79
- Build YOLO Dataset.
87
+ def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
88
+ """Build YOLO Dataset.
80
89
 
81
90
  Args:
82
91
  img_path (str): Path to the folder containing images.
83
- mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
84
- batch (int, optional): Size of batches, this is for `rect`.
92
+ mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
93
+ batch (int, optional): Size of batches, this is for rectangular training.
85
94
 
86
95
  Returns:
87
96
  (Dataset): YOLO dataset configured for training or validation.
88
97
  """
89
- gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
98
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
90
99
  return build_yolo_dataset(
91
100
  self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
92
101
  )
93
102
 
94
103
 
95
104
  class YOLOEPETrainer(DetectionTrainer):
96
- """Fine-tune YOLOE model in linear probing way."""
105
+ """Fine-tune YOLOE model using linear probing approach.
97
106
 
98
- def get_model(self, cfg=None, weights=None, verbose=True):
99
- """
100
- Return YOLOEModel initialized with specified config and weights.
107
+ This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
108
+ datasets while preserving pretrained features.
109
+
110
+ Methods:
111
+ get_model: Initialize YOLOEModel with frozen layers except projection layers.
112
+ """
113
+
114
+ def get_model(self, cfg=None, weights=None, verbose: bool = True):
115
+ """Return YOLOEModel initialized with specified config and weights.
101
116
 
102
117
  Args:
103
118
  cfg (dict | str, optional): Model configuration.
@@ -139,17 +154,24 @@ class YOLOEPETrainer(DetectionTrainer):
139
154
 
140
155
 
141
156
  class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
142
- """Train YOLOE models from scratch."""
157
+ """Train YOLOE models from scratch with text embedding support.
143
158
 
144
- def build_dataset(self, img_path, mode="train", batch=None):
145
- """
146
- Build YOLO Dataset for training or validation.
159
+ This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
160
+ text embeddings and grounding datasets.
161
+
162
+ Methods:
163
+ build_dataset: Build datasets for training with grounding support.
164
+ generate_text_embeddings: Generate and cache text embeddings for training.
165
+ """
166
+
167
+ def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
168
+ """Build YOLO Dataset for training or validation.
147
169
 
148
- This method constructs appropriate datasets based on the mode and input paths, handling both
149
- standard YOLO datasets and grounding datasets with different formats.
170
+ This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
171
+ datasets and grounding datasets with different formats.
150
172
 
151
173
  Args:
152
- img_path (List[str] | str): Path to the folder containing images or list of paths.
174
+ img_path (list[str] | str): Path to the folder containing images or list of paths.
153
175
  mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
154
176
  batch (int, optional): Size of batches, used for rectangular training/validation.
155
177
 
@@ -158,22 +180,11 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
158
180
  """
159
181
  return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
160
182
 
161
- def preprocess_batch(self, batch):
162
- """Process batch for training, moving text features to the appropriate device."""
163
- batch = DetectionTrainer.preprocess_batch(self, batch)
164
-
165
- texts = list(itertools.chain(*batch["texts"]))
166
- txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
167
- txt_feats = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
168
- batch["txt_feats"] = txt_feats
169
- return batch
170
-
171
- def generate_text_embeddings(self, texts, batch, cache_dir):
172
- """
173
- Generate text embeddings for a list of text samples.
183
+ def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
184
+ """Generate text embeddings for a list of text samples.
174
185
 
175
186
  Args:
176
- texts (List[str]): List of text samples to encode.
187
+ texts (list[str]): List of text samples to encode.
177
188
  batch (int): Batch size for processing.
178
189
  cache_dir (Path): Directory to save/load cached embeddings.
179
190
 
@@ -184,42 +195,49 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
184
195
  cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
185
196
  if cache_path.exists():
186
197
  LOGGER.info(f"Reading existed cache from '{cache_path}'")
187
- txt_map = torch.load(cache_path)
198
+ txt_map = torch.load(cache_path, map_location=self.device)
188
199
  if sorted(txt_map.keys()) == sorted(texts):
189
200
  return txt_map
190
201
  LOGGER.info(f"Caching text embeddings to '{cache_path}'")
191
202
  assert self.model is not None
192
- txt_feats = self.model.get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
203
+ txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
193
204
  txt_map = dict(zip(texts, txt_feats.squeeze(0)))
194
205
  torch.save(txt_map, cache_path)
195
206
  return txt_map
196
207
 
197
208
 
198
209
  class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
199
- """Train prompt-free YOLOE model."""
210
+ """Train prompt-free YOLOE model.
211
+
212
+ This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
213
+ require text prompts during inference.
214
+
215
+ Methods:
216
+ get_validator: Return standard DetectionValidator for validation.
217
+ preprocess_batch: Preprocess batches without text features.
218
+ set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).
219
+ """
200
220
 
201
221
  def get_validator(self):
202
- """Returns a DetectionValidator for YOLO model validation."""
222
+ """Return a DetectionValidator for YOLO model validation."""
203
223
  self.loss_names = "box", "cls", "dfl"
204
224
  return DetectionValidator(
205
225
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
206
226
  )
207
227
 
208
228
  def preprocess_batch(self, batch):
209
- """Preprocesses a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
210
- batch = DetectionTrainer.preprocess_batch(self, batch)
211
- return batch
229
+ """Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
230
+ return DetectionTrainer.preprocess_batch(self, batch)
212
231
 
213
- def set_text_embeddings(self, datasets, batch):
214
- """
215
- Set text embeddings for datasets to accelerate training by caching category names.
232
+ def set_text_embeddings(self, datasets, batch: int):
233
+ """Set text embeddings for datasets to accelerate training by caching category names.
216
234
 
217
- This method collects unique category names from all datasets, generates text embeddings for them,
218
- and caches these embeddings to improve training efficiency. The embeddings are stored in a file
219
- in the parent directory of the first dataset's image path.
235
+ This method collects unique category names from all datasets, generates text embeddings for them, and caches
236
+ these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
237
+ the first dataset's image path.
220
238
 
221
239
  Args:
222
- datasets (List[Dataset]): List of datasets containing category names to process.
240
+ datasets (list[Dataset]): List of datasets containing category names to process.
223
241
  batch (int): Batch size for processing text embeddings.
224
242
 
225
243
  Notes:
@@ -231,14 +249,20 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
231
249
 
232
250
 
233
251
  class YOLOEVPTrainer(YOLOETrainerFromScratch):
234
- """Train YOLOE model with visual prompts."""
252
+ """Train YOLOE model with visual prompts.
235
253
 
236
- def build_dataset(self, img_path, mode="train", batch=None):
237
- """
238
- Build YOLO Dataset for training or validation with visual prompts.
254
+ This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
255
+ alongside images to guide the detection process.
256
+
257
+ Methods:
258
+ build_dataset: Build dataset with visual prompt loading transforms.
259
+ """
260
+
261
+ def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
262
+ """Build YOLO Dataset for training or validation with visual prompts.
239
263
 
240
264
  Args:
241
- img_path (List[str] | str): Path to the folder containing images or list of paths.
265
+ img_path (list[str] | str): Path to the folder containing images or list of paths.
242
266
  mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
243
267
  batch (int, optional): Size of batches, used for rectangular training/validation.
244
268
 
@@ -261,9 +285,3 @@ class YOLOEVPTrainer(YOLOETrainerFromScratch):
261
285
  d.transforms.append(LoadVisualPrompt())
262
286
  else:
263
287
  self.train_loader.dataset.transforms.append(LoadVisualPrompt())
264
-
265
- def preprocess_batch(self, batch):
266
- """Preprocesses a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
267
- batch = super().preprocess_batch(batch)
268
- batch["visuals"] = batch["visuals"].to(self.device)
269
- return batch
@@ -11,11 +11,10 @@ from .val import YOLOESegValidator
11
11
 
12
12
 
13
13
  class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
14
- """
15
- Trainer class for YOLOE segmentation models.
14
+ """Trainer class for YOLOE segmentation models.
16
15
 
17
- This class combines YOLOETrainer and SegmentationTrainer to provide training functionality
18
- specifically for YOLOE segmentation models.
16
+ This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
17
+ segmentation models, enabling both object detection and instance segmentation capabilities.
19
18
 
20
19
  Attributes:
21
20
  cfg (dict): Configuration dictionary with training parameters.
@@ -24,11 +23,10 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
24
23
  """
25
24
 
26
25
  def get_model(self, cfg=None, weights=None, verbose=True):
27
- """
28
- Return YOLOESegModel initialized with specified config and weights.
26
+ """Return YOLOESegModel initialized with specified config and weights.
29
27
 
30
28
  Args:
31
- cfg (dict | str): Model configuration dictionary or YAML file path.
29
+ cfg (dict | str, optional): Model configuration dictionary or YAML file path.
32
30
  weights (str, optional): Path to pretrained weights file.
33
31
  verbose (bool): Whether to display model information.
34
32
 
@@ -49,8 +47,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
49
47
  return model
50
48
 
51
49
  def get_validator(self):
52
- """
53
- Create and return a validator for YOLOE segmentation model evaluation.
50
+ """Create and return a validator for YOLOE segmentation model evaluation.
54
51
 
55
52
  Returns:
56
53
  (YOLOESegValidator): Validator for YOLOE segmentation models.
@@ -62,19 +59,20 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
62
59
 
63
60
 
64
61
  class YOLOEPESegTrainer(SegmentationTrainer):
65
- """
66
- Fine-tune YOLOESeg model in linear probing way.
62
+ """Fine-tune YOLOESeg model in linear probing way.
67
63
 
68
64
  This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
69
- most of the model and only training specific layers.
65
+ most of the model and only training specific layers for efficient adaptation to new tasks.
66
+
67
+ Attributes:
68
+ data (dict): Dataset configuration containing channels, class names, and number of classes.
70
69
  """
71
70
 
72
71
  def get_model(self, cfg=None, weights=None, verbose=True):
73
- """
74
- Return YOLOESegModel initialized with specified config and weights for linear probing.
72
+ """Return YOLOESegModel initialized with specified config and weights for linear probing.
75
73
 
76
74
  Args:
77
- cfg (dict | str): Model configuration dictionary or YAML file path.
75
+ cfg (dict | str, optional): Model configuration dictionary or YAML file path.
78
76
  weights (str, optional): Path to pretrained weights file.
79
77
  verbose (bool): Whether to display model information.
80
78
 
@@ -113,12 +111,12 @@ class YOLOEPESegTrainer(SegmentationTrainer):
113
111
 
114
112
 
115
113
  class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
116
- """Trainer for YOLOE segmentation from scratch."""
114
+ """Trainer for YOLOE segmentation models trained from scratch without pretrained weights."""
117
115
 
118
116
  pass
119
117
 
120
118
 
121
119
  class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
122
- """Trainer for YOLOE segmentation with VP."""
120
+ """Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities."""
123
121
 
124
122
  pass
@@ -1,6 +1,10 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from copy import deepcopy
6
+ from pathlib import Path
7
+ from typing import Any
4
8
 
5
9
  import torch
6
10
  from torch.nn import functional as F
@@ -17,27 +21,39 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
17
21
 
18
22
 
19
23
  class YOLOEDetectValidator(DetectionValidator):
20
- """
21
- A mixin class for YOLOE model validation that handles both text and visual prompt embeddings.
24
+ """A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
22
25
 
23
- This mixin provides functionality to validate YOLOE models using either text or visual prompt embeddings.
24
- It includes methods for extracting visual prompt embeddings from samples, preprocessing batches, and
25
- running validation with different prompt types.
26
+ This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
27
+ validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
28
+ evaluation strategies for prompt-based object detection.
26
29
 
27
30
  Attributes:
28
31
  device (torch.device): The device on which validation is performed.
29
32
  args (namespace): Configuration arguments for validation.
30
33
  dataloader (DataLoader): DataLoader for validation data.
34
+
35
+ Methods:
36
+ get_visual_pe: Extract visual prompt embeddings from training samples.
37
+ preprocess: Preprocess batch data ensuring visuals are on the same device as images.
38
+ get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.
39
+ __call__: Run validation using either text or visual prompt embeddings.
40
+
41
+ Examples:
42
+ Validate with text prompts
43
+ >>> validator = YOLOEDetectValidator()
44
+ >>> stats = validator(model=model, load_vp=False)
45
+
46
+ Validate with visual prompts
47
+ >>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
31
48
  """
32
49
 
33
50
  @smart_inference_mode()
34
- def get_visual_pe(self, dataloader, model):
35
- """
36
- Extract visual prompt embeddings from training samples.
51
+ def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
52
+ """Extract visual prompt embeddings from training samples.
37
53
 
38
- This function processes a dataloader to compute visual prompt embeddings for each class
39
- using a YOLOE model. It normalizes the embeddings and handles cases where no samples
40
- exist for a class.
54
+ This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
55
+ normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
56
+ zero.
41
57
 
42
58
  Args:
43
59
  dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
@@ -47,12 +63,13 @@ class YOLOEDetectValidator(DetectionValidator):
47
63
  (torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
48
64
  """
49
65
  assert isinstance(model, YOLOEModel)
50
- names = [name.split("/")[0] for name in list(dataloader.dataset.data["names"].values())]
66
+ names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
51
67
  visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
52
68
  cls_visual_num = torch.zeros(len(names))
53
69
 
54
70
  desc = "Get visual prompt embeddings from samples"
55
71
 
72
+ # Count samples per class
56
73
  for batch in dataloader:
57
74
  cls = batch["cls"].squeeze(-1).to(torch.int).unique()
58
75
  count = torch.bincount(cls, minlength=len(names))
@@ -60,6 +77,7 @@ class YOLOEDetectValidator(DetectionValidator):
60
77
 
61
78
  cls_visual_num = cls_visual_num.to(self.device)
62
79
 
80
+ # Extract visual prompt embeddings
63
81
  pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
64
82
  for batch in pbar:
65
83
  batch = self.preprocess(batch)
@@ -69,34 +87,26 @@ class YOLOEDetectValidator(DetectionValidator):
69
87
  for i in range(preds.shape[0]):
70
88
  cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
71
89
  pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
72
- pad_cls[: len(cls)] = cls
90
+ pad_cls[: cls.shape[0]] = cls
73
91
  for c in cls:
74
92
  visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
75
93
 
94
+ # Normalize embeddings for classes with samples, set others to zero
76
95
  visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
77
96
  visual_pe[cls_visual_num == 0] = 0
78
97
  return visual_pe.unsqueeze(0)
79
98
 
80
- def preprocess(self, batch):
81
- """Preprocess batch data, ensuring visuals are on the same device as images."""
82
- batch = super().preprocess(batch)
83
- if "visuals" in batch:
84
- batch["visuals"] = batch["visuals"].to(batch["img"].device)
85
- return batch
86
-
87
- def get_vpe_dataloader(self, data):
88
- """
89
- Create a dataloader for LVIS training visual prompt samples.
99
+ def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
100
+ """Create a dataloader for LVIS training visual prompt samples.
90
101
 
91
- This function prepares a dataloader for visual prompt embeddings (VPE) using the LVIS dataset.
92
- It applies necessary transformations and configurations to the dataset and returns a dataloader
93
- for validation purposes.
102
+ This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
103
+ necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
94
104
 
95
105
  Args:
96
106
  data (dict): Dataset configuration dictionary containing paths and settings.
97
107
 
98
108
  Returns:
99
- (torch.utils.data.DataLoader): The dataLoader for visual prompt samples.
109
+ (torch.utils.data.DataLoader): The dataloader for visual prompt samples.
100
110
  """
101
111
  dataset = build_yolo_dataset(
102
112
  self.args,
@@ -120,17 +130,22 @@ class YOLOEDetectValidator(DetectionValidator):
120
130
  )
121
131
 
122
132
  @smart_inference_mode()
123
- def __call__(self, trainer=None, model=None, refer_data=None, load_vp=False):
124
- """
125
- Run validation on the model using either text or visual prompt embeddings.
126
-
127
- This method validates the model using either text prompts or visual prompts, depending
128
- on the `load_vp` flag. It supports validation during training (using a trainer object)
129
- or standalone validation with a provided model.
133
+ def __call__(
134
+ self,
135
+ trainer: Any | None = None,
136
+ model: YOLOEModel | str | None = None,
137
+ refer_data: str | None = None,
138
+ load_vp: bool = False,
139
+ ) -> dict[str, Any]:
140
+ """Run validation on the model using either text or visual prompt embeddings.
141
+
142
+ This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
143
+ supports validation during training (using a trainer object) or standalone validation with a provided model. For
144
+ visual prompts, reference data can be specified to extract embeddings from a different dataset.
130
145
 
131
146
  Args:
132
147
  trainer (object, optional): Trainer object containing the model and device.
133
- model (YOLOEModel, optional): Model to validate. Required if `trainer` is not provided.
148
+ model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
134
149
  refer_data (str, optional): Path to reference data for visual prompts.
135
150
  load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
136
151
 
@@ -140,7 +155,7 @@ class YOLOEDetectValidator(DetectionValidator):
140
155
  if trainer is not None:
141
156
  self.device = trainer.device
142
157
  model = trainer.ema.ema
143
- names = [name.split("/")[0] for name in list(self.dataloader.dataset.data["names"].values())]
158
+ names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]
144
159
 
145
160
  if load_vp:
146
161
  LOGGER.info("Validate using the visual prompt.")
@@ -156,15 +171,15 @@ class YOLOEDetectValidator(DetectionValidator):
156
171
  else:
157
172
  if refer_data is not None:
158
173
  assert load_vp, "Refer data is only used for visual prompt validation."
159
- self.device = select_device(self.args.device)
174
+ self.device = select_device(self.args.device, verbose=False)
160
175
 
161
- if isinstance(model, str):
162
- from ultralytics.nn.tasks import attempt_load_weights
176
+ if isinstance(model, (str, Path)):
177
+ from ultralytics.nn.tasks import load_checkpoint
163
178
 
164
- model = attempt_load_weights(model, device=self.device, inplace=True)
179
+ model, _ = load_checkpoint(model, device=self.device) # model, ckpt
165
180
  model.eval().to(self.device)
166
181
  data = check_det_dataset(refer_data or self.args.data)
167
- names = [name.split("/")[0] for name in list(data["names"].values())]
182
+ names = [name.split("/", 1)[0] for name in list(data["names"].values())]
168
183
 
169
184
  if load_vp:
170
185
  LOGGER.info("Validate using the visual prompt.")
@@ -5,25 +5,23 @@ from .tasks import (
5
5
  ClassificationModel,
6
6
  DetectionModel,
7
7
  SegmentationModel,
8
- attempt_load_one_weight,
9
- attempt_load_weights,
10
8
  guess_model_scale,
11
9
  guess_model_task,
10
+ load_checkpoint,
12
11
  parse_model,
13
12
  torch_safe_load,
14
13
  yaml_model_load,
15
14
  )
16
15
 
17
16
  __all__ = (
18
- "attempt_load_one_weight",
19
- "attempt_load_weights",
20
- "parse_model",
21
- "yaml_model_load",
22
- "guess_model_task",
23
- "guess_model_scale",
24
- "torch_safe_load",
17
+ "BaseModel",
18
+ "ClassificationModel",
25
19
  "DetectionModel",
26
20
  "SegmentationModel",
27
- "ClassificationModel",
28
- "BaseModel",
21
+ "guess_model_scale",
22
+ "guess_model_task",
23
+ "load_checkpoint",
24
+ "parse_model",
25
+ "torch_safe_load",
26
+ "yaml_model_load",
29
27
  )