ultralytics 8.3.97__py3-none-any.whl → 8.3.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. tests/test_python.py +56 -0
  2. ultralytics/__init__.py +3 -2
  3. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  4. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  5. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  6. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  7. ultralytics/data/augment.py +101 -5
  8. ultralytics/data/dataset.py +165 -12
  9. ultralytics/engine/exporter.py +13 -13
  10. ultralytics/engine/trainer.py +16 -7
  11. ultralytics/models/__init__.py +2 -2
  12. ultralytics/models/nas/model.py +1 -0
  13. ultralytics/models/nas/predict.py +4 -24
  14. ultralytics/models/nas/val.py +1 -4
  15. ultralytics/models/yolo/__init__.py +3 -3
  16. ultralytics/models/yolo/detect/val.py +6 -1
  17. ultralytics/models/yolo/model.py +182 -3
  18. ultralytics/models/yolo/segment/val.py +43 -16
  19. ultralytics/models/yolo/yoloe/__init__.py +21 -0
  20. ultralytics/models/yolo/yoloe/predict.py +170 -0
  21. ultralytics/models/yolo/yoloe/train.py +355 -0
  22. ultralytics/models/yolo/yoloe/train_seg.py +141 -0
  23. ultralytics/models/yolo/yoloe/val.py +187 -0
  24. ultralytics/nn/autobackend.py +3 -2
  25. ultralytics/nn/modules/__init__.py +18 -1
  26. ultralytics/nn/modules/block.py +17 -1
  27. ultralytics/nn/modules/head.py +359 -22
  28. ultralytics/nn/tasks.py +276 -10
  29. ultralytics/nn/text_model.py +193 -0
  30. ultralytics/utils/callbacks/comet.py +3 -6
  31. ultralytics/utils/downloads.py +6 -2
  32. ultralytics/utils/instance.py +7 -2
  33. ultralytics/utils/loss.py +67 -6
  34. ultralytics/utils/plotting.py +1 -1
  35. ultralytics/utils/tal.py +1 -1
  36. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/METADATA +69 -67
  37. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/RECORD +41 -31
  38. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
  39. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
  40. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
  41. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,355 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import itertools
