ultralytics 8.2.36__py3-none-any.whl → 8.2.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

@@ -1,6 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
  """Model head modules."""
3
3
 
4
+ import copy
4
5
  import math
5
6
 
6
7
  import torch
@@ -14,7 +15,7 @@ from .conv import Conv
14
15
  from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
15
16
  from .utils import bias_init_with_prob, linear_init
16
17
 
17
- __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
18
+ __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
18
19
 
19
20
 
20
21
  class Detect(nn.Module):
@@ -22,6 +23,8 @@ class Detect(nn.Module):
22
23
 
23
24
  dynamic = False # force grid reconstruction
24
25
  export = False # export mode
26
+ end2end = False # end2end
27
+ max_det = 300 # max_det
25
28
  shape = None
26
29
  anchors = torch.empty(0) # init
27
30
  strides = torch.empty(0) # init
@@ -41,13 +44,48 @@ class Detect(nn.Module):
41
44
  self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
42
45
  self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
43
46
 
47
+ if self.end2end:
48
+ self.one2one_cv2 = copy.deepcopy(self.cv2)
49
+ self.one2one_cv3 = copy.deepcopy(self.cv3)
50
+
44
51
  def forward(self, x):
45
52
  """Concatenates and returns predicted bounding boxes and class probabilities."""
53
+ if self.end2end:
54
+ return self.forward_end2end(x)
55
+
46
56
  for i in range(self.nl):
47
57
  x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
48
58
  if self.training: # Training path
49
59
  return x
60
+ y = self._inference(x)
61
+ return y if self.export else (y, x)
62
+
63
+ def forward_end2end(self, x):
64
+ """
65
+ Performs forward pass of the v10Detect module.
66
+
67
+ Args:
68
+ x (tensor): Input tensor.
69
+
70
+ Returns:
71
+ (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
72
+ If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
73
+ """
74
+ x_detach = [xi.detach() for xi in x]
75
+ one2one = [
76
+ torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
77
+ ]
78
+ for i in range(self.nl):
79
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
80
+ if self.training: # Training path
81
+ return {"one2many": x, "one2one": one2one}
50
82
 
