dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,10 @@ from .val import YOLOEDetectValidator
19
19
 
20
20
 
21
21
  class YOLOETrainer(DetectionTrainer):
22
- """
23
- A trainer class for YOLOE object detection models.
22
+ """A trainer class for YOLOE object detection models.
24
23
 
25
- This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,
26
- including custom model initialization, validation, and dataset building with multi-modal support.
24
+ This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
25
+ model initialization, validation, and dataset building with multi-modal support.
27
26
 
28
27
  Attributes:
29
28
  loss_names (tuple): Names of loss components used during training.
@@ -35,8 +34,7 @@ class YOLOETrainer(DetectionTrainer):
35
34
  """
36
35
 
37
36
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
38
- """
39
- Initialize the YOLOE Trainer with specified configurations.
37
+ """Initialize the YOLOE Trainer with specified configurations.
40
38
 
41
39
  Args:
42
40
  cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
@@ -50,12 +48,11 @@ class YOLOETrainer(DetectionTrainer):
50
48
  super().__init__(cfg, overrides, _callbacks)
51
49
 
52
50
  def get_model(self, cfg=None, weights=None, verbose: bool = True):
53
- """
54
- Return a YOLOEModel initialized with the specified configuration and weights.
51
+ """Return a YOLOEModel initialized with the specified configuration and weights.
55
52
 
56
53
  Args:
57
- cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
58
- a direct path to a YAML file, or None to use default configuration.
54
+ cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
55
+ path to a YAML file, or None to use default configuration.
59
56
  weights (str | Path, optional): Path to pretrained weights file to load into the model.
60
57
  verbose (bool): Whether to display model information during initialization.
61
58
 
@@ -88,8 +85,7 @@ class YOLOETrainer(DetectionTrainer):
88
85
  )
89
86
 
90
87
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
91
- """
92
- Build YOLO Dataset.
88
+ """Build YOLO Dataset.
93
89
 
94
90
  Args:
95
91
  img_path (str): Path to the folder containing images.
@@ -106,19 +102,17 @@ class YOLOETrainer(DetectionTrainer):
106
102
 
107
103
 
108
104
  class YOLOEPETrainer(DetectionTrainer):
109
- """
110
- Fine-tune YOLOE model using linear probing approach.
105
+ """Fine-tune YOLOE model using linear probing approach.
111
106
 
112
- This trainer freezes most model layers and only trains specific projection layers for efficient
113
- fine-tuning on new datasets while preserving pretrained features.
107
+ This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
108
+ datasets while preserving pretrained features.
114
109
 
115
110
  Methods:
116
111
  get_model: Initialize YOLOEModel with frozen layers except projection layers.
117
112
  """
118
113
 
119
114
  def get_model(self, cfg=None, weights=None, verbose: bool = True):
120
- """
121
- Return YOLOEModel initialized with specified config and weights.
115
+ """Return YOLOEModel initialized with specified config and weights.
122
116
 
123
117
  Args:
124
118
  cfg (dict | str, optional): Model configuration.
@@ -153,18 +147,22 @@ class YOLOEPETrainer(DetectionTrainer):
153
147
  model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
154
148
  model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
155
149
  model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
156
- del model.pe
150
+
151
+ if getattr(model.model[-1], "one2one_cv3", None) is not None:
152
+ model.model[-1].one2one_cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
153
+ model.model[-1].one2one_cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
154
+ model.model[-1].one2one_cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
155
+
157
156
  model.train()
158
157
 
159
158
  return model
160
159
 
161
160
 
162
161
  class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
163
- """
164
- Train YOLOE models from scratch with text embedding support.
162
+ """Train YOLOE models from scratch with text embedding support.
165
163
 
166
- This trainer combines YOLOE training capabilities with world training features, enabling
167
- training from scratch with text embeddings and grounding datasets.
164
+ This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
165
+ text embeddings and grounding datasets.
168
166
 
169
167
  Methods:
170
168
  build_dataset: Build datasets for training with grounding support.