4
+ from copy import copy, deepcopy
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
10
+ from ultralytics.data.augment import LoadVisualPrompt
11
+ from ultralytics.data.utils import check_det_dataset
12
+ from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
13
+ from ultralytics.nn.tasks import YOLOEModel
14
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
15
+ from ultralytics.utils.torch_utils import de_parallel
16
+
17
+ from .val import YOLOEDetectValidator
18
+
19
+
20
+ class YOLOETrainer(DetectionTrainer):
21
+ """A base trainer for YOLOE training."""
22
+
23
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
24
+ """
25
+ Initialize the YOLOE Trainer with specified configurations.
26
+
27
+ This method sets up the YOLOE trainer with the provided configuration and overrides, initializing
28
+ the training environment, model, and callbacks for YOLOE object detection training.
29
+
30
+ Args:
31
+ cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
32
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
33
+ _callbacks (list, optional): List of callback functions to be applied during training.
34
+ """
35
+ if overrides is None:
36
+ overrides = {}
37
+ overrides["overlap_mask"] = False
38
+ super().__init__(cfg, overrides, _callbacks)
39
+
40
+ def get_model(self, cfg=None, weights=None, verbose=True):
41
+ """Return YOLOEModel initialized with specified config and weights."""
42
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
43
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
44
+ model = YOLOEModel(
45
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
46
+ ch=3,
47
+ nc=min(self.data["nc"], 80),
48
+ verbose=verbose and RANK == -1,
49
+ )
50
+ if weights:
51
+ model.load(weights)
52
+
53
+ return model
54
+
55
+ def get_validator(self):
56
+ """Returns a DetectionValidator for YOLO model validation."""
57
+ self.loss_names = "box", "cls", "dfl"
58
+ return YOLOEDetectValidator(
59
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
60
+ )
61
+
62
+ def build_dataset(self, img_path, mode="train", batch=None):
63
+ """
64
+ Build YOLO Dataset.
65
+
66
+ Args:
67
+ img_path (str): Path to the folder containing images.
68
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
69
+ batch (int, optional): Size of batches, this is for `rect`.
70
+
71
+ Returns:
72
+ (Dataset): YOLO dataset configured for training or validation.
73
+ """
74
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
75
+ return build_yolo_dataset(
76
+ self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
77
+ )
78
+
79
+ def preprocess_batch(self, batch):
80
+ """Process batch for training, moving text features to the appropriate device."""
81
+ batch = super().preprocess_batch(batch)
82
+ return batch
83
+
84
+
85
+ class YOLOEPETrainer(DetectionTrainer):
86
+ """Fine-tune YOLOE model in linear probing way."""
87
+
88
+ def get_model(self, cfg=None, weights=None, verbose=True):
89
+ """
90
+ Return YOLOEModel initialized with specified config and weights.
91
+
92
+ Args:
93
+ cfg (dict | str, optional): Model configuration.
94
+ weights (str, optional): Path to pretrained weights.
95
+ verbose (bool): Whether to display model information.
96
+
97
+ Returns:
98
+ (YOLOEModel): Initialized model with frozen layers except for specific projection layers.
99
+ """
100
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
101
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
102
+ model = YOLOEModel(
103
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
104
+ ch=3,
105
+ nc=self.data["nc"],
106
+ verbose=verbose and RANK == -1,
107
+ )
108
+
109
+ del model.model[-1].savpe
110
+
111
+ assert weights is not None, "Pretrained weights must be provided for linear probing."
112
+ if weights:
113
+ model.load(weights)
114
+
115
+ model.eval()
116
+ names = list(self.data["names"].values())
117
+ # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
118
+ # it'd get correct results as long as loading proper pretrained weights.
119
+ tpe = model.get_text_pe(names)
120
+ model.set_classes(names, tpe)
121
+ model.model[-1].fuse(model.pe) # fuse text embeddings to classify head
122
+ model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
123
+ model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
124
+ model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
125
+ del model.pe
126
+ model.train()
127
+
128
+ return model
129
+
130
+
131
+ class YOLOETrainerFromScratch(YOLOETrainer):
132
+ """Train YOLOE models from scratch."""
133
+
134
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
135
+ """
136
+ Initialize the YOLOETrainerFromScratch class.
137
+
138
+ This class extends YOLOETrainer to train YOLOE models from scratch. It inherits all functionality from
139
+ the parent class while providing specialized initialization for training without pre-trained weights.
140
+
141
+ Args:
142
+ cfg (dict, optional): Configuration dictionary with training parameters. Defaults to DEFAULT_CFG.
143
+ overrides (dict, optional): Dictionary of parameter overrides for configuration.
144
+ _callbacks (list, optional): List of callback functions to be executed during training.
145
+
146
+ Examples:
147
+ >>> from ultralytics.models.yoloe.train import YOLOETrainerFromScratch
148
+ >>> trainer = YOLOETrainerFromScratch()
149
+ >>> trainer.train()
150
+ """
151
+ if overrides is None:
152
+ overrides = {}
153
+ super().__init__(cfg, overrides, _callbacks)
154
+
155
+ def build_dataset(self, img_path, mode="train", batch=None):
156
+ """
157
+ Build YOLO Dataset for training or validation.
158
+
159
+ This method constructs appropriate datasets based on the mode and input paths, handling both
160
+ standard YOLO datasets and grounding datasets with different formats.
161
+
162
+ Args:
163
+ img_path (List[str] | str): Path to the folder containing images or list of paths.
164
+ mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
165
+ batch (int, optional): Size of batches, used for rectangular training/validation.
166
+
167
+ Returns:
168
+ (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
169
+ """
170
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
171
+ if mode != "train":
172
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=gs)
173
+ datasets = [
174
+ build_yolo_dataset(self.args, im_path, batch, self.training_data[im_path], stride=gs, multi_modal=True)
175
+ if isinstance(im_path, str)
176
+ else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
177
+ for im_path in img_path
178
+ ]
179
+ self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training
180
+ return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
181
+
182
+ def set_text_embeddings(self, datasets, batch):
183
+ """Set text embeddings for datasets to accelerate training by caching category names."""
184
+ # TODO: open up an interface to determine whether to do cache
185
+ category_names = set()
186
+ for dataset in datasets:
187
+ if not hasattr(dataset, "category_names"):
188
+ continue
189
+ category_names |= dataset.category_names
190
+
191
+ # TODO: enable to update the path or use a more general way to get the path
192
+ img_path = datasets[0].img_path
193
+ self.text_embeddings = self.generate_text_embeddings(
194
+ category_names, batch, cache_path=Path(img_path).parent / "text_embeddings.pt"
195
+ )
196
+
197
+ def preprocess_batch(self, batch):
198
+ """Process batch for training, moving text features to the appropriate device."""
199
+ batch = super().preprocess_batch(batch)
200
+
201
+ texts = list(itertools.chain(*batch["texts"]))
202
+ txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
203
+ txt_feats = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
204
+ batch["txt_feats"] = txt_feats
205
+ return batch
206
+
207
+ def generate_text_embeddings(self, texts, batch, cache_path="embeddings.pt"):
208
+ """
209
+ Generate text embeddings for a list of text samples.
210
+
211
+ Args:
212
+ texts (List[str]): List of text samples to encode.
213
+ batch (int): Batch size for processing.
214
+ cache_path (str | Path): Path to save/load cached embeddings.
215
+
216
+ Returns:
217
+ (dict): Dictionary mapping text samples to their embeddings.
218
+ """
219
+ if cache_path.exists():
220
+ return torch.load(cache_path)
221
+ assert self.model is not None
222
+ txt_feats = self.model.get_text_pe(texts, batch, without_reprta=True)
223
+ txt_map = dict(zip(texts, txt_feats.squeeze(0)))
224
+ torch.save(txt_map, cache_path)
225
+ return txt_map
226
+
227
+ def get_dataset(self):
228
+ """
229
+ Get train and validation paths from data dictionary.
230
+
231
+ Processes the data configuration to extract paths for training and validation datasets,
232
+ handling both YOLO detection datasets and grounding datasets.
233
+
234
+ Returns:
235
+ (str): Train dataset path.
236
+ (str): Validation dataset path.
237
+
238
+ Raises:
239
+ AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
240
+ """
241
+ final_data = {}
242
+ data_yaml = self.args.data
243
+ assert data_yaml.get("train", False), "train dataset not found" # object365.yaml
244
+ assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml
245
+ data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
246
+ assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
247
+ val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
248
+ for d in data["val"]:
249
+ if d.get("minival") is None: # for lvis dataset
250
+ continue
251
+ d["minival"] = str(d["path"] / d["minival"])
252
+ for s in ["train", "val"]:
253
+ final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
254
+ # save grounding data if there's one
255
+ grounding_data = data_yaml[s].get("grounding_data")
256
+ if grounding_data is None:
257
+ continue
258
+ grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
259
+ for g in grounding_data:
260
+ assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
261
+ final_data[s] += grounding_data
262
+ # NOTE: to make training work properly, set `nc` and `names`
263
+ final_data["nc"] = data["val"][0]["nc"]
264
+ final_data["names"] = data["val"][0]["names"]
265
+ # NOTE: add path with lvis path
266
+ final_data["path"] = data["val"][0]["path"]
267
+ self.data = final_data
268
+ if self.args.single_cls: # consistent with base trainer
269
+ LOGGER.info("Overriding class names with single class.")
270
+ self.data["names"] = {0: "object"}
271
+ self.data["nc"] = 1
272
+ self.training_data = {}
273
+ for d in data["train"]:
274
+ if self.args.single_cls:
275
+ d["names"] = {0: "object"}
276
+ d["nc"] = 1
277
+ self.training_data[d["train"]] = d
278
+ return final_data["train"], final_data["val"][0]
279
+
280
+ def plot_training_labels(self):
281
+ """Do not plot labels for YOLO-World training."""
282
+ pass
283
+
284
+ def final_eval(self):
285
+ """
286
+ Perform final evaluation on the validation dataset.
287
+
288
+ Configures the validator with the appropriate dataset and split before running evaluation.
289
+
290
+ Returns:
291
+ (dict): Evaluation metrics.
292
+ """
293
+ val = self.args.data["val"]["yolo_data"][0]
294
+ self.validator.args.data = val
295
+ self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
296
+ return super().final_eval()
297
+
298
+
299
+ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
300
+ """Train prompt-free YOLOE model."""
301
+
302
+ def get_validator(self):
303
+ """Returns a DetectionValidator for YOLO model validation."""
304
+ self.loss_names = "box", "cls", "dfl"
305
+ return DetectionValidator(
306
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
307
+ )
308
+
309
+ def preprocess_batch(self, batch):
310
+ """Preprocesses a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
311
+ batch = super(YOLOETrainer, self).preprocess_batch(batch)
312
+ return batch
313
+
314
+ def set_text_embeddings(self, datasets, batch):
315
+ """No need to set text embeddings for prompt-free fine-tuning."""
316
+ pass
317
+
318
+
319
+ class YOLOEVPTrainer(YOLOETrainerFromScratch):
320
+ """Train YOLOE model with visual prompts."""
321
+
322
+ def build_dataset(self, img_path, mode="train", batch=None):
323
+ """
324
+ Build YOLO Dataset for training or validation with visual prompts.
325
+
326
+ Args:
327
+ img_path (List[str] | str): Path to the folder containing images or list of paths.
328
+ mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
329
+ batch (int, optional): Size of batches, used for rectangular training/validation.
330
+
331
+ Returns:
332
+ (Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.
333
+ """
334
+ dataset = super().build_dataset(img_path, mode, batch)
335
+ if isinstance(dataset, YOLOConcatDataset):
336
+ for d in dataset.datasets:
337
+ d.transforms.append(LoadVisualPrompt())
338
+ else:
339
+ dataset.transforms.append(LoadVisualPrompt())
340
+ return dataset
341
+
342
+ def _close_dataloader_mosaic(self):
343
+ """Close mosaic augmentation and add visual prompt loading to the training dataset."""
344
+ super()._close_dataloader_mosaic()
345
+ if isinstance(self.train_loader.dataset, YOLOConcatDataset):
346
+ for d in self.train_loader.dataset.datasets:
347
+ d.transforms.append(LoadVisualPrompt())
348
+ else:
349
+ self.train_loader.dataset.transforms.append(LoadVisualPrompt())
350
+
351
+ def preprocess_batch(self, batch):
352
+ """Preprocesses a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
353
+ batch = super().preprocess_batch(batch)
354
+ batch["visuals"] = batch["visuals"].to(self.device)
355
+ return batch
@@ -0,0 +1,141 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+
4
+ from copy import copy, deepcopy
5
+
6
+ from ultralytics.models.yolo.segment import SegmentationTrainer
7
+ from ultralytics.nn.tasks import YOLOESegModel
8
+ from ultralytics.utils import DEFAULT_CFG, RANK
9
+
10
+ from .train import YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
11
+ from .val import YOLOESegValidator
12
+
13
+
14
+ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
15
+ """
16
+ Trainer class for YOLOE segmentation models.
17
+
18
+ This class combines YOLOETrainer and SegmentationTrainer to provide training functionality
19
+ specifically for YOLOE segmentation models.
20
+
21
+ Attributes:
22
+ cfg (dict): Configuration dictionary with training parameters.
23
+ overrides (dict): Dictionary with parameter overrides.
24
+ _callbacks (list): List of callback functions for training events.
25
+ """
26
+
27
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
28
+ """
29
+ Initialize the YOLOESegTrainer class.
30
+
31
+ This class combines YOLOETrainer and SegmentationTrainer to provide training functionality
32
+ specifically for YOLOE segmentation models.
33
+
34
+ Args:
35
+ cfg (Dict): Configuration dictionary with training parameters.
36
+ overrides (Dict, optional): Dictionary with parameter overrides.
37
+ _callbacks (List, optional): List of callback functions for training events.
38
+ """
39
+ if overrides is None:
40
+ overrides = {}
41
+ super().__init__(cfg, overrides, _callbacks)
42
+
43
+ def get_model(self, cfg=None, weights=None, verbose=True):
44
+ """
45
+ Return YOLOESegModel initialized with specified config and weights.
46
+
47
+ Args:
48
+ cfg (dict | str): Model configuration dictionary or YAML file path.
49
+ weights (str, optional): Path to pretrained weights file.
50
+ verbose (bool): Whether to display model information.
51
+
52
+ Returns:
53
+ (YOLOESegModel): Initialized YOLOE segmentation model.
54
+ """
55
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
56
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
57
+ model = YOLOESegModel(
58
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
59
+ ch=3,
60
+ nc=min(self.data["nc"], 80),
61
+ verbose=verbose and RANK == -1,
62
+ )
63
+ if weights:
64
+ model.load(weights)
65
+
66
+ return model
67
+
68
+ def get_validator(self):
69
+ """
70
+ Create and return a validator for YOLOE segmentation model evaluation.
71
+
72
+ Returns:
73
+ (YOLOESegValidator): Validator for YOLOE segmentation models.
74
+ """
75
+ self.loss_names = "box", "seg", "cls", "dfl"
76
+ return YOLOESegValidator(
77
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
78
+ )
79
+
80
+
81
+ class YOLOEPESegTrainer(SegmentationTrainer):
82
+ """
83
+ Fine-tune YOLOESeg model in linear probing way.
84
+
85
+ This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
86
+ most of the model and only training specific layers.
87
+ """
88
+
89
+ def get_model(self, cfg=None, weights=None, verbose=True):
90
+ """
91
+ Return YOLOESegModel initialized with specified config and weights for linear probing.
92
+
93
+ Args:
94
+ cfg (dict | str): Model configuration dictionary or YAML file path.
95
+ weights (str, optional): Path to pretrained weights file.
96
+ verbose (bool): Whether to display model information.
97
+
98
+ Returns:
99
+ (YOLOESegModel): Initialized YOLOE segmentation model configured for linear probing.
100
+ """
101
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
102
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
103
+ model = YOLOESegModel(
104
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
105
+ ch=3,
106
+ nc=self.data["nc"],
107
+ verbose=verbose and RANK == -1,
108
+ )
109
+
110
+ del model.model[-1].savpe
111
+
112
+ assert weights is not None, "Pretrained weights must be provided for linear probing."
113
+ if weights:
114
+ model.load(weights)
115
+
116
+ model.eval()
117
+ names = list(self.data["names"].values())
118
+ # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
119
+ # it'd get correct results as long as loading proper pretrained weights.
120
+ tpe = model.get_text_pe(names)
121
+ model.set_classes(names, tpe)
122
+ model.model[-1].fuse(model.pe)
123
+ model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
124
+ model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
125
+ model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
126
+ del model.pe
127
+ model.train()
128
+
129
+ return model
130
+
131
+
132
+ class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
133
+ """Trainer for YOLOE segmentation from scratch."""
134
+
135
+ pass
136
+
137
+
138
+ class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
139
+ """Trainer for YOLOE segmentation with VP."""
140
+
141
+ pass
@@ -0,0 +1,187 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ from torch.nn import functional as F
7
+
8
+ from ultralytics.data import YOLOConcatDataset, build_dataloader, build_yolo_dataset
9
+ from ultralytics.data.augment import LoadVisualPrompt
10
+ from ultralytics.data.utils import check_det_dataset
11
+ from ultralytics.models.yolo.detect import DetectionValidator
12
+ from ultralytics.models.yolo.model import YOLOEModel
13
+ from ultralytics.models.yolo.segment import SegmentationValidator
14
+ from ultralytics.nn.modules.head import YOLOEDetect
15
+ from ultralytics.utils import LOGGER, TQDM
16
+ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
17
+
18
+
19
+ class YOLOEDetectValidator(DetectionValidator):
20
+ """
21
+ A mixin class for YOLOE model validation that handles both text and visual prompt embeddings.
22
+
23
+ This mixin provides functionality to validate YOLOE models using either text or visual prompt embeddings.
24
+ It includes methods for extracting visual prompt embeddings from samples, preprocessing batches, and
25
+ running validation with different prompt types.
26
+
27
+ Attributes:
28
+ device (torch.device): The device on which validation is performed.
29
+ args (namespace): Configuration arguments for validation.
30
+ dataloader (DataLoader): DataLoader for validation data.
31
+ """
32
+
33
+ @smart_inference_mode()
34
+ def get_visual_pe(self, dataloader, model):
35
+ """
36
+ Extract visual prompt embeddings from training samples.
37
+
38
+ This function processes a dataloader to compute visual prompt embeddings for each class
39
+ using a YOLOE model. It normalizes the embeddings and handles cases where no samples
40
+ exist for a class.
41
+
42
+ Args:
43
+ dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
44
+ model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.
45
+
46
+ Returns:
47
+ (torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
48
+ """
49
+ assert isinstance(model, YOLOEModel)
50
+ names = [name.split("/")[0] for name in list(dataloader.dataset.data["names"].values())]
51
+ visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
52
+ cls_visual_num = torch.zeros(len(names))
53
+
54
+ desc = "Get visual prompt embeddings from samples"
55
+
56
+ for batch in dataloader:
57
+ cls = batch["cls"].squeeze(-1).to(torch.int).unique()
58
+ count = torch.bincount(cls, minlength=len(names))
59
+ cls_visual_num += count
60
+
61
+ cls_visual_num = cls_visual_num.to(self.device)
62
+
63
+ pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
64
+ for batch in pbar:
65
+ batch = self.preprocess(batch)
66
+ preds = model.get_visual_pe(batch["img"], visual=batch["visuals"]) # (B, max_n, embed_dim)
67
+
68
+ batch_idx = batch["batch_idx"]
69
+ for i in range(preds.shape[0]):
70
+ cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
71
+ pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
72
+ pad_cls[: len(cls)] = cls
73
+ for c in cls:
74
+ visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
75
+
76
+ visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
77
+ visual_pe[cls_visual_num == 0] = 0
78
+ return visual_pe.unsqueeze(0)
79
+
80
+ def preprocess(self, batch):
81
+ """Preprocess batch data, ensuring visuals are on the same device as images."""
82
+ batch = super().preprocess(batch)
83
+ if "visuals" in batch:
84
+ batch["visuals"] = batch["visuals"].to(batch["img"].device)
85
+ return batch
86
+
87
+ def get_vpe_dataloader(self, data):
88
+ """
89
+ Create a dataloader for LVIS training visual prompt samples.
90
+
91
+ This function prepares a dataloader for visual prompt embeddings (VPE) using the LVIS dataset.
92
+ It applies necessary transformations and configurations to the dataset and returns a dataloader
93
+ for validation purposes.
94
+
95
+ Args:
96
+ data (dict): Dataset configuration dictionary containing paths and settings.
97
+
98
+ Returns:
99
+ (torch.utils.data.DataLoader): The dataLoader for visual prompt samples.
100
+ """
101
+ dataset = build_yolo_dataset(
102
+ self.args,
103
+ data.get(self.args.split, data.get("val")),
104
+ self.args.batch,
105
+ data,
106
+ mode="val",
107
+ rect=False,
108
+ )
109
+ if isinstance(dataset, YOLOConcatDataset):
110
+ for d in dataset.datasets:
111
+ d.transforms.append(LoadVisualPrompt())
112
+ else:
113
+ dataset.transforms.append(LoadVisualPrompt())
114
+ return build_dataloader(
115
+ dataset,
116
+ self.args.batch,
117
+ self.args.workers,
118
+ shuffle=False,
119
+ rank=-1,
120
+ )
121
+
122
+ @smart_inference_mode()
123
+ def __call__(self, trainer=None, model=None, refer_data=None, load_vp=False):
124
+ """
125
+ Run validation on the model using either text or visual prompt embeddings.
126
+
127
+ This method validates the model using either text prompts or visual prompts, depending
128
+ on the `load_vp` flag. It supports validation during training (using a trainer object)
129
+ or standalone validation with a provided model.
130
+
131
+ Args:
132
+ trainer (object, optional): Trainer object containing the model and device.
133
+ model (YOLOEModel, optional): Model to validate. Required if `trainer` is not provided.
134
+ refer_data (str, optional): Path to reference data for visual prompts.
135
+ load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
136
+
137
+ Returns:
138
+ (dict): Validation statistics containing metrics computed during validation.
139
+ """
140
+ if trainer is not None:
141
+ self.device = trainer.device
142
+ model = trainer.ema.ema
143
+ names = [name.split("/")[0] for name in list(self.dataloader.dataset.data["names"].values())]
144
+
145
+ if load_vp:
146
+ LOGGER.info("Validate using the visual prompt.")
147
+ self.args.half = False
148
+ # Directly use the same dataloader for visual embeddings extracted during training
149
+ vpe = self.get_visual_pe(self.dataloader, model)
150
+ model.set_classes(names, vpe)
151
+ else:
152
+ LOGGER.info("Validate using the text prompt.")
153
+ tpe = model.get_text_pe(names)
154
+ model.set_classes(names, tpe)
155
+ stats = super().__call__(trainer, model)
156
+ else:
157
+ if refer_data is not None:
158
+ assert load_vp, "Refer data is only used for visual prompt validation."
159
+ self.device = select_device(self.args.device)
160
+
161
+ model.eval().to(self.device)
162
+ data = check_det_dataset(refer_data or self.args.data)
163
+ names = [name.split("/")[0] for name in list(data["names"].values())]
164
+
165
+ if load_vp:
166
+ LOGGER.info("Validate using the visual prompt.")
167
+ self.args.half = False
168
+ # TODO: need to check if the names from refer data is consistent with the evaluated dataset
169
+ # could use same dataset or refer to extract visual prompt embeddings
170
+ dataloader = self.get_vpe_dataloader(data)
171
+ vpe = self.get_visual_pe(dataloader, model)
172
+ model.set_classes(names, vpe)
173
+ stats = super().__call__(model=deepcopy(model))
174
+ elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"): # prompt-free
175
+ return super().__call__(trainer, model)
176
+ else:
177
+ LOGGER.info("Validate using the text prompt.")
178
+ tpe = model.get_text_pe(names)
179
+ model.set_classes(names, tpe)
180
+ stats = super().__call__(model=deepcopy(model))
181
+ return stats
182
+
183
+
184
+ class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):
185
+ """YOLOE segmentation validator that supports both text and visual prompt embeddings."""
186
+
187
+ pass