ultralytics 8.1.38__py3-none-any.whl → 8.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (58) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +3 -3
  3. ultralytics/cfg/datasets/lvis.yaml +1239 -0
  4. ultralytics/data/__init__.py +18 -2
  5. ultralytics/data/augment.py +124 -3
  6. ultralytics/data/base.py +2 -2
  7. ultralytics/data/build.py +25 -3
  8. ultralytics/data/converter.py +24 -6
  9. ultralytics/data/dataset.py +142 -27
  10. ultralytics/data/loaders.py +11 -8
  11. ultralytics/data/split_dota.py +1 -1
  12. ultralytics/data/utils.py +33 -8
  13. ultralytics/engine/exporter.py +3 -3
  14. ultralytics/engine/model.py +6 -3
  15. ultralytics/engine/results.py +2 -2
  16. ultralytics/engine/trainer.py +59 -55
  17. ultralytics/engine/validator.py +2 -2
  18. ultralytics/hub/utils.py +1 -1
  19. ultralytics/models/fastsam/model.py +1 -1
  20. ultralytics/models/fastsam/prompt.py +4 -5
  21. ultralytics/models/nas/model.py +1 -1
  22. ultralytics/models/sam/model.py +1 -1
  23. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  24. ultralytics/models/yolo/__init__.py +2 -2
  25. ultralytics/models/yolo/classify/train.py +1 -1
  26. ultralytics/models/yolo/detect/train.py +1 -1
  27. ultralytics/models/yolo/detect/val.py +36 -17
  28. ultralytics/models/yolo/model.py +1 -0
  29. ultralytics/models/yolo/world/__init__.py +5 -0
  30. ultralytics/models/yolo/world/train.py +92 -0
  31. ultralytics/models/yolo/world/train_world.py +108 -0
  32. ultralytics/nn/autobackend.py +5 -5
  33. ultralytics/nn/modules/block.py +4 -2
  34. ultralytics/nn/modules/conv.py +1 -1
  35. ultralytics/nn/modules/head.py +13 -4
  36. ultralytics/nn/tasks.py +30 -14
  37. ultralytics/solutions/ai_gym.py +1 -1
  38. ultralytics/solutions/heatmap.py +85 -47
  39. ultralytics/solutions/object_counter.py +79 -64
  40. ultralytics/trackers/byte_tracker.py +1 -1
  41. ultralytics/trackers/track.py +1 -1
  42. ultralytics/trackers/utils/gmc.py +1 -1
  43. ultralytics/utils/__init__.py +4 -4
  44. ultralytics/utils/benchmarks.py +2 -2
  45. ultralytics/utils/callbacks/comet.py +1 -1
  46. ultralytics/utils/callbacks/mlflow.py +1 -1
  47. ultralytics/utils/checks.py +3 -3
  48. ultralytics/utils/downloads.py +2 -2
  49. ultralytics/utils/loss.py +1 -1
  50. ultralytics/utils/metrics.py +1 -1
  51. ultralytics/utils/plotting.py +36 -22
  52. ultralytics/utils/torch_utils.py +17 -3
  53. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/METADATA +1 -1
  54. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/RECORD +58 -54
  55. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/LICENSE +0 -0
  56. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/WHEEL +0 -0
  57. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/entry_points.txt +0 -0
  58. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,92 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import itertools