83
+ y = self._inference(one2one)
84
+ y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
85
+ return y if self.export else (y, {"one2many": x, "one2one": one2one})
86
+
87
+ def _inference(self, x):
88
+ """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
51
89
  # Inference path
52
90
  shape = x[0].shape # BCHW
53
91
  x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
@@ -73,7 +111,7 @@ class Detect(nn.Module):
73
111
  dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
74
112
 
75
113
  y = torch.cat((dbox, cls.sigmoid()), 1)
76
- return y if self.export else (y, x)
114
+ return y
77
115
 
78
116
  def bias_init(self):
79
117
  """Initialize Detect() biases, WARNING: requires stride availability."""
@@ -83,10 +121,47 @@ class Detect(nn.Module):
83
121
  for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
84
122
  a[-1].bias.data[:] = 1.0 # box
85
123
  b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
124
+ if self.end2end:
125
+ for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
126
+ a[-1].bias.data[:] = 1.0 # box
127
+ b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
86
128
 
87
129
  def decode_bboxes(self, bboxes, anchors):
88
130
  """Decode bounding boxes."""
89
- return dist2bbox(bboxes, anchors, xywh=True, dim=1)
131
+ return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
132
+
133
+ @staticmethod
134
+ def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
135
+ """
136
+ Post-processes the predictions obtained from a YOLOv10 model.
137
+
138
+ Args:
139
+ preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).
140
+ max_det (int): The maximum number of detections to keep.
141
+ nc (int, optional): The number of classes. Defaults to 80.
142
+
143
+ Returns:
144
+ (torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),
145
+ including bounding boxes, scores and cls.
146
+ """
147
+ assert 4 + nc == preds.shape[-1]
148
+ boxes, scores = preds.split([4, nc], dim=-1)
149
+ max_scores = scores.amax(dim=-1)
150
+ max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1)
151
+ index = index.unsqueeze(-1)
152
+ boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
153
+ scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
154
+
155
+ # NOTE: simplify but result slightly lower mAP
156
+ # scores, labels = scores.max(dim=-1)
157
+ # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
158
+
159
+ scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
160
+ labels = index % nc
161
+ index = index // nc
162
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
163
+
164
+ return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
90
165
 
91
166
 
92
167
  class Segment(Detect):
@@ -487,3 +562,39 @@ class RTDETRDecoder(nn.Module):
487
562
  xavier_uniform_(self.query_pos_head.layers[1].weight)
488
563
  for layer in self.input_proj:
489
564
  xavier_uniform_(layer[0].weight)
565
+
566
+
567
+ class v10Detect(Detect):
568
+ """
569
+ v10 Detection head from https://arxiv.org/pdf/2405.14458
570
+
571
+ Args:
572
+ nc (int): Number of classes.
573
+ ch (tuple): Tuple of channel sizes.
574
+
575
+ Attributes:
576
+ max_det (int): Maximum number of detections.
577
+
578
+ Methods:
579
+ __init__(self, nc=80, ch=()): Initializes the v10Detect object.
580
+ forward(self, x): Performs forward pass of the v10Detect module.
581
+ bias_init(self): Initializes biases of the Detect module.
582
+
583
+ """
584
+
585
+ end2end = True
586
+
587
+ def __init__(self, nc=80, ch=()):
588
+ """Initializes the v10Detect object with the specified number of classes and input channels."""
589
+ super().__init__(nc, ch)
590
+ c3 = max(ch[0], min(self.nc, 100)) # channels
591
+ # Light cls head
592
+ self.cv3 = nn.ModuleList(
593
+ nn.Sequential(
594
+ nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)),
595
+ nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)),
596
+ nn.Conv2d(c3, self.nc, 1),
597
+ )
598
+ for x in ch
599
+ )
600
+ self.one2one_cv3 = copy.deepcopy(self.cv3)
ultralytics/nn/tasks.py CHANGED
@@ -15,6 +15,7 @@ from ultralytics.nn.modules import (
15
15
  C3TR,
16
16
  ELAN1,
17
17
  OBB,
18
+ PSA,
18
19
  SPP,
19
20
  SPPELAN,
20
21
  SPPF,
@@ -24,6 +25,7 @@ from ultralytics.nn.modules import (
24
25
  BottleneckCSP,
25
26
  C2f,
26
27
  C2fAttn,
28
+ C2fCIB,
27
29
  C3Ghost,
28
30
  C3x,
29
31
  CBFuse,
@@ -46,14 +48,24 @@ from ultralytics.nn.modules import (
46
48
  RepC3,
47
49
  RepConv,
48
50
  RepNCSPELAN4,
51
+ RepVGGDW,
49
52
  ResNetLayer,
50
53
  RTDETRDecoder,
54
+ SCDown,
51
55
  Segment,
52
56
  WorldDetect,
57
+ v10Detect,
53
58
  )
54
59
  from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
55
60
  from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
56
- from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
61
+ from ultralytics.utils.loss import (
62
+ E2EDetectLoss,
63
+ v8ClassificationLoss,
64
+ v8DetectionLoss,
65
+ v8OBBLoss,
66
+ v8PoseLoss,
67
+ v8SegmentationLoss,
68
+ )
57
69
  from ultralytics.utils.plotting import feature_visualization
58
70
  from ultralytics.utils.torch_utils import (
59
71
  fuse_conv_and_bn,
@@ -192,6 +204,9 @@ class BaseModel(nn.Module):
192
204
  if isinstance(m, RepConv):
193
205
  m.fuse_convs()
194
206
  m.forward = m.forward_fuse # update forward
207
+ if isinstance(m, RepVGGDW):
208
+ m.fuse()
209
+ m.forward = m.forward_fuse
195
210
  self.info(verbose=verbose)
196
211
 
197
212
  return self
@@ -294,6 +309,7 @@ class DetectionModel(BaseModel):
294
309
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
295
310
  self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
296
311
  self.inplace = self.yaml.get("inplace", True)
312
+ self.end2end = getattr(self.model[-1], "end2end", False)
297
313
 
298
314
  # Build strides
299
315
  m = self.model[-1] # Detect()
@@ -303,6 +319,8 @@ class DetectionModel(BaseModel):
303
319
 
304
320
  def _forward(x):
305
321
  """Performs a forward pass through the model, handling different Detect subclass types accordingly."""
322
+ if self.end2end:
323
+ return self.forward(x)["one2many"]
306
324
  return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
307
325
 
308
326
  m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
@@ -355,7 +373,7 @@ class DetectionModel(BaseModel):
355
373
 
356
374
  def init_criterion(self):
357
375
  """Initialize the loss criterion for the DetectionModel."""
358
- return v8DetectionLoss(self)
376
+ return E2EDetectLoss(self) if self.end2end else v8DetectionLoss(self)
359
377
 
360
378
 
361
379
  class OBBModel(DetectionModel):
@@ -675,7 +693,7 @@ class Ensemble(nn.ModuleList):
675
693
 
676
694
 
677
695
  @contextlib.contextmanager
678
- def temporary_modules(modules=None):
696
+ def temporary_modules(modules={}, attributes={}):
679
697
  """
