ultralytics 8.1.42__py3-none-any.whl → 8.1.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

@@ -30,6 +30,7 @@ class Heatmap:
30
30
  self.imw = None
31
31
  self.imh = None
32
32
  self.im0 = None
33
+ self.tf = 2
33
34
  self.view_in_counts = True
34
35
  self.view_out_counts = True
35
36
 
@@ -56,11 +57,9 @@ class Heatmap:
56
57
  self.out_counts = 0
57
58
  self.count_ids = []
58
59
  self.class_wise_count = {}
59
- self.count_txt_thickness = 0
60
- self.count_txt_color = (255, 255, 255)
61
- self.line_color = (255, 255, 255)
60
+ self.count_txt_color = (0, 0, 0)
61
+ self.count_bg_color = (255, 255, 255)
62
62
  self.cls_txtdisplay_gap = 50
63
- self.fontsize = 0.6
64
63
 
65
64
  # Decay factor
66
65
  self.decay_factor = 0.99
@@ -79,16 +78,14 @@ class Heatmap:
79
78
  view_in_counts=True,
80
79
  view_out_counts=True,
81
80
  count_reg_pts=None,
82
- count_txt_thickness=2,
83
- count_txt_color=(255, 255, 255),
84
- fontsize=0.8,
85
- line_color=(255, 255, 255),
81
+ count_txt_color=(0, 0, 0),
82
+ count_bg_color=(255, 255, 255),
86
83
  count_reg_color=(255, 0, 255),
87
84
  region_thickness=5,
88
85
  line_dist_thresh=15,
86
+ line_thickness=2,
89
87
  decay_factor=0.99,
90
88
  shape="circle",
91
- cls_txtdisplay_gap=50,
92
89
  ):
93
90
  """
94
91
  Configures the heatmap colormap, width, height and display parameters.
@@ -98,22 +95,21 @@ class Heatmap:
98
95
  imw (int): The width of the frame.
99
96
  imh (int): The height of the frame.
100
97
  classes_names (dict): Classes names
98
+ line_thickness (int): Line thickness for bounding boxes.
101
99
  heatmap_alpha (float): alpha value for heatmap display
102
100
  view_img (bool): Flag indicating frame display
103
101
  view_in_counts (bool): Flag to control whether to display the incounts on video stream.
104
102
  view_out_counts (bool): Flag to control whether to display the outcounts on video stream.
105
103
  count_reg_pts (list): Object counting region points
106
- count_txt_thickness (int): Text thickness for object counting display
107
104
  count_txt_color (RGB color): count text color value
108
- fontsize (float): Text display font size
109
- line_color (RGB color): count highlighter line color
105
+ count_bg_color (RGB color): count highlighter line color
110
106
  count_reg_color (RGB color): Color of object counting region
111
107
  region_thickness (int): Object counting Region thickness
112
108
  line_dist_thresh (int): Euclidean Distance threshold for line counter
113
109
  decay_factor (float): value for removing heatmap area after object passed
114
110
  shape (str): Heatmap shape, rect or circle shape supported
115
- cls_txtdisplay_gap (int): Display gap between each class count
116
111
  """
112
+ self.tf = line_thickness
117
113
  self.names = classes_names
118
114
  self.imw = imw
119
115
  self.imh = imh
@@ -141,16 +137,13 @@ class Heatmap:
141
137
  # Heatmap new frame
142
138
  self.heatmap = np.zeros((int(self.imh), int(self.imw)), dtype=np.float32)
143
139
 
144
- self.count_txt_thickness = count_txt_thickness
145
140
  self.count_txt_color = count_txt_color
146
- self.fontsize = fontsize
147
- self.line_color = line_color
141
+ self.count_bg_color = count_bg_color
148
142
  self.region_color = count_reg_color
149
143
  self.region_thickness = region_thickness
150
144
  self.decay_factor = decay_factor
151
145
  self.line_dist_thresh = line_dist_thresh
152
146
  self.shape = shape
153
- self.cls_txtdisplay_gap = cls_txtdisplay_gap
154
147
 
155
148
  # shape of heatmap, if not selected
156
149
  if self.shape not in {"circle", "rect"}:
@@ -185,7 +178,7 @@ class Heatmap:
185
178
  return im0
186
179
  self.heatmap *= self.decay_factor # decay factor
187
180
  self.extract_results(tracks)
188
- self.annotator = Annotator(self.im0, self.count_txt_thickness, None)
181
+ self.annotator = Annotator(self.im0, self.tf, None)
189
182
 