4
+
5
+ from ultralytics.data import build_yolo_dataset
6
+ from ultralytics.models import yolo
7
+ from ultralytics.nn.tasks import WorldModel
8
+ from ultralytics.utils import DEFAULT_CFG, RANK, checks
9
+ from ultralytics.utils.torch_utils import de_parallel
10
+
11
+
12
+ def on_pretrain_routine_end(trainer):
13
+ """Callback."""
14
+ if RANK in {-1, 0}:
15
+ # NOTE: for evaluation
16
+ names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
17
+ de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
18
+ device = next(trainer.model.parameters()).device
19
+ trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
20
+ for p in trainer.text_model.parameters():
21
+ p.requires_grad_(False)
22
+
23
+
24
+ class WorldTrainer(yolo.detect.DetectionTrainer):
25
+ """
26
+ A class to fine-tune a world model on a close-set dataset.
27
+
28
+ Example:
29
+ ```python
30
+ from ultralytics.models.yolo.world import WorldModel
31
+
32
+ args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
33
+ trainer = WorldTrainer(overrides=args)
34
+ trainer.train()
35
+ ```
36
+ """
37
+
38
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
+ """Initialize a WorldTrainer object with given arguments."""
40
+ if overrides is None:
41
+ overrides = {}
42
+ super().__init__(cfg, overrides, _callbacks)
43
+
44
+ # Import and assign clip
45
+ try:
46
+ import clip
47
+ except ImportError:
48
+ checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
49
+ import clip
50
+ self.clip = clip
51
+
52
+ def get_model(self, cfg=None, weights=None, verbose=True):
53
+ """Return WorldModel initialized with specified config and weights."""
54
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
55
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
56
+ model = WorldModel(
57
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
58
+ ch=3,
59
+ nc=min(self.data["nc"], 80),
60
+ verbose=verbose and RANK == -1,
61
+ )
62
+ if weights:
63
+ model.load(weights)
64
+ self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
65
+
66
+ return model
67
+
68
+ def build_dataset(self, img_path, mode="train", batch=None):
69
+ """
70
+ Build YOLO Dataset.
71
+
72
+ Args:
73
+ img_path (str): Path to the folder containing images.
74
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
75
+ batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
76
+ """
77
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
78
+ return build_yolo_dataset(
79
+ self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
80
+ )
81
+
82
+ def preprocess_batch(self, batch):
83
+ """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
84
+ batch = super().preprocess_batch(batch)
85
+
86
+ # NOTE: add text features
87
+ texts = list(itertools.chain(*batch["texts"]))
88
+ text_token = self.clip.tokenize(texts).to(batch["img"].device)
89
+ txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
90
+ txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
91
+ batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
92
+ return batch
@@ -0,0 +1,108 @@
1
+ from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
2
+ from ultralytics.data.utils import check_det_dataset
3
+ from ultralytics.models.yolo.world import WorldTrainer
4
+ from ultralytics.utils.torch_utils import de_parallel
5
+ from ultralytics.utils import DEFAULT_CFG
6
+
7
+
8
+ class WorldTrainerFromScratch(WorldTrainer):
9
+ """
10
+ A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
11
+
12
+ Example:
13
+ ```python
14
+ from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
15
+ from ultralytics import YOLOWorld
16
+
17
+ data = dict(
18
+ train=dict(
19
+ yolo_data=["Objects365.yaml"],
20
+ grounding_data=[
21
+ dict(
22
+ img_path="../datasets/flickr30k/images",
23
+ json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
24
+ ),
25
+ dict(
26
+ img_path="../datasets/GQA/images",
27
+ json_file="../datasets/GQA/final_mixed_train_no_coco.json",
28
+ ),
29
+ ],
30
+ ),
31
+ val=dict(yolo_data=["lvis.yaml"]),
32
+ )
33
+
34
+ model = YOLOWorld("yolov8s-worldv2.yaml")
35
+ model.train(data=data, trainer=WorldTrainerFromScratch)
36
+ ```
37
+ """
38
+
39
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
40
+ """Initialize a WorldTrainer object with given arguments."""
41
+ if overrides is None:
42
+ overrides = {}
43
+ super().__init__(cfg, overrides, _callbacks)
44
+
45
+ def build_dataset(self, img_path, mode="train", batch=None):
46
+ """
47
+ Build YOLO Dataset.
48
+
49
+ Args:
50
+ img_path (List[str] | str): Path to the folder containing images.
51
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
52
+ batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
53
+ """
54
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
55
+ if mode == "train":
56
+ dataset = [
57
+ build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
58
+ if isinstance(im_path, str)
59
+ else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
60
+ for im_path in img_path
61
+ ]
62
+ return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
63
+ else:
64
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
65
+
66
+ def get_dataset(self):
67
+ """
68
+ Get train, val path from data dict if it exists.
69
+
70
+ Returns None if data format is not recognized.
71
+ """
72
+ final_data = dict()
73
+ data_yaml = self.args.data
74
+ assert data_yaml.get("train", False) # object365.yaml
75
+ assert data_yaml.get("val", False) # lvis.yaml
76
+ data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
77
+ assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
78
+ val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
79
+ for d in data["val"]:
80
+ if d.get("minival") is None: # for lvis dataset
81
+ continue
82
+ d["minival"] = str(d["path"] / d["minival"])
83
+ for s in ["train", "val"]:
84
+ final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
85
+ # save grounding data if there's one
86
+ grounding_data = data_yaml[s].get("grounding_data")
87
+ if grounding_data is None:
88
+ continue
89
+ grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
90
+ for g in grounding_data:
91
+ assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
92
+ final_data[s] += grounding_data
93
+ # NOTE: to make training work properly, set `nc` and `names`
94
+ final_data["nc"] = data["val"][0]["nc"]
95
+ final_data["names"] = data["val"][0]["names"]
96
+ self.data = final_data
97
+ return final_data["train"], final_data["val"][0]
98
+
99
+ def plot_training_labels(self):
100
+ """DO NOT plot labels."""
101
+ pass
102
+
103
+ def final_eval(self):
104
+ """Performs final evaluation and validation for object detection YOLO-World model."""
105
+ val = self.args.data["val"]["yolo_data"][0]
106
+ self.validator.args.data = val
107
+ self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
108
+ return super().final_eval()
@@ -374,9 +374,9 @@ class AutoBackend(nn.Module):
374
374
  metadata = yaml_load(metadata)
375
375
  if metadata:
376
376
  for k, v in metadata.items():
377
- if k in ("stride", "batch"):
377
+ if k in {"stride", "batch"}:
378
378
  metadata[k] = int(v)
379
- elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str):
379
+ elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
380
380
  metadata[k] = eval(v)
381
381
  stride = metadata["stride"]
382
382
  task = metadata["task"]
@@ -531,8 +531,8 @@ class AutoBackend(nn.Module):
531
531
  self.names = {i: f"class{i}" for i in range(nc)}
532
532
  else: # Lite or Edge TPU
533
533
  details = self.input_details[0]
534
- integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
535
- if integer:
534
+ is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
535
+ if is_int:
536
536
  scale, zero_point = details["quantization"]
537
537
  im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
538
538
  self.interpreter.set_tensor(details["index"], im)
@@ -540,7 +540,7 @@ class AutoBackend(nn.Module):
540
540
  y = []
541
541
  for output in self.output_details:
542
542
  x = self.interpreter.get_tensor(output["index"])
543
- if integer:
543
+ if is_int:
544
544
  scale, zero_point = output["quantization"]
545
545
  x = (x.astype(np.float32) - zero_point) * scale # re-scale
546
546
  if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
@@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
519
519
  def __init__(self):
520
520
  """Initializes ContrastiveHead with specified region-text similarity parameters."""
521
521
  super().__init__()
522
- self.bias = nn.Parameter(torch.zeros([]))
522
+ # NOTE: use -10.0 to keep the init cls loss consistency with other losses
523
+ self.bias = nn.Parameter(torch.tensor([-10.0]))
523
524
  self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
524
525
 
525
526
  def forward(self, x, w):
@@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
542
543
  """Initialize ContrastiveHead with region-text similarity parameters."""
543
544
  super().__init__()
544
545
  self.norm = nn.BatchNorm2d(embed_dims)
545
- self.bias = nn.Parameter(torch.zeros([]))
546
+ # NOTE: use -10.0 to keep the init cls loss consistency with other losses
547
+ self.bias = nn.Parameter(torch.tensor([-10.0]))
546
548
  # use -1.0 is more stable
547
549
  self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
548
550
 
@@ -296,7 +296,7 @@ class SpatialAttention(nn.Module):
296
296
  def __init__(self, kernel_size=7):
297
297
  """Initialize Spatial-attention module with kernel size argument."""
298
298
  super().__init__()
299
- assert kernel_size in (3, 7), "kernel size must be 3 or 7"
299
+ assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
300
300
  padding = 3 if kernel_size == 7 else 1
301
301
  self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
302
302
  self.act = nn.Sigmoid()
@@ -54,13 +54,13 @@ class Detect(nn.Module):
54
54
  self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
55
55
  self.shape = shape
56
56
 
57
- if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
57
+ if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
58
58
  box = x_cat[:, : self.reg_max * 4]
59
59
  cls = x_cat[:, self.reg_max * 4 :]
60
60
  else:
61
61
  box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
62
62
 
63
- if self.export and self.format in ("tflite", "edgetpu"):
63
+ if self.export and self.format in {"tflite", "edgetpu"}:
64
64
  # Precompute normalization factor to increase numerical stability
65
65
  # See https://github.com/ultralytics/ultralytics/issues/7371
66
66
  grid_h = shape[2]
@@ -230,13 +230,13 @@ class WorldDetect(Detect):
230
230
  self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
231
231
  self.shape = shape
232
232
 
233
- if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
233
+ if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
234
234
  box = x_cat[:, : self.reg_max * 4]
235
235
  cls = x_cat[:, self.reg_max * 4 :]
236
236
  else:
237
237
  box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
238
238
 
239
- if self.export and self.format in ("tflite", "edgetpu"):
239
+ if self.export and self.format in {"tflite", "edgetpu"}:
240
240
  # Precompute normalization factor to increase numerical stability
241
241
  # See https://github.com/ultralytics/ultralytics/issues/7371
242
242
  grid_h = shape[2]
@@ -250,6 +250,15 @@ class WorldDetect(Detect):
250
250
  y = torch.cat((dbox, cls.sigmoid()), 1)
251
251
  return y if self.export else (y, x)
252
252
 
253
+ def bias_init(self):
254
+ """Initialize Detect() biases, WARNING: requires stride availability."""
255
+ m = self # self.model[-1] # Detect() module
256
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
257
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
258
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
259
+ a[-1].bias.data[:] = 1.0 # box
260
+ # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
261
+
253
262
 
254
263
  class RTDETRDecoder(nn.Module):
255
264
  """