680
698
  Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
681
699
 
@@ -685,11 +703,13 @@ def temporary_modules(modules=None):
685
703
 
686
704
  Args:
687
705
  modules (dict, optional): A dictionary mapping old module paths to new module paths.
706
+ attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
688
707
 
689
708
  Example:
690
709
  ```python
691
- with temporary_modules({'old.module.path': 'new.module.path'}):
692
- import old.module.path # this will now import new.module.path
710
+ with temporary_modules({'old.module': 'new.module'}, {'old.module.attribute': 'new.module.attribute'}):
711
+ import old.module # this will now import new.module
712
+ from old.module import attribute # this will now import new.module.attribute
693
713
  ```
694
714
 
695
715
  Note:
@@ -697,16 +717,20 @@ def temporary_modules(modules=None):
697
717
  Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
698
718
  applications or libraries. Use this function with caution.
699
719
  """
700
- if not modules:
701
- modules = {}
702
720
 
703
- import importlib
704
721
  import sys
722
+ from importlib import import_module
705
723
 
706
724
  try:
725
+ # Set attributes in sys.modules under their old name
726
+ for old, new in attributes.items():
727
+ old_module, old_attr = old.rsplit(".", 1)
728
+ new_module, new_attr = new.rsplit(".", 1)
729
+ setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
730
+
707
731
  # Set modules in sys.modules under their old name
708
732
  for old, new in modules.items():
709
- sys.modules[old] = importlib.import_module(new)
733
+ sys.modules[old] = import_module(new)
710
734
 
711
735
  yield
712
736
  finally:
@@ -734,12 +758,16 @@ def torch_safe_load(weight):
734
758
  file = attempt_download_asset(weight) # search online if missing locally
735
759
  try:
736
760
  with temporary_modules(
737
- {
761
+ modules={
738
762
  "ultralytics.yolo.utils": "ultralytics.utils",
739
763
  "ultralytics.yolo.v8": "ultralytics.models.yolo",
740
764
  "ultralytics.yolo.data": "ultralytics.data",
741
- }
742
- ): # for legacy 8.0 Classify and Pose models
765
+ },
766
+ attributes={
767
+ "ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
768
+ "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
769
+ },
770
+ ):
743
771
  ckpt = torch.load(file, map_location="cpu")
744
772
 
745
773
  except ModuleNotFoundError as e: # e.name is missing module name
@@ -898,6 +926,9 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
898
926
  DWConvTranspose2d,
899
927
  C3x,
900
928
  RepC3,
929
+ PSA,
930
+ SCDown,
931
+ C2fCIB,
901
932
  }:
902
933
  c1, c2 = ch[f], args[0]
903
934
  if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
@@ -909,7 +940,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
909
940
  ) # num heads
910
941
 
911
942
  args = [c1, c2, *args[1:]]
912
- if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}:
943
+ if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB}:
913
944
  args.insert(2, n) # number of repeats
914
945
  n = 1
915
946
  elif m is AIFI:
@@ -926,7 +957,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
926
957
  args = [ch[f]]
927
958
  elif m is Concat:
928
959
  c2 = sum(ch[x] for x in f)
929
- elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
960
+ elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
930
961
  args.append([ch[x] for x in f])
931
962
  if m is Segment:
932
963
  args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@@ -1011,7 +1042,7 @@ def guess_model_task(model):
1011
1042
  m = cfg["head"][-1][-2].lower() # output module name
1012
1043
  if m in {"classify", "classifier", "cls", "fc"}:
1013
1044
  return "classify"
1014
- if m == "detect":
1045
+ if "detect" in m:
1015
1046
  return "detect"
1016
1047
  if m == "segment":
1017
1048
  return "segment"
@@ -1043,7 +1074,7 @@ def guess_model_task(model):
1043
1074
  return "pose"
1044
1075
  elif isinstance(m, OBB):
1045
1076
  return "obb"
1046
- elif isinstance(m, (Detect, WorldDetect)):
1077
+ elif isinstance(m, (Detect, WorldDetect, v10Detect)):
1047
1078
  return "detect"
1048
1079
 
1049
1080
  # Guess from model filename
@@ -81,6 +81,7 @@ def benchmark(
81
81
  device = select_device(device, verbose=False)
82
82
  if isinstance(model, (str, Path)):
83
83
  model = YOLO(model)
84
+ is_end2end = getattr(model.model.model[-1], "end2end", False)
84
85
 
85
86
  y = []
86
87
  t0 = time.time()
@@ -96,14 +97,18 @@ def benchmark(
96
97
  assert MACOS or LINUX, "CoreML and TF.js export only supported on macOS and Linux"
97
98
  assert not IS_RASPBERRYPI, "CoreML and TF.js export not supported on Raspberry Pi"
98
99
  assert not IS_JETSON, "CoreML and TF.js export not supported on NVIDIA Jetson"
100
+ assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet"
99
101
  if i in {3, 5}: # CoreML and OpenVINO
100
102
  assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
101
103
  if i in {6, 7, 8, 9, 10}: # All TF formats
102
104
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
105
+ assert not is_end2end, "End-to-end models not supported by onnx2tf yet"
103
106
  if i in {11}: # Paddle
104
107
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
108
+ assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
105
109
  if i in {12}: # NCNN
106
110
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
111
+ assert not is_end2end, "End-to-end models not supported by NCNN yet"
107
112
  if "cpu" in device.type:
108
113
  assert cpu, "inference not supported on CPU"
109
114
  if "cuda" in device.type:
@@ -23,6 +23,7 @@ GITHUB_ASSETS_NAMES = (
23
23
  + [f"yolov8{k}-world.pt" for k in "smlx"]
24
24
  + [f"yolov8{k}-worldv2.pt" for k in "smlx"]
25
25
  + [f"yolov9{k}.pt" for k in "ce"]
26
+ + [f"yolov10{k}.pt" for k in "nsmblx"]
26
27
  + [f"yolo_nas_{k}.pt" for k in "sml"]
27
28
  + [f"sam_{k}.pt" for k in "bl"]
28
29
  + [f"FastSAM-{k}.pt" for k in "sx"]
ultralytics/utils/loss.py CHANGED
@@ -148,7 +148,7 @@ class KeypointLoss(nn.Module):
148
148
  class v8DetectionLoss:
149
149
  """Criterion class for computing training losses."""
150
150
 
151
- def __init__(self, model): # model must be de-paralleled
151
+ def __init__(self, model, tal_topk=10): # model must be de-paralleled
152
152
  """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
