ultralytics 8.1.38__py3-none-any.whl → 8.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

@@ -42,7 +42,7 @@ from ultralytics.utils.files import get_latest_run
42
42
  from ultralytics.utils.torch_utils import (
43
43
  EarlyStopping,
44
44
  ModelEMA,
45
- de_parallel,
45
+ convert_optimizer_state_dict_to_fp16,
46
46
  init_seeds,
47
47
  one_cycle,
48
48
  select_device,
@@ -126,22 +126,7 @@ class BaseTrainer:
126
126
 
127
127
  # Model and Dataset
128
128
  self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
129
- try:
130
- if self.args.task == "classify":
131
- self.data = check_cls_dataset(self.args.data)
132
- elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
133
- "detect",
134
- "segment",
135
- "pose",
136
- "obb",
137
- ):
138
- self.data = check_det_dataset(self.args.data)
139
- if "yaml_file" in self.data:
140
- self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
141
- except Exception as e:
142
- raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
143
-
144
- self.trainset, self.testset = self.get_dataset(self.data)
129
+ self.trainset, self.testset = self.get_dataset()
145
130
  self.ema = None
146
131
 
147
132
  # Optimization utils init
@@ -477,40 +462,59 @@ class BaseTrainer:
477
462
 
478
463
  def save_model(self):
479
464
  """Save model training checkpoints with additional metadata."""
465
+ import io
480
466
  import pandas as pd # scope for faster startup
481
467
 
482
- metrics = {**self.metrics, **{"fitness": self.fitness}}
483
- results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
484
- ckpt = {
485
- "epoch": self.epoch,
486
- "best_fitness": self.best_fitness,
487
- "model": deepcopy(de_parallel(self.model)).half(),
488
- "ema": deepcopy(self.ema.ema).half(),
489
- "updates": self.ema.updates,
490
- "optimizer": self.optimizer.state_dict(),
491
- "train_args": vars(self.args), # save as dict
492
- "train_metrics": metrics,
493
- "train_results": results,
494
- "date": datetime.now().isoformat(),
495
- "version": __version__,
496
- "license": "AGPL-3.0 (https://ultralytics.com/license)",
497
- "docs": "https://docs.ultralytics.com",
498
- }
499
-
500
- # Save last and best
501
- torch.save(ckpt, self.last)
468
+ # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
469
+ buffer = io.BytesIO()
470
+ torch.save(
471
+ {
472
+ "epoch": self.epoch,
473
+ "best_fitness": self.best_fitness,
474
+ "model": None, # resume and final checkpoints derive from EMA
475
+ "ema": deepcopy(self.ema.ema).half(),
476
+ "updates": self.ema.updates,
477
+ "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
478
+ "train_args": vars(self.args), # save as dict
479
+ "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
480
+ "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
481
+ "date": datetime.now().isoformat(),
482
+ "version": __version__,
483
+ "license": "AGPL-3.0 (https://ultralytics.com/license)",
484
+ "docs": "https://docs.ultralytics.com",
485
+ },
486
+ buffer,
487
+ )
488
+ serialized_ckpt = buffer.getvalue() # get the serialized content to save
489
+
490
+ # Save checkpoints
491
+ self.last.write_bytes(serialized_ckpt) # save last.pt
502
492
  if self.best_fitness == self.fitness:
503
- torch.save(ckpt, self.best)
493
+ self.best.write_bytes(serialized_ckpt) # save best.pt
504
494
  if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
505
- torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
495
+ (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
506
496
 
507
- @staticmethod
508
- def get_dataset(data):
497
+ def get_dataset(self):
509
498
  """
510
499
  Get train, val path from data dict if it exists.
511
500
 
512
501
  Returns None if data format is not recognized.
513
502
  """
503
+ try:
504
+ if self.args.task == "classify":
505
+ data = check_cls_dataset(self.args.data)
506
+ elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
507
+ "detect",
508
+ "segment",
509
+ "pose",
510
+ "obb",
511
+ ):
512
+ data = check_det_dataset(self.args.data)
513
+ if "yaml_file" in data:
514
+ self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
515
+ except Exception as e:
516
+ raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
517
+ self.data = data
514
518
  return data["train"], data.get("val") or data.get("test")