@@ -172,11 +170,10 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
172
170
  """
173
171
 
174
172
  def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
175
- """
176
- Build YOLO Dataset for training or validation.
173
+ """Build YOLO Dataset for training or validation.
177
174
 
178
- This method constructs appropriate datasets based on the mode and input paths, handling both
179
- standard YOLO datasets and grounding datasets with different formats.
175
+ This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
176
+ datasets and grounding datasets with different formats.
180
177
 
181
178
  Args:
182
179
  img_path (list[str] | str): Path to the folder containing images or list of paths.
@@ -189,8 +186,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
189
186
  return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
190
187
 
191
188
  def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
192
- """
193
- Generate text embeddings for a list of text samples.
189
+ """Generate text embeddings for a list of text samples.
194
190
 
195
191
  Args:
196
192
  texts (list[str]): List of text samples to encode.
@@ -216,11 +212,10 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
216
212
 
217
213
 
218
214
  class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
219
- """
220
- Train prompt-free YOLOE model.
215
+ """Train prompt-free YOLOE model.
221
216
 
222
- This trainer combines linear probing capabilities with from-scratch training for prompt-free
223
- YOLOE models that don't require text prompts during inference.
217
+ This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
218
+ require text prompts during inference.
224
219
 
225
220
  Methods:
226
221
  get_validator: Return standard DetectionValidator for validation.
@@ -240,12 +235,11 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
240
235
  return DetectionTrainer.preprocess_batch(self, batch)
241
236
 
242
237
  def set_text_embeddings(self, datasets, batch: int):
243
- """
244
- Set text embeddings for datasets to accelerate training by caching category names.
238
+ """Set text embeddings for datasets to accelerate training by caching category names.
245
239
 
246
- This method collects unique category names from all datasets, generates text embeddings for them,
247
- and caches these embeddings to improve training efficiency. The embeddings are stored in a file
248
- in the parent directory of the first dataset's image path.
240
+ This method collects unique category names from all datasets, generates text embeddings for them, and caches
241
+ these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
242
+ the first dataset's image path.
249
243
 
250
244
  Args:
251
245
  datasets (list[Dataset]): List of datasets containing category names to process.
@@ -260,19 +254,17 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
260
254
 
261
255
 
262
256
  class YOLOEVPTrainer(YOLOETrainerFromScratch):
263
- """
264
- Train YOLOE model with visual prompts.
257
+ """Train YOLOE model with visual prompts.
265
258
 
266
- This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,
267
- where visual cues are provided alongside images to guide the detection process.
259
+ This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
260
+ alongside images to guide the detection process.
268
261
 
269
262
  Methods:
270
263
  build_dataset: Build dataset with visual prompt loading transforms.
271
264
  """
272
265
 
273
266
  def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
274
- """
275
- Build YOLO Dataset for training or validation with visual prompts.
267
+ """Build YOLO Dataset for training or validation with visual prompts.
276
268
 
277
269
  Args:
278
270
  img_path (list[str] | str): Path to the folder containing images or list of paths.
@@ -11,8 +11,7 @@ from .val import YOLOESegValidator
11
11
 
12
12
 
13
13
  class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
14
- """
15
- Trainer class for YOLOE segmentation models.
14
+ """Trainer class for YOLOE segmentation models.
16
15
 
17
16
  This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
18
17
  segmentation models, enabling both object detection and instance segmentation capabilities.
@@ -24,8 +23,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
24
23
  """
25
24
 
26
25
  def get_model(self, cfg=None, weights=None, verbose=True):
27
- """
28
- Return YOLOESegModel initialized with specified config and weights.
26
+ """Return YOLOESegModel initialized with specified config and weights.
29
27
 
30
28
  Args:
31
29
  cfg (dict | str, optional): Model configuration dictionary or YAML file path.
@@ -49,8 +47,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
49
47
  return model
50
48
 
51
49
  def get_validator(self):
52
- """
53
- Create and return a validator for YOLOE segmentation model evaluation.
50
+ """Create and return a validator for YOLOE segmentation model evaluation.
54
51
 
55
52
  Returns:
56
53
  (YOLOESegValidator): Validator for YOLOE segmentation models.
@@ -62,8 +59,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
62
59
 
63
60
 
64
61
  class YOLOEPESegTrainer(SegmentationTrainer):
65
- """
66
- Fine-tune YOLOESeg model in linear probing way.
62
+ """Fine-tune YOLOESeg model in linear probing way.
67
63
 
68
64
  This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
69
65
  most of the model and only training specific layers for efficient adaptation to new tasks.
@@ -73,8 +69,7 @@ class YOLOEPESegTrainer(SegmentationTrainer):
73
69
  """
