ultralytics 8.1.37__py3-none-any.whl → 8.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +1 -2
- ultralytics/cfg/datasets/lvis.yaml +1239 -0
- ultralytics/cfg/default.yaml +2 -2
- ultralytics/data/__init__.py +18 -2
- ultralytics/data/augment.py +123 -2
- ultralytics/data/base.py +2 -0
- ultralytics/data/build.py +25 -3
- ultralytics/data/converter.py +22 -4
- ultralytics/data/dataset.py +143 -27
- ultralytics/data/utils.py +25 -1
- ultralytics/engine/exporter.py +1 -3
- ultralytics/engine/model.py +4 -1
- ultralytics/engine/trainer.py +48 -44
- ultralytics/models/fastsam/prompt.py +1 -1
- ultralytics/models/yolo/__init__.py +2 -2
- ultralytics/models/yolo/detect/val.py +36 -17
- ultralytics/models/yolo/model.py +1 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +91 -0
- ultralytics/models/yolo/world/train_world.py +108 -0
- ultralytics/nn/autobackend.py +1 -1
- ultralytics/nn/modules/block.py +4 -2
- ultralytics/nn/modules/head.py +9 -0
- ultralytics/nn/tasks.py +29 -13
- ultralytics/solutions/heatmap.py +84 -46
- ultralytics/solutions/object_counter.py +79 -64
- ultralytics/trackers/utils/gmc.py +1 -1
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/loss.py +1 -1
- ultralytics/utils/plotting.py +35 -21
- ultralytics/utils/torch_utils.py +14 -0
- ultralytics/utils/tuner.py +2 -2
- {ultralytics-8.1.37.dist-info → ultralytics-8.1.39.dist-info}/METADATA +1 -1
- {ultralytics-8.1.37.dist-info → ultralytics-8.1.39.dist-info}/RECORD +39 -35
- {ultralytics-8.1.37.dist-info → ultralytics-8.1.39.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.37.dist-info → ultralytics-8.1.39.dist-info}/WHEEL +0 -0
- {ultralytics-8.1.37.dist-info → ultralytics-8.1.39.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.37.dist-info → ultralytics-8.1.39.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
|
@@ -42,7 +42,7 @@ from ultralytics.utils.files import get_latest_run
|
|
|
42
42
|
from ultralytics.utils.torch_utils import (
|
|
43
43
|
EarlyStopping,
|
|
44
44
|
ModelEMA,
|
|
45
|
-
|
|
45
|
+
convert_optimizer_state_dict_to_fp16,
|
|
46
46
|
init_seeds,
|
|
47
47
|
one_cycle,
|
|
48
48
|
select_device,
|
|
@@ -126,22 +126,7 @@ class BaseTrainer:
|
|
|
126
126
|
|
|
127
127
|
# Model and Dataset
|
|
128
128
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
|
129
|
-
|
|
130
|
-
if self.args.task == "classify":
|
|
131
|
-
self.data = check_cls_dataset(self.args.data)
|
|
132
|
-
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
|
133
|
-
"detect",
|
|
134
|
-
"segment",
|
|
135
|
-
"pose",
|
|
136
|
-
"obb",
|
|
137
|
-
):
|
|
138
|
-
self.data = check_det_dataset(self.args.data)
|
|
139
|
-
if "yaml_file" in self.data:
|
|
140
|
-
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
|
141
|
-
except Exception as e:
|
|
142
|
-
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
|
143
|
-
|
|
144
|
-
self.trainset, self.testset = self.get_dataset(self.data)
|
|
129
|
+
self.trainset, self.testset = self.get_dataset()
|
|
145
130
|
self.ema = None
|
|
146
131
|
|
|
147
132
|
# Optimization utils init
|
|
@@ -477,40 +462,59 @@ class BaseTrainer:
|
|
|
477
462
|
|
|
478
463
|
def save_model(self):
|
|
479
464
|
"""Save model training checkpoints with additional metadata."""
|
|
465
|
+
import io
|
|
480
466
|
import pandas as pd # scope for faster startup
|
|
481
467
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
468
|
+
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
|
469
|
+
buffer = io.BytesIO()
|
|
470
|
+
torch.save(
|
|
471
|
+
{
|
|
472
|
+
"epoch": self.epoch,
|
|
473
|
+
"best_fitness": self.best_fitness,
|
|
474
|
+
"model": None, # resume and final checkpoints derive from EMA
|
|
475
|
+
"ema": deepcopy(self.ema.ema).half(),
|
|
476
|
+
"updates": self.ema.updates,
|
|
477
|
+
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
|
|
478
|
+
"train_args": vars(self.args), # save as dict
|
|
479
|
+
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
|
480
|
+
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
|
|
481
|
+
"date": datetime.now().isoformat(),
|
|
482
|
+
"version": __version__,
|
|
483
|
+
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
|
484
|
+
"docs": "https://docs.ultralytics.com",
|
|
485
|
+
},
|
|
486
|
+
buffer,
|
|
487
|
+
)
|
|
488
|
+
serialized_ckpt = buffer.getvalue() # get the serialized content to save
|
|
489
|
+
|
|
490
|
+
# Save checkpoints
|
|
491
|
+
self.last.write_bytes(serialized_ckpt) # save last.pt
|
|
502
492
|
if self.best_fitness == self.fitness:
|
|
503
|
-
|
|
493
|
+
self.best.write_bytes(serialized_ckpt) # save best.pt
|
|
504
494
|
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
|
505
|
-
|
|
495
|
+
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
|
506
496
|
|
|
507
|
-
|
|
508
|
-
def get_dataset(data):
|
|
497
|
+
def get_dataset(self):
|
|
509
498
|
"""
|
|
510
499
|
Get train, val path from data dict if it exists.
|
|
511
500
|
|
|
512
501
|
Returns None if data format is not recognized.
|
|
513
502
|
"""
|
|
503
|
+
try:
|
|
504
|
+
if self.args.task == "classify":
|
|
505
|
+
data = check_cls_dataset(self.args.data)
|
|
506
|
+
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
|
507
|
+
"detect",
|
|
508
|
+
"segment",
|
|
509
|
+
"pose",
|
|
510
|
+
"obb",
|
|
511
|
+
):
|
|
512
|
+
data = check_det_dataset(self.args.data)
|
|
513
|
+
if "yaml_file" in data:
|
|
514
|
+
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
|
515
|
+
except Exception as e:
|
|
516
|
+
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
|
517
|
+
self.data = data
|
|
514
518
|
return data["train"], data.get("val") or data.get("test")
|
|
515
519
|
|
|
516
520
|
def setup_model(self):
|
|
@@ -522,7 +526,7 @@ class BaseTrainer:
|
|
|
522
526
|
ckpt = None
|
|
523
527
|
if str(model).endswith(".pt"):
|
|
524
528
|
weights, ckpt = attempt_load_one_weight(model)
|
|
525
|
-
cfg =
|
|
529
|
+
cfg = weights.yaml
|
|
526
530
|
else:
|
|
527
531
|
cfg = model
|
|
528
532
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
|
@@ -661,8 +665,8 @@ class BaseTrainer:
|
|
|
661
665
|
if ckpt is None:
|
|
662
666
|
return
|
|
663
667
|
best_fitness = 0.0
|
|
664
|
-
start_epoch = ckpt
|
|
665
|
-
if ckpt
|
|
668
|
+
start_epoch = ckpt.get("epoch", -1) + 1
|
|
669
|
+
if ckpt.get("optimizer", None) is not None:
|
|
666
670
|
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
|
667
671
|
best_fitness = ckpt["best_fitness"]
|
|
668
672
|
if self.ema and ckpt.get("ema"):
|
|
@@ -35,7 +35,7 @@ class FastSAMPrompt:
|
|
|
35
35
|
except ImportError:
|
|
36
36
|
from ultralytics.utils.checks import check_requirements
|
|
37
37
|
|
|
38
|
-
check_requirements("git+https://github.com/
|
|
38
|
+
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
|
39
39
|
import clip
|
|
40
40
|
self.clip = clip
|
|
41
41
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
|
-
from ultralytics.models.yolo import classify, detect, obb, pose, segment
|
|
3
|
+
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
|
|
4
4
|
|
|
5
5
|
from .model import YOLO, YOLOWorld
|
|
6
6
|
|
|
7
|
-
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
|
|
7
|
+
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
|
@@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
|
|
|
33
33
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
|
34
34
|
self.nt_per_class = None
|
|
35
35
|
self.is_coco = False
|
|
36
|
+
self.is_lvis = False
|
|
36
37
|
self.class_map = None
|
|
37
38
|
self.args.task = "detect"
|
|
38
39
|
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
|
@@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
|
|
|
66
67
|
"""Initialize evaluation metrics for YOLO."""
|
|
67
68
|
val = self.data.get(self.args.split, "") # validation path
|
|
68
69
|
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
70
|
+
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
|
|
71
|
+
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
|
|
72
|
+
self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
|
|
71
73
|
self.names = model.names
|
|
72
74
|
self.nc = len(model.names)
|
|
73
75
|
self.metrics.names = self.names
|
|
@@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
|
|
|
266
268
|
self.jdict.append(
|
|
267
269
|
{
|
|
268
270
|
"image_id": image_id,
|
|
269
|
-
"category_id": self.class_map[int(p[5])]
|
|
271
|
+
"category_id": self.class_map[int(p[5])]
|
|
272
|
+
+ (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
|
|
270
273
|
"bbox": [round(x, 3) for x in b],
|
|
271
274
|
"score": round(p[4], 5),
|
|
272
275
|
}
|
|
@@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
|
|
|
274
277
|
|
|
275
278
|
def eval_json(self, stats):
|
|
276
279
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
|
277
|
-
if self.args.save_json and self.is_coco and len(self.jdict):
|
|
278
|
-
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
|
280
|
+
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
|
279
281
|
pred_json = self.save_dir / "predictions.json" # predictions
|
|
280
|
-
|
|
282
|
+
anno_json = (
|
|
283
|
+
self.data["path"]
|
|
284
|
+
/ "annotations"
|
|
285
|
+
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
|
286
|
+
) # annotations
|
|
287
|
+
pkg = "pycocotools" if self.is_coco else "lvis"
|
|
288
|
+
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
|
|
281
289
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
|
282
|
-
|
|
283
|
-
from pycocotools.coco import COCO # noqa
|
|
284
|
-
from pycocotools.cocoeval import COCOeval # noqa
|
|
285
|
-
|
|
286
|
-
for x in anno_json, pred_json:
|
|
290
|
+
for x in pred_json, anno_json:
|
|
287
291
|
assert x.is_file(), f"{x} file not found"
|
|
288
|
-
|
|
289
|
-
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
|
290
|
-
eval = COCOeval(anno, pred, "bbox")
|
|
292
|
+
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
|
|
291
293
|
if self.is_coco:
|
|
292
|
-
|
|
294
|
+
from pycocotools.coco import COCO # noqa
|
|
295
|
+
from pycocotools.cocoeval import COCOeval # noqa
|
|
296
|
+
|
|
297
|
+
anno = COCO(str(anno_json)) # init annotations api
|
|
298
|
+
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
|
299
|
+
eval = COCOeval(anno, pred, "bbox")
|
|
300
|
+
else:
|
|
301
|
+
from lvis import LVIS, LVISEval
|
|
302
|
+
|
|
303
|
+
anno = LVIS(str(anno_json)) # init annotations api
|
|
304
|
+
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
|
|
305
|
+
eval = LVISEval(anno, pred, "bbox")
|
|
306
|
+
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
|
293
307
|
eval.evaluate()
|
|
294
308
|
eval.accumulate()
|
|
295
309
|
eval.summarize()
|
|
296
|
-
|
|
310
|
+
if self.is_lvis:
|
|
311
|
+
eval.print_results() # explicitly call print_results
|
|
312
|
+
# update mAP50-95 and mAP50
|
|
313
|
+
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
|
314
|
+
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
|
|
315
|
+
)
|
|
297
316
|
except Exception as e:
|
|
298
|
-
LOGGER.warning(f"
|
|
317
|
+
LOGGER.warning(f"{pkg} unable to run: {e}")
|
|
299
318
|
return stats
|
ultralytics/models/yolo/model.py
CHANGED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
3
|
+
from ultralytics.models import yolo
|
|
4
|
+
from ultralytics.nn.tasks import WorldModel
|
|
5
|
+
from ultralytics.utils import DEFAULT_CFG, RANK
|
|
6
|
+
from ultralytics.data import build_yolo_dataset
|
|
7
|
+
from ultralytics.utils.torch_utils import de_parallel
|
|
8
|
+
from ultralytics.utils.checks import check_requirements
|
|
9
|
+
import itertools
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import clip
|
|
13
|
+
except ImportError:
|
|
14
|
+
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
|
15
|
+
import clip
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def on_pretrain_routine_end(trainer):
|
|
19
|
+
"""Callback."""
|
|
20
|
+
if RANK in (-1, 0):
|
|
21
|
+
# NOTE: for evaluation
|
|
22
|
+
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
|
23
|
+
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
|
24
|
+
device = next(trainer.model.parameters()).device
|
|
25
|
+
text_model, _ = clip.load("ViT-B/32", device=device)
|
|
26
|
+
for p in text_model.parameters():
|
|
27
|
+
p.requires_grad_(False)
|
|
28
|
+
trainer.text_model = text_model
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
32
|
+
"""
|
|
33
|
+
A class to fine-tune a world model on a close-set dataset.
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
```python
|
|
37
|
+
from ultralytics.models.yolo.world import WorldModel
|
|
38
|
+
|
|
39
|
+
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
|
|
40
|
+
trainer = WorldTrainer(overrides=args)
|
|
41
|
+
trainer.train()
|
|
42
|
+
```
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
46
|
+
"""Initialize a WorldTrainer object with given arguments."""
|
|
47
|
+
if overrides is None:
|
|
48
|
+
overrides = {}
|
|
49
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
50
|
+
|
|
51
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
52
|
+
"""Return WorldModel initialized with specified config and weights."""
|
|
53
|
+
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
|
54
|
+
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
|
55
|
+
model = WorldModel(
|
|
56
|
+
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
|
57
|
+
ch=3,
|
|
58
|
+
nc=min(self.data["nc"], 80),
|
|
59
|
+
verbose=verbose and RANK == -1,
|
|
60
|
+
)
|
|
61
|
+
if weights:
|
|
62
|
+
model.load(weights)
|
|
63
|
+
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
|
|
64
|
+
|
|
65
|
+
return model
|
|
66
|
+
|
|
67
|
+
def build_dataset(self, img_path, mode="train", batch=None):
|
|
68
|
+
"""
|
|
69
|
+
Build YOLO Dataset.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
img_path (str): Path to the folder containing images.
|
|
73
|
+
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
|
74
|
+
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
|
75
|
+
"""
|
|
76
|
+
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
|
77
|
+
return build_yolo_dataset(
|
|
78
|
+
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def preprocess_batch(self, batch):
|
|
82
|
+
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
|
|
83
|
+
batch = super().preprocess_batch(batch)
|
|
84
|
+
|
|
85
|
+
# NOTE: add text features
|
|
86
|
+
texts = list(itertools.chain(*batch["texts"]))
|
|
87
|
+
text_token = clip.tokenize(texts).to(batch["img"].device)
|
|
88
|
+
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
|
|
89
|
+
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
|
90
|
+
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
|
91
|
+
return batch
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
|
|
2
|
+
from ultralytics.data.utils import check_det_dataset
|
|
3
|
+
from ultralytics.models.yolo.world import WorldTrainer
|
|
4
|
+
from ultralytics.utils.torch_utils import de_parallel
|
|
5
|
+
from ultralytics.utils import DEFAULT_CFG
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class WorldTrainerFromScratch(WorldTrainer):
|
|
9
|
+
"""
|
|
10
|
+
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
```python
|
|
14
|
+
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
|
15
|
+
from ultralytics import YOLOWorld
|
|
16
|
+
|
|
17
|
+
data = dict(
|
|
18
|
+
train=dict(
|
|
19
|
+
yolo_data=["Objects365.yaml"],
|
|
20
|
+
grounding_data=[
|
|
21
|
+
dict(
|
|
22
|
+
img_path="../datasets/flickr30k/images",
|
|
23
|
+
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
|
|
24
|
+
),
|
|
25
|
+
dict(
|
|
26
|
+
img_path="../datasets/GQA/images",
|
|
27
|
+
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
|
|
28
|
+
),
|
|
29
|
+
],
|
|
30
|
+
),
|
|
31
|
+
val=dict(yolo_data=["lvis.yaml"]),
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
model = YOLOWorld("yolov8s-worldv2.yaml")
|
|
35
|
+
model.train(data=data, trainer=WorldTrainerFromScratch)
|
|
36
|
+
```
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
40
|
+
"""Initialize a WorldTrainer object with given arguments."""
|
|
41
|
+
if overrides is None:
|
|
42
|
+
overrides = {}
|
|
43
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
44
|
+
|
|
45
|
+
def build_dataset(self, img_path, mode="train", batch=None):
|
|
46
|
+
"""
|
|
47
|
+
Build YOLO Dataset.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
img_path (List[str] | str): Path to the folder containing images.
|
|
51
|
+
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
|
52
|
+
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
|
53
|
+
"""
|
|
54
|
+
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
|
55
|
+
if mode == "train":
|
|
56
|
+
dataset = [
|
|
57
|
+
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
|
|
58
|
+
if isinstance(im_path, str)
|
|
59
|
+
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
|
|
60
|
+
for im_path in img_path
|
|
61
|
+
]
|
|
62
|
+
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
|
|
63
|
+
else:
|
|
64
|
+
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
|
65
|
+
|
|
66
|
+
def get_dataset(self):
|
|
67
|
+
"""
|
|
68
|
+
Get train, val path from data dict if it exists.
|
|
69
|
+
|
|
70
|
+
Returns None if data format is not recognized.
|
|
71
|
+
"""
|
|
72
|
+
final_data = dict()
|
|
73
|
+
data_yaml = self.args.data
|
|
74
|
+
assert data_yaml.get("train", False) # object365.yaml
|
|
75
|
+
assert data_yaml.get("val", False) # lvis.yaml
|
|
76
|
+
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
|
|
77
|
+
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
|
|
78
|
+
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
|
|
79
|
+
for d in data["val"]:
|
|
80
|
+
if d.get("minival") is None: # for lvis dataset
|
|
81
|
+
continue
|
|
82
|
+
d["minival"] = str(d["path"] / d["minival"])
|
|
83
|
+
for s in ["train", "val"]:
|
|
84
|
+
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
|
|
85
|
+
# save grounding data if there's one
|
|
86
|
+
grounding_data = data_yaml[s].get("grounding_data")
|
|
87
|
+
if grounding_data is None:
|
|
88
|
+
continue
|
|
89
|
+
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
|
|
90
|
+
for g in grounding_data:
|
|
91
|
+
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
|
92
|
+
final_data[s] += grounding_data
|
|
93
|
+
# NOTE: to make training work properly, set `nc` and `names`
|
|
94
|
+
final_data["nc"] = data["val"][0]["nc"]
|
|
95
|
+
final_data["names"] = data["val"][0]["names"]
|
|
96
|
+
self.data = final_data
|
|
97
|
+
return final_data["train"], final_data["val"][0]
|
|
98
|
+
|
|
99
|
+
def plot_training_labels(self):
|
|
100
|
+
"""DO NOT plot labels."""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
def final_eval(self):
|
|
104
|
+
"""Performs final evaluation and validation for object detection YOLO-World model."""
|
|
105
|
+
val = self.args.data["val"]["yolo_data"][0]
|
|
106
|
+
self.validator.args.data = val
|
|
107
|
+
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
|
|
108
|
+
return super().final_eval()
|
ultralytics/nn/autobackend.py
CHANGED
|
@@ -543,7 +543,7 @@ class AutoBackend(nn.Module):
|
|
|
543
543
|
if integer:
|
|
544
544
|
scale, zero_point = output["quantization"]
|
|
545
545
|
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
|
546
|
-
if x.ndim
|
|
546
|
+
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
|
|
547
547
|
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
|
|
548
548
|
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
|
|
549
549
|
x[:, [0, 2]] *= w
|
ultralytics/nn/modules/block.py
CHANGED
|
@@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
|
|
|
519
519
|
def __init__(self):
|
|
520
520
|
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
|
|
521
521
|
super().__init__()
|
|
522
|
-
|
|
522
|
+
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
|
|
523
|
+
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
|
523
524
|
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
|
|
524
525
|
|
|
525
526
|
def forward(self, x, w):
|
|
@@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
|
|
|
542
543
|
"""Initialize ContrastiveHead with region-text similarity parameters."""
|
|
543
544
|
super().__init__()
|
|
544
545
|
self.norm = nn.BatchNorm2d(embed_dims)
|
|
545
|
-
|
|
546
|
+
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
|
|
547
|
+
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
|
546
548
|
# use -1.0 is more stable
|
|
547
549
|
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
|
548
550
|
|
ultralytics/nn/modules/head.py
CHANGED
|
@@ -250,6 +250,15 @@ class WorldDetect(Detect):
|
|
|
250
250
|
y = torch.cat((dbox, cls.sigmoid()), 1)
|
|
251
251
|
return y if self.export else (y, x)
|
|
252
252
|
|
|
253
|
+
def bias_init(self):
|
|
254
|
+
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
|
255
|
+
m = self # self.model[-1] # Detect() module
|
|
256
|
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
|
257
|
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
|
258
|
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
|
259
|
+
a[-1].bias.data[:] = 1.0 # box
|
|
260
|
+
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
|
261
|
+
|
|
253
262
|
|
|
254
263
|
class RTDETRDecoder(nn.Module):
|
|
255
264
|
"""
|
ultralytics/nn/tasks.py
CHANGED
|
@@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
|
|
|
564
564
|
self.clip_model = None # CLIP model placeholder
|
|
565
565
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
566
566
|
|
|
567
|
-
def set_classes(self, text):
|
|
568
|
-
"""
|
|
567
|
+
def set_classes(self, text, batch=80, cache_clip_model=True):
|
|
568
|
+
"""Set classes in advance so that model could do offline-inference without clip model."""
|
|
569
569
|
try:
|
|
570
570
|
import clip
|
|
571
571
|
except ImportError:
|
|
572
|
-
check_requirements("git+https://github.com/
|
|
572
|
+
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
|
573
573
|
import clip
|
|
574
574
|
|
|
575
|
-
if
|
|
575
|
+
if (
|
|
576
|
+
not getattr(self, "clip_model", None) and cache_clip_model
|
|
577
|
+
): # for backwards compatibility of models lacking clip_model attribute
|
|
576
578
|
self.clip_model = clip.load("ViT-B/32")[0]
|
|
577
|
-
|
|
579
|
+
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
|
|
580
|
+
device = next(model.parameters()).device
|
|
578
581
|
text_token = clip.tokenize(text).to(device)
|
|
579
|
-
txt_feats =
|
|
582
|
+
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
|
583
|
+
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
|
580
584
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
|
581
|
-
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
|
585
|
+
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
|
582
586
|
self.model[-1].nc = len(text)
|
|
583
587
|
|
|
584
|
-
def
|
|
585
|
-
"""Initialize the loss criterion for the model."""
|
|
586
|
-
raise NotImplementedError
|
|
587
|
-
|
|
588
|
-
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
|
588
|
+
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
|
589
589
|
"""
|
|
590
590
|
Perform a forward pass through the model.
|
|
591
591
|
|
|
@@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
|
|
|
593
593
|
x (torch.Tensor): The input tensor.
|
|
594
594
|
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
|
595
595
|
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
|
596
|
+
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
|
|
596
597
|
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
|
597
598
|
embed (list, optional): A list of feature vectors/embeddings to return.
|
|
598
599
|
|
|
599
600
|
Returns:
|
|
600
601
|
(torch.Tensor): Model's output tensor.
|
|
601
602
|
"""
|
|
602
|
-
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
|
|
603
|
+
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
|
603
604
|
if len(txt_feats) != len(x):
|
|
604
605
|
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
|
605
606
|
ori_txt_feats = txt_feats.clone()
|
|
@@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
|
|
|
627
628
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
|
628
629
|
return x
|
|
629
630
|
|
|
631
|
+
def loss(self, batch, preds=None):
|
|
632
|
+
"""
|
|
633
|
+
Compute loss.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
batch (dict): Batch to compute loss on.
|
|
637
|
+
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
|
638
|
+
"""
|
|
639
|
+
if not hasattr(self, "criterion"):
|
|
640
|
+
self.criterion = self.init_criterion()
|
|
641
|
+
|
|
642
|
+
if preds is None:
|
|
643
|
+
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
|
|
644
|
+
return self.criterion(preds, batch)
|
|
645
|
+
|
|
630
646
|
|
|
631
647
|
class Ensemble(nn.ModuleList):
|
|
632
648
|
"""Ensemble of models."""
|