190
183
  if self.count_reg_pts is not None:
191
184
  # Draw counting region
@@ -239,11 +232,8 @@ class Heatmap:
239
232
 
240
233
  # Count objects using line
241
234
  elif len(self.count_reg_pts) == 2:
242
- is_inside = (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0
243
-
244
- if prev_position is not None and is_inside and track_id not in self.count_ids:
235
+ if prev_position is not None and track_id not in self.count_ids:
245
236
  distance = Point(track_line[-1]).distance(self.counting_region)
246
-
247
237
  if distance < self.line_dist_thresh and track_id not in self.count_ids:
248
238
  self.count_ids.append(track_id)
249
239
 
@@ -293,11 +283,8 @@ class Heatmap:
293
283
  if self.count_reg_pts is not None and label is not None:
294
284
  self.annotator.display_counts(
295
285
  counts=label,
296
- tf=self.count_txt_thickness,
297
- fontScale=self.fontsize,
298
- txt_color=self.count_txt_color,
299
- line_color=self.line_color,
300
- classwise_txtgap=self.cls_txtdisplay_gap,
286
+ count_txt_color=self.count_txt_color,
287
+ count_bg_color=self.count_bg_color,
301
288
  )
302
289
 
303
290
  self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
@@ -1,7 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  from collections import defaultdict
4
-
5
4
  import cv2
6
5
 
7
6
  from ultralytics.utils.checks import check_imshow, check_requirements
@@ -47,7 +46,7 @@ class ObjectCounter:
47
46
  self.class_wise_count = {}
48
47
  self.count_txt_thickness = 0
49
48
  self.count_txt_color = (255, 255, 255)
50
- self.line_color = (255, 255, 255)
49
+ self.count_bg_color = (255, 255, 255)
51
50
  self.cls_txtdisplay_gap = 50
52
51
  self.fontsize = 0.6
53
52
 
@@ -65,16 +64,14 @@ class ObjectCounter:
65
64
  classes_names,
66
65
  reg_pts,
67
66
  count_reg_color=(255, 0, 255),
67
+ count_txt_color=(0, 0, 0),
68
+ count_bg_color=(255, 255, 255),
68
69
  line_thickness=2,
69
70
  track_thickness=2,
70
71
  view_img=False,
71
72
  view_in_counts=True,
72
73
  view_out_counts=True,
73
74
  draw_tracks=False,
74
- count_txt_thickness=3,
75
- count_txt_color=(255, 255, 255),
76
- fontsize=0.8,
77
- line_color=(255, 255, 255),
78
75
  track_color=None,
79
76
  region_thickness=5,
80
77
  line_dist_thresh=15,
@@ -92,10 +89,8 @@ class ObjectCounter:
92
89
  classes_names (dict): Classes names
93
90
  track_thickness (int): Track thickness
94
91
  draw_tracks (Bool): draw tracks
95
- count_txt_thickness (int): Text thickness for object counting display
96
92
  count_txt_color (RGB color): count text color value
97
- fontsize (float): Text display font size
98
- line_color (RGB color): count highlighter line color
93
+ count_bg_color (RGB color): count highlighter line color
99
94
  count_reg_color (RGB color): Color of object counting region
100
95
  track_color (RGB color): color for tracks
101
96
  region_thickness (int): Object counting Region thickness
@@ -125,10 +120,8 @@ class ObjectCounter:
125
120
 
126
121
  self.names = classes_names
127
122
  self.track_color = track_color
128
- self.count_txt_thickness = count_txt_thickness
129
123
  self.count_txt_color = count_txt_color
130
- self.fontsize = fontsize
131
- self.line_color = line_color
124
+ self.count_bg_color = count_bg_color
132
125
  self.region_color = count_reg_color
133
126
  self.region_thickness = region_thickness
134
127
  self.line_dist_thresh = line_dist_thresh
@@ -172,6 +165,9 @@ class ObjectCounter:
172
165
  # Annotator Init and region drawing
173
166
  self.annotator = Annotator(self.im0, self.tf, self.names)
174
167
 
168
+ # Draw region or line
169
+ self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
170
+
175
171
  if tracks[0].boxes.id is not None:
176
172
  boxes = tracks[0].boxes.xyxy.cpu()
177
173
  clss = tracks[0].boxes.cls.cpu().tolist()
@@ -220,17 +216,14 @@ class ObjectCounter:
220
216
 
221
217
  # Count objects using line
222
218
  elif len(self.reg_pts) == 2:
223
- is_inside = (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0
224
-
225
- if prev_position is not None and is_inside and track_id not in self.count_ids:
219
+ if prev_position is not None and track_id not in self.count_ids:
226
220
  distance = Point(track_line[-1]).distance(self.counting_region)
227
-
228
221
  if distance < self.line_dist_thresh and track_id not in self.count_ids:
229
222
  self.count_ids.append(track_id)
230
223
 
231
224
  if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
232
225
  self.in_counts += 1
233
- self.class_wise_count[self.names[cls]]["in"] += 1
226
+ self.class_wise_count[self.names[cls]]["in"] += 2
234
227
  else:
235
228
  self.out_counts += 1
236
229
  self.class_wise_count[self.names[cls]]["out"] += 1
@@ -254,17 +247,13 @@ class ObjectCounter:
254
247
  if label is not None:
255
248
  self.annotator.display_counts(
256
249
  counts=label,
257
- tf=self.count_txt_thickness,
258
- fontScale=self.fontsize,
259
- txt_color=self.count_txt_color,
260
- line_color=self.line_color,
261
- classwise_txtgap=self.cls_txtdisplay_gap,
250
+ count_txt_color=self.count_txt_color,
251
+ count_bg_color=self.count_bg_color,
262
252
  )
263
253
 
264
254
  def display_frames(self):
265
255
  """Display frame."""
266
256
  if self.env_check:
267
- self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
268
257
  cv2.namedWindow(self.window_name)
269
258
  if len(self.reg_pts) == 4: # only add mouse event If user drawn region
270
259
  cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
@@ -1,6 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  import contextlib
4
+ import importlib.metadata
4
5
  import inspect
5
6
  import logging.config
6
7
  import os
@@ -42,6 +43,8 @@ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar form
42
43
  LOGGING_NAME = "ultralytics"
43
44
  MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
44
45
  ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
46
+ PYTHON_VERSION = platform.python_version()
47
+ TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
45
48
  HELP_MSG = """
46
49
  Usage examples for running YOLOv8:
47
50
 
@@ -476,7 +479,7 @@ def is_online() -> bool:
476
479
 
477
480
  for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
478
481
  try:
479
- test_connection = socket.create_connection(address=(host, 53), timeout=2)
482
+ test_connection = socket.create_connection(address=(host, 80), timeout=2)
480
483
  except (socket.timeout, socket.gaierror, OSError):
481
484
  continue
482
485
  else:
@@ -69,8 +69,7 @@ def benchmark(
69
69
  benchmark(model='yolov8n.pt', imgsz=640)
70
70
  ```
71
71
  """
72
-
73
- import pandas as pd
72
+ import pandas as pd # scope for faster 'import ultralytics'
74
73
 
75
74
  pd.options.display.max_columns = 10
76
75
  pd.options.display.width = 120
@@ -7,8 +7,6 @@ try:
7
7
  assert SETTINGS["clearml"] is True # verify integration is enabled
8
8
  import clearml
9
9
  from clearml import Task
10
- from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
11
- from clearml.binding.matplotlib_bind import PatchedMatplotlib
12
10
 
13
11
  assert hasattr(clearml, "__version__") # verify package is not directory
14
12
 
@@ -61,8 +59,11 @@ def on_pretrain_routine_start(trainer):
61
59
  """Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
62
60
  try:
63
61
  if task := Task.current_task():
64
- # Make sure the automatic pytorch and matplotlib bindings are disabled!
62
+ # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
65
63
  # We are logging these plots and model files manually in the integration
64
+ from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
65
+ from clearml.binding.matplotlib_bind import PatchedMatplotlib
66
+
66
67
  PatchPyTorchModelIO.update_current_task(None)
67
68
  PatchedMatplotlib.update_current_task(None)
68
69
  else:
@@ -9,10 +9,6 @@ try:
9
9
  import wandb as wb
10
10
 
11
11
  assert hasattr(wb, "__version__") # verify package is not directory
12
-
13
- import numpy as np
14
- import pandas as pd
15
-
16
12
  _processed_plots = {}
17
13
 
18
14
  except (ImportError, AssertionError):
@@ -38,7 +34,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
38
34
  Returns:
39
35
  (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
40
36
  """
41
- df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
37
+ import pandas # scope for faster 'import ultralytics'
38
+
39
+ df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
42
40
  fields = {"x": "x", "y": "y", "class": "class"}
43
41
  string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
44
42
  return wb.plot_table(
@@ -77,6 +75,8 @@ def _plot_curve(
77
75
  Note:
78
76
  The function leverages the '_custom_table' function to generate the actual visualization.
79
77
  """
78
+ import numpy as np
79
+
80
80
  # Create new x
81
81
  if names is None:
82
82
  names = []
@@ -18,15 +18,16 @@ import cv2
18
18
  import numpy as np
19
19
  import requests
20
20
  import torch
21
- from matplotlib import font_manager
22
21
 
23
22
  from ultralytics.utils import (
24
23
  ASSETS,
25
24
  AUTOINSTALL,
26
25
  LINUX,
27
26
  LOGGER,
27
+ PYTHON_VERSION,
28
28
  ONLINE,
29
29
  ROOT,
30
+ TORCHVISION_VERSION,
30
31
  USER_CONFIG_DIR,
31
32
  Retry,
32
33
  SimpleNamespace,
@@ -41,13 +42,10 @@ from ultralytics.utils import (
41
42
  is_github_action_running,
42
43
  is_jupyter,
43
44
  is_kaggle,
44
- is_online,
45
45
  is_pip_package,
46
46
  url2file,
47
47
  )
48
48
 
49
- PYTHON_VERSION = platform.python_version()
50
-
51
49
 
52
50
  def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
53
51
  """
@@ -304,9 +302,10 @@ def check_font(font="Arial.ttf"):
304
302
  Returns:
305
303
  file (Path): Resolved font file path.
306
304
  """
307
- name = Path(font).name
305
+ from matplotlib import font_manager
308
306
 
309
307
  # Check USER_CONFIG_DIR
308
+ name = Path(font).name
310
309
  file = USER_CONFIG_DIR / name
311
310
  if file.exists():
312
311
  return file
@@ -390,7 +389,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
390
389
  LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
391
390
  try:
392
391
  t = time.time()
393
- assert is_online(), "AutoUpdate skipped (offline)"
392
+ assert ONLINE, "AutoUpdate skipped (offline)"
394
393
  with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay
395
394
  LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
396
395
  dt = time.time() - t
@@ -419,14 +418,12 @@ def check_torchvision():
419
418
  Torchvision versions.
420
419
  """
421
420
 
422
- import torchvision
423
-
424
421
  # Compatibility table
425
422
  compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
426
423
 
427
424
  # Extract only the major and minor versions
428
425
  v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
429
- v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
426
+ v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
430
427
 
431
428
  if v_torch in compatibility_table:
432
429
  compatible_versions = compatibility_table[v_torch]
@@ -395,19 +395,19 @@ class ConfusionMatrix:
395
395
  names (tuple): Names of classes, used as labels on the plot.
396
396
  on_plot (func): An optional callback to pass plots path and data when they are rendered.
397
397
  """
398
- import seaborn as sn
398
+ import seaborn # scope for faster 'import ultralytics'
399
399
 
400
400
  array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
401
401
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
402
402
 
403
403
  fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
404
404
  nc, nn = self.nc, len(names) # number of classes, names
405
- sn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
405
+ seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
406
406
  labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
407
407
  ticklabels = (list(names) + ["background"]) if labels else "auto"
408
408
  with warnings.catch_warnings():
409
409
  warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
410
- sn.heatmap(
410
+ seaborn.heatmap(
411
411
  array,
412
412
  ax=ax,
413
413
  annot=nc < 30,
ultralytics/utils/ops.py CHANGED
@@ -9,7 +9,6 @@ import cv2
9
9
  import numpy as np
10
10
  import torch
11
11
  import torch.nn.functional as F
12
- import torchvision
13
12
 
14
13
  from ultralytics.utils import LOGGER
15
14
  from ultralytics.utils.metrics import batch_probiou
@@ -206,6 +205,7 @@ def non_max_suppression(
206
205
  shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
207
206
  (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
208
207
  """
208
+ import torchvision # scope for faster 'import ultralytics'
209
209
 
210
210
  # Checks
211
211
  assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
@@ -339,6 +339,21 @@ class Annotator:
339
339
  """Save the annotated image to 'filename'."""
340
340
  cv2.imwrite(filename, np.asarray(self.im))
341
341
 
342
+ def get_bbox_dimension(self, bbox=None):
343
+ """
344
+ Calculate the area of a bounding box.
345
+
346
+ Args:
347
+ bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
348
+
349
+ Returns:
350
+ angle (degree): Degree value of angle between three points
351
+ """
352
+ x_min, y_min, x_max, y_max = bbox
353
+ width = x_max - x_min
354
+ height = y_max - y_min
355
+ return width, height, width * height
356
+
342
357
  def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
343
358
  """
344
359
  Draw region line.
@@ -364,13 +379,22 @@ class Annotator:
364
379
  cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
365
380
 
366
381
  def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0), fontsize=0.7):
367
- """Displays queue counts on an image centered at the points with customizable font size and colors."""
382
+ """
383
+ Displays queue counts on an image centered at the points with customizable font size and colors.
384
+
385
+ Args:
386
+ label (str): queue counts label
387
+ points (tuple): region points for center point calculation to display text
388
+ region_color (RGB): queue region color
389
+ txt_color (RGB): text display color
390
+ fontsize (float): text fontsize
391
+ """
368
392
  x_values = [point[0] for point in points]
369
393
  y_values = [point[1] for point in points]
370
394
  center_x = sum(x_values) // len(points)
371
395
  center_y = sum(y_values) // len(points)
372
396
 
373
- text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=fontsize, thickness=self.tf)[0]
397
+ text_size = cv2.getTextSize(label, 0, fontScale=fontsize, thickness=self.tf)[0]
374
398
  text_width = text_size[0]
375
399
  text_height = text_size[1]
376
400
 
@@ -388,56 +412,63 @@ class Annotator:
388
412
  self.im,
389
413
  label,
390
414
  (text_x, text_y),
391
- cv2.FONT_HERSHEY_SIMPLEX,
415
+ 0,
392
416
  fontScale=fontsize,
393
417
  color=txt_color,
394
418
  thickness=self.tf,
395
419
  lineType=cv2.LINE_AA,
396
420
  )
397
421
 
398
- def display_counts(
399
- self, counts=None, tf=2, fontScale=0.6, line_color=(0, 0, 0), txt_color=(255, 255, 255), classwise_txtgap=55
400
- ):
422
+ def display_counts(self, counts=None, count_bg_color=(0, 0, 0), count_txt_color=(255, 255, 255)):
401
423
  """
402
- Display counts on im0.
424
+ Display counts on im0 with text background and border.
403
425
 
404
426
  Args:
405
427
  counts (str): objects count data
406
- tf (int): text thickness for display
407
- fontScale (float): text fontsize for display
408
- line_color (RGB Color): counts highlighter color
409
- txt_color (RGB Color): counts display color
410
- classwise_txtgap (int): Gap between each class count data
428
+ count_bg_color (RGB Color): counts highlighter color
429
+ count_txt_color (RGB Color): counts display color
411
430
  """
412
431
 
413
- tl = tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
432
+ tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
414
433
  tf = max(tl - 1, 1)
415
434
 
416
- t_sizes = [cv2.getTextSize(str(count), 0, fontScale=0.8, thickness=tf)[0] for count in counts]
435
+ t_sizes = [cv2.getTextSize(str(count), 0, fontScale=self.sf, thickness=self.tf)[0] for count in counts]
417
436
 
418
437
  max_text_width = max([size[0] for size in t_sizes])
419
438
  max_text_height = max([size[1] for size in t_sizes])
420
439
 
421
- text_x = self.im.shape[1] - max_text_width - 20
422
- text_y = classwise_txtgap
440
+ text_x = self.im.shape[1] - int(self.im.shape[1] * 0.025 + max_text_width)
441
+ text_y = int(self.im.shape[0] * 0.025)
442
+
443
+ # Calculate dynamic gap between each count value based on the width of the image
444
+ dynamic_gap = max(1, self.im.shape[1] // 100) * tf
423
445
 
424
446
  for i, count in enumerate(counts):
425
447
  text_x_pos = text_x
426
- text_y_pos = text_y + i * classwise_txtgap
448
+ text_y_pos = text_y + i * dynamic_gap # Adjust vertical position with dynamic gap
427
449
 
450
+ # Draw the border
451
+ cv2.rectangle(
452
+ self.im,
453
+ (text_x_pos - (10 * tf), text_y_pos - (10 * tf)),
454
+ (text_x_pos + max_text_width + (10 * tf), text_y_pos + max_text_height + (10 * tf)),
455
+ count_bg_color,
456
+ -1,
457
+ )
458
+
459
+ # Draw the count text
428
460
  cv2.putText(
429
461
  self.im,
430
462
  str(count),
431
- (text_x_pos, text_y_pos),
432
- cv2.FONT_HERSHEY_SIMPLEX,
433
- fontScale=fontScale,
434
- color=txt_color,
435
- thickness=tf,
463
+ (text_x_pos, text_y_pos + max_text_height),
464
+ 0,
465
+ fontScale=self.sf,
466
+ color=count_txt_color,
467
+ thickness=self.tf,
436
468
  lineType=cv2.LINE_AA,
437
469
  )
438
470
 
439
- line_y_pos = text_y_pos + max_text_height + 5
440
- cv2.line(self.im, (text_x_pos, line_y_pos), (text_x_pos + max_text_width, line_y_pos), line_color, tf)
471
+ text_y_pos += tf * max_text_height
441
472
 
442
473
  @staticmethod
443
474
  def estimate_pose_angle(a, b, c):
@@ -588,30 +619,26 @@ class Annotator:
588
619
  line_color (RGB): Distance line color.
589
620
  centroid_color (RGB): Bounding box centroid color.
590
621
  """
591
- (text_width_m, text_height_m), _ = cv2.getTextSize(
592
- f"Distance M: {distance_m:.2f}m", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
593
- )
622
+ (text_width_m, text_height_m), _ = cv2.getTextSize(f"Distance M: {distance_m:.2f}m", 0, 0.8, 2)
594
623
  cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), (255, 255, 255), -1)
595
624
  cv2.putText(
596
625
  self.im,
597
626
  f"Distance M: {distance_m:.2f}m",
598
627
  (20, 50),
599
- cv2.FONT_HERSHEY_SIMPLEX,
628
+ 0,
600
629
  0.8,
601
630
  (0, 0, 0),
602
631
  2,
603
632
  cv2.LINE_AA,
604
633
  )
605
634
 
606
- (text_width_mm, text_height_mm), _ = cv2.getTextSize(
607
- f"Distance MM: {distance_mm:.2f}mm", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
608
- )
635
+ (text_width_mm, text_height_mm), _ = cv2.getTextSize(f"Distance MM: {distance_mm:.2f}mm", 0, 0.8, 2)
609
636
  cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), (255, 255, 255), -1)