ultralytics/nn/tasks.py CHANGED
@@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
564
564
  self.clip_model = None # CLIP model placeholder
565
565
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
566
566
 
567
- def set_classes(self, text):
568
- """Perform a forward pass with optional profiling, visualization, and embedding extraction."""
567
+ def set_classes(self, text, batch=80, cache_clip_model=True):
568
+ """Set classes in advance so that model could do offline-inference without clip model."""
569
569
  try:
570
570
  import clip
571
571
  except ImportError:
572
- check_requirements("git+https://github.com/openai/CLIP.git")
572
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
573
573
  import clip
574
574
 
575
- if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute
575
+ if (
576
+ not getattr(self, "clip_model", None) and cache_clip_model
577
+ ): # for backwards compatibility of models lacking clip_model attribute
576
578
  self.clip_model = clip.load("ViT-B/32")[0]
577
- device = next(self.clip_model.parameters()).device
579
+ model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
580
+ device = next(model.parameters()).device
578
581
  text_token = clip.tokenize(text).to(device)
579
- txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
582
+ txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
583
+ txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
580
584
  txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
581
- self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
585
+ self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
582
586
  self.model[-1].nc = len(text)
583
587
 
584
- def init_criterion(self):
585
- """Initialize the loss criterion for the model."""
586
- raise NotImplementedError
587
-
588
- def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
588
+ def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
589
589
  """
590
590
  Perform a forward pass through the model.
591
591
 
@@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
593
593
  x (torch.Tensor): The input tensor.
594
594
  profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
595
595
  visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
596
+ txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
596
597
  augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
597
598
  embed (list, optional): A list of feature vectors/embeddings to return.
598
599
 
599
600
  Returns:
600
601
  (torch.Tensor): Model's output tensor.
601
602
  """
