ultralytics 8.1.38__py3-none-any.whl → 8.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -3
- ultralytics/cfg/datasets/lvis.yaml +1239 -0
- ultralytics/data/__init__.py +18 -2
- ultralytics/data/augment.py +124 -3
- ultralytics/data/base.py +2 -2
- ultralytics/data/build.py +25 -3
- ultralytics/data/converter.py +24 -6
- ultralytics/data/dataset.py +142 -27
- ultralytics/data/loaders.py +11 -8
- ultralytics/data/split_dota.py +1 -1
- ultralytics/data/utils.py +33 -8
- ultralytics/engine/exporter.py +3 -3
- ultralytics/engine/model.py +6 -3
- ultralytics/engine/results.py +2 -2
- ultralytics/engine/trainer.py +59 -55
- ultralytics/engine/validator.py +2 -2
- ultralytics/hub/utils.py +1 -1
- ultralytics/models/fastsam/model.py +1 -1
- ultralytics/models/fastsam/prompt.py +4 -5
- ultralytics/models/nas/model.py +1 -1
- ultralytics/models/sam/model.py +1 -1
- ultralytics/models/sam/modules/tiny_encoder.py +1 -1
- ultralytics/models/yolo/__init__.py +2 -2
- ultralytics/models/yolo/classify/train.py +1 -1
- ultralytics/models/yolo/detect/train.py +1 -1
- ultralytics/models/yolo/detect/val.py +36 -17
- ultralytics/models/yolo/model.py +1 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +108 -0
- ultralytics/nn/autobackend.py +5 -5
- ultralytics/nn/modules/block.py +4 -2
- ultralytics/nn/modules/conv.py +1 -1
- ultralytics/nn/modules/head.py +13 -4
- ultralytics/nn/tasks.py +30 -14
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/heatmap.py +85 -47
- ultralytics/solutions/object_counter.py +79 -64
- ultralytics/trackers/byte_tracker.py +1 -1
- ultralytics/trackers/track.py +1 -1
- ultralytics/trackers/utils/gmc.py +1 -1
- ultralytics/utils/__init__.py +4 -4
- ultralytics/utils/benchmarks.py +2 -2
- ultralytics/utils/callbacks/comet.py +1 -1
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/checks.py +3 -3
- ultralytics/utils/downloads.py +2 -2
- ultralytics/utils/loss.py +1 -1
- ultralytics/utils/metrics.py +1 -1
- ultralytics/utils/plotting.py +36 -22
- ultralytics/utils/torch_utils.py +17 -3
- {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/METADATA +1 -1
- {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/RECORD +58 -54
- {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/WHEEL +0 -0
- {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
|
|
5
|
+
from ultralytics.data import build_yolo_dataset
|
|
6
|
+
from ultralytics.models import yolo
|
|
7
|
+
from ultralytics.nn.tasks import WorldModel
|
|
8
|
+
from ultralytics.utils import DEFAULT_CFG, RANK, checks
|
|
9
|
+
from ultralytics.utils.torch_utils import de_parallel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def on_pretrain_routine_end(trainer):
|
|
13
|
+
"""Callback."""
|
|
14
|
+
if RANK in {-1, 0}:
|
|
15
|
+
# NOTE: for evaluation
|
|
16
|
+
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
|
17
|
+
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
|
18
|
+
device = next(trainer.model.parameters()).device
|
|
19
|
+
trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
|
|
20
|
+
for p in trainer.text_model.parameters():
|
|
21
|
+
p.requires_grad_(False)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
25
|
+
"""
|
|
26
|
+
A class to fine-tune a world model on a close-set dataset.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
```python
|
|
30
|
+
from ultralytics.models.yolo.world import WorldModel
|
|
31
|
+
|
|
32
|
+
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
|
|
33
|
+
trainer = WorldTrainer(overrides=args)
|
|
34
|
+
trainer.train()
|
|
35
|
+
```
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
39
|
+
"""Initialize a WorldTrainer object with given arguments."""
|
|
40
|
+
if overrides is None:
|
|
41
|
+
overrides = {}
|
|
42
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
43
|
+
|
|
44
|
+
# Import and assign clip
|
|
45
|
+
try:
|
|
46
|
+
import clip
|
|
47
|
+
except ImportError:
|
|
48
|
+
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
|
49
|
+
import clip
|
|
50
|
+
self.clip = clip
|
|
51
|
+
|
|
52
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
53
|
+
"""Return WorldModel initialized with specified config and weights."""
|
|
54
|
+
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
|
55
|
+
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
|
56
|
+
model = WorldModel(
|
|
57
|
+
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
|
58
|
+
ch=3,
|
|
59
|
+
nc=min(self.data["nc"], 80),
|
|
60
|
+
verbose=verbose and RANK == -1,
|
|
61
|
+
)
|
|
62
|
+
if weights:
|
|
63
|
+
model.load(weights)
|
|
64
|
+
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
|
|
65
|
+
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
def build_dataset(self, img_path, mode="train", batch=None):
|
|
69
|
+
"""
|
|
70
|
+
Build YOLO Dataset.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
img_path (str): Path to the folder containing images.
|
|
74
|
+
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
|
75
|
+
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
|
76
|
+
"""
|
|
77
|
+
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
|
78
|
+
return build_yolo_dataset(
|
|
79
|
+
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def preprocess_batch(self, batch):
|
|
83
|
+
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
|
|
84
|
+
batch = super().preprocess_batch(batch)
|
|
85
|
+
|
|
86
|
+
# NOTE: add text features
|
|
87
|
+
texts = list(itertools.chain(*batch["texts"]))
|
|
88
|
+
text_token = self.clip.tokenize(texts).to(batch["img"].device)
|
|
89
|
+
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
|
|
90
|
+
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
|
91
|
+
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
|
92
|
+
return batch
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
|
|
2
|
+
from ultralytics.data.utils import check_det_dataset
|
|
3
|
+
from ultralytics.models.yolo.world import WorldTrainer
|
|
4
|
+
from ultralytics.utils.torch_utils import de_parallel
|
|
5
|
+
from ultralytics.utils import DEFAULT_CFG
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class WorldTrainerFromScratch(WorldTrainer):
|
|
9
|
+
"""
|
|
10
|
+
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
```python
|
|
14
|
+
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
|
15
|
+
from ultralytics import YOLOWorld
|
|
16
|
+
|
|
17
|
+
data = dict(
|
|
18
|
+
train=dict(
|
|
19
|
+
yolo_data=["Objects365.yaml"],
|
|
20
|
+
grounding_data=[
|
|
21
|
+
dict(
|
|
22
|
+
img_path="../datasets/flickr30k/images",
|
|
23
|
+
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
|
|
24
|
+
),
|
|
25
|
+
dict(
|
|
26
|
+
img_path="../datasets/GQA/images",
|
|
27
|
+
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
|
|
28
|
+
),
|
|
29
|
+
],
|
|
30
|
+
),
|
|
31
|
+
val=dict(yolo_data=["lvis.yaml"]),
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
model = YOLOWorld("yolov8s-worldv2.yaml")
|
|
35
|
+
model.train(data=data, trainer=WorldTrainerFromScratch)
|
|
36
|
+
```
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
40
|
+
"""Initialize a WorldTrainer object with given arguments."""
|
|
41
|
+
if overrides is None:
|
|
42
|
+
overrides = {}
|
|
43
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
44
|
+
|
|
45
|
+
def build_dataset(self, img_path, mode="train", batch=None):
|
|
46
|
+
"""
|
|
47
|
+
Build YOLO Dataset.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
img_path (List[str] | str): Path to the folder containing images.
|
|
51
|
+
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
|
52
|
+
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
|
53
|
+
"""
|
|
54
|
+
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
|
55
|
+
if mode == "train":
|
|
56
|
+
dataset = [
|
|
57
|
+
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
|
|
58
|
+
if isinstance(im_path, str)
|
|
59
|
+
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
|
|
60
|
+
for im_path in img_path
|
|
61
|
+
]
|
|
62
|
+
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
|
|
63
|
+
else:
|
|
64
|
+
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
|
65
|
+
|
|
66
|
+
def get_dataset(self):
|
|
67
|
+
"""
|
|
68
|
+
Get train, val path from data dict if it exists.
|
|
69
|
+
|
|
70
|
+
Returns None if data format is not recognized.
|
|
71
|
+
"""
|
|
72
|
+
final_data = dict()
|
|
73
|
+
data_yaml = self.args.data
|
|
74
|
+
assert data_yaml.get("train", False) # object365.yaml
|
|
75
|
+
assert data_yaml.get("val", False) # lvis.yaml
|
|
76
|
+
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
|
|
77
|
+
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
|
|
78
|
+
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
|
|
79
|
+
for d in data["val"]:
|
|
80
|
+
if d.get("minival") is None: # for lvis dataset
|
|
81
|
+
continue
|
|
82
|
+
d["minival"] = str(d["path"] / d["minival"])
|
|
83
|
+
for s in ["train", "val"]:
|
|
84
|
+
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
|
|
85
|
+
# save grounding data if there's one
|
|
86
|
+
grounding_data = data_yaml[s].get("grounding_data")
|
|
87
|
+
if grounding_data is None:
|
|
88
|
+
continue
|
|
89
|
+
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
|
|
90
|
+
for g in grounding_data:
|
|
91
|
+
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
|
92
|
+
final_data[s] += grounding_data
|
|
93
|
+
# NOTE: to make training work properly, set `nc` and `names`
|
|
94
|
+
final_data["nc"] = data["val"][0]["nc"]
|
|
95
|
+
final_data["names"] = data["val"][0]["names"]
|
|
96
|
+
self.data = final_data
|
|
97
|
+
return final_data["train"], final_data["val"][0]
|
|
98
|
+
|
|
99
|
+
def plot_training_labels(self):
|
|
100
|
+
"""DO NOT plot labels."""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
def final_eval(self):
|
|
104
|
+
"""Performs final evaluation and validation for object detection YOLO-World model."""
|
|
105
|
+
val = self.args.data["val"]["yolo_data"][0]
|
|
106
|
+
self.validator.args.data = val
|
|
107
|
+
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
|
|
108
|
+
return super().final_eval()
|
ultralytics/nn/autobackend.py
CHANGED
|
@@ -374,9 +374,9 @@ class AutoBackend(nn.Module):
|
|
|
374
374
|
metadata = yaml_load(metadata)
|
|
375
375
|
if metadata:
|
|
376
376
|
for k, v in metadata.items():
|
|
377
|
-
if k in
|
|
377
|
+
if k in {"stride", "batch"}:
|
|
378
378
|
metadata[k] = int(v)
|
|
379
|
-
elif k in
|
|
379
|
+
elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
|
|
380
380
|
metadata[k] = eval(v)
|
|
381
381
|
stride = metadata["stride"]
|
|
382
382
|
task = metadata["task"]
|
|
@@ -531,8 +531,8 @@ class AutoBackend(nn.Module):
|
|
|
531
531
|
self.names = {i: f"class{i}" for i in range(nc)}
|
|
532
532
|
else: # Lite or Edge TPU
|
|
533
533
|
details = self.input_details[0]
|
|
534
|
-
|
|
535
|
-
if
|
|
534
|
+
is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
|
|
535
|
+
if is_int:
|
|
536
536
|
scale, zero_point = details["quantization"]
|
|
537
537
|
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
|
|
538
538
|
self.interpreter.set_tensor(details["index"], im)
|
|
@@ -540,7 +540,7 @@ class AutoBackend(nn.Module):
|
|
|
540
540
|
y = []
|
|
541
541
|
for output in self.output_details:
|
|
542
542
|
x = self.interpreter.get_tensor(output["index"])
|
|
543
|
-
if
|
|
543
|
+
if is_int:
|
|
544
544
|
scale, zero_point = output["quantization"]
|
|
545
545
|
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
|
546
546
|
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
|
ultralytics/nn/modules/block.py
CHANGED
|
@@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
|
|
|
519
519
|
def __init__(self):
|
|
520
520
|
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
|
|
521
521
|
super().__init__()
|
|
522
|
-
|
|
522
|
+
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
|
|
523
|
+
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
|
523
524
|
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
|
|
524
525
|
|
|
525
526
|
def forward(self, x, w):
|
|
@@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
|
|
|
542
543
|
"""Initialize ContrastiveHead with region-text similarity parameters."""
|
|
543
544
|
super().__init__()
|
|
544
545
|
self.norm = nn.BatchNorm2d(embed_dims)
|
|
545
|
-
|
|
546
|
+
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
|
|
547
|
+
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
|
546
548
|
# use -1.0 is more stable
|
|
547
549
|
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
|
548
550
|
|
ultralytics/nn/modules/conv.py
CHANGED
|
@@ -296,7 +296,7 @@ class SpatialAttention(nn.Module):
|
|
|
296
296
|
def __init__(self, kernel_size=7):
|
|
297
297
|
"""Initialize Spatial-attention module with kernel size argument."""
|
|
298
298
|
super().__init__()
|
|
299
|
-
assert kernel_size in
|
|
299
|
+
assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
|
|
300
300
|
padding = 3 if kernel_size == 7 else 1
|
|
301
301
|
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
|
302
302
|
self.act = nn.Sigmoid()
|
ultralytics/nn/modules/head.py
CHANGED
|
@@ -54,13 +54,13 @@ class Detect(nn.Module):
|
|
|
54
54
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
|
55
55
|
self.shape = shape
|
|
56
56
|
|
|
57
|
-
if self.export and self.format in
|
|
57
|
+
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
|
58
58
|
box = x_cat[:, : self.reg_max * 4]
|
|
59
59
|
cls = x_cat[:, self.reg_max * 4 :]
|
|
60
60
|
else:
|
|
61
61
|
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
|
62
62
|
|
|
63
|
-
if self.export and self.format in
|
|
63
|
+
if self.export and self.format in {"tflite", "edgetpu"}:
|
|
64
64
|
# Precompute normalization factor to increase numerical stability
|
|
65
65
|
# See https://github.com/ultralytics/ultralytics/issues/7371
|
|
66
66
|
grid_h = shape[2]
|
|
@@ -230,13 +230,13 @@ class WorldDetect(Detect):
|
|
|
230
230
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
|
231
231
|
self.shape = shape
|
|
232
232
|
|
|
233
|
-
if self.export and self.format in
|
|
233
|
+
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
|
234
234
|
box = x_cat[:, : self.reg_max * 4]
|
|
235
235
|
cls = x_cat[:, self.reg_max * 4 :]
|
|
236
236
|
else:
|
|
237
237
|
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
|
238
238
|
|
|
239
|
-
if self.export and self.format in
|
|
239
|
+
if self.export and self.format in {"tflite", "edgetpu"}:
|
|
240
240
|
# Precompute normalization factor to increase numerical stability
|
|
241
241
|
# See https://github.com/ultralytics/ultralytics/issues/7371
|
|
242
242
|
grid_h = shape[2]
|
|
@@ -250,6 +250,15 @@ class WorldDetect(Detect):
|
|
|
250
250
|
y = torch.cat((dbox, cls.sigmoid()), 1)
|
|
251
251
|
return y if self.export else (y, x)
|
|
252
252
|
|
|
253
|
+
def bias_init(self):
|
|
254
|
+
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
|
255
|
+
m = self # self.model[-1] # Detect() module
|
|
256
|
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
|
257
|
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
|
258
|
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
|
259
|
+
a[-1].bias.data[:] = 1.0 # box
|
|
260
|
+
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
|
261
|
+
|
|
253
262
|
|
|
254
263
|
class RTDETRDecoder(nn.Module):
|
|
255
264
|
"""
|
ultralytics/nn/tasks.py
CHANGED
|
@@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
|
|
|
564
564
|
self.clip_model = None # CLIP model placeholder
|
|
565
565
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
566
566
|
|
|
567
|
-
def set_classes(self, text):
|
|
568
|
-
"""
|
|
567
|
+
def set_classes(self, text, batch=80, cache_clip_model=True):
|
|
568
|
+
"""Set classes in advance so that model could do offline-inference without clip model."""
|
|
569
569
|
try:
|
|
570
570
|
import clip
|
|
571
571
|
except ImportError:
|
|
572
|
-
check_requirements("git+https://github.com/
|
|
572
|
+
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
|
573
573
|
import clip
|
|
574
574
|
|
|
575
|
-
if
|
|
575
|
+
if (
|
|
576
|
+
not getattr(self, "clip_model", None) and cache_clip_model
|
|
577
|
+
): # for backwards compatibility of models lacking clip_model attribute
|
|
576
578
|
self.clip_model = clip.load("ViT-B/32")[0]
|
|
577
|
-
|
|
579
|
+
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
|
|
580
|
+
device = next(model.parameters()).device
|
|
578
581
|
text_token = clip.tokenize(text).to(device)
|
|
579
|
-
txt_feats =
|
|
582
|
+
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
|
583
|
+
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
|
580
584
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
|
581
|
-
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
|
585
|
+
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
|
582
586
|
self.model[-1].nc = len(text)
|
|
583
587
|
|
|
584
|
-
def
|
|
585
|
-
"""Initialize the loss criterion for the model."""
|
|
586
|
-
raise NotImplementedError
|
|
587
|
-
|
|
588
|
-
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
|
588
|
+
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
|
589
589
|
"""
|
|
590
590
|
Perform a forward pass through the model.
|
|
591
591
|
|
|
@@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
|
|
|
593
593
|
x (torch.Tensor): The input tensor.
|
|
594
594
|
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
|
|
595
595
|
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
|
596
|
+
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
|
|
596
597
|
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
|
597
598
|
embed (list, optional): A list of feature vectors/embeddings to return.
|
|
598
599
|
|
|
599
600
|
Returns:
|
|
600
601
|
(torch.Tensor): Model's output tensor.
|
|
601
602
|
"""
|
|
602
|
-
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
|
|
603
|
+
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
|
|
603
604
|
if len(txt_feats) != len(x):
|
|
604
605
|
txt_feats = txt_feats.repeat(len(x), 1, 1)
|
|
605
606
|
ori_txt_feats = txt_feats.clone()
|
|
@@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
|
|
|
627
628
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
|
628
629
|
return x
|
|
629
630
|
|
|
631
|
+
def loss(self, batch, preds=None):
|
|
632
|
+
"""
|
|
633
|
+
Compute loss.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
batch (dict): Batch to compute loss on.
|
|
637
|
+
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
|
638
|
+
"""
|
|
639
|
+
if not hasattr(self, "criterion"):
|
|
640
|
+
self.criterion = self.init_criterion()
|
|
641
|
+
|
|
642
|
+
if preds is None:
|
|
643
|
+
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
|
|
644
|
+
return self.criterion(preds, batch)
|
|
645
|
+
|
|
630
646
|
|
|
631
647
|
class Ensemble(nn.ModuleList):
|
|
632
648
|
"""Ensemble of models."""
|
|
@@ -880,7 +896,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|
|
880
896
|
) # num heads
|
|
881
897
|
|
|
882
898
|
args = [c1, c2, *args[1:]]
|
|
883
|
-
if m in
|
|
899
|
+
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}:
|
|
884
900
|
args.insert(2, n) # number of repeats
|
|
885
901
|
n = 1
|
|
886
902
|
elif m is AIFI:
|
ultralytics/solutions/ai_gym.py
CHANGED
|
@@ -81,7 +81,7 @@ class AIGym:
|
|
|
81
81
|
self.annotator = Annotator(im0, line_width=2)
|
|
82
82
|
|
|
83
83
|
for ind, k in enumerate(reversed(self.keypoints)):
|
|
84
|
-
if self.pose_type in
|
|
84
|
+
if self.pose_type in {"pushup", "pullup"}:
|
|
85
85
|
self.angle[ind] = self.annotator.estimate_pose_angle(
|
|
86
86
|
k[int(self.kpts_to_check[0])].cpu(),
|
|
87
87
|
k[int(self.kpts_to_check[1])].cpu(),
|
ultralytics/solutions/heatmap.py
CHANGED
|
@@ -24,6 +24,8 @@ class Heatmap:
|
|
|
24
24
|
self.view_img = False
|
|
25
25
|
self.shape = "circle"
|
|
26
26
|
|
|
27
|
+
self.names = None # Classes names
|
|
28
|
+
|
|
27
29
|
# Image information
|
|
28
30
|
self.imw = None
|
|
29
31
|
self.imh = None
|
|
@@ -52,10 +54,13 @@ class Heatmap:
|
|
|
52
54
|
# Object Counting Information
|
|
53
55
|
self.in_counts = 0
|
|
54
56
|
self.out_counts = 0
|
|
55
|
-
self.
|
|
57
|
+
self.count_ids = []
|
|
58
|
+
self.class_wise_count = {}
|
|
56
59
|
self.count_txt_thickness = 0
|
|
57
|
-
self.count_txt_color = (
|
|
58
|
-
self.
|
|
60
|
+
self.count_txt_color = (255, 255, 255)
|
|
61
|
+
self.line_color = (255, 255, 255)
|
|
62
|
+
self.cls_txtdisplay_gap = 50
|
|
63
|
+
self.fontsize = 0.6
|
|
59
64
|
|
|
60
65
|
# Decay factor
|
|
61
66
|
self.decay_factor = 0.99
|
|
@@ -67,6 +72,7 @@ class Heatmap:
|
|
|
67
72
|
self,
|
|
68
73
|
imw,
|
|
69
74
|
imh,
|
|
75
|
+
classes_names=None,
|
|
70
76
|
colormap=cv2.COLORMAP_JET,
|
|
71
77
|
heatmap_alpha=0.5,
|
|
72
78
|
view_img=False,
|
|
@@ -74,13 +80,15 @@ class Heatmap:
|
|
|
74
80
|
view_out_counts=True,
|
|
75
81
|
count_reg_pts=None,
|
|
76
82
|
count_txt_thickness=2,
|
|
77
|
-
count_txt_color=(
|
|
78
|
-
|
|
83
|
+
count_txt_color=(255, 255, 255),
|
|
84
|
+
fontsize=0.8,
|
|
85
|
+
line_color=(255, 255, 255),
|
|
79
86
|
count_reg_color=(255, 0, 255),
|
|
80
87
|
region_thickness=5,
|
|
81
88
|
line_dist_thresh=15,
|
|
82
89
|
decay_factor=0.99,
|
|
83
90
|
shape="circle",
|
|
91
|
+
cls_txtdisplay_gap=50,
|
|
84
92
|
):
|
|
85
93
|
"""
|
|
86
94
|
Configures the heatmap colormap, width, height and display parameters.
|
|
@@ -89,6 +97,7 @@ class Heatmap:
|
|
|
89
97
|
colormap (cv2.COLORMAP): The colormap to be set.
|
|
90
98
|
imw (int): The width of the frame.
|
|
91
99
|
imh (int): The height of the frame.
|
|
100
|
+
classes_names (dict): Classes names
|
|
92
101
|
heatmap_alpha (float): alpha value for heatmap display
|
|
93
102
|
view_img (bool): Flag indicating frame display
|
|
94
103
|
view_in_counts (bool): Flag to control whether to display the incounts on video stream.
|
|
@@ -96,13 +105,16 @@ class Heatmap:
|
|
|
96
105
|
count_reg_pts (list): Object counting region points
|
|
97
106
|
count_txt_thickness (int): Text thickness for object counting display
|
|
98
107
|
count_txt_color (RGB color): count text color value
|
|
99
|
-
|
|
108
|
+
fontsize (float): Text display font size
|
|
109
|
+
line_color (RGB color): count highlighter line color
|
|
100
110
|
count_reg_color (RGB color): Color of object counting region
|
|
101
111
|
region_thickness (int): Object counting Region thickness
|
|
102
112
|
line_dist_thresh (int): Euclidean Distance threshold for line counter
|
|
103
113
|
decay_factor (float): value for removing heatmap area after object passed
|
|
104
114
|
shape (str): Heatmap shape, rect or circle shape supported
|
|
115
|
+
cls_txtdisplay_gap (int): Display gap between each class count
|
|
105
116
|
"""
|
|
117
|
+
self.names = classes_names
|
|
106
118
|
self.imw = imw
|
|
107
119
|
self.imh = imh
|
|
108
120
|
self.heatmap_alpha = heatmap_alpha
|
|
@@ -116,32 +128,32 @@ class Heatmap:
|
|
|
116
128
|
if len(count_reg_pts) == 2:
|
|
117
129
|
print("Line Counter Initiated.")
|
|
118
130
|
self.count_reg_pts = count_reg_pts
|
|
119
|
-
self.counting_region = LineString(count_reg_pts)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
print("Region Counter Initiated.")
|
|
131
|
+
self.counting_region = LineString(self.count_reg_pts)
|
|
132
|
+
elif len(count_reg_pts) >= 3:
|
|
133
|
+
print("Polygon Counter Initiated.")
|
|
123
134
|
self.count_reg_pts = count_reg_pts
|
|
124
135
|
self.counting_region = Polygon(self.count_reg_pts)
|
|
125
|
-
|
|
126
136
|
else:
|
|
127
|
-
print("Region
|
|
137
|
+
print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
|
|
128
138
|
print("Using Line Counter Now")
|
|
129
|
-
self.counting_region =
|
|
139
|
+
self.counting_region = LineString(self.count_reg_pts)
|
|
130
140
|
|
|
131
141
|
# Heatmap new frame
|
|
132
142
|
self.heatmap = np.zeros((int(self.imh), int(self.imw)), dtype=np.float32)
|
|
133
143
|
|
|
134
144
|
self.count_txt_thickness = count_txt_thickness
|
|
135
145
|
self.count_txt_color = count_txt_color
|
|
136
|
-
self.
|
|
146
|
+
self.fontsize = fontsize
|
|
147
|
+
self.line_color = line_color
|
|
137
148
|
self.region_color = count_reg_color
|
|
138
149
|
self.region_thickness = region_thickness
|
|
139
150
|
self.decay_factor = decay_factor
|
|
140
151
|
self.line_dist_thresh = line_dist_thresh
|
|
141
152
|
self.shape = shape
|
|
153
|
+
self.cls_txtdisplay_gap = cls_txtdisplay_gap
|
|
142
154
|
|
|
143
155
|
# shape of heatmap, if not selected
|
|
144
|
-
if self.shape not in
|
|
156
|
+
if self.shape not in {"circle", "rect"}:
|
|
145
157
|
print("Unknown shape value provided, 'circle' & 'rect' supported")
|
|
146
158
|
print("Using Circular shape now")
|
|
147
159
|
self.shape = "circle"
|
|
@@ -183,6 +195,12 @@ class Heatmap:
|
|
|
183
195
|
)
|
|
184
196
|
|
|
185
197
|
for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
|
|
198
|
+
# Store class info
|
|
199
|
+
if self.names[cls] not in self.class_wise_count:
|
|
200
|
+
if len(self.names[cls]) > 5:
|
|
201
|
+
self.names[cls] = self.names[cls][:5]
|
|
202
|
+
self.class_wise_count[self.names[cls]] = {"in": 0, "out": 0}
|
|
203
|
+
|
|
186
204
|
if self.shape == "circle":
|
|
187
205
|
center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
|
|
188
206
|
radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
|
|
@@ -203,23 +221,39 @@ class Heatmap:
|
|
|
203
221
|
if len(track_line) > 30:
|
|
204
222
|
track_line.pop(0)
|
|
205
223
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
224
|
+
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
|
|
225
|
+
|
|
226
|
+
# Count objects in any polygon
|
|
227
|
+
if len(self.count_reg_pts) >= 3:
|
|
228
|
+
is_inside = self.counting_region.contains(Point(track_line[-1]))
|
|
229
|
+
|
|
230
|
+
if prev_position is not None and is_inside and track_id not in self.count_ids:
|
|
231
|
+
self.count_ids.append(track_id)
|
|
232
|
+
|
|
233
|
+
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
|
213
234
|
self.in_counts += 1
|
|
235
|
+
self.class_wise_count[self.names[cls]]["in"] += 1
|
|
236
|
+
else:
|
|
237
|
+
self.out_counts += 1
|
|
238
|
+
self.class_wise_count[self.names[cls]]["out"] += 1
|
|
214
239
|
|
|
240
|
+
# Count objects using line
|
|
215
241
|
elif len(self.count_reg_pts) == 2:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
self.
|
|
242
|
+
is_inside = (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0
|
|
243
|
+
|
|
244
|
+
if prev_position is not None and is_inside and track_id not in self.count_ids:
|
|
245
|
+
distance = Point(track_line[-1]).distance(self.counting_region)
|
|
246
|
+
|
|
247
|
+
if distance < self.line_dist_thresh and track_id not in self.count_ids:
|
|
248
|
+
self.count_ids.append(track_id)
|
|
249
|
+
|
|
250
|
+
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
|
251
|
+
self.in_counts += 1
|
|
252
|
+
self.class_wise_count[self.names[cls]]["in"] += 1
|
|
253
|
+
else:
|
|
254
|
+
self.out_counts += 1
|
|
255
|
+
self.class_wise_count[self.names[cls]]["out"] += 1
|
|
256
|
+
|
|
223
257
|
else:
|
|
224
258
|
for box, cls in zip(self.boxes, self.clss):
|
|
225
259
|
if self.shape == "circle":
|
|
@@ -240,26 +274,30 @@ class Heatmap:
|
|
|
240
274
|
heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
|
241
275
|
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap)
|
|
242
276
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
277
|
+
label = "Ultralytics Analytics \t"
|
|
278
|
+
|
|
279
|
+
for key, value in self.class_wise_count.items():
|
|
280
|
+
if value["in"] != 0 or value["out"] != 0:
|
|
281
|
+
if not self.view_in_counts and not self.view_out_counts:
|
|
282
|
+
label = None
|
|
283
|
+
elif not self.view_in_counts:
|
|
284
|
+
label += f"{str.capitalize(key)}: IN {value['in']} \t"
|
|
285
|
+
elif not self.view_out_counts:
|
|
286
|
+
label += f"{str.capitalize(key)}: OUT {value['out']} \t"
|
|
287
|
+
else:
|
|
288
|
+
label += f"{str.capitalize(key)}: IN {value['in']} OUT {value['out']} \t"
|
|
289
|
+
|
|
290
|
+
label = label.rstrip()
|
|
291
|
+
label = label.split("\t")
|
|
256
292
|
|
|
257
|
-
if self.count_reg_pts is not None and
|
|
258
|
-
self.annotator.
|
|
259
|
-
counts=
|
|
260
|
-
|
|
293
|
+
if self.count_reg_pts is not None and label is not None:
|
|
294
|
+
self.annotator.display_counts(
|
|
295
|
+
counts=label,
|
|
296
|
+
tf=self.count_txt_thickness,
|
|
297
|
+
fontScale=self.fontsize,
|
|
261
298
|
txt_color=self.count_txt_color,
|
|
262
|
-
|
|
299
|
+
line_color=self.line_color,
|
|
300
|
+
classwise_txtgap=self.cls_txtdisplay_gap,
|
|
263
301
|
)
|
|
264
302
|
|
|
265
303
|
self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
|