610
637
  cv2.putText(
611
638
  self.im,
612
639
  f"Distance MM: {distance_mm:.2f}mm",
613
640
  (20, 100),
614
- cv2.FONT_HERSHEY_SIMPLEX,
641
+ 0,
615
642
  0.8,
616
643
  (0, 0, 0),
617
644
  2,
@@ -644,8 +671,8 @@ class Annotator:
644
671
  @plt_settings()
645
672
  def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
646
673
  """Plot training labels including class histograms and box statistics."""
647
- import pandas as pd
648
- import seaborn as sn
674
+ import pandas # scope for faster 'import ultralytics'
675
+ import seaborn # scope for faster 'import ultralytics'
649
676
 
650
677
  # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
651
678
  warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
@@ -655,10 +682,10 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
655
682
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
656
683
  nc = int(cls.max() + 1) # number of classes
657
684
  boxes = boxes[:1000000] # limit to 1M boxes
658
- x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
685
+ x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
659
686
 
660
687
  # Seaborn correlogram
661
- sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
688
+ seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
662
689
  plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
663
690
  plt.close()
664
691
 
@@ -673,8 +700,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
673
700
  ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
674
701
  else:
675
702
  ax[0].set_xlabel("classes")
676
- sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
677
- sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
703
+ seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
704
+ seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
678
705
 
679
706
  # Rectangles
680
707
  boxes[:, 0:2] = 0.5 # center
@@ -906,7 +933,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
906
933
  plot_results('path/to/results.csv', segment=True)
907
934
  ```
908
935
  """
909
- import pandas as pd
936
+ import pandas as pd # scope for faster 'import ultralytics'
910
937
  from scipy.ndimage import gaussian_filter1d
911
938
 
912
939
  save_dir = Path(file).parent if file else Path(dir)
@@ -992,7 +1019,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
992
1019
  >>> plot_tune_results('path/to/tune_results.csv')
993
1020
  """
994
1021
 
995
- import pandas as pd
1022
+ import pandas as pd # scope for faster 'import ultralytics'
996
1023
  from scipy.ndimage import gaussian_filter1d
997
1024
 
998
1025
  # Scatter plots for each hyperparameter