602
- txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
603
+ txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
603
604
  if len(txt_feats) != len(x):
604
605
  txt_feats = txt_feats.repeat(len(x), 1, 1)
605
606
  ori_txt_feats = txt_feats.clone()
@@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
627
628
  return torch.unbind(torch.cat(embeddings, 1), dim=0)
628
629
  return x
629
630
 
631
+ def loss(self, batch, preds=None):
632
+ """
633
+ Compute loss.
634
+
635
+ Args:
636
+ batch (dict): Batch to compute loss on.
637
+ preds (torch.Tensor | List[torch.Tensor]): Predictions.
638
+ """
639
+ if not hasattr(self, "criterion"):
640
+ self.criterion = self.init_criterion()
641
+
642
+ if preds is None:
643
+ preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
644
+ return self.criterion(preds, batch)
645
+
630
646
 
631
647
  class Ensemble(nn.ModuleList):
632
648
  """Ensemble of models."""
@@ -880,7 +896,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
880
896
  ) # num heads
881
897
 
882
898
  args = [c1, c2, *args[1:]]
883
- if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3):
899
+ if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}:
884
900
  args.insert(2, n) # number of repeats
885
901
  n = 1
886
902
  elif m is AIFI:
@@ -81,7 +81,7 @@ class AIGym:
81
81
  self.annotator = Annotator(im0, line_width=2)