153
153
  device = next(model.parameters()).device # get model device
154
154
  h = model.args # hyperparameters
@@ -164,7 +164,7 @@ class v8DetectionLoss:
164
164
 
165
165
  self.use_dfl = m.reg_max > 1
166
166
 
167
- self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
167
+ self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
168
168
  self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
169
169
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
170
170
 
@@ -714,3 +714,21 @@ class v8OBBLoss(v8DetectionLoss):
714
714
  b, a, c = pred_dist.shape # batch, anchors, channels
715
715
  pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
716
716
  return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
717
+
718
+
719
+ class E2EDetectLoss:
720
+ """Criterion class for computing training losses."""
721
+
722
+ def __init__(self, model):
723
+ """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
724
+ self.one2many = v8DetectionLoss(model, tal_topk=10)
725
+ self.one2one = v8DetectionLoss(model, tal_topk=1)
726
+
727
+ def __call__(self, preds, batch):
728
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
729
+ preds = preds[1] if isinstance(preds, tuple) else preds
730
+ one2many = preds["one2many"]
731
+ loss_one2many = self.one2many(one2many, batch)
732
+ one2one = preds["one2one"]
733
+ loss_one2one = self.one2one(one2one, batch)
734
+ return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
@@ -64,8 +64,9 @@ def box_iou(box1, box2, eps=1e-7):
64
64
  (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
65
65
  """