515
519
 
516
520
  def setup_model(self):
@@ -522,7 +526,7 @@ class BaseTrainer:
522
526
  ckpt = None
523
527
  if str(model).endswith(".pt"):
524
528
  weights, ckpt = attempt_load_one_weight(model)
525
- cfg = ckpt["model"].yaml
529
+ cfg = weights.yaml
526
530
  else:
527
531
  cfg = model
528
532
  self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
@@ -661,8 +665,8 @@ class BaseTrainer:
661
665
  if ckpt is None:
662
666
  return
663
667
  best_fitness = 0.0
664
- start_epoch = ckpt["epoch"] + 1
665
- if ckpt["optimizer"] is not None:
668
+ start_epoch = ckpt.get("epoch", -1) + 1
669
+ if ckpt.get("optimizer", None) is not None:
666
670
  self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
667
671
  best_fitness = ckpt["best_fitness"]
668
672
  if self.ema and ckpt.get("ema"):
@@ -35,7 +35,7 @@ class FastSAMPrompt:
35
35
  except ImportError:
36
36
  from ultralytics.utils.checks import check_requirements
37
37
 
38
- check_requirements("git+https://github.com/openai/CLIP.git")
38
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
39
39
  import clip
40
40
  self.clip = clip
41
41
 
@@ -1,7 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from ultralytics.models.yolo import classify, detect, obb, pose, segment
3
+ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
4
4
 
5
5
  from .model import YOLO, YOLOWorld
6
6
 
7
- __all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
7
+ __all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
@@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
33
33
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
34
34
  self.nt_per_class = None
35
35
  self.is_coco = False
36
+ self.is_lvis = False
36
37
  self.class_map = None
37
38
  self.args.task = "detect"
38
39
  self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
@@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
66
67
  """Initialize evaluation metrics for YOLO."""
67
68
  val = self.data.get(self.args.split, "") # validation path
68
69
  self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
69
- self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
70
- self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
70
+ self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
71
+ self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
72
+ self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
71
73
  self.names = model.names
72
74
  self.nc = len(model.names)
73
75
  self.metrics.names = self.names
@@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
266
268
  self.jdict.append(
267
269
  {
268
270
  "image_id": image_id,
269
- "category_id": self.class_map[int(p[5])],
271
+ "category_id": self.class_map[int(p[5])]
272
+ + (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
270
273
  "bbox": [round(x, 3) for x in b],
271
274
  "score": round(p[4], 5),
272
275
  }
@@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
274
277
 
275
278
  def eval_json(self, stats):
276
279
  """Evaluates YOLO output in JSON format and returns performance statistics."""
277
- if self.args.save_json and self.is_coco and len(self.jdict):
278
- anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
280
+ if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
279
281
  pred_json = self.save_dir / "predictions.json" # predictions
280
- LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
282
+ anno_json = (
283
+ self.data["path"]
284
+ / "annotations"
285
+ / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
286
+ ) # annotations
287
+ pkg = "pycocotools" if self.is_coco else "lvis"
288
+ LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
281
289
  try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
282
- check_requirements("pycocotools>=2.0.6")
283
- from pycocotools.coco import COCO # noqa
284
- from pycocotools.cocoeval import COCOeval # noqa
285
-
286
- for x in anno_json, pred_json:
290
+ for x in pred_json, anno_json:
287
291
  assert x.is_file(), f"{x} file not found"
288
- anno = COCO(str(anno_json)) # init annotations api
289
- pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
290
- eval = COCOeval(anno, pred, "bbox")
292
+ check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
291
293
  if self.is_coco:
292
- eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
294
+ from pycocotools.coco import COCO # noqa
295
+ from pycocotools.cocoeval import COCOeval # noqa
296
+
297
+ anno = COCO(str(anno_json)) # init annotations api
298
+ pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
299
+ eval = COCOeval(anno, pred, "bbox")
300
+ else:
301
+ from lvis import LVIS, LVISEval
302
+
303
+ anno = LVIS(str(anno_json)) # init annotations api
304
+ pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
305
+ eval = LVISEval(anno, pred, "bbox")
306
+ eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
293
307
  eval.evaluate()
294
308
  eval.accumulate()
295
309
  eval.summarize()
296
- stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
310
+ if self.is_lvis:
311
+ eval.print_results() # explicitly call print_results
312
+ # update mAP50-95 and mAP50
313
+ stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
314
+ eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
315
+ )
297
316
  except Exception as e:
298
- LOGGER.warning(f"pycocotools unable to run: {e}")
317
+ LOGGER.warning(f"{pkg} unable to run: {e}")
299
318
  return stats
@@ -83,6 +83,7 @@ class YOLOWorld(Model):
83
83
  "model": WorldModel,
84
84
  "validator": yolo.detect.DetectionValidator,
85
85
  "predictor": yolo.detect.DetectionPredictor,
86
+ "trainer": yolo.world.WorldTrainer,
86
87
  }