82
82
 
83
83
  for ind, k in enumerate(reversed(self.keypoints)):
84
- if self.pose_type in ["pushup", "pullup"]:
84
+ if self.pose_type in {"pushup", "pullup"}:
85
85
  self.angle[ind] = self.annotator.estimate_pose_angle(
86
86
  k[int(self.kpts_to_check[0])].cpu(),
87
87
  k[int(self.kpts_to_check[1])].cpu(),
@@ -24,6 +24,8 @@ class Heatmap:
24
24
  self.view_img = False
25
25
  self.shape = "circle"
26
26
 
27
+ self.names = None # Classes names
28
+
27
29
  # Image information
28
30
  self.imw = None
29
31
  self.imh = None
@@ -52,10 +54,13 @@ class Heatmap:
52
54
  # Object Counting Information
53
55
  self.in_counts = 0
54
56
  self.out_counts = 0
55
- self.counting_list = []
57
+ self.count_ids = []
58
+ self.class_wise_count = {}
56
59
  self.count_txt_thickness = 0
57
- self.count_txt_color = (0, 0, 0)
58
- self.count_color = (255, 255, 255)
60
+ self.count_txt_color = (255, 255, 255)
61
+ self.line_color = (255, 255, 255)
62
+ self.cls_txtdisplay_gap = 50
63
+ self.fontsize = 0.6
59
64
 
60
65
  # Decay factor
61
66
  self.decay_factor = 0.99
@@ -67,6 +72,7 @@ class Heatmap:
67
72
  self,
68
73
  imw,
69
74
  imh,
75
+ classes_names=None,
70
76
  colormap=cv2.COLORMAP_JET,
71
77
  heatmap_alpha=0.5,
72
78
  view_img=False,
@@ -74,13 +80,15 @@ class Heatmap:
74
80
  view_out_counts=True,
75
81
  count_reg_pts=None,
76
82
  count_txt_thickness=2,
77
- count_txt_color=(0, 0, 0),
78
- count_color=(255, 255, 255),
83
+ count_txt_color=(255, 255, 255),
84
+ fontsize=0.8,
85
+ line_color=(255, 255, 255),
79
86
  count_reg_color=(255, 0, 255),
80
87
  region_thickness=5,
81
88
  line_dist_thresh=15,
82
89
  decay_factor=0.99,
83
90
  shape="circle",
91
+ cls_txtdisplay_gap=50,
84
92
  ):