74
70
 
75
71
  def get_model(self, cfg=None, weights=None, verbose=True):
76
- """
77
- Return YOLOESegModel initialized with specified config and weights for linear probing.
72
+ """Return YOLOESegModel initialized with specified config and weights for linear probing.
78
73
 
79
74
  Args:
80
75
  cfg (dict | str, optional): Model configuration dictionary or YAML file path.
@@ -109,7 +104,12 @@ class YOLOEPESegTrainer(SegmentationTrainer):
109
104
  model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
110
105
  model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
111
106
  model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
112
- del model.pe
107
+
108
+ if getattr(model.model[-1], "one2one_cv3", None) is not None:
109
+ model.model[-1].one2one_cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
110
+ model.model[-1].one2one_cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
111
+ model.model[-1].one2one_cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
112
+
113
113
  model.train()
114
114
 
115
115
  return model
@@ -21,12 +21,11 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
21
21
 
22
22
 
23
23
  class YOLOEDetectValidator(DetectionValidator):
24
- """
25
- A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
24
+ """A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
26
25
 
27
- This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.
28
- It supports validation using either text prompts or visual prompt embeddings extracted from training samples,
29
- enabling flexible evaluation strategies for prompt-based object detection.
26
+ This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
27
+ validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
28
+ evaluation strategies for prompt-based object detection.
30
29
 
31
30
  Attributes:
32
31
  device (torch.device): The device on which validation is performed.
@@ -50,12 +49,11 @@ class YOLOEDetectValidator(DetectionValidator):
50
49
 
51
50
  @smart_inference_mode()
52
51
  def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
53
- """
54
- Extract visual prompt embeddings from training samples.
52
+ """Extract visual prompt embeddings from training samples.
55
53
 
56
- This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
57
- It normalizes the embeddings and handles cases where no samples exist for a class by setting their
58
- embeddings to zero.
54
+ This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
55
+ normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
56
+ zero.
59
57
 
60
58
  Args:
61
59
  dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
@@ -99,12 +97,10 @@ class YOLOEDetectValidator(DetectionValidator):
99
97
  return visual_pe.unsqueeze(0)
100
98
 
101
99
  def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
102
- """
103
- Create a dataloader for LVIS training visual prompt samples.
100
+ """Create a dataloader for LVIS training visual prompt samples.
104
101
 
105
- This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
106
- It applies necessary transformations including LoadVisualPrompt and configurations to the dataset
107
- for validation purposes.
102
+ This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
103
+ necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
108
104
 
109
105
  Args:
110
106
  data (dict): Dataset configuration dictionary containing paths and settings.
@@ -141,12 +137,11 @@ class YOLOEDetectValidator(DetectionValidator):
141
137
  refer_data: str | None = None,
142
138
  load_vp: bool = False,
143
139
  ) -> dict[str, Any]:
144
- """
145
- Run validation on the model using either text or visual prompt embeddings.
140
+ """Run validation on the model using either text or visual prompt embeddings.
146
141
 
147
- This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
148
- It supports validation during training (using a trainer object) or standalone validation with a provided
149
- model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.
142
+ This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
143
+ supports validation during training (using a trainer object) or standalone validation with a provided model. For
144
+ visual prompts, reference data can be specified to extract embeddings from a different dataset.
150
145
 
151
146
  Args:
152
147
  trainer (object, optional): Trainer object containing the model and device.
@@ -14,14 +14,14 @@ from .tasks import (
14
14
  )
15
15
 
16
16
  __all__ = (
17
+ "BaseModel",
18
+ "ClassificationModel",
19
+ "DetectionModel",
20
+ "SegmentationModel",
21
+ "guess_model_scale",
22
+ "guess_model_task",
17
23
  "load_checkpoint",
18
24
  "parse_model",
19
- "yaml_model_load",
20
- "guess_model_task",
21
- "guess_model_scale",
22
25
  "torch_safe_load",
23
- "DetectionModel",
24
- "SegmentationModel",
25
- "ClassificationModel",
26
- "BaseModel",
26
+ "yaml_model_load",
27
27
  )