dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/utils/patches.py
CHANGED
|
@@ -18,8 +18,7 @@ _imshow = cv2.imshow # copy to avoid recursion errors
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
21
|
-
"""
|
|
22
|
-
Read an image from a file with multilanguage filename support.
|
|
21
|
+
"""Read an image from a file with multilanguage filename support.
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
24
|
filename (str): Path to the file to read.
|
|
@@ -36,7 +35,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
|
36
35
|
if filename.endswith((".tiff", ".tif")):
|
|
37
36
|
success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
|
|
38
37
|
if success:
|
|
39
|
-
# Handle
|
|
38
|
+
# Handle multi-frame TIFFs and color images
|
|
40
39
|
return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
|
|
41
40
|
return None
|
|
42
41
|
else:
|
|
@@ -45,8 +44,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
|
|
48
|
-
"""
|
|
49
|
-
Write an image to a file with multilanguage filename support.
|
|
47
|
+
"""Write an image to a file with multilanguage filename support.
|
|
50
48
|
|
|
51
49
|
Args:
|
|
52
50
|
filename (str): Path to the file to write.
|
|
@@ -71,15 +69,14 @@ def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) ->
|
|
|
71
69
|
|
|
72
70
|
|
|
73
71
|
def imshow(winname: str, mat: np.ndarray) -> None:
|
|
74
|
-
"""
|
|
75
|
-
Display an image in the specified window with multilanguage window name support.
|
|
72
|
+
"""Display an image in the specified window with multilanguage window name support.
|
|
76
73
|
|
|
77
74
|
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
|
|
78
75
|
multilanguage window names by encoding them properly for OpenCV compatibility.
|
|
79
76
|
|
|
80
77
|
Args:
|
|
81
|
-
winname (str): Name of the window where the image will be displayed. If a window with this name already
|
|
82
|
-
|
|
78
|
+
winname (str): Name of the window where the image will be displayed. If a window with this name already exists,
|
|
79
|
+
the image will be displayed in that window.
|
|
83
80
|
mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
|
|
84
81
|
|
|
85
82
|
Examples:
|
|
@@ -96,8 +93,7 @@ _torch_save = torch.save
|
|
|
96
93
|
|
|
97
94
|
|
|
98
95
|
def torch_load(*args, **kwargs):
|
|
99
|
-
"""
|
|
100
|
-
Load a PyTorch model with updated arguments to avoid warnings.
|
|
96
|
+
"""Load a PyTorch model with updated arguments to avoid warnings.
|
|
101
97
|
|
|
102
98
|
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
|
|
103
99
|
|
|
@@ -109,8 +105,8 @@ def torch_load(*args, **kwargs):
|
|
|
109
105
|
(Any): The loaded PyTorch object.
|
|
110
106
|
|
|
111
107
|
Notes:
|
|
112
|
-
For PyTorch versions
|
|
113
|
-
|
|
108
|
+
For PyTorch versions 1.13 and above, this function automatically sets `weights_only=False` if the argument is
|
|
109
|
+
not provided, to avoid deprecation warnings.
|
|
114
110
|
"""
|
|
115
111
|
from ultralytics.utils.torch_utils import TORCH_1_13
|
|
116
112
|
|
|
@@ -121,11 +117,10 @@ def torch_load(*args, **kwargs):
|
|
|
121
117
|
|
|
122
118
|
|
|
123
119
|
def torch_save(*args, **kwargs):
|
|
124
|
-
"""
|
|
125
|
-
Save PyTorch objects with retry mechanism for robustness.
|
|
120
|
+
"""Save PyTorch objects with retry mechanism for robustness.
|
|
126
121
|
|
|
127
|
-
This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
|
|
128
|
-
|
|
122
|
+
This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur due
|
|
123
|
+
to device flushing delays or antivirus scanning.
|
|
129
124
|
|
|
130
125
|
Args:
|
|
131
126
|
*args (Any): Positional arguments to pass to torch.save.
|
|
@@ -146,8 +141,7 @@ def torch_save(*args, **kwargs):
|
|
|
146
141
|
|
|
147
142
|
@contextmanager
|
|
148
143
|
def arange_patch(args):
|
|
149
|
-
"""
|
|
150
|
-
Workaround for ONNX torch.arange incompatibility with FP16.
|
|
144
|
+
"""Workaround for ONNX torch.arange incompatibility with FP16.
|
|
151
145
|
|
|
152
146
|
https://github.com/pytorch/pytorch/issues/148041.
|
|
153
147
|
"""
|
|
@@ -165,10 +159,28 @@ def arange_patch(args):
|
|
|
165
159
|
yield
|
|
166
160
|
|
|
167
161
|
|
|
162
|
+
@contextmanager
|
|
163
|
+
def onnx_export_patch():
|
|
164
|
+
"""Workaround for ONNX export issues in PyTorch 2.9+ with Dynamo enabled."""
|
|
165
|
+
from ultralytics.utils.torch_utils import TORCH_2_9
|
|
166
|
+
|
|
167
|
+
if TORCH_2_9:
|
|
168
|
+
func = torch.onnx.export
|
|
169
|
+
|
|
170
|
+
def torch_export(*args, **kwargs):
|
|
171
|
+
"""Return a 1-D tensor of size with values from the interval and common difference."""
|
|
172
|
+
return func(*args, **kwargs, dynamo=False) # cast to dtype instead of passing dtype
|
|
173
|
+
|
|
174
|
+
torch.onnx.export = torch_export # patch
|
|
175
|
+
yield
|
|
176
|
+
torch.onnx.export = func # unpatch
|
|
177
|
+
else:
|
|
178
|
+
yield
|
|
179
|
+
|
|
180
|
+
|
|
168
181
|
@contextmanager
|
|
169
182
|
def override_configs(args, overrides: dict[str, Any] | None = None):
|
|
170
|
-
"""
|
|
171
|
-
Context manager to temporarily override configurations in args.
|
|
183
|
+
"""Context manager to temporarily override configurations in args.
|
|
172
184
|
|
|
173
185
|
Args:
|
|
174
186
|
args (IterableSimpleNamespace): Original configuration arguments.
|
ultralytics/utils/plotting.py
CHANGED
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import math
|
|
6
|
-
import
|
|
6
|
+
from collections.abc import Callable
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import cv2
|
|
11
11
|
import numpy as np
|
|
@@ -19,22 +19,10 @@ from ultralytics.utils.files import increment_path
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class Colors:
|
|
22
|
-
"""
|
|
23
|
-
Ultralytics color palette for visualization and plotting.
|
|
22
|
+
"""Ultralytics color palette for visualization and plotting.
|
|
24
23
|
|
|
25
|
-
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Attributes:
|
|
29
|
-
palette (list[tuple]): List of RGB color tuples for general use.
|
|
30
|
-
n (int): The number of colors in the palette.
|
|
31
|
-
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
|
32
|
-
|
|
33
|
-
Examples:
|
|
34
|
-
>>> from ultralytics.utils.plotting import Colors
|
|
35
|
-
>>> colors = Colors()
|
|
36
|
-
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
|
37
|
-
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
|
24
|
+
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
|
|
25
|
+
values and accessing predefined color schemes for object detection and pose estimation.
|
|
38
26
|
|
|
39
27
|
## Ultralytics Color Palette
|
|
40
28
|
|
|
@@ -90,6 +78,17 @@ class Colors:
|
|
|
90
78
|
|
|
91
79
|
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
|
92
80
|
Please use the official Ultralytics colors for all marketing materials.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
palette (list[tuple]): List of RGB color tuples for general use.
|
|
84
|
+
n (int): The number of colors in the palette.
|
|
85
|
+
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
|
86
|
+
|
|
87
|
+
Examples:
|
|
88
|
+
>>> from ultralytics.utils.plotting import Colors
|
|
89
|
+
>>> colors = Colors()
|
|
90
|
+
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
|
91
|
+
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
|
93
92
|
"""
|
|
94
93
|
|
|
95
94
|
def __init__(self):
|
|
@@ -145,8 +144,7 @@ class Colors:
|
|
|
145
144
|
)
|
|
146
145
|
|
|
147
146
|
def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
|
|
148
|
-
"""
|
|
149
|
-
Convert hex color codes to RGB values.
|
|
147
|
+
"""Convert hex color codes to RGB values.
|
|
150
148
|
|
|
151
149
|
Args:
|
|
152
150
|
i (int | torch.Tensor): Color index.
|
|
@@ -168,8 +166,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
|
|
168
166
|
|
|
169
167
|
|
|
170
168
|
class Annotator:
|
|
171
|
-
"""
|
|
172
|
-
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
169
|
+
"""Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
173
170
|
|
|
174
171
|
Attributes:
|
|
175
172
|
im (Image.Image | np.ndarray): The image to annotate.
|
|
@@ -206,10 +203,12 @@ class Annotator:
|
|
|
206
203
|
if not input_is_pil:
|
|
207
204
|
if im.shape[2] == 1: # handle grayscale
|
|
208
205
|
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
|
|
206
|
+
elif im.shape[2] == 2: # handle 2-channel images
|
|
207
|
+
im = np.ascontiguousarray(np.dstack((im, np.zeros_like(im[..., :1]))))
|
|
209
208
|
elif im.shape[2] > 3: # multispectral
|
|
210
209
|
im = np.ascontiguousarray(im[..., :3])
|
|
211
210
|
if self.pil: # use PIL
|
|
212
|
-
self.im = im if input_is_pil else Image.fromarray(im)
|
|
211
|
+
self.im = im if input_is_pil else Image.fromarray(im) # stay in BGR since color palette is in BGR
|
|
213
212
|
if self.im.mode not in {"RGB", "RGBA"}: # multispectral
|
|
214
213
|
self.im = self.im.convert("RGB")
|
|
215
214
|
self.draw = ImageDraw.Draw(self.im, "RGBA")
|
|
@@ -278,8 +277,7 @@ class Annotator:
|
|
|
278
277
|
}
|
|
279
278
|
|
|
280
279
|
def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
|
|
281
|
-
"""
|
|
282
|
-
Assign text color based on background color.
|
|
280
|
+
"""Assign text color based on background color.
|
|
283
281
|
|
|
284
282
|
Args:
|
|
285
283
|
color (tuple, optional): The background color of the rectangle for text (B, G, R).
|
|
@@ -302,8 +300,7 @@ class Annotator:
|
|
|
302
300
|
return txt_color
|
|
303
301
|
|
|
304
302
|
def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
|
|
305
|
-
"""
|
|
306
|
-
Draw a bounding box on an image with a given label.
|
|
303
|
+
"""Draw a bounding box on an image with a given label.
|
|
307
304
|
|
|
308
305
|
Args:
|
|
309
306
|
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
|
|
@@ -364,8 +361,7 @@ class Annotator:
|
|
|
364
361
|
)
|
|
365
362
|
|
|
366
363
|
def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
|
|
367
|
-
"""
|
|
368
|
-
Plot masks on image.
|
|
364
|
+
"""Plot masks on image.
|
|
369
365
|
|
|
370
366
|
Args:
|
|
371
367
|
masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
|
|
@@ -384,25 +380,32 @@ class Annotator:
|
|
|
384
380
|
overlay[mask.astype(bool)] = colors[i]
|
|
385
381
|
self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
|
|
386
382
|
else:
|
|
387
|
-
assert isinstance(masks, torch.Tensor), "
|
|
383
|
+
assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
|
|
388
384
|
if len(masks) == 0:
|
|
389
385
|
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
|
386
|
+
return
|
|
390
387
|
if im_gpu.device != masks.device:
|
|
391
388
|
im_gpu = im_gpu.to(masks.device)
|
|
389
|
+
|
|
390
|
+
ih, iw = self.im.shape[:2]
|
|
391
|
+
if not retina_masks:
|
|
392
|
+
# Use scale_masks to properly remove padding and upsample, convert bool to float first
|
|
393
|
+
masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
|
|
394
|
+
# Convert original BGR image to RGB tensor
|
|
395
|
+
im_gpu = (
|
|
396
|
+
torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
|
|
397
|
+
)
|
|
398
|
+
|
|
392
399
|
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
|
|
393
400
|
colors = colors[:, None, None] # shape(n,1,1,3)
|
|
394
401
|
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
|
395
402
|
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
|
396
|
-
|
|
397
403
|
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
|
398
|
-
mcs = masks_color.max(dim=0).values # shape(
|
|
404
|
+
mcs = masks_color.max(dim=0).values # shape(h,w,3)
|
|
399
405
|
|
|
400
|
-
im_gpu = im_gpu.flip(dims=[0]) #
|
|
401
|
-
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
|
406
|
+
im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
|
|
402
407
|
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
|
|
403
|
-
|
|
404
|
-
im_mask_np = im_mask.byte().cpu().numpy()
|
|
405
|
-
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
|
|
408
|
+
self.im[:] = (im_gpu * 255).byte().cpu().numpy()
|
|
406
409
|
if self.pil:
|
|
407
410
|
# Convert im back to PIL and update draw
|
|
408
411
|
self.fromarray(self.im)
|
|
@@ -416,8 +419,7 @@ class Annotator:
|
|
|
416
419
|
conf_thres: float = 0.25,
|
|
417
420
|
kpt_color: tuple | None = None,
|
|
418
421
|
):
|
|
419
|
-
"""
|
|
420
|
-
Plot keypoints on the image.
|
|
422
|
+
"""Plot keypoints on the image.
|
|
421
423
|
|
|
422
424
|
Args:
|
|
423
425
|
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
|
|
@@ -427,7 +429,7 @@ class Annotator:
|
|
|
427
429
|
conf_thres (float, optional): Confidence threshold.
|
|
428
430
|
kpt_color (tuple, optional): Keypoint color (B, G, R).
|
|
429
431
|
|
|
430
|
-
|
|
432
|
+
Notes:
|
|
431
433
|
- `kpt_line=True` currently only supports human pose plotting.
|
|
432
434
|
- Modifies self.im in-place.
|
|
433
435
|
- If self.pil is True, converts image to numpy array and back to PIL.
|
|
@@ -480,8 +482,7 @@ class Annotator:
|
|
|
480
482
|
self.draw.rectangle(xy, fill, outline, width)
|
|
481
483
|
|
|
482
484
|
def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
|
|
483
|
-
"""
|
|
484
|
-
Add text to an image using PIL or cv2.
|
|
485
|
+
"""Add text to an image using PIL or cv2.
|
|
485
486
|
|
|
486
487
|
Args:
|
|
487
488
|
xy (list[int]): Top-left coordinates for text placement.
|
|
@@ -511,18 +512,19 @@ class Annotator:
|
|
|
511
512
|
cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
|
|
512
513
|
|
|
513
514
|
def fromarray(self, im):
|
|
514
|
-
"""Update self.im from a
|
|
515
|
+
"""Update `self.im` from a NumPy array or PIL image."""
|
|
515
516
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
|
516
517
|
self.draw = ImageDraw.Draw(self.im)
|
|
517
518
|
|
|
518
|
-
def result(self):
|
|
519
|
-
"""Return annotated image as array."""
|
|
520
|
-
|
|
519
|
+
def result(self, pil=False):
|
|
520
|
+
"""Return annotated image as array or PIL image."""
|
|
521
|
+
im = np.asarray(self.im) # self.im is in BGR
|
|
522
|
+
return Image.fromarray(im[..., ::-1]) if pil else im
|
|
521
523
|
|
|
522
524
|
def show(self, title: str | None = None):
|
|
523
525
|
"""Show the annotated image."""
|
|
524
|
-
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert
|
|
525
|
-
if IS_COLAB or IS_KAGGLE: #
|
|
526
|
+
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert BGR NumPy array to RGB PIL Image
|
|
527
|
+
if IS_COLAB or IS_KAGGLE: # cannot use IS_JUPYTER as it runs for all IPython environments
|
|
526
528
|
try:
|
|
527
529
|
display(im) # noqa - display() function only available in ipython environments
|
|
528
530
|
except ImportError as e:
|
|
@@ -535,12 +537,11 @@ class Annotator:
|
|
|
535
537
|
cv2.imwrite(filename, np.asarray(self.im))
|
|
536
538
|
|
|
537
539
|
@staticmethod
|
|
538
|
-
def get_bbox_dimension(bbox: tuple |
|
|
539
|
-
"""
|
|
540
|
-
Calculate the dimensions and area of a bounding box.
|
|
540
|
+
def get_bbox_dimension(bbox: tuple | list):
|
|
541
|
+
"""Calculate the dimensions and area of a bounding box.
|
|
541
542
|
|
|
542
543
|
Args:
|
|
543
|
-
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
544
|
+
bbox (tuple | list): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
544
545
|
|
|
545
546
|
Returns:
|
|
546
547
|
width (float): Width of the bounding box.
|
|
@@ -562,8 +563,7 @@ class Annotator:
|
|
|
562
563
|
@TryExcept()
|
|
563
564
|
@plt_settings()
|
|
564
565
|
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
565
|
-
"""
|
|
566
|
-
Plot training labels including class histograms and box statistics.
|
|
566
|
+
"""Plot training labels including class histograms and box statistics.
|
|
567
567
|
|
|
568
568
|
Args:
|
|
569
569
|
boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
|
|
@@ -576,10 +576,6 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
576
576
|
import polars
|
|
577
577
|
from matplotlib.colors import LinearSegmentedColormap
|
|
578
578
|
|
|
579
|
-
# Filter matplotlib>=3.7.2 warning
|
|
580
|
-
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
|
581
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
582
|
-
|
|
583
579
|
# Plot dataset labels
|
|
584
580
|
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
|
585
581
|
nc = int(cls.max() + 1) # number of classes
|
|
@@ -601,8 +597,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
601
597
|
ax[0].set_xlabel("classes")
|
|
602
598
|
boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
|
|
603
599
|
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
|
|
604
|
-
for
|
|
605
|
-
ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(
|
|
600
|
+
for class_id, box in zip(cls[:500], boxes[:500]):
|
|
601
|
+
ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(class_id)) # plot
|
|
606
602
|
ax[1].imshow(img)
|
|
607
603
|
ax[1].axis("off")
|
|
608
604
|
|
|
@@ -633,12 +629,11 @@ def save_one_box(
|
|
|
633
629
|
BGR: bool = False,
|
|
634
630
|
save: bool = True,
|
|
635
631
|
):
|
|
636
|
-
"""
|
|
637
|
-
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
632
|
+
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
638
633
|
|
|
639
|
-
This function takes a bounding box and an image, and then saves a cropped portion of the image according
|
|
640
|
-
|
|
641
|
-
|
|
634
|
+
This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
|
|
635
|
+
bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
|
|
636
|
+
bounding box.
|
|
642
637
|
|
|
643
638
|
Args:
|
|
644
639
|
xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
|
|
@@ -691,11 +686,11 @@ def plot_images(
|
|
|
691
686
|
save: bool = True,
|
|
692
687
|
conf_thres: float = 0.25,
|
|
693
688
|
) -> np.ndarray | None:
|
|
694
|
-
"""
|
|
695
|
-
Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
689
|
+
"""Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
696
690
|
|
|
697
691
|
Args:
|
|
698
|
-
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
692
|
+
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
693
|
+
'keypoints', 'batch_idx', 'img'.
|
|
699
694
|
images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
|
|
700
695
|
paths (Optional[list[str]]): List of file paths for each image in the batch.
|
|
701
696
|
fname (str): Output filename for the plotted image grid.
|
|
@@ -709,9 +704,15 @@ def plot_images(
|
|
|
709
704
|
Returns:
|
|
710
705
|
(np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
|
|
711
706
|
|
|
712
|
-
|
|
707
|
+
Notes:
|
|
713
708
|
This function supports both tensor and numpy array inputs. It will automatically
|
|
714
709
|
convert tensor inputs to numpy arrays for processing.
|
|
710
|
+
|
|
711
|
+
Channel Support:
|
|
712
|
+
- 1 channel: Grayscale
|
|
713
|
+
- 2 channels: Third channel added as zeros
|
|
714
|
+
- 3 channels: Used as-is (standard RGB)
|
|
715
|
+
- 4+ channels: Cropped to first 3 channels
|
|
715
716
|
"""
|
|
716
717
|
for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}:
|
|
717
718
|
if k not in labels:
|
|
@@ -731,7 +732,13 @@ def plot_images(
|
|
|
731
732
|
|
|
732
733
|
if len(images) and isinstance(images, torch.Tensor):
|
|
733
734
|
images = images.cpu().float().numpy()
|
|
734
|
-
|
|
735
|
+
|
|
736
|
+
# Handle 2-ch and n-ch images
|
|
737
|
+
c = images.shape[1]
|
|
738
|
+
if c == 2:
|
|
739
|
+
zero = np.zeros_like(images[:, :1])
|
|
740
|
+
images = np.concatenate((images, zero), axis=1) # pad 2-ch with a black channel
|
|
741
|
+
elif c > 3:
|
|
735
742
|
images = images[:, :3] # crop multispectral images to first 3 channels
|
|
736
743
|
|
|
737
744
|
bs, _, h, w = images.shape # batch size, _, height, width
|
|
@@ -766,10 +773,10 @@ def plot_images(
|
|
|
766
773
|
idx = batch_idx == i
|
|
767
774
|
classes = cls[idx].astype("int")
|
|
768
775
|
labels = confs is None
|
|
776
|
+
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
|
769
777
|
|
|
770
778
|
if len(bboxes):
|
|
771
779
|
boxes = bboxes[idx]
|
|
772
|
-
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
|
773
780
|
if len(boxes):
|
|
774
781
|
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
|
775
782
|
boxes[..., [0, 2]] *= w # scale to pixels
|
|
@@ -779,7 +786,6 @@ def plot_images(
|
|
|
779
786
|
boxes[..., 0] += x
|
|
780
787
|
boxes[..., 1] += y
|
|
781
788
|
is_obb = boxes.shape[-1] == 5 # xywhr
|
|
782
|
-
# TODO: this transformation might be unnecessary
|
|
783
789
|
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
|
784
790
|
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
|
785
791
|
c = classes[j]
|
|
@@ -793,7 +799,8 @@ def plot_images(
|
|
|
793
799
|
for c in classes:
|
|
794
800
|
color = colors(c)
|
|
795
801
|
c = names.get(c, c) if names else c
|
|
796
|
-
|
|
802
|
+
label = f"{c}" if labels else f"{c} {conf[0]:.1f}"
|
|
803
|
+
annotator.text([x, y], label, txt_color=color, box_color=(64, 64, 64, 128))
|
|
797
804
|
|
|
798
805
|
# Plot keypoints
|
|
799
806
|
if len(kpts):
|
|
@@ -812,14 +819,13 @@ def plot_images(
|
|
|
812
819
|
|
|
813
820
|
# Plot masks
|
|
814
821
|
if len(masks):
|
|
815
|
-
if idx.shape[0] == masks.shape[0]: # overlap_mask=False
|
|
822
|
+
if idx.shape[0] == masks.shape[0] and masks.max() <= 1: # overlap_mask=False
|
|
816
823
|
image_masks = masks[idx]
|
|
817
824
|
else: # overlap_mask=True
|
|
818
825
|
image_masks = masks[[i]] # (1, 640, 640)
|
|
819
826
|
nl = idx.sum()
|
|
820
|
-
index = np.arange(nl).reshape((nl, 1, 1))
|
|
821
|
-
image_masks =
|
|
822
|
-
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
|
827
|
+
index = np.arange(1, nl + 1).reshape((nl, 1, 1))
|
|
828
|
+
image_masks = (image_masks == index).astype(np.float32)
|
|
823
829
|
|
|
824
830
|
im = np.asarray(annotator.im).copy()
|
|
825
831
|
for j in range(len(image_masks)):
|
|
@@ -847,24 +853,14 @@ def plot_images(
|
|
|
847
853
|
|
|
848
854
|
|
|
849
855
|
@plt_settings()
|
|
850
|
-
def plot_results(
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
pose: bool = False,
|
|
855
|
-
classify: bool = False,
|
|
856
|
-
on_plot: Callable | None = None,
|
|
857
|
-
):
|
|
858
|
-
"""
|
|
859
|
-
Plot training results from a results CSV file. The function supports various types of data including segmentation,
|
|
860
|
-
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
|
|
856
|
+
def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
|
|
857
|
+
"""Plot training results from a results CSV file. The function supports various types of data including
|
|
858
|
+
segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
|
|
859
|
+
CSV is located.
|
|
861
860
|
|
|
862
861
|
Args:
|
|
863
862
|
file (str, optional): Path to the CSV file containing the training results.
|
|
864
863
|
dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
|
|
865
|
-
segment (bool, optional): Flag to indicate if the data is for segmentation.
|
|
866
|
-
pose (bool, optional): Flag to indicate if the data is for pose estimation.
|
|
867
|
-
classify (bool, optional): Flag to indicate if the data is for classification.
|
|
868
864
|
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
|
|
869
865
|
|
|
870
866
|
Examples:
|
|
@@ -876,34 +872,31 @@ def plot_results(
|
|
|
876
872
|
from scipy.ndimage import gaussian_filter1d
|
|
877
873
|
|
|
878
874
|
save_dir = Path(file).parent if file else Path(dir)
|
|
879
|
-
if classify:
|
|
880
|
-
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
|
|
881
|
-
index = [2, 5, 3, 4]
|
|
882
|
-
elif segment:
|
|
883
|
-
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
|
884
|
-
index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
|
|
885
|
-
elif pose:
|
|
886
|
-
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
|
|
887
|
-
index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
|
|
888
|
-
else:
|
|
889
|
-
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
|
890
|
-
index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
|
|
891
|
-
ax = ax.ravel()
|
|
892
875
|
files = list(save_dir.glob("results*.csv"))
|
|
893
876
|
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
|
894
|
-
|
|
877
|
+
|
|
878
|
+
loss_keys, metric_keys = [], []
|
|
879
|
+
for i, f in enumerate(files):
|
|
895
880
|
try:
|
|
896
881
|
data = pl.read_csv(f, infer_schema_length=None)
|
|
897
|
-
|
|
882
|
+
if i == 0:
|
|
883
|
+
for c in data.columns:
|
|
884
|
+
if "loss" in c:
|
|
885
|
+
loss_keys.append(c)
|
|
886
|
+
elif "metric" in c:
|
|
887
|
+
metric_keys.append(c)
|
|
888
|
+
loss_mid, metric_mid = len(loss_keys) // 2, len(metric_keys) // 2
|
|
889
|
+
columns = (
|
|
890
|
+
loss_keys[:loss_mid] + metric_keys[:metric_mid] + loss_keys[loss_mid:] + metric_keys[metric_mid:]
|
|
891
|
+
)
|
|
892
|
+
fig, ax = plt.subplots(2, len(columns) // 2, figsize=(len(columns) + 2, 6), tight_layout=True)
|
|
893
|
+
ax = ax.ravel()
|
|
898
894
|
x = data.select(data.columns[0]).to_numpy().flatten()
|
|
899
|
-
for i, j in enumerate(
|
|
900
|
-
y = data.select(
|
|
901
|
-
# y[y == 0] = np.nan # don't show zero values
|
|
895
|
+
for i, j in enumerate(columns):
|
|
896
|
+
y = data.select(j).to_numpy().flatten().astype("float")
|
|
902
897
|
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
|
903
898
|
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
|
904
|
-
ax[i].set_title(
|
|
905
|
-
# if j in {8, 9, 10}: # share train and val loss y axes
|
|
906
|
-
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
|
899
|
+
ax[i].set_title(j, fontsize=12)
|
|
907
900
|
except Exception as e:
|
|
908
901
|
LOGGER.error(f"Plotting error for {f}: {e}")
|
|
909
902
|
ax[1].legend()
|
|
@@ -915,8 +908,7 @@ def plot_results(
|
|
|
915
908
|
|
|
916
909
|
|
|
917
910
|
def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
|
|
918
|
-
"""
|
|
919
|
-
Plot a scatter plot with points colored based on a 2D histogram.
|
|
911
|
+
"""Plot a scatter plot with points colored based on a 2D histogram.
|
|
920
912
|
|
|
921
913
|
Args:
|
|
922
914
|
v (array-like): Values for the x-axis.
|
|
@@ -948,13 +940,14 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
|
|
|
948
940
|
|
|
949
941
|
|
|
950
942
|
@plt_settings()
|
|
951
|
-
def plot_tune_results(csv_file: str = "tune_results.csv"):
|
|
952
|
-
"""
|
|
953
|
-
|
|
954
|
-
|
|
943
|
+
def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
|
|
944
|
+
"""Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
|
|
945
|
+
key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
|
|
946
|
+
the plots.
|
|
955
947
|
|
|
956
948
|
Args:
|
|
957
949
|
csv_file (str, optional): Path to the CSV file containing the tuning results.
|
|
950
|
+
exclude_zero_fitness_points (bool, optional): Don't include points with zero fitness in tuning plots.
|
|
958
951
|
|
|
959
952
|
Examples:
|
|
960
953
|
>>> plot_tune_results("path/to/tune_results.csv")
|
|
@@ -976,6 +969,17 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
|
|
|
976
969
|
keys = [x.strip() for x in data.columns][num_metrics_columns:]
|
|
977
970
|
x = data.to_numpy()
|
|
978
971
|
fitness = x[:, 0] # fitness
|
|
972
|
+
if exclude_zero_fitness_points:
|
|
973
|
+
mask = fitness > 0 # exclude zero-fitness points
|
|
974
|
+
x, fitness = x[mask], fitness[mask]
|
|
975
|
+
# Iterative sigma rejection on lower bound only
|
|
976
|
+
for _ in range(3): # max 3 iterations
|
|
977
|
+
mean, std = fitness.mean(), fitness.std()
|
|
978
|
+
lower_bound = mean - 3 * std
|
|
979
|
+
mask = fitness >= lower_bound
|
|
980
|
+
if mask.all(): # no more outliers
|
|
981
|
+
break
|
|
982
|
+
x, fitness = x[mask], fitness[mask]
|
|
979
983
|
j = np.argmax(fitness) # max fitness index
|
|
980
984
|
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
|
|
981
985
|
plt.figure(figsize=(10, 10), tight_layout=True)
|
|
@@ -1006,8 +1010,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
|
|
|
1006
1010
|
|
|
1007
1011
|
@plt_settings()
|
|
1008
1012
|
def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
|
|
1009
|
-
"""
|
|
1010
|
-
Visualize feature maps of a given model module during inference.
|
|
1013
|
+
"""Visualize feature maps of a given model module during inference.
|
|
1011
1014
|
|
|
1012
1015
|
Args:
|
|
1013
1016
|
x (torch.Tensor): Features to be visualized.
|