66
66
 
67
+ # NOTE: Need .float() to get accurate iou values
67
68
  # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
68
- (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
69
+ (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
69
70
  inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
70
71
 
71
72
  # IoU = inter / (area1 + area2 - inter)
ultralytics/utils/ops.py CHANGED
@@ -213,6 +213,9 @@ def non_max_suppression(
213
213
  if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
214
214
  prediction = prediction[0] # select only inference output
215
215
 
216
+ if prediction.shape[-1] == 6: # end-to-end model
217
+ return [pred[pred[:, 4] > conf_thres] for pred in prediction]
218
+
216
219
  bs = prediction.shape[0] # batch size
217
220
  nc = nc or (prediction.shape[1] - 4) # number of classes
218
221
  nm = prediction.shape[1] - nc - 4 # number of masks
@@ -183,11 +183,108 @@ class Annotator:
183
183
  (104, 31, 17),
184
184
  }
185
185
 
186
- def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
187
- """Add one xyxy box to image with label."""
188
- txt_color = (
189
- (104, 31, 17) if color in self.dark_colors else (255, 255, 255) if color in self.light_colors else txt_color
186
+ def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
187
+ """Assign text color based on background color."""
188
+ if color in self.dark_colors:
189
+ return 104, 31, 17
190
+ elif color in self.light_colors:
191
+ return 255, 255, 255
192
+ else:
193
+ return txt_color
194
+
195
+ def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):
196
+ """
197
+ Draws a label with a background rectangle centered within a given bounding box.
198
+
199
+ Args:
200
+ box (tuple): The bounding box coordinates (x1, y1, x2, y2).
201
+ label (str): The text label to be displayed.
202
+ color (tuple, optional): The background color of the rectangle (R, G, B).
203
+ txt_color (tuple, optional): The color of the text (R, G, B).
204
+ margin (int, optional): The margin between the text and the rectangle border.
205
+ """
206
+
207
+ # If label have more than 3 characters, skip other characters, due to circle size
208
+ if len(label) > 3:
209
+ print(
210
+ f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!"
211
+ )
212
+ label = label[:3]
213
+
214
+ # Calculate the center of the box
215
+ x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
216
+ # Get the text size
217
+ text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
218
+ # Calculate the required radius to fit the text with the margin
219
+ required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
220
+ # Draw the circle with the required radius
221
+ cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
222
+ # Calculate the position for the text
223
+ text_x = x_center - text_size[0] // 2
224
+ text_y = y_center + text_size[1] // 2
225
+ # Draw the text
226
+ cv2.putText(
227
+ self.im,
228
+ str(label),
229
+ (text_x, text_y),
230
+ cv2.FONT_HERSHEY_SIMPLEX,
231
+ self.sf - 0.15,
232
+ self.get_txt_color(color, txt_color),
233
+ self.tf,
234
+ lineType=cv2.LINE_AA,
235
+ )
236
+
237
+ def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):
238
+ """
239
+ Draws a label with a background rectangle centered within a given bounding box.
240
+
241
+ Args:
242
+ box (tuple): The bounding box coordinates (x1, y1, x2, y2).
243
+ label (str): The text label to be displayed.
244
+ color (tuple, optional): The background color of the rectangle (R, G, B).
245
+ txt_color (tuple, optional): The color of the text (R, G, B).
246
+ margin (int, optional): The margin between the text and the rectangle border.
247
+ """
248
+
249
+ # Calculate the center of the bounding box
250
+ x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
251
+ # Get the size of the text
252
+ text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
253
+ # Calculate the top-left corner of the text (to center it)
254
+ text_x = x_center - text_size[0] // 2
255
+ text_y = y_center + text_size[1] // 2
256
+ # Calculate the coordinates of the background rectangle
257
+ rect_x1 = text_x - margin
258
+ rect_y1 = text_y - text_size[1] - margin
259
+ rect_x2 = text_x + text_size[0] + margin
260
+ rect_y2 = text_y + margin
261
+ # Draw the background rectangle
262
+ cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
263
+ # Draw the text on top of the rectangle
264
+ cv2.putText(
265
+ self.im,
266
+ label,
267
+ (text_x, text_y),
268
+ cv2.FONT_HERSHEY_SIMPLEX,
269
+ self.sf - 0.1,
270
+ self.get_txt_color(color, txt_color),
271
+ self.tf,
272
+ lineType=cv2.LINE_AA,
190
273
  )
274
+
275
+ def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
276
+ """
277
+ Draws a bounding box to image with label.
278
+
279
+ Args:
280
+ box (tuple): The bounding box coordinates (x1, y1, x2, y2).
281
+ label (str): The text label to be displayed.
282
+ color (tuple, optional): The background color of the rectangle (R, G, B).
283
+ txt_color (tuple, optional): The color of the text (R, G, B).
284
+ rotated (bool, optional): Variable used to check if task is OBB
285
+ """
286
+
287
+ txt_color = self.get_txt_color(color, txt_color)
191
288
  if isinstance(box, torch.Tensor):
192
289
  box = box.tolist()
193
290
  if self.pil or not is_ascii(label):
@@ -242,6 +339,7 @@ class Annotator:
242
339
  alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
243
340
  retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
244
341
  """
342
+
245
343
  if self.pil:
246
344
  # Convert to numpy first
247
345
  self.im = np.asarray(self.im).copy()
@@ -281,6 +379,7 @@ class Annotator:
281
379
  Note:
282
380
  `kpt_line=True` currently only supports human pose plotting.
283
381
  """
382
+
284
383
  if self.pil:
285
384
  # Convert to numpy first
286
385
  self.im = np.asarray(self.im).copy()
@@ -376,6 +475,7 @@ class Annotator:
376
475
  Returns:
377
476
  angle (degree): Degree value of angle between three points
378
477
  """
478
+
379
479
  x_min, y_min, x_max, y_max = bbox
380
480
  width = x_max - x_min
381
481
  height = y_max - y_min
@@ -390,6 +490,7 @@ class Annotator:
390
490
  color (tuple): Region Color value
391
491
  thickness (int): Region area thickness value
392
492
  """
493
+
393
494
  cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
394
495
 
395
496
  def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
@@ -401,6 +502,7 @@ class Annotator:
401
502
  color (tuple): tracks line color
402
503
  track_thickness (int): track line thickness value
403
504
  """
505
+
404
506
  points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
405
507
  cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
406
508
  cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
@@ -513,6 +615,7 @@ class Annotator:
513
615
  Returns:
514
616
  angle (degree): Degree value of angle between three points
515
617
  """
618
+
516
619
  a, b, c = np.array(a), np.array(b), np.array(c)
517
620
  radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
518
621
  angle = np.abs(radians * 180.0 / np.pi)
@@ -530,6 +633,7 @@ class Annotator:
530
633
  shape (tuple): imgsz for model inference
531
634
  radius (int): Keypoint radius value
532
635
  """
636
+
533
637
  if indices is None:
534
638
  indices = [2, 5, 7]
535
639
  for i, k in enumerate(keypoints):
@@ -626,6 +730,7 @@ class Annotator:
626
730
  det_label (str): Detection label text
627
731
  track_label (str): Tracking label text
628
732
  """
733
+
629
734
  cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
630
735
 
631
736
  label = f"Track ID: {track_label}" if track_label else det_label
@@ -695,6 +800,7 @@ class Annotator:
695
800
  color (tuple): object centroid and line color value
696
801
  pin_color (tuple): visioneye point color value
697
802
  """
803
+
698
804
  center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
699
805
  cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
700
806
  cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)