87
88
  }
88
89
 
@@ -0,0 +1,5 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .train import WorldTrainer
4
+
5
+ __all__ = ["WorldTrainer"]
@@ -0,0 +1,91 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from ultralytics.models import yolo
4
+ from ultralytics.nn.tasks import WorldModel
5
+ from ultralytics.utils import DEFAULT_CFG, RANK
6
+ from ultralytics.data import build_yolo_dataset
7
+ from ultralytics.utils.torch_utils import de_parallel
8
+ from ultralytics.utils.checks import check_requirements
9
+ import itertools
10
+
11
+ try:
12
+ import clip
13
+ except ImportError:
14
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
15
+ import clip
16
+
17
+
18
+ def on_pretrain_routine_end(trainer):
19
+ """Callback."""
20
+ if RANK in (-1, 0):
21
+ # NOTE: for evaluation
22
+ names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
23
+ de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
24
+ device = next(trainer.model.parameters()).device
25
+ text_model, _ = clip.load("ViT-B/32", device=device)
26
+ for p in text_model.parameters():
27
+ p.requires_grad_(False)
28
+ trainer.text_model = text_model
29
+
30
+
31
+ class WorldTrainer(yolo.detect.DetectionTrainer):
32
+ """
33
+ A class to fine-tune a world model on a close-set dataset.
34
+
35
+ Example:
36
+ ```python
37
+ from ultralytics.models.yolo.world import WorldModel
38
+
39
+ args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
40
+ trainer = WorldTrainer(overrides=args)
41
+ trainer.train()
42
+ ```
43
+ """
44
+
45
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
46
+ """Initialize a WorldTrainer object with given arguments."""
47
+ if overrides is None:
48
+ overrides = {}
49
+ super().__init__(cfg, overrides, _callbacks)
50
+
51
+ def get_model(self, cfg=None, weights=None, verbose=True):
52
+ """Return WorldModel initialized with specified config and weights."""
53
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
54
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
55
+ model = WorldModel(
56
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
57
+ ch=3,
58
+ nc=min(self.data["nc"], 80),
59
+ verbose=verbose and RANK == -1,
60
+ )
61
+ if weights:
62
+ model.load(weights)
63
+ self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
64
+
65
+ return model
66
+
67
+ def build_dataset(self, img_path, mode="train", batch=None):
68
+ """
69
+ Build YOLO Dataset.
70
+
71
+ Args:
72
+ img_path (str): Path to the folder containing images.
73
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
74
+ batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
75
+ """
76
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
77
+ return build_yolo_dataset(
78
+ self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
79
+ )
80
+
81
+ def preprocess_batch(self, batch):
82
+ """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
83
+ batch = super().preprocess_batch(batch)
84
+
85
+ # NOTE: add text features
86
+ texts = list(itertools.chain(*batch["texts"]))
87
+ text_token = clip.tokenize(texts).to(batch["img"].device)
88
+ txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
89
+ txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
90
+ batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
91
+ return batch
@@ -0,0 +1,108 @@
1
+ from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
2
+ from ultralytics.data.utils import check_det_dataset
3
+ from ultralytics.models.yolo.world import WorldTrainer
4
+ from ultralytics.utils.torch_utils import de_parallel
5
+ from ultralytics.utils import DEFAULT_CFG
6
+
7
+
8
+ class WorldTrainerFromScratch(WorldTrainer):
9
+ """
10
+ A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
11
+
12
+ Example:
13
+ ```python
14
+ from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
15
+ from ultralytics import YOLOWorld
16
+
17
+ data = dict(
18
+ train=dict(
19
+ yolo_data=["Objects365.yaml"],
20
+ grounding_data=[
21
+ dict(
22
+ img_path="../datasets/flickr30k/images",
23
+ json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
24
+ ),
25
+ dict(
26
+ img_path="../datasets/GQA/images",
27
+ json_file="../datasets/GQA/final_mixed_train_no_coco.json",
28
+ ),
29
+ ],
30
+ ),
31
+ val=dict(yolo_data=["lvis.yaml"]),
32
+ )
33
+
34
+ model = YOLOWorld("yolov8s-worldv2.yaml")
35
+ model.train(data=data, trainer=WorldTrainerFromScratch)
36
+ ```
37
+ """
38
+
39
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
40
+ """Initialize a WorldTrainer object with given arguments."""
41
+ if overrides is None:
42
+ overrides = {}
43
+ super().__init__(cfg, overrides, _callbacks)
44
+
45
+ def build_dataset(self, img_path, mode="train", batch=None):
46
+ """
47
+ Build YOLO Dataset.
48
+
49
+ Args:
50
+ img_path (List[str] | str): Path to the folder containing images.
51
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
52
+ batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
53
+ """
54
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
55
+ if mode == "train":
56
+ dataset = [
57
+ build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
58
+ if isinstance(im_path, str)
59
+ else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
60
+ for im_path in img_path
61
+ ]
62
+ return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
63
+ else:
64
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
65
+
66
+ def get_dataset(self):
67
+ """
68
+ Get train, val path from data dict if it exists.
69
+
70
+ Returns None if data format is not recognized.
71
+ """
72
+ final_data = dict()
73
+ data_yaml = self.args.data
74
+ assert data_yaml.get("train", False) # object365.yaml
75
+ assert data_yaml.get("val", False) # lvis.yaml
76
+ data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
77
+ assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
78
+ val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
79
+ for d in data["val"]:
80
+ if d.get("minival") is None: # for lvis dataset
81
+ continue
82
+ d["minival"] = str(d["path"] / d["minival"])
83
+ for s in ["train", "val"]:
84
+ final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
85
+ # save grounding data if there's one
86
+ grounding_data = data_yaml[s].get("grounding_data")
87
+ if grounding_data is None:
88
+ continue
89
+ grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
90
+ for g in grounding_data:
91
+ assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
92
+ final_data[s] += grounding_data
93
+ # NOTE: to make training work properly, set `nc` and `names`
94
+ final_data["nc"] = data["val"][0]["nc"]
95
+ final_data["names"] = data["val"][0]["names"]
96
+ self.data = final_data
97
+ return final_data["train"], final_data["val"][0]
98
+
99
+ def plot_training_labels(self):
100
+ """DO NOT plot labels."""
101
+ pass
102
+
103
+ def final_eval(self):
104
+ """Performs final evaluation and validation for object detection YOLO-World model."""
105
+ val = self.args.data["val"]["yolo_data"][0]
106
+ self.validator.args.data = val
107
+ self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
108
+ return super().final_eval()
@@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
519
519
  def __init__(self):
520
520
  """Initializes ContrastiveHead with specified region-text similarity parameters."""
521
521
  super().__init__()
522
- self.bias = nn.Parameter(torch.zeros([]))
522
+ # NOTE: use -10.0 to keep the init cls loss consistency with other losses
523
+ self.bias = nn.Parameter(torch.tensor([-10.0]))
523
524
  self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
524
525
 
525
526
  def forward(self, x, w):
@@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
542
543
  """Initialize ContrastiveHead with region-text similarity parameters."""
543
544
  super().__init__()
544
545
  self.norm = nn.BatchNorm2d(embed_dims)
545
- self.bias = nn.Parameter(torch.zeros([]))
546
+ # NOTE: use -10.0 to keep the init cls loss consistency with other losses
547
+ self.bias = nn.Parameter(torch.tensor([-10.0]))
546
548
  # use -1.0 is more stable
547
549
  self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
548
550
 
@@ -250,6 +250,15 @@ class WorldDetect(Detect):
250
250
  y = torch.cat((dbox, cls.sigmoid()), 1)
251
251
  return y if self.export else (y, x)
252
252
 
253
+ def bias_init(self):
254
+ """Initialize Detect() biases, WARNING: requires stride availability."""
255
+ m = self # self.model[-1] # Detect() module
256
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
257
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
258
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
259
+ a[-1].bias.data[:] = 1.0 # box
260
+ # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
261
+
253
262
 
254
263
  class RTDETRDecoder(nn.Module):
255
264
  """