85
93
  """
86
94
  Configures the heatmap colormap, width, height and display parameters.
@@ -89,6 +97,7 @@ class Heatmap:
89
97
  colormap (cv2.COLORMAP): The colormap to be set.
90
98
  imw (int): The width of the frame.
91
99
  imh (int): The height of the frame.
100
+ classes_names (dict): Classes names
92
101
  heatmap_alpha (float): alpha value for heatmap display
93
102
  view_img (bool): Flag indicating frame display
94
103
  view_in_counts (bool): Flag to control whether to display the incounts on video stream.
@@ -96,13 +105,16 @@ class Heatmap:
96
105
  count_reg_pts (list): Object counting region points
97
106
  count_txt_thickness (int): Text thickness for object counting display
98
107
  count_txt_color (RGB color): count text color value
99
- count_color (RGB color): count text background color value
108
+ fontsize (float): Text display font size
109
+ line_color (RGB color): count highlighter line color
100
110
  count_reg_color (RGB color): Color of object counting region
101
111
  region_thickness (int): Object counting Region thickness
102
112
  line_dist_thresh (int): Euclidean Distance threshold for line counter
103
113
  decay_factor (float): value for removing heatmap area after object passed
104
114
  shape (str): Heatmap shape, rect or circle shape supported
115
+ cls_txtdisplay_gap (int): Display gap between each class count
105
116
  """
117
+ self.names = classes_names
106
118
  self.imw = imw
107
119
  self.imh = imh
108
120
  self.heatmap_alpha = heatmap_alpha
@@ -116,32 +128,32 @@ class Heatmap:
116
128
  if len(count_reg_pts) == 2:
117
129
  print("Line Counter Initiated.")
118
130
  self.count_reg_pts = count_reg_pts
119
- self.counting_region = LineString(count_reg_pts)
120
-
121
- elif len(count_reg_pts) == 4:
122
- print("Region Counter Initiated.")
131
+ self.counting_region = LineString(self.count_reg_pts)
132
+ elif len(count_reg_pts) >= 3:
133
+ print("Polygon Counter Initiated.")
123
134
  self.count_reg_pts = count_reg_pts
124
135
  self.counting_region = Polygon(self.count_reg_pts)
125
-
126
136
  else:
127
- print("Region or line points Invalid, 2 or 4 points supported")
137
+ print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
128
138
  print("Using Line Counter Now")
129
- self.counting_region = Polygon([(20, 400), (1260, 400)]) # dummy points
139
+ self.counting_region = LineString(self.count_reg_pts)
130
140
 
131
141
  # Heatmap new frame
132
142
  self.heatmap = np.zeros((int(self.imh), int(self.imw)), dtype=np.float32)
133
143
 
134
144
  self.count_txt_thickness = count_txt_thickness
135
145
  self.count_txt_color = count_txt_color
136
- self.count_color = count_color
146
+ self.fontsize = fontsize
147
+ self.line_color = line_color
137
148
  self.region_color = count_reg_color
138
149
  self.region_thickness = region_thickness
139
150
  self.decay_factor = decay_factor
140
151
  self.line_dist_thresh = line_dist_thresh
141
152
  self.shape = shape
153
+ self.cls_txtdisplay_gap = cls_txtdisplay_gap
142
154
 
143
155
  # shape of heatmap, if not selected
144
- if self.shape not in ["circle", "rect"]:
156
+ if self.shape not in {"circle", "rect"}:
145
157
  print("Unknown shape value provided, 'circle' & 'rect' supported")
146
158
  print("Using Circular shape now")
147
159
  self.shape = "circle"
@@ -183,6 +195,12 @@ class Heatmap:
183
195
  )
184
196
 
185
197
  for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
198
+ # Store class info
199
+ if self.names[cls] not in self.class_wise_count:
200
+ if len(self.names[cls]) > 5:
201
+ self.names[cls] = self.names[cls][:5]
202
+ self.class_wise_count[self.names[cls]] = {"in": 0, "out": 0}
203
+
186
204
  if self.shape == "circle":
187
205
  center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
188
206
  radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
@@ -203,23 +221,39 @@ class Heatmap:
203
221
  if len(track_line) > 30:
204
222
  track_line.pop(0)
205
223
 
