ultralytics 8.1.42__py3-none-any.whl → 8.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +2 -3
- ultralytics/cfg/models/v9/yolov9e.yaml +2 -3
- ultralytics/data/__init__.py +3 -8
- ultralytics/data/augment.py +14 -11
- ultralytics/data/base.py +1 -1
- ultralytics/data/build.py +1 -1
- ultralytics/data/converter.py +4 -3
- ultralytics/data/dataset.py +149 -144
- ultralytics/data/explorer/explorer.py +10 -11
- ultralytics/data/explorer/gui/dash.py +3 -3
- ultralytics/data/explorer/utils.py +3 -2
- ultralytics/data/loaders.py +3 -3
- ultralytics/data/utils.py +1 -1
- ultralytics/engine/exporter.py +3 -2
- ultralytics/engine/model.py +2 -1
- ultralytics/engine/trainer.py +2 -1
- ultralytics/hub/auth.py +3 -3
- ultralytics/hub/session.py +3 -3
- ultralytics/hub/utils.py +6 -6
- ultralytics/models/fastsam/prompt.py +4 -1
- ultralytics/models/rtdetr/val.py +1 -1
- ultralytics/models/sam/modules/tiny_encoder.py +2 -2
- ultralytics/models/sam/modules/transformer.py +1 -1
- ultralytics/models/sam/predict.py +16 -13
- ultralytics/models/yolo/classify/train.py +2 -1
- ultralytics/models/yolo/detect/val.py +1 -1
- ultralytics/models/yolo/model.py +1 -1
- ultralytics/models/yolo/obb/val.py +1 -1
- ultralytics/models/yolo/world/train_world.py +2 -2
- ultralytics/nn/modules/__init__.py +8 -8
- ultralytics/nn/modules/head.py +1 -1
- ultralytics/nn/tasks.py +7 -7
- ultralytics/solutions/heatmap.py +14 -27
- ultralytics/solutions/object_counter.py +12 -22
- ultralytics/trackers/byte_tracker.py +1 -1
- ultralytics/trackers/utils/kalman_filter.py +4 -4
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +56 -41
- ultralytics/utils/benchmarks.py +1 -2
- ultralytics/utils/callbacks/clearml.py +4 -3
- ultralytics/utils/callbacks/hub.py +1 -4
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +1 -0
- ultralytics/utils/callbacks/wb.py +5 -5
- ultralytics/utils/checks.py +17 -20
- ultralytics/utils/metrics.py +3 -3
- ultralytics/utils/ops.py +1 -1
- ultralytics/utils/plotting.py +67 -40
- ultralytics/utils/torch_utils.py +13 -6
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/METADATA +1 -1
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/RECORD +58 -58
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/WHEEL +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.42.dist-info → ultralytics-8.1.44.dist-info}/top_level.txt +0 -0
|
@@ -47,7 +47,7 @@ class ObjectCounter:
|
|
|
47
47
|
self.class_wise_count = {}
|
|
48
48
|
self.count_txt_thickness = 0
|
|
49
49
|
self.count_txt_color = (255, 255, 255)
|
|
50
|
-
self.
|
|
50
|
+
self.count_bg_color = (255, 255, 255)
|
|
51
51
|
self.cls_txtdisplay_gap = 50
|
|
52
52
|
self.fontsize = 0.6
|
|
53
53
|
|
|
@@ -65,16 +65,14 @@ class ObjectCounter:
|
|
|
65
65
|
classes_names,
|
|
66
66
|
reg_pts,
|
|
67
67
|
count_reg_color=(255, 0, 255),
|
|
68
|
+
count_txt_color=(0, 0, 0),
|
|
69
|
+
count_bg_color=(255, 255, 255),
|
|
68
70
|
line_thickness=2,
|
|
69
71
|
track_thickness=2,
|
|
70
72
|
view_img=False,
|
|
71
73
|
view_in_counts=True,
|
|
72
74
|
view_out_counts=True,
|
|
73
75
|
draw_tracks=False,
|
|
74
|
-
count_txt_thickness=3,
|
|
75
|
-
count_txt_color=(255, 255, 255),
|
|
76
|
-
fontsize=0.8,
|
|
77
|
-
line_color=(255, 255, 255),
|
|
78
76
|
track_color=None,
|
|
79
77
|
region_thickness=5,
|
|
80
78
|
line_dist_thresh=15,
|
|
@@ -92,10 +90,8 @@ class ObjectCounter:
|
|
|
92
90
|
classes_names (dict): Classes names
|
|
93
91
|
track_thickness (int): Track thickness
|
|
94
92
|
draw_tracks (Bool): draw tracks
|
|
95
|
-
count_txt_thickness (int): Text thickness for object counting display
|
|
96
93
|
count_txt_color (RGB color): count text color value
|
|
97
|
-
|
|
98
|
-
line_color (RGB color): count highlighter line color
|
|
94
|
+
count_bg_color (RGB color): count highlighter line color
|
|
99
95
|
count_reg_color (RGB color): Color of object counting region
|
|
100
96
|
track_color (RGB color): color for tracks
|
|
101
97
|
region_thickness (int): Object counting Region thickness
|
|
@@ -125,10 +121,8 @@ class ObjectCounter:
|
|
|
125
121
|
|
|
126
122
|
self.names = classes_names
|
|
127
123
|
self.track_color = track_color
|
|
128
|
-
self.count_txt_thickness = count_txt_thickness
|
|
129
124
|
self.count_txt_color = count_txt_color
|
|
130
|
-
self.
|
|
131
|
-
self.line_color = line_color
|
|
125
|
+
self.count_bg_color = count_bg_color
|
|
132
126
|
self.region_color = count_reg_color
|
|
133
127
|
self.region_thickness = region_thickness
|
|
134
128
|
self.line_dist_thresh = line_dist_thresh
|
|
@@ -172,6 +166,9 @@ class ObjectCounter:
|
|
|
172
166
|
# Annotator Init and region drawing
|
|
173
167
|
self.annotator = Annotator(self.im0, self.tf, self.names)
|
|
174
168
|
|
|
169
|
+
# Draw region or line
|
|
170
|
+
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
|
171
|
+
|
|
175
172
|
if tracks[0].boxes.id is not None:
|
|
176
173
|
boxes = tracks[0].boxes.xyxy.cpu()
|
|
177
174
|
clss = tracks[0].boxes.cls.cpu().tolist()
|
|
@@ -220,17 +217,14 @@ class ObjectCounter:
|
|
|
220
217
|
|
|
221
218
|
# Count objects using line
|
|
222
219
|
elif len(self.reg_pts) == 2:
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
if prev_position is not None and is_inside and track_id not in self.count_ids:
|
|
220
|
+
if prev_position is not None and track_id not in self.count_ids:
|
|
226
221
|
distance = Point(track_line[-1]).distance(self.counting_region)
|
|
227
|
-
|
|
228
222
|
if distance < self.line_dist_thresh and track_id not in self.count_ids:
|
|
229
223
|
self.count_ids.append(track_id)
|
|
230
224
|
|
|
231
225
|
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
|
232
226
|
self.in_counts += 1
|
|
233
|
-
self.class_wise_count[self.names[cls]]["in"] +=
|
|
227
|
+
self.class_wise_count[self.names[cls]]["in"] += 2
|
|
234
228
|
else:
|
|
235
229
|
self.out_counts += 1
|
|
236
230
|
self.class_wise_count[self.names[cls]]["out"] += 1
|
|
@@ -254,17 +248,13 @@ class ObjectCounter:
|
|
|
254
248
|
if label is not None:
|
|
255
249
|
self.annotator.display_counts(
|
|
256
250
|
counts=label,
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
txt_color=self.count_txt_color,
|
|
260
|
-
line_color=self.line_color,
|
|
261
|
-
classwise_txtgap=self.cls_txtdisplay_gap,
|
|
251
|
+
count_txt_color=self.count_txt_color,
|
|
252
|
+
count_bg_color=self.count_bg_color,
|
|
262
253
|
)
|
|
263
254
|
|
|
264
255
|
def display_frames(self):
|
|
265
256
|
"""Display frame."""
|
|
266
257
|
if self.env_check:
|
|
267
|
-
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
|
268
258
|
cv2.namedWindow(self.window_name)
|
|
269
259
|
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
|
|
270
260
|
cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
|
|
@@ -5,8 +5,8 @@ import numpy as np
|
|
|
5
5
|
from .basetrack import BaseTrack, TrackState
|
|
6
6
|
from .utils import matching
|
|
7
7
|
from .utils.kalman_filter import KalmanFilterXYAH
|
|
8
|
-
from ..utils.ops import xywh2ltwh
|
|
9
8
|
from ..utils import LOGGER
|
|
9
|
+
from ..utils.ops import xywh2ltwh
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class STrack(BaseTrack):
|
|
@@ -39,8 +39,8 @@ class KalmanFilterXYAH:
|
|
|
39
39
|
and height h.
|
|
40
40
|
|
|
41
41
|
Returns:
|
|
42
|
-
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
|
43
|
-
the new track. Unobserved velocities are initialized to 0 mean.
|
|
42
|
+
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
|
43
|
+
of the new track. Unobserved velocities are initialized to 0 mean.
|
|
44
44
|
"""
|
|
45
45
|
mean_pos = measurement
|
|
46
46
|
mean_vel = np.zeros_like(mean_pos)
|
|
@@ -235,8 +235,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
|
235
235
|
measurement (ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
|
|
236
236
|
|
|
237
237
|
Returns:
|
|
238
|
-
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
|
239
|
-
the new track. Unobserved velocities are initialized to 0 mean.
|
|
238
|
+
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
|
239
|
+
of the new track. Unobserved velocities are initialized to 0 mean.
|
|
240
240
|
"""
|
|
241
241
|
mean_pos = measurement
|
|
242
242
|
mean_vel = np.zeros_like(mean_pos)
|
ultralytics/utils/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
3
|
import contextlib
|
|
4
|
+
import importlib.metadata
|
|
4
5
|
import inspect
|
|
5
6
|
import logging.config
|
|
6
7
|
import os
|
|
@@ -42,6 +43,8 @@ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar form
|
|
|
42
43
|
LOGGING_NAME = "ultralytics"
|
|
43
44
|
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
|
44
45
|
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
|
|
46
|
+
PYTHON_VERSION = platform.python_version()
|
|
47
|
+
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
|
45
48
|
HELP_MSG = """
|
|
46
49
|
Usage examples for running YOLOv8:
|
|
47
50
|
|
|
@@ -457,12 +460,23 @@ def is_docker() -> bool:
|
|
|
457
460
|
Returns:
|
|
458
461
|
(bool): True if the script is running inside a Docker container, False otherwise.
|
|
459
462
|
"""
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
with open(file) as f:
|
|
463
|
+
with contextlib.suppress(Exception):
|
|
464
|
+
with open("/proc/self/cgroup") as f:
|
|
463
465
|
return "docker" in f.read()
|
|
464
|
-
|
|
465
|
-
|
|
466
|
+
return False
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def is_raspberrypi() -> bool:
|
|
470
|
+
"""
|
|
471
|
+
Determines if the Python environment is running on a Raspberry Pi by checking the device model information.
|
|
472
|
+
|
|
473
|
+
Returns:
|
|
474
|
+
(bool): True if running on a Raspberry Pi, False otherwise.
|
|
475
|
+
"""
|
|
476
|
+
with contextlib.suppress(Exception):
|
|
477
|
+
with open("/proc/device-tree/model") as f:
|
|
478
|
+
return "Raspberry Pi" in f.read()
|
|
479
|
+
return False
|
|
466
480
|
|
|
467
481
|
|
|
468
482
|
def is_online() -> bool:
|
|
@@ -472,23 +486,15 @@ def is_online() -> bool:
|
|
|
472
486
|
Returns:
|
|
473
487
|
(bool): True if connection is successful, False otherwise.
|
|
474
488
|
"""
|
|
475
|
-
|
|
489
|
+
with contextlib.suppress(Exception):
|
|
490
|
+
assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True"
|
|
491
|
+
import socket
|
|
476
492
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
|
480
|
-
except (socket.timeout, socket.gaierror, OSError):
|
|
481
|
-
continue
|
|
482
|
-
else:
|
|
483
|
-
# If the connection was successful, close it to avoid a ResourceWarning
|
|
484
|
-
test_connection.close()
|
|
485
|
-
return True
|
|
493
|
+
socket.create_connection(address=("1.1.1.1", 80), timeout=1.0).close() # check Cloudflare DNS
|
|
494
|
+
return True
|
|
486
495
|
return False
|
|
487
496
|
|
|
488
497
|
|
|
489
|
-
ONLINE = is_online()
|
|
490
|
-
|
|
491
|
-
|
|
492
498
|
def is_pip_package(filepath: str = __name__) -> bool:
|
|
493
499
|
"""
|
|
494
500
|
Determines if the file at the given filepath is part of a pip package.
|
|
@@ -541,17 +547,6 @@ def is_github_action_running() -> bool:
|
|
|
541
547
|
return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ
|
|
542
548
|
|
|
543
549
|
|
|
544
|
-
def is_git_dir():
|
|
545
|
-
"""
|
|
546
|
-
Determines whether the current file is part of a git repository. If the current file is not part of a git
|
|
547
|
-
repository, returns None.
|
|
548
|
-
|
|
549
|
-
Returns:
|
|
550
|
-
(bool): True if current file is part of a git repository.
|
|
551
|
-
"""
|
|
552
|
-
return get_git_dir() is not None
|
|
553
|
-
|
|
554
|
-
|
|
555
550
|
def get_git_dir():
|
|
556
551
|
"""
|
|
557
552
|
Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
|
|
@@ -565,6 +560,17 @@ def get_git_dir():
|
|
|
565
560
|
return d
|
|
566
561
|
|
|
567
562
|
|
|
563
|
+
def is_git_dir():
|
|
564
|
+
"""
|
|
565
|
+
Determines whether the current file is part of a git repository. If the current file is not part of a git
|
|
566
|
+
repository, returns None.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
(bool): True if current file is part of a git repository.
|
|
570
|
+
"""
|
|
571
|
+
return GIT_DIR is not None
|
|
572
|
+
|
|
573
|
+
|
|
568
574
|
def get_git_origin_url():
|
|
569
575
|
"""
|
|
570
576
|
Retrieves the origin URL of a git repository.
|
|
@@ -572,7 +578,7 @@ def get_git_origin_url():
|
|
|
572
578
|
Returns:
|
|
573
579
|
(str | None): The origin URL of the git repository or None if not git directory.
|
|
574
580
|
"""
|
|
575
|
-
if
|
|
581
|
+
if IS_GIT_DIR:
|
|
576
582
|
with contextlib.suppress(subprocess.CalledProcessError):
|
|
577
583
|
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
|
|
578
584
|
return origin.decode().strip()
|
|
@@ -585,7 +591,7 @@ def get_git_branch():
|
|
|
585
591
|
Returns:
|
|
586
592
|
(str | None): The current git branch name or None if not a git directory.
|
|
587
593
|
"""
|
|
588
|
-
if
|
|
594
|
+
if IS_GIT_DIR:
|
|
589
595
|
with contextlib.suppress(subprocess.CalledProcessError):
|
|
590
596
|
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
|
591
597
|
return origin.decode().strip()
|
|
@@ -651,6 +657,16 @@ def get_user_config_dir(sub_dir="Ultralytics"):
|
|
|
651
657
|
return path
|
|
652
658
|
|
|
653
659
|
|
|
660
|
+
# Define constants (required below)
|
|
661
|
+
ONLINE = is_online()
|
|
662
|
+
IS_COLAB = is_colab()
|
|
663
|
+
IS_DOCKER = is_docker()
|
|
664
|
+
IS_JUPYTER = is_jupyter()
|
|
665
|
+
IS_KAGGLE = is_kaggle()
|
|
666
|
+
IS_PIP_PACKAGE = is_pip_package()
|
|
667
|
+
IS_RASPBERRYPI = is_raspberrypi()
|
|
668
|
+
GIT_DIR = get_git_dir()
|
|
669
|
+
IS_GIT_DIR = is_git_dir()
|
|
654
670
|
USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir
|
|
655
671
|
SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
|
|
656
672
|
|
|
@@ -877,7 +893,7 @@ def set_sentry():
|
|
|
877
893
|
event["tags"] = {
|
|
878
894
|
"sys_argv": ARGV[0],
|
|
879
895
|
"sys_argv_name": Path(ARGV[0]).name,
|
|
880
|
-
"install": "git" if
|
|
896
|
+
"install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",
|
|
881
897
|
"os": ENVIRONMENT,
|
|
882
898
|
}
|
|
883
899
|
return event
|
|
@@ -888,8 +904,8 @@ def set_sentry():
|
|
|
888
904
|
and Path(ARGV[0]).name == "yolo"
|
|
889
905
|
and not TESTS_RUNNING
|
|
890
906
|
and ONLINE
|
|
891
|
-
and
|
|
892
|
-
and not
|
|
907
|
+
and IS_PIP_PACKAGE
|
|
908
|
+
and not IS_GIT_DIR
|
|
893
909
|
):
|
|
894
910
|
# If sentry_sdk package is not installed then return and do not use Sentry
|
|
895
911
|
try:
|
|
@@ -928,9 +944,8 @@ class SettingsManager(dict):
|
|
|
928
944
|
from ultralytics.utils.checks import check_version
|
|
929
945
|
from ultralytics.utils.torch_utils import torch_distributed_zero_first
|
|
930
946
|
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
datasets_root = (root.parent if git_dir and is_dir_writeable(root.parent) else root).resolve()
|
|
947
|
+
root = GIT_DIR or Path()
|
|
948
|
+
datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve()
|
|
934
949
|
|
|
935
950
|
self.file = Path(file)
|
|
936
951
|
self.version = version
|
|
@@ -1034,13 +1049,13 @@ WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory
|
|
|
1034
1049
|
RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory
|
|
1035
1050
|
ENVIRONMENT = (
|
|
1036
1051
|
"Colab"
|
|
1037
|
-
if
|
|
1052
|
+
if IS_COLAB
|
|
1038
1053
|
else "Kaggle"
|
|
1039
|
-
if
|
|
1054
|
+
if IS_KAGGLE
|
|
1040
1055
|
else "Jupyter"
|
|
1041
|
-
if
|
|
1056
|
+
if IS_JUPYTER
|
|
1042
1057
|
else "Docker"
|
|
1043
|
-
if
|
|
1058
|
+
if IS_DOCKER
|
|
1044
1059
|
else platform.system()
|
|
1045
1060
|
)
|
|
1046
1061
|
TESTS_RUNNING = is_pytest_running() or is_github_action_running()
|
ultralytics/utils/benchmarks.py
CHANGED
|
@@ -7,8 +7,6 @@ try:
|
|
|
7
7
|
assert SETTINGS["clearml"] is True # verify integration is enabled
|
|
8
8
|
import clearml
|
|
9
9
|
from clearml import Task
|
|
10
|
-
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
|
11
|
-
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
|
12
10
|
|
|
13
11
|
assert hasattr(clearml, "__version__") # verify package is not directory
|
|
14
12
|
|
|
@@ -61,8 +59,11 @@ def on_pretrain_routine_start(trainer):
|
|
|
61
59
|
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
|
62
60
|
try:
|
|
63
61
|
if task := Task.current_task():
|
|
64
|
-
#
|
|
62
|
+
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
|
|
65
63
|
# We are logging these plots and model files manually in the integration
|
|
64
|
+
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
|
65
|
+
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
|
66
|
+
|
|
66
67
|
PatchPyTorchModelIO.update_current_task(None)
|
|
67
68
|
PatchedMatplotlib.update_current_task(None)
|
|
68
69
|
else:
|
|
@@ -12,10 +12,7 @@ def on_pretrain_routine_end(trainer):
|
|
|
12
12
|
session = getattr(trainer, "hub_session", None)
|
|
13
13
|
if session:
|
|
14
14
|
# Start timer for upload rate limit
|
|
15
|
-
session.timers = {
|
|
16
|
-
"metrics": time(),
|
|
17
|
-
"ckpt": time(),
|
|
18
|
-
} # start timer on session.rate_limit
|
|
15
|
+
session.timers = {"metrics": time(), "ckpt": time()} # start timer on session.rate_limit
|
|
19
16
|
|
|
20
17
|
|
|
21
18
|
def on_fit_epoch_end(trainer):
|
|
@@ -58,7 +58,7 @@ def on_pretrain_routine_end(trainer):
|
|
|
58
58
|
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
|
|
59
59
|
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
|
|
60
60
|
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
|
|
61
|
-
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of
|
|
61
|
+
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of training.
|
|
62
62
|
"""
|
|
63
63
|
global mlflow
|
|
64
64
|
|
|
@@ -9,10 +9,6 @@ try:
|
|
|
9
9
|
import wandb as wb
|
|
10
10
|
|
|
11
11
|
assert hasattr(wb, "__version__") # verify package is not directory
|
|
12
|
-
|
|
13
|
-
import numpy as np
|
|
14
|
-
import pandas as pd
|
|
15
|
-
|
|
16
12
|
_processed_plots = {}
|
|
17
13
|
|
|
18
14
|
except (ImportError, AssertionError):
|
|
@@ -38,7 +34,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|
|
38
34
|
Returns:
|
|
39
35
|
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
|
40
36
|
"""
|
|
41
|
-
|
|
37
|
+
import pandas # scope for faster 'import ultralytics'
|
|
38
|
+
|
|
39
|
+
df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
|
42
40
|
fields = {"x": "x", "y": "y", "class": "class"}
|
|
43
41
|
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
|
44
42
|
return wb.plot_table(
|
|
@@ -77,6 +75,8 @@ def _plot_curve(
|
|
|
77
75
|
Note:
|
|
78
76
|
The function leverages the '_custom_table' function to generate the actual visualization.
|
|
79
77
|
"""
|
|
78
|
+
import numpy as np
|
|
79
|
+
|
|
80
80
|
# Create new x
|
|
81
81
|
if names is None:
|
|
82
82
|
names = []
|
ultralytics/utils/checks.py
CHANGED
|
@@ -18,15 +18,21 @@ import cv2
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import requests
|
|
20
20
|
import torch
|
|
21
|
-
from matplotlib import font_manager
|
|
22
21
|
|
|
23
22
|
from ultralytics.utils import (
|
|
24
23
|
ASSETS,
|
|
25
24
|
AUTOINSTALL,
|
|
25
|
+
IS_COLAB,
|
|
26
|
+
IS_DOCKER,
|
|
27
|
+
IS_JUPYTER,
|
|
28
|
+
IS_KAGGLE,
|
|
29
|
+
IS_PIP_PACKAGE,
|
|
26
30
|
LINUX,
|
|
27
31
|
LOGGER,
|
|
28
32
|
ONLINE,
|
|
33
|
+
PYTHON_VERSION,
|
|
29
34
|
ROOT,
|
|
35
|
+
TORCHVISION_VERSION,
|
|
30
36
|
USER_CONFIG_DIR,
|
|
31
37
|
Retry,
|
|
32
38
|
SimpleNamespace,
|
|
@@ -36,18 +42,10 @@ from ultralytics.utils import (
|
|
|
36
42
|
colorstr,
|
|
37
43
|
downloads,
|
|
38
44
|
emojis,
|
|
39
|
-
is_colab,
|
|
40
|
-
is_docker,
|
|
41
45
|
is_github_action_running,
|
|
42
|
-
is_jupyter,
|
|
43
|
-
is_kaggle,
|
|
44
|
-
is_online,
|
|
45
|
-
is_pip_package,
|
|
46
46
|
url2file,
|
|
47
47
|
)
|
|
48
48
|
|
|
49
|
-
PYTHON_VERSION = platform.python_version()
|
|
50
|
-
|
|
51
49
|
|
|
52
50
|
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
|
53
51
|
"""
|
|
@@ -279,7 +277,7 @@ def check_pip_update_available():
|
|
|
279
277
|
Returns:
|
|
280
278
|
(bool): True if an update is available, False otherwise.
|
|
281
279
|
"""
|
|
282
|
-
if ONLINE and
|
|
280
|
+
if ONLINE and IS_PIP_PACKAGE:
|
|
283
281
|
with contextlib.suppress(Exception):
|
|
284
282
|
from ultralytics import __version__
|
|
285
283
|
|
|
@@ -304,9 +302,10 @@ def check_font(font="Arial.ttf"):
|
|
|
304
302
|
Returns:
|
|
305
303
|
file (Path): Resolved font file path.
|
|
306
304
|
"""
|
|
307
|
-
|
|
305
|
+
from matplotlib import font_manager
|
|
308
306
|
|
|
309
307
|
# Check USER_CONFIG_DIR
|
|
308
|
+
name = Path(font).name
|
|
310
309
|
file = USER_CONFIG_DIR / name
|
|
311
310
|
if file.exists():
|
|
312
311
|
return file
|
|
@@ -390,7 +389,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
|
|
|
390
389
|
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
|
391
390
|
try:
|
|
392
391
|
t = time.time()
|
|
393
|
-
assert
|
|
392
|
+
assert ONLINE, "AutoUpdate skipped (offline)"
|
|
394
393
|
with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay
|
|
395
394
|
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
|
|
396
395
|
dt = time.time() - t
|
|
@@ -419,14 +418,12 @@ def check_torchvision():
|
|
|
419
418
|
Torchvision versions.
|
|
420
419
|
"""
|
|
421
420
|
|
|
422
|
-
import torchvision
|
|
423
|
-
|
|
424
421
|
# Compatibility table
|
|
425
422
|
compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
|
|
426
423
|
|
|
427
424
|
# Extract only the major and minor versions
|
|
428
425
|
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
|
429
|
-
v_torchvision = ".".join(
|
|
426
|
+
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
|
|
430
427
|
|
|
431
428
|
if v_torch in compatibility_table:
|
|
432
429
|
compatible_versions = compatibility_table[v_torch]
|
|
@@ -531,7 +528,7 @@ def check_imshow(warn=False):
|
|
|
531
528
|
"""Check if environment supports image displays."""
|
|
532
529
|
try:
|
|
533
530
|
if LINUX:
|
|
534
|
-
assert "DISPLAY" in os.environ and not
|
|
531
|
+
assert "DISPLAY" in os.environ and not IS_DOCKER and not IS_COLAB and not IS_KAGGLE
|
|
535
532
|
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
|
|
536
533
|
cv2.waitKey(1)
|
|
537
534
|
cv2.destroyAllWindows()
|
|
@@ -549,10 +546,10 @@ def check_yolo(verbose=True, device=""):
|
|
|
549
546
|
|
|
550
547
|
from ultralytics.utils.torch_utils import select_device
|
|
551
548
|
|
|
552
|
-
if
|
|
549
|
+
if IS_JUPYTER:
|
|
553
550
|
if check_requirements("wandb", install=False):
|
|
554
551
|
os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
|
|
555
|
-
if
|
|
552
|
+
if IS_COLAB:
|
|
556
553
|
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
|
|
557
554
|
|
|
558
555
|
if verbose:
|
|
@@ -577,7 +574,7 @@ def collect_system_info():
|
|
|
577
574
|
|
|
578
575
|
import psutil
|
|
579
576
|
|
|
580
|
-
from ultralytics.utils import ENVIRONMENT,
|
|
577
|
+
from ultralytics.utils import ENVIRONMENT, IS_GIT_DIR
|
|
581
578
|
from ultralytics.utils.torch_utils import get_cpu_info
|
|
582
579
|
|
|
583
580
|
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
|
|
@@ -586,7 +583,7 @@ def collect_system_info():
|
|
|
586
583
|
f"\n{'OS':<20}{platform.platform()}\n"
|
|
587
584
|
f"{'Environment':<20}{ENVIRONMENT}\n"
|
|
588
585
|
f"{'Python':<20}{PYTHON_VERSION}\n"
|
|
589
|
-
f"{'Install':<20}{'git' if
|
|
586
|
+
f"{'Install':<20}{'git' if IS_GIT_DIR else 'pip' if IS_PIP_PACKAGE else 'other'}\n"
|
|
590
587
|
f"{'RAM':<20}{ram_info:.2f} GB\n"
|
|
591
588
|
f"{'CPU':<20}{get_cpu_info()}\n"
|
|
592
589
|
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
|
ultralytics/utils/metrics.py
CHANGED
|
@@ -395,19 +395,19 @@ class ConfusionMatrix:
|
|
|
395
395
|
names (tuple): Names of classes, used as labels on the plot.
|
|
396
396
|
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
|
397
397
|
"""
|
|
398
|
-
import seaborn
|
|
398
|
+
import seaborn # scope for faster 'import ultralytics'
|
|
399
399
|
|
|
400
400
|
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
|
401
401
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
|
402
402
|
|
|
403
403
|
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
|
404
404
|
nc, nn = self.nc, len(names) # number of classes, names
|
|
405
|
-
|
|
405
|
+
seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
|
406
406
|
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
|
407
407
|
ticklabels = (list(names) + ["background"]) if labels else "auto"
|
|
408
408
|
with warnings.catch_warnings():
|
|
409
409
|
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
|
410
|
-
|
|
410
|
+
seaborn.heatmap(
|
|
411
411
|
array,
|
|
412
412
|
ax=ax,
|
|
413
413
|
annot=nc < 30,
|
ultralytics/utils/ops.py
CHANGED
|
@@ -9,7 +9,6 @@ import cv2
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
11
11
|
import torch.nn.functional as F
|
|
12
|
-
import torchvision
|
|
13
12
|
|
|
14
13
|
from ultralytics.utils import LOGGER
|
|
15
14
|
from ultralytics.utils.metrics import batch_probiou
|
|
@@ -206,6 +205,7 @@ def non_max_suppression(
|
|
|
206
205
|
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
|
207
206
|
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
|
208
207
|
"""
|
|
208
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
209
209
|
|
|
210
210
|
# Checks
|
|
211
211
|
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|