ultralytics/nn/tasks.py CHANGED
@@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
564
564
  self.clip_model = None # CLIP model placeholder
565
565
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
566
566
 
567
- def set_classes(self, text):
568
- """Perform a forward pass with optional profiling, visualization, and embedding extraction."""
567
+ def set_classes(self, text, batch=80, cache_clip_model=True):
568
+ """Set classes in advance so that model could do offline-inference without clip model."""
569
569
  try:
570
570
  import clip
571
571
  except ImportError:
572
- check_requirements("git+https://github.com/openai/CLIP.git")
572
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
573
573
  import clip
574
574
 
575
- if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute
575
+ if (
576
+ not getattr(self, "clip_model", None) and cache_clip_model
577
+ ): # for backwards compatibility of models lacking clip_model attribute
576
578
  self.clip_model = clip.load("ViT-B/32")[0]
577
- device = next(self.clip_model.parameters()).device
579
+ model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
580
+ device = next(model.parameters()).device
578
581
  text_token = clip.tokenize(text).to(device)
579
- txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
582
+ txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
583
+ txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
580
584
  txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
581
- self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
585
+ self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
582
586
  self.model[-1].nc = len(text)
583
587
 
584
- def init_criterion(self):
585
- """Initialize the loss criterion for the model."""
586
- raise NotImplementedError
587
-
588
- def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
588
+ def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
589
589
  """
590
590
  Perform a forward pass through the model.
591
591
 
@@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
593
593
  x (torch.Tensor): The input tensor.
594
594
  profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
595
595
  visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
596
+ txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
596
597
  augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
597
598
  embed (list, optional): A list of feature vectors/embeddings to return.
598
599
 
599
600
  Returns:
600
601
  (torch.Tensor): Model's output tensor.
601
602
  """
602
- txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
603
+ txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
603
604
  if len(txt_feats) != len(x):
604
605
  txt_feats = txt_feats.repeat(len(x), 1, 1)
605
606
  ori_txt_feats = txt_feats.clone()
@@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
627
628
  return torch.unbind(torch.cat(embeddings, 1), dim=0)
628
629
  return x
629
630
 
631
+ def loss(self, batch, preds=None):
632
+ """
633
+ Compute loss.
634
+
635
+ Args:
636
+ batch (dict): Batch to compute loss on.
637
+ preds (torch.Tensor | List[torch.Tensor]): Predictions.
638
+ """
639
+ if not hasattr(self, "criterion"):
640
+ self.criterion = self.init_criterion()
641
+
642
+ if preds is None:
643
+ preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
644
+ return self.criterion(preds, batch)
645
+
630
646
 
631
647
  class Ensemble(nn.ModuleList):
632
648
  """Ensemble of models."""