ultralytics 8.1.42__py3-none-any.whl → 8.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/data/augment.py +12 -9
- ultralytics/data/dataset.py +147 -142
- ultralytics/data/explorer/explorer.py +4 -6
- ultralytics/data/explorer/gui/dash.py +3 -3
- ultralytics/data/explorer/utils.py +3 -2
- ultralytics/engine/exporter.py +3 -2
- ultralytics/engine/trainer.py +1 -1
- ultralytics/models/fastsam/prompt.py +4 -1
- ultralytics/models/sam/predict.py +4 -1
- ultralytics/models/yolo/classify/train.py +2 -1
- ultralytics/solutions/heatmap.py +14 -27
- ultralytics/solutions/object_counter.py +12 -23
- ultralytics/utils/__init__.py +4 -1
- ultralytics/utils/benchmarks.py +1 -2
- ultralytics/utils/callbacks/clearml.py +4 -3
- ultralytics/utils/callbacks/wb.py +5 -5
- ultralytics/utils/checks.py +6 -9
- ultralytics/utils/metrics.py +3 -3
- ultralytics/utils/ops.py +1 -1
- ultralytics/utils/plotting.py +67 -40
- ultralytics/utils/torch_utils.py +13 -6
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/METADATA +1 -1
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/RECORD +28 -28
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/WHEEL +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.43.dist-info}/top_level.txt +0 -0
ultralytics/solutions/heatmap.py
CHANGED
|
@@ -30,6 +30,7 @@ class Heatmap:
|
|
|
30
30
|
self.imw = None
|
|
31
31
|
self.imh = None
|
|
32
32
|
self.im0 = None
|
|
33
|
+
self.tf = 2
|
|
33
34
|
self.view_in_counts = True
|
|
34
35
|
self.view_out_counts = True
|
|
35
36
|
|
|
@@ -56,11 +57,9 @@ class Heatmap:
|
|
|
56
57
|
self.out_counts = 0
|
|
57
58
|
self.count_ids = []
|
|
58
59
|
self.class_wise_count = {}
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
61
|
-
self.line_color = (255, 255, 255)
|
|
60
|
+
self.count_txt_color = (0, 0, 0)
|
|
61
|
+
self.count_bg_color = (255, 255, 255)
|
|
62
62
|
self.cls_txtdisplay_gap = 50
|
|
63
|
-
self.fontsize = 0.6
|
|
64
63
|
|
|
65
64
|
# Decay factor
|
|
66
65
|
self.decay_factor = 0.99
|
|
@@ -79,16 +78,14 @@ class Heatmap:
|
|
|
79
78
|
view_in_counts=True,
|
|
80
79
|
view_out_counts=True,
|
|
81
80
|
count_reg_pts=None,
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
fontsize=0.8,
|
|
85
|
-
line_color=(255, 255, 255),
|
|
81
|
+
count_txt_color=(0, 0, 0),
|
|
82
|
+
count_bg_color=(255, 255, 255),
|
|
86
83
|
count_reg_color=(255, 0, 255),
|
|
87
84
|
region_thickness=5,
|
|
88
85
|
line_dist_thresh=15,
|
|
86
|
+
line_thickness=2,
|
|
89
87
|
decay_factor=0.99,
|
|
90
88
|
shape="circle",
|
|
91
|
-
cls_txtdisplay_gap=50,
|
|
92
89
|
):
|
|
93
90
|
"""
|
|
94
91
|
Configures the heatmap colormap, width, height and display parameters.
|
|
@@ -98,22 +95,21 @@ class Heatmap:
|
|
|
98
95
|
imw (int): The width of the frame.
|
|
99
96
|
imh (int): The height of the frame.
|
|
100
97
|
classes_names (dict): Classes names
|
|
98
|
+
line_thickness (int): Line thickness for bounding boxes.
|
|
101
99
|
heatmap_alpha (float): alpha value for heatmap display
|
|
102
100
|
view_img (bool): Flag indicating frame display
|
|
103
101
|
view_in_counts (bool): Flag to control whether to display the incounts on video stream.
|
|
104
102
|
view_out_counts (bool): Flag to control whether to display the outcounts on video stream.
|
|
105
103
|
count_reg_pts (list): Object counting region points
|
|
106
|
-
count_txt_thickness (int): Text thickness for object counting display
|
|
107
104
|
count_txt_color (RGB color): count text color value
|
|
108
|
-
|
|
109
|
-
line_color (RGB color): count highlighter line color
|
|
105
|
+
count_bg_color (RGB color): count highlighter line color
|
|
110
106
|
count_reg_color (RGB color): Color of object counting region
|
|
111
107
|
region_thickness (int): Object counting Region thickness
|
|
112
108
|
line_dist_thresh (int): Euclidean Distance threshold for line counter
|
|
113
109
|
decay_factor (float): value for removing heatmap area after object passed
|
|
114
110
|
shape (str): Heatmap shape, rect or circle shape supported
|
|
115
|
-
cls_txtdisplay_gap (int): Display gap between each class count
|
|
116
111
|
"""
|
|
112
|
+
self.tf = line_thickness
|
|
117
113
|
self.names = classes_names
|
|
118
114
|
self.imw = imw
|
|
119
115
|
self.imh = imh
|
|
@@ -141,16 +137,13 @@ class Heatmap:
|
|
|
141
137
|
# Heatmap new frame
|
|
142
138
|
self.heatmap = np.zeros((int(self.imh), int(self.imw)), dtype=np.float32)
|
|
143
139
|
|
|
144
|
-
self.count_txt_thickness = count_txt_thickness
|
|
145
140
|
self.count_txt_color = count_txt_color
|
|
146
|
-
self.
|
|
147
|
-
self.line_color = line_color
|
|
141
|
+
self.count_bg_color = count_bg_color
|
|
148
142
|
self.region_color = count_reg_color
|
|
149
143
|
self.region_thickness = region_thickness
|
|
150
144
|
self.decay_factor = decay_factor
|
|
151
145
|
self.line_dist_thresh = line_dist_thresh
|
|
152
146
|
self.shape = shape
|
|
153
|
-
self.cls_txtdisplay_gap = cls_txtdisplay_gap
|
|
154
147
|
|
|
155
148
|
# shape of heatmap, if not selected
|
|
156
149
|
if self.shape not in {"circle", "rect"}:
|
|
@@ -185,7 +178,7 @@ class Heatmap:
|
|
|
185
178
|
return im0
|
|
186
179
|
self.heatmap *= self.decay_factor # decay factor
|
|
187
180
|
self.extract_results(tracks)
|
|
188
|
-
self.annotator = Annotator(self.im0, self.
|
|
181
|
+
self.annotator = Annotator(self.im0, self.tf, None)
|
|
189
182
|
|
|
190
183
|
if self.count_reg_pts is not None:
|
|
191
184
|
# Draw counting region
|
|
@@ -239,11 +232,8 @@ class Heatmap:
|
|
|
239
232
|
|
|
240
233
|
# Count objects using line
|
|
241
234
|
elif len(self.count_reg_pts) == 2:
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
if prev_position is not None and is_inside and track_id not in self.count_ids:
|
|
235
|
+
if prev_position is not None and track_id not in self.count_ids:
|
|
245
236
|
distance = Point(track_line[-1]).distance(self.counting_region)
|
|
246
|
-
|
|
247
237
|
if distance < self.line_dist_thresh and track_id not in self.count_ids:
|
|
248
238
|
self.count_ids.append(track_id)
|
|
249
239
|
|
|
@@ -293,11 +283,8 @@ class Heatmap:
|
|
|
293
283
|
if self.count_reg_pts is not None and label is not None:
|
|
294
284
|
self.annotator.display_counts(
|
|
295
285
|
counts=label,
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
txt_color=self.count_txt_color,
|
|
299
|
-
line_color=self.line_color,
|
|
300
|
-
classwise_txtgap=self.cls_txtdisplay_gap,
|
|
286
|
+
count_txt_color=self.count_txt_color,
|
|
287
|
+
count_bg_color=self.count_bg_color,
|
|
301
288
|
)
|
|
302
289
|
|
|
303
290
|
self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
|
|
5
4
|
import cv2
|
|
6
5
|
|
|
7
6
|
from ultralytics.utils.checks import check_imshow, check_requirements
|
|
@@ -47,7 +46,7 @@ class ObjectCounter:
|
|
|
47
46
|
self.class_wise_count = {}
|
|
48
47
|
self.count_txt_thickness = 0
|
|
49
48
|
self.count_txt_color = (255, 255, 255)
|
|
50
|
-
self.
|
|
49
|
+
self.count_bg_color = (255, 255, 255)
|
|
51
50
|
self.cls_txtdisplay_gap = 50
|
|
52
51
|
self.fontsize = 0.6
|
|
53
52
|
|
|
@@ -65,16 +64,14 @@ class ObjectCounter:
|
|
|
65
64
|
classes_names,
|
|
66
65
|
reg_pts,
|
|
67
66
|
count_reg_color=(255, 0, 255),
|
|
67
|
+
count_txt_color=(0, 0, 0),
|
|
68
|
+
count_bg_color=(255, 255, 255),
|
|
68
69
|
line_thickness=2,
|
|
69
70
|
track_thickness=2,
|
|
70
71
|
view_img=False,
|
|
71
72
|
view_in_counts=True,
|
|
72
73
|
view_out_counts=True,
|
|
73
74
|
draw_tracks=False,
|
|
74
|
-
count_txt_thickness=3,
|
|
75
|
-
count_txt_color=(255, 255, 255),
|
|
76
|
-
fontsize=0.8,
|
|
77
|
-
line_color=(255, 255, 255),
|
|
78
75
|
track_color=None,
|
|
79
76
|
region_thickness=5,
|
|
80
77
|
line_dist_thresh=15,
|
|
@@ -92,10 +89,8 @@ class ObjectCounter:
|
|
|
92
89
|
classes_names (dict): Classes names
|
|
93
90
|
track_thickness (int): Track thickness
|
|
94
91
|
draw_tracks (Bool): draw tracks
|
|
95
|
-
count_txt_thickness (int): Text thickness for object counting display
|
|
96
92
|
count_txt_color (RGB color): count text color value
|
|
97
|
-
|
|
98
|
-
line_color (RGB color): count highlighter line color
|
|
93
|
+
count_bg_color (RGB color): count highlighter line color
|
|
99
94
|
count_reg_color (RGB color): Color of object counting region
|
|
100
95
|
track_color (RGB color): color for tracks
|
|
101
96
|
region_thickness (int): Object counting Region thickness
|
|
@@ -125,10 +120,8 @@ class ObjectCounter:
|
|
|
125
120
|
|
|
126
121
|
self.names = classes_names
|
|
127
122
|
self.track_color = track_color
|
|
128
|
-
self.count_txt_thickness = count_txt_thickness
|
|
129
123
|
self.count_txt_color = count_txt_color
|
|
130
|
-
self.
|
|
131
|
-
self.line_color = line_color
|
|
124
|
+
self.count_bg_color = count_bg_color
|
|
132
125
|
self.region_color = count_reg_color
|
|
133
126
|
self.region_thickness = region_thickness
|
|
134
127
|
self.line_dist_thresh = line_dist_thresh
|
|
@@ -172,6 +165,9 @@ class ObjectCounter:
|
|
|
172
165
|
# Annotator Init and region drawing
|
|
173
166
|
self.annotator = Annotator(self.im0, self.tf, self.names)
|
|
174
167
|
|
|
168
|
+
# Draw region or line
|
|
169
|
+
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
|
170
|
+
|
|
175
171
|
if tracks[0].boxes.id is not None:
|
|
176
172
|
boxes = tracks[0].boxes.xyxy.cpu()
|
|
177
173
|
clss = tracks[0].boxes.cls.cpu().tolist()
|
|
@@ -220,17 +216,14 @@ class ObjectCounter:
|
|
|
220
216
|
|
|
221
217
|
# Count objects using line
|
|
222
218
|
elif len(self.reg_pts) == 2:
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
if prev_position is not None and is_inside and track_id not in self.count_ids:
|
|
219
|
+
if prev_position is not None and track_id not in self.count_ids:
|
|
226
220
|
distance = Point(track_line[-1]).distance(self.counting_region)
|
|
227
|
-
|
|
228
221
|
if distance < self.line_dist_thresh and track_id not in self.count_ids:
|
|
229
222
|
self.count_ids.append(track_id)
|
|
230
223
|
|
|
231
224
|
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
|
232
225
|
self.in_counts += 1
|
|
233
|
-
self.class_wise_count[self.names[cls]]["in"] +=
|
|
226
|
+
self.class_wise_count[self.names[cls]]["in"] += 2
|
|
234
227
|
else:
|
|
235
228
|
self.out_counts += 1
|
|
236
229
|
self.class_wise_count[self.names[cls]]["out"] += 1
|
|
@@ -254,17 +247,13 @@ class ObjectCounter:
|
|
|
254
247
|
if label is not None:
|
|
255
248
|
self.annotator.display_counts(
|
|
256
249
|
counts=label,
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
txt_color=self.count_txt_color,
|
|
260
|
-
line_color=self.line_color,
|
|
261
|
-
classwise_txtgap=self.cls_txtdisplay_gap,
|
|
250
|
+
count_txt_color=self.count_txt_color,
|
|
251
|
+
count_bg_color=self.count_bg_color,
|
|
262
252
|
)
|
|
263
253
|
|
|
264
254
|
def display_frames(self):
|
|
265
255
|
"""Display frame."""
|
|
266
256
|
if self.env_check:
|
|
267
|
-
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
|
268
257
|
cv2.namedWindow(self.window_name)
|
|
269
258
|
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
|
|
270
259
|
cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
|
ultralytics/utils/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
3
|
import contextlib
|
|
4
|
+
import importlib.metadata
|
|
4
5
|
import inspect
|
|
5
6
|
import logging.config
|
|
6
7
|
import os
|
|
@@ -42,6 +43,8 @@ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar form
|
|
|
42
43
|
LOGGING_NAME = "ultralytics"
|
|
43
44
|
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
|
44
45
|
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
|
|
46
|
+
PYTHON_VERSION = platform.python_version()
|
|
47
|
+
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
|
45
48
|
HELP_MSG = """
|
|
46
49
|
Usage examples for running YOLOv8:
|
|
47
50
|
|
|
@@ -476,7 +479,7 @@ def is_online() -> bool:
|
|
|
476
479
|
|
|
477
480
|
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
|
478
481
|
try:
|
|
479
|
-
test_connection = socket.create_connection(address=(host,
|
|
482
|
+
test_connection = socket.create_connection(address=(host, 80), timeout=2)
|
|
480
483
|
except (socket.timeout, socket.gaierror, OSError):
|
|
481
484
|
continue
|
|
482
485
|
else:
|
ultralytics/utils/benchmarks.py
CHANGED
|
@@ -7,8 +7,6 @@ try:
|
|
|
7
7
|
assert SETTINGS["clearml"] is True # verify integration is enabled
|
|
8
8
|
import clearml
|
|
9
9
|
from clearml import Task
|
|
10
|
-
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
|
11
|
-
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
|
12
10
|
|
|
13
11
|
assert hasattr(clearml, "__version__") # verify package is not directory
|
|
14
12
|
|
|
@@ -61,8 +59,11 @@ def on_pretrain_routine_start(trainer):
|
|
|
61
59
|
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
|
62
60
|
try:
|
|
63
61
|
if task := Task.current_task():
|
|
64
|
-
#
|
|
62
|
+
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
|
|
65
63
|
# We are logging these plots and model files manually in the integration
|
|
64
|
+
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
|
65
|
+
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
|
66
|
+
|
|
66
67
|
PatchPyTorchModelIO.update_current_task(None)
|
|
67
68
|
PatchedMatplotlib.update_current_task(None)
|
|
68
69
|
else:
|
|
@@ -9,10 +9,6 @@ try:
|
|
|
9
9
|
import wandb as wb
|
|
10
10
|
|
|
11
11
|
assert hasattr(wb, "__version__") # verify package is not directory
|
|
12
|
-
|
|
13
|
-
import numpy as np
|
|
14
|
-
import pandas as pd
|
|
15
|
-
|
|
16
12
|
_processed_plots = {}
|
|
17
13
|
|
|
18
14
|
except (ImportError, AssertionError):
|
|
@@ -38,7 +34,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|
|
38
34
|
Returns:
|
|
39
35
|
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
|
40
36
|
"""
|
|
41
|
-
|
|
37
|
+
import pandas # scope for faster 'import ultralytics'
|
|
38
|
+
|
|
39
|
+
df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
|
42
40
|
fields = {"x": "x", "y": "y", "class": "class"}
|
|
43
41
|
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
|
44
42
|
return wb.plot_table(
|
|
@@ -77,6 +75,8 @@ def _plot_curve(
|
|
|
77
75
|
Note:
|
|
78
76
|
The function leverages the '_custom_table' function to generate the actual visualization.
|
|
79
77
|
"""
|
|
78
|
+
import numpy as np
|
|
79
|
+
|
|
80
80
|
# Create new x
|
|
81
81
|
if names is None:
|
|
82
82
|
names = []
|
ultralytics/utils/checks.py
CHANGED
|
@@ -18,15 +18,16 @@ import cv2
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import requests
|
|
20
20
|
import torch
|
|
21
|
-
from matplotlib import font_manager
|
|
22
21
|
|
|
23
22
|
from ultralytics.utils import (
|
|
24
23
|
ASSETS,
|
|
25
24
|
AUTOINSTALL,
|
|
26
25
|
LINUX,
|
|
27
26
|
LOGGER,
|
|
27
|
+
PYTHON_VERSION,
|
|
28
28
|
ONLINE,
|
|
29
29
|
ROOT,
|
|
30
|
+
TORCHVISION_VERSION,
|
|
30
31
|
USER_CONFIG_DIR,
|
|
31
32
|
Retry,
|
|
32
33
|
SimpleNamespace,
|
|
@@ -41,13 +42,10 @@ from ultralytics.utils import (
|
|
|
41
42
|
is_github_action_running,
|
|
42
43
|
is_jupyter,
|
|
43
44
|
is_kaggle,
|
|
44
|
-
is_online,
|
|
45
45
|
is_pip_package,
|
|
46
46
|
url2file,
|
|
47
47
|
)
|
|
48
48
|
|
|
49
|
-
PYTHON_VERSION = platform.python_version()
|
|
50
|
-
|
|
51
49
|
|
|
52
50
|
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
|
53
51
|
"""
|
|
@@ -304,9 +302,10 @@ def check_font(font="Arial.ttf"):
|
|
|
304
302
|
Returns:
|
|
305
303
|
file (Path): Resolved font file path.
|
|
306
304
|
"""
|
|
307
|
-
|
|
305
|
+
from matplotlib import font_manager
|
|
308
306
|
|
|
309
307
|
# Check USER_CONFIG_DIR
|
|
308
|
+
name = Path(font).name
|
|
310
309
|
file = USER_CONFIG_DIR / name
|
|
311
310
|
if file.exists():
|
|
312
311
|
return file
|
|
@@ -390,7 +389,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|
|
390
389
|
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
|
391
390
|
try:
|
|
392
391
|
t = time.time()
|
|
393
|
-
assert
|
|
392
|
+
assert ONLINE, "AutoUpdate skipped (offline)"
|
|
394
393
|
with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay
|
|
395
394
|
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
|
|
396
395
|
dt = time.time() - t
|
|
@@ -419,14 +418,12 @@ def check_torchvision():
|
|
|
419
418
|
Torchvision versions.
|
|
420
419
|
"""
|
|
421
420
|
|
|
422
|
-
import torchvision
|
|
423
|
-
|
|
424
421
|
# Compatibility table
|
|
425
422
|
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
|
|
426
423
|
|
|
427
424
|
# Extract only the major and minor versions
|
|
428
425
|
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
|
429
|
-
v_torchvision = ".".join(
|
|
426
|
+
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
|
|
430
427
|
|
|
431
428
|
if v_torch in compatibility_table:
|
|
432
429
|
compatible_versions = compatibility_table[v_torch]
|
ultralytics/utils/metrics.py
CHANGED
|
@@ -395,19 +395,19 @@ class ConfusionMatrix:
|
|
|
395
395
|
names (tuple): Names of classes, used as labels on the plot.
|
|
396
396
|
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
|
397
397
|
"""
|
|
398
|
-
import seaborn
|
|
398
|
+
import seaborn # scope for faster 'import ultralytics'
|
|
399
399
|
|
|
400
400
|
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
|
401
401
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
|
402
402
|
|
|
403
403
|
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
|
404
404
|
nc, nn = self.nc, len(names) # number of classes, names
|
|
405
|
-
|
|
405
|
+
seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
|
406
406
|
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
|
407
407
|
ticklabels = (list(names) + ["background"]) if labels else "auto"
|
|
408
408
|
with warnings.catch_warnings():
|
|
409
409
|
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
|
410
|
-
|
|
410
|
+
seaborn.heatmap(
|
|
411
411
|
array,
|
|
412
412
|
ax=ax,
|
|
413
413
|
annot=nc < 30,
|
ultralytics/utils/ops.py
CHANGED
|
@@ -9,7 +9,6 @@ import cv2
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
11
11
|
import torch.nn.functional as F
|
|
12
|
-
import torchvision
|
|
13
12
|
|
|
14
13
|
from ultralytics.utils import LOGGER
|
|
15
14
|
from ultralytics.utils.metrics import batch_probiou
|
|
@@ -206,6 +205,7 @@ def non_max_suppression(
|
|
|
206
205
|
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
|
207
206
|
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
|
208
207
|
"""
|
|
208
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
209
209
|
|
|
210
210
|
# Checks
|
|
211
211
|
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
ultralytics/utils/plotting.py
CHANGED
|
@@ -339,6 +339,21 @@ class Annotator:
|
|
|
339
339
|
"""Save the annotated image to 'filename'."""
|
|
340
340
|
cv2.imwrite(filename, np.asarray(self.im))
|
|
341
341
|
|
|
342
|
+
def get_bbox_dimension(self, bbox=None):
|
|
343
|
+
"""
|
|
344
|
+
Calculate the area of a bounding box.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
angle (degree): Degree value of angle between three points
|
|
351
|
+
"""
|
|
352
|
+
x_min, y_min, x_max, y_max = bbox
|
|
353
|
+
width = x_max - x_min
|
|
354
|
+
height = y_max - y_min
|
|
355
|
+
return width, height, width * height
|
|
356
|
+
|
|
342
357
|
def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
|
|
343
358
|
"""
|
|
344
359
|
Draw region line.
|
|
@@ -364,13 +379,22 @@ class Annotator:
|
|
|
364
379
|
cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
|
|
365
380
|
|
|
366
381
|
def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0), fontsize=0.7):
|
|
367
|
-
"""
|
|
382
|
+
"""
|
|
383
|
+
Displays queue counts on an image centered at the points with customizable font size and colors.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
label (str): queue counts label
|
|
387
|
+
points (tuple): region points for center point calculation to display text
|
|
388
|
+
region_color (RGB): queue region color
|
|
389
|
+
txt_color (RGB): text display color
|
|
390
|
+
fontsize (float): text fontsize
|
|
391
|
+
"""
|
|
368
392
|
x_values = [point[0] for point in points]
|
|
369
393
|
y_values = [point[1] for point in points]
|
|
370
394
|
center_x = sum(x_values) // len(points)
|
|
371
395
|
center_y = sum(y_values) // len(points)
|
|
372
396
|
|
|
373
|
-
text_size = cv2.getTextSize(label,
|
|
397
|
+
text_size = cv2.getTextSize(label, 0, fontScale=fontsize, thickness=self.tf)[0]
|
|
374
398
|
text_width = text_size[0]
|
|
375
399
|
text_height = text_size[1]
|
|
376
400
|
|
|
@@ -388,56 +412,63 @@ class Annotator:
|
|
|
388
412
|
self.im,
|
|
389
413
|
label,
|
|
390
414
|
(text_x, text_y),
|
|
391
|
-
|
|
415
|
+
0,
|
|
392
416
|
fontScale=fontsize,
|
|
393
417
|
color=txt_color,
|
|
394
418
|
thickness=self.tf,
|
|
395
419
|
lineType=cv2.LINE_AA,
|
|
396
420
|
)
|
|
397
421
|
|
|
398
|
-
def display_counts(
|
|
399
|
-
self, counts=None, tf=2, fontScale=0.6, line_color=(0, 0, 0), txt_color=(255, 255, 255), classwise_txtgap=55
|
|
400
|
-
):
|
|
422
|
+
def display_counts(self, counts=None, count_bg_color=(0, 0, 0), count_txt_color=(255, 255, 255)):
|
|
401
423
|
"""
|
|
402
|
-
Display counts on im0.
|
|
424
|
+
Display counts on im0 with text background and border.
|
|
403
425
|
|
|
404
426
|
Args:
|
|
405
427
|
counts (str): objects count data
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
line_color (RGB Color): counts highlighter color
|
|
409
|
-
txt_color (RGB Color): counts display color
|
|
410
|
-
classwise_txtgap (int): Gap between each class count data
|
|
428
|
+
count_bg_color (RGB Color): counts highlighter color
|
|
429
|
+
count_txt_color (RGB Color): counts display color
|
|
411
430
|
"""
|
|
412
431
|
|
|
413
|
-
tl = tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
|
|
432
|
+
tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
|
|
414
433
|
tf = max(tl - 1, 1)
|
|
415
434
|
|
|
416
|
-
t_sizes = [cv2.getTextSize(str(count), 0, fontScale=
|
|
435
|
+
t_sizes = [cv2.getTextSize(str(count), 0, fontScale=self.sf, thickness=self.tf)[0] for count in counts]
|
|
417
436
|
|
|
418
437
|
max_text_width = max([size[0] for size in t_sizes])
|
|
419
438
|
max_text_height = max([size[1] for size in t_sizes])
|
|
420
439
|
|
|
421
|
-
text_x = self.im.shape[1] -
|
|
422
|
-
text_y =
|
|
440
|
+
text_x = self.im.shape[1] - int(self.im.shape[1] * 0.025 + max_text_width)
|
|
441
|
+
text_y = int(self.im.shape[0] * 0.025)
|
|
442
|
+
|
|
443
|
+
# Calculate dynamic gap between each count value based on the width of the image
|
|
444
|
+
dynamic_gap = max(1, self.im.shape[1] // 100) * tf
|
|
423
445
|
|
|
424
446
|
for i, count in enumerate(counts):
|
|
425
447
|
text_x_pos = text_x
|
|
426
|
-
text_y_pos = text_y + i *
|
|
448
|
+
text_y_pos = text_y + i * dynamic_gap # Adjust vertical position with dynamic gap
|
|
427
449
|
|
|
450
|
+
# Draw the border
|
|
451
|
+
cv2.rectangle(
|
|
452
|
+
self.im,
|
|
453
|
+
(text_x_pos - (10 * tf), text_y_pos - (10 * tf)),
|
|
454
|
+
(text_x_pos + max_text_width + (10 * tf), text_y_pos + max_text_height + (10 * tf)),
|
|
455
|
+
count_bg_color,
|
|
456
|
+
-1,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Draw the count text
|
|
428
460
|
cv2.putText(
|
|
429
461
|
self.im,
|
|
430
462
|
str(count),
|
|
431
|
-
(text_x_pos, text_y_pos),
|
|
432
|
-
|
|
433
|
-
fontScale=
|
|
434
|
-
color=
|
|
435
|
-
thickness=tf,
|
|
463
|
+
(text_x_pos, text_y_pos + max_text_height),
|
|
464
|
+
0,
|
|
465
|
+
fontScale=self.sf,
|
|
466
|
+
color=count_txt_color,
|
|
467
|
+
thickness=self.tf,
|
|
436
468
|
lineType=cv2.LINE_AA,
|
|
437
469
|
)
|
|
438
470
|
|
|
439
|
-
|
|
440
|
-
cv2.line(self.im, (text_x_pos, line_y_pos), (text_x_pos + max_text_width, line_y_pos), line_color, tf)
|
|
471
|
+
text_y_pos += tf * max_text_height
|
|
441
472
|
|
|
442
473
|
@staticmethod
|
|
443
474
|
def estimate_pose_angle(a, b, c):
|
|
@@ -588,30 +619,26 @@ class Annotator:
|
|
|
588
619
|
line_color (RGB): Distance line color.
|
|
589
620
|
centroid_color (RGB): Bounding box centroid color.
|
|
590
621
|
"""
|
|
591
|
-
(text_width_m, text_height_m), _ = cv2.getTextSize(
|
|
592
|
-
f"Distance M: {distance_m:.2f}m", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
|
|
593
|
-
)
|
|
622
|
+
(text_width_m, text_height_m), _ = cv2.getTextSize(f"Distance M: {distance_m:.2f}m", 0, 0.8, 2)
|
|
594
623
|
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), (255, 255, 255), -1)
|
|
595
624
|
cv2.putText(
|
|
596
625
|
self.im,
|
|
597
626
|
f"Distance M: {distance_m:.2f}m",
|
|
598
627
|
(20, 50),
|
|
599
|
-
|
|
628
|
+
0,
|
|
600
629
|
0.8,
|
|
601
630
|
(0, 0, 0),
|
|
602
631
|
2,
|
|
603
632
|
cv2.LINE_AA,
|
|
604
633
|
)
|
|
605
634
|
|
|
606
|
-
(text_width_mm, text_height_mm), _ = cv2.getTextSize(
|
|
607
|
-
f"Distance MM: {distance_mm:.2f}mm", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
|
|
608
|
-
)
|
|
635
|
+
(text_width_mm, text_height_mm), _ = cv2.getTextSize(f"Distance MM: {distance_mm:.2f}mm", 0, 0.8, 2)
|
|
609
636
|
cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), (255, 255, 255), -1)
|
|
610
637
|
cv2.putText(
|
|
611
638
|
self.im,
|
|
612
639
|
f"Distance MM: {distance_mm:.2f}mm",
|
|
613
640
|
(20, 100),
|
|
614
|
-
|
|
641
|
+
0,
|
|
615
642
|
0.8,
|
|
616
643
|
(0, 0, 0),
|
|
617
644
|
2,
|
|
@@ -644,8 +671,8 @@ class Annotator:
|
|
|
644
671
|
@plt_settings()
|
|
645
672
|
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
646
673
|
"""Plot training labels including class histograms and box statistics."""
|
|
647
|
-
import pandas
|
|
648
|
-
import seaborn
|
|
674
|
+
import pandas # scope for faster 'import ultralytics'
|
|
675
|
+
import seaborn # scope for faster 'import ultralytics'
|
|
649
676
|
|
|
650
677
|
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
|
|
651
678
|
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
|
@@ -655,10 +682,10 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
655
682
|
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
|
656
683
|
nc = int(cls.max() + 1) # number of classes
|
|
657
684
|
boxes = boxes[:1000000] # limit to 1M boxes
|
|
658
|
-
x =
|
|
685
|
+
x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
|
|
659
686
|
|
|
660
687
|
# Seaborn correlogram
|
|
661
|
-
|
|
688
|
+
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
|
662
689
|
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
|
|
663
690
|
plt.close()
|
|
664
691
|
|
|
@@ -673,8 +700,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
673
700
|
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
|
674
701
|
else:
|
|
675
702
|
ax[0].set_xlabel("classes")
|
|
676
|
-
|
|
677
|
-
|
|
703
|
+
seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
|
|
704
|
+
seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
|
|
678
705
|
|
|
679
706
|
# Rectangles
|
|
680
707
|
boxes[:, 0:2] = 0.5 # center
|
|
@@ -906,7 +933,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|
|
906
933
|
plot_results('path/to/results.csv', segment=True)
|
|
907
934
|
```
|
|
908
935
|
"""
|
|
909
|
-
import pandas as pd
|
|
936
|
+
import pandas as pd # scope for faster 'import ultralytics'
|
|
910
937
|
from scipy.ndimage import gaussian_filter1d
|
|
911
938
|
|
|
912
939
|
save_dir = Path(file).parent if file else Path(dir)
|
|
@@ -992,7 +1019,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
|
|
|
992
1019
|
>>> plot_tune_results('path/to/tune_results.csv')
|
|
993
1020
|
"""
|
|
994
1021
|
|
|
995
|
-
import pandas as pd
|
|
1022
|
+
import pandas as pd # scope for faster 'import ultralytics'
|
|
996
1023
|
from scipy.ndimage import gaussian_filter1d
|
|
997
1024
|
|
|
998
1025
|
# Scatter plots for each hyperparameter
|