206
- # Count objects
207
- if len(self.count_reg_pts) == 4:
208
- if self.counting_region.contains(Point(track_line[-1])) and track_id not in self.counting_list:
209
- self.counting_list.append(track_id)
210
- if box[0] < self.counting_region.centroid.x:
211
- self.out_counts += 1
212
- else:
224
+ prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
225
+
226
+ # Count objects in any polygon
227
+ if len(self.count_reg_pts) >= 3:
228
+ is_inside = self.counting_region.contains(Point(track_line[-1]))
229
+
230
+ if prev_position is not None and is_inside and track_id not in self.count_ids:
231
+ self.count_ids.append(track_id)
232
+
233
+ if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
213
234
  self.in_counts += 1
235
+ self.class_wise_count[self.names[cls]]["in"] += 1
236
+ else:
237
+ self.out_counts += 1
238
+ self.class_wise_count[self.names[cls]]["out"] += 1
214
239
 
240
+ # Count objects using line
215
241
  elif len(self.count_reg_pts) == 2:
216
- distance = Point(track_line[-1]).distance(self.counting_region)
217
- if distance < self.line_dist_thresh and track_id not in self.counting_list:
218
- self.counting_list.append(track_id)
219
- if box[0] < self.counting_region.centroid.x:
220
- self.out_counts += 1
221
- else:
222
- self.in_counts += 1
242
+ is_inside = (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0
243
+
244
+ if prev_position is not None and is_inside and track_id not in self.count_ids:
245
+ distance = Point(track_line[-1]).distance(self.counting_region)
246
+
247
+ if distance < self.line_dist_thresh and track_id not in self.count_ids:
248
+ self.count_ids.append(track_id)
249
+
250
+ if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
251
+ self.in_counts += 1
252
+ self.class_wise_count[self.names[cls]]["in"] += 1
253
+ else:
254
+ self.out_counts += 1
255
+ self.class_wise_count[self.names[cls]]["out"] += 1
256
+
223
257
  else:
224
258
  for box, cls in zip(self.boxes, self.clss):
225
259
  if self.shape == "circle":
@@ -240,26 +274,30 @@ class Heatmap:
240
274
  heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
241
275
  heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap)
242
276
 
243
- incount_label = f"In Count : {self.in_counts}"
244
- outcount_label = f"OutCount : {self.out_counts}"
245
-
246
- # Display counts based on user choice
247
- counts_label = None
248
- if not self.view_in_counts and not self.view_out_counts:
249
- counts_label = None
250
- elif not self.view_in_counts:
251
- counts_label = outcount_label
252
- elif not self.view_out_counts:
253
- counts_label = incount_label
254
- else:
255
- counts_label = f"{incount_label} {outcount_label}"
277
+ label = "Ultralytics Analytics \t"
278
+
279
+ for key, value in self.class_wise_count.items():
280
+ if value["in"] != 0 or value["out"] != 0:
281
+ if not self.view_in_counts and not self.view_out_counts:
282
+ label = None
283
+ elif not self.view_in_counts:
284
+ label += f"{str.capitalize(key)}: IN {value['in']} \t"
285
+ elif not self.view_out_counts:
286
+ label += f"{str.capitalize(key)}: OUT {value['out']} \t"
287
+ else:
288
+ label += f"{str.capitalize(key)}: IN {value['in']} OUT {value['out']} \t"
289
+
290
+ label = label.rstrip()
291
+ label = label.split("\t")
256
292
 
257
- if self.count_reg_pts is not None and counts_label is not None:
258
- self.annotator.count_labels(
259
- counts=counts_label,
260
- count_txt_size=self.count_txt_thickness,
293
+ if self.count_reg_pts is not None and label is not None:
294
+ self.annotator.display_counts(
295
+ counts=label,
296
+ tf=self.count_txt_thickness,
297
+ fontScale=self.fontsize,
261
298
  txt_color=self.count_txt_color,
262
- color=self.count_color,
299
+ line_color=self.line_color,
300
+ classwise_txtgap=self.cls_txtdisplay_gap,
263
301
  )
264
302
 
265
303
  self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)