dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/utils/plotting.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import math
|
|
4
6
|
import warnings
|
|
7
|
+
from collections.abc import Callable
|
|
5
8
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
9
|
+
from typing import Any
|
|
7
10
|
|
|
8
11
|
import cv2
|
|
9
12
|
import numpy as np
|
|
@@ -17,21 +20,21 @@ from ultralytics.utils.files import increment_path
|
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
class Colors:
|
|
20
|
-
"""
|
|
21
|
-
Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
|
|
23
|
+
"""Ultralytics color palette for visualization and plotting.
|
|
22
24
|
|
|
23
|
-
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
|
|
24
|
-
|
|
25
|
+
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
|
|
26
|
+
values and accessing predefined color schemes for object detection and pose estimation.
|
|
25
27
|
|
|
26
28
|
Attributes:
|
|
27
|
-
palette (
|
|
29
|
+
palette (list[tuple]): List of RGB color tuples for general use.
|
|
28
30
|
n (int): The number of colors in the palette.
|
|
29
31
|
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
|
30
32
|
|
|
31
33
|
Examples:
|
|
32
34
|
>>> from ultralytics.utils.plotting import Colors
|
|
33
35
|
>>> colors = Colors()
|
|
34
|
-
>>> colors(5, True) #
|
|
36
|
+
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
|
37
|
+
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
|
35
38
|
|
|
36
39
|
## Ultralytics Color Palette
|
|
37
40
|
|
|
@@ -85,7 +88,8 @@ class Colors:
|
|
|
85
88
|
|
|
86
89
|
!!! note "Ultralytics Brand Colors"
|
|
87
90
|
|
|
88
|
-
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
|
91
|
+
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
|
92
|
+
Please use the official Ultralytics colors for all marketing materials.
|
|
89
93
|
"""
|
|
90
94
|
|
|
91
95
|
def __init__(self):
|
|
@@ -140,13 +144,21 @@ class Colors:
|
|
|
140
144
|
dtype=np.uint8,
|
|
141
145
|
)
|
|
142
146
|
|
|
143
|
-
def __call__(self, i, bgr=False):
|
|
144
|
-
"""Convert hex color codes to RGB values.
|
|
147
|
+
def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
|
|
148
|
+
"""Convert hex color codes to RGB values.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
i (int | torch.Tensor): Color index.
|
|
152
|
+
bgr (bool, optional): Whether to return BGR format instead of RGB.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
(tuple): RGB or BGR color tuple.
|
|
156
|
+
"""
|
|
145
157
|
c = self.palette[int(i) % self.n]
|
|
146
158
|
return (c[2], c[1], c[0]) if bgr else c
|
|
147
159
|
|
|
148
160
|
@staticmethod
|
|
149
|
-
def hex2rgb(h):
|
|
161
|
+
def hex2rgb(h: str) -> tuple:
|
|
150
162
|
"""Convert hex color codes to RGB values (i.e. default PIL order)."""
|
|
151
163
|
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
|
|
152
164
|
|
|
@@ -155,17 +167,16 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
|
|
155
167
|
|
|
156
168
|
|
|
157
169
|
class Annotator:
|
|
158
|
-
"""
|
|
159
|
-
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
170
|
+
"""Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
160
171
|
|
|
161
172
|
Attributes:
|
|
162
|
-
im (Image.Image
|
|
173
|
+
im (Image.Image | np.ndarray): The image to annotate.
|
|
163
174
|
pil (bool): Whether to use PIL or cv2 for drawing annotations.
|
|
164
|
-
font (ImageFont.truetype
|
|
175
|
+
font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.
|
|
165
176
|
lw (float): Line width for drawing.
|
|
166
|
-
skeleton (
|
|
167
|
-
limb_color (
|
|
168
|
-
kpt_color (
|
|
177
|
+
skeleton (list[list[int]]): Skeleton structure for keypoints.
|
|
178
|
+
limb_color (list[int]): Color palette for limbs.
|
|
179
|
+
kpt_color (list[int]): Color palette for keypoints.
|
|
169
180
|
dark_colors (set): Set of colors considered dark for text contrast.
|
|
170
181
|
light_colors (set): Set of colors considered light for text contrast.
|
|
171
182
|
|
|
@@ -173,14 +184,28 @@ class Annotator:
|
|
|
173
184
|
>>> from ultralytics.utils.plotting import Annotator
|
|
174
185
|
>>> im0 = cv2.imread("test.png")
|
|
175
186
|
>>> annotator = Annotator(im0, line_width=10)
|
|
187
|
+
>>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0))
|
|
176
188
|
"""
|
|
177
189
|
|
|
178
|
-
def __init__(
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
im,
|
|
193
|
+
line_width: int | None = None,
|
|
194
|
+
font_size: int | None = None,
|
|
195
|
+
font: str = "Arial.ttf",
|
|
196
|
+
pil: bool = False,
|
|
197
|
+
example: str = "abc",
|
|
198
|
+
):
|
|
179
199
|
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
|
180
200
|
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
|
181
201
|
input_is_pil = isinstance(im, Image.Image)
|
|
182
202
|
self.pil = pil or non_ascii or input_is_pil
|
|
183
203
|
self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
|
|
204
|
+
if not input_is_pil:
|
|
205
|
+
if im.shape[2] == 1: # handle grayscale
|
|
206
|
+
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
|
|
207
|
+
elif im.shape[2] > 3: # multispectral
|
|
208
|
+
im = np.ascontiguousarray(im[..., :3])
|
|
184
209
|
if self.pil: # use PIL
|
|
185
210
|
self.im = im if input_is_pil else Image.fromarray(im)
|
|
186
211
|
if self.im.mode not in {"RGB", "RGBA"}: # multispectral
|
|
@@ -196,10 +221,6 @@ class Annotator:
|
|
|
196
221
|
if check_version(pil_version, "9.2.0"):
|
|
197
222
|
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
|
|
198
223
|
else: # use cv2
|
|
199
|
-
if im.shape[2] == 1: # handle grayscale
|
|
200
|
-
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
|
|
201
|
-
elif im.shape[2] > 3: # multispectral
|
|
202
|
-
im = np.ascontiguousarray(im[..., :3])
|
|
203
224
|
assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
|
|
204
225
|
self.im = im if im.flags.writeable else im.copy()
|
|
205
226
|
self.tf = max(self.lw - 1, 1) # font thickness
|
|
@@ -254,9 +275,8 @@ class Annotator:
|
|
|
254
275
|
(104, 31, 17),
|
|
255
276
|
}
|
|
256
277
|
|
|
257
|
-
def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
|
|
258
|
-
"""
|
|
259
|
-
Assign text color based on background color.
|
|
278
|
+
def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
|
|
279
|
+
"""Assign text color based on background color.
|
|
260
280
|
|
|
261
281
|
Args:
|
|
262
282
|
color (tuple, optional): The background color of the rectangle for text (B, G, R).
|
|
@@ -278,16 +298,14 @@ class Annotator:
|
|
|
278
298
|
else:
|
|
279
299
|
return txt_color
|
|
280
300
|
|
|
281
|
-
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)
|
|
282
|
-
"""
|
|
283
|
-
Draw a bounding box on an image with a given label.
|
|
301
|
+
def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
|
|
302
|
+
"""Draw a bounding box on an image with a given label.
|
|
284
303
|
|
|
285
304
|
Args:
|
|
286
305
|
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
|
|
287
306
|
label (str, optional): The text label to be displayed.
|
|
288
307
|
color (tuple, optional): The background color of the rectangle (B, G, R).
|
|
289
308
|
txt_color (tuple, optional): The color of the text (R, G, B).
|
|
290
|
-
rotated (bool, optional): Whether the task is oriented bounding box detection.
|
|
291
309
|
|
|
292
310
|
Examples:
|
|
293
311
|
>>> from ultralytics.utils.plotting import Annotator
|
|
@@ -298,13 +316,13 @@ class Annotator:
|
|
|
298
316
|
txt_color = self.get_txt_color(color, txt_color)
|
|
299
317
|
if isinstance(box, torch.Tensor):
|
|
300
318
|
box = box.tolist()
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
319
|
+
|
|
320
|
+
multi_points = isinstance(box[0], list) # multiple points with shape (n, 2)
|
|
321
|
+
p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1]))
|
|
322
|
+
if self.pil:
|
|
323
|
+
self.draw.polygon(
|
|
324
|
+
[tuple(b) for b in box], width=self.lw, outline=color
|
|
325
|
+
) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color)
|
|
308
326
|
if label:
|
|
309
327
|
w, h = self.font.getsize(label) # text width, height
|
|
310
328
|
outside = p1[1] >= h # label fits outside box
|
|
@@ -317,12 +335,11 @@ class Annotator:
|
|
|
317
335
|
# self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
|
318
336
|
self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
|
|
319
337
|
else: # cv2
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
|
338
|
+
cv2.polylines(
|
|
339
|
+
self.im, [np.asarray(box, dtype=int)], True, color, self.lw
|
|
340
|
+
) if multi_points else cv2.rectangle(
|
|
341
|
+
self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA
|
|
342
|
+
)
|
|
326
343
|
if label:
|
|
327
344
|
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
|
|
328
345
|
h += 3 # add pixels to pad text
|
|
@@ -342,45 +359,66 @@ class Annotator:
|
|
|
342
359
|
lineType=cv2.LINE_AA,
|
|
343
360
|
)
|
|
344
361
|
|
|
345
|
-
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
|
346
|
-
"""
|
|
347
|
-
Plot masks on image.
|
|
362
|
+
def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
|
|
363
|
+
"""Plot masks on image.
|
|
348
364
|
|
|
349
365
|
Args:
|
|
350
|
-
masks (torch.Tensor): Predicted masks
|
|
351
|
-
colors (
|
|
352
|
-
im_gpu (torch.Tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
|
|
366
|
+
masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
|
|
367
|
+
colors (list[list[int]]): Colors for predicted masks, [[r, g, b] * n]
|
|
368
|
+
im_gpu (torch.Tensor | None): Image is in cuda, shape: [3, h, w], range: [0, 1]
|
|
353
369
|
alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque.
|
|
354
370
|
retina_masks (bool, optional): Whether to use high resolution masks or not.
|
|
355
371
|
"""
|
|
356
372
|
if self.pil:
|
|
357
373
|
# Convert to numpy first
|
|
358
374
|
self.im = np.asarray(self.im).copy()
|
|
359
|
-
if
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
375
|
+
if im_gpu is None:
|
|
376
|
+
assert isinstance(masks, np.ndarray), "`masks` must be a np.ndarray if `im_gpu` is not provided."
|
|
377
|
+
overlay = self.im.copy()
|
|
378
|
+
for i, mask in enumerate(masks):
|
|
379
|
+
overlay[mask.astype(bool)] = colors[i]
|
|
380
|
+
self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
|
|
381
|
+
else:
|
|
382
|
+
assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
|
|
383
|
+
if len(masks) == 0:
|
|
384
|
+
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
|
385
|
+
return
|
|
386
|
+
if im_gpu.device != masks.device:
|
|
387
|
+
im_gpu = im_gpu.to(masks.device)
|
|
388
|
+
|
|
389
|
+
ih, iw = self.im.shape[:2]
|
|
390
|
+
if not retina_masks:
|
|
391
|
+
# Use scale_masks to properly remove padding and upsample, convert bool to float first
|
|
392
|
+
masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
|
|
393
|
+
# Convert original BGR image to RGB tensor
|
|
394
|
+
im_gpu = (
|
|
395
|
+
torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
|
|
399
|
+
colors = colors[:, None, None] # shape(n,1,1,3)
|
|
400
|
+
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
|
401
|
+
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
|
402
|
+
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
|
403
|
+
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
|
|
404
|
+
|
|
405
|
+
im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
|
|
406
|
+
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
|
|
407
|
+
self.im[:] = (im_gpu * 255).byte().cpu().numpy()
|
|
377
408
|
if self.pil:
|
|
378
409
|
# Convert im back to PIL and update draw
|
|
379
410
|
self.fromarray(self.im)
|
|
380
411
|
|
|
381
|
-
def kpts(
|
|
382
|
-
|
|
383
|
-
|
|
412
|
+
def kpts(
|
|
413
|
+
self,
|
|
414
|
+
kpts,
|
|
415
|
+
shape: tuple = (640, 640),
|
|
416
|
+
radius: int | None = None,
|
|
417
|
+
kpt_line: bool = True,
|
|
418
|
+
conf_thres: float = 0.25,
|
|
419
|
+
kpt_color: tuple | None = None,
|
|
420
|
+
):
|
|
421
|
+
"""Plot keypoints on the image.
|
|
384
422
|
|
|
385
423
|
Args:
|
|
386
424
|
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
|
|
@@ -390,7 +428,7 @@ class Annotator:
|
|
|
390
428
|
conf_thres (float, optional): Confidence threshold.
|
|
391
429
|
kpt_color (tuple, optional): Keypoint color (B, G, R).
|
|
392
430
|
|
|
393
|
-
|
|
431
|
+
Notes:
|
|
394
432
|
- `kpt_line=True` currently only supports human pose plotting.
|
|
395
433
|
- Modifies self.im in-place.
|
|
396
434
|
- If self.pil is True, converts image to numpy array and back to PIL.
|
|
@@ -438,16 +476,15 @@ class Annotator:
|
|
|
438
476
|
# Convert im back to PIL and update draw
|
|
439
477
|
self.fromarray(self.im)
|
|
440
478
|
|
|
441
|
-
def rectangle(self, xy, fill=None, outline=None, width=1):
|
|
479
|
+
def rectangle(self, xy, fill=None, outline=None, width: int = 1):
|
|
442
480
|
"""Add rectangle to image (PIL-only)."""
|
|
443
481
|
self.draw.rectangle(xy, fill, outline, width)
|
|
444
482
|
|
|
445
|
-
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_color=()):
|
|
446
|
-
"""
|
|
447
|
-
Add text to an image using PIL or cv2.
|
|
483
|
+
def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
|
|
484
|
+
"""Add text to an image using PIL or cv2.
|
|
448
485
|
|
|
449
486
|
Args:
|
|
450
|
-
xy (
|
|
487
|
+
xy (list[int]): Top-left coordinates for text placement.
|
|
451
488
|
text (str): Text to be drawn.
|
|
452
489
|
txt_color (tuple, optional): Text color (R, G, B).
|
|
453
490
|
anchor (str, optional): Text anchor position ('top' or 'bottom').
|
|
@@ -482,7 +519,7 @@ class Annotator:
|
|
|
482
519
|
"""Return annotated image as array."""
|
|
483
520
|
return np.asarray(self.im)
|
|
484
521
|
|
|
485
|
-
def show(self, title=None):
|
|
522
|
+
def show(self, title: str | None = None):
|
|
486
523
|
"""Show the annotated image."""
|
|
487
524
|
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
|
|
488
525
|
if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
|
|
@@ -493,14 +530,13 @@ class Annotator:
|
|
|
493
530
|
else:
|
|
494
531
|
im.show(title=title)
|
|
495
532
|
|
|
496
|
-
def save(self, filename="image.jpg"):
|
|
533
|
+
def save(self, filename: str = "image.jpg"):
|
|
497
534
|
"""Save the annotated image to 'filename'."""
|
|
498
535
|
cv2.imwrite(filename, np.asarray(self.im))
|
|
499
536
|
|
|
500
537
|
@staticmethod
|
|
501
|
-
def get_bbox_dimension(bbox=None):
|
|
502
|
-
"""
|
|
503
|
-
Calculate the dimensions and area of a bounding box.
|
|
538
|
+
def get_bbox_dimension(bbox: tuple | None = None):
|
|
539
|
+
"""Calculate the dimensions and area of a bounding box.
|
|
504
540
|
|
|
505
541
|
Args:
|
|
506
542
|
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
@@ -522,11 +558,10 @@ class Annotator:
|
|
|
522
558
|
return width, height, width * height
|
|
523
559
|
|
|
524
560
|
|
|
525
|
-
@TryExcept()
|
|
561
|
+
@TryExcept()
|
|
526
562
|
@plt_settings()
|
|
527
563
|
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
528
|
-
"""
|
|
529
|
-
Plot training labels including class histograms and box statistics.
|
|
564
|
+
"""Plot training labels including class histograms and box statistics.
|
|
530
565
|
|
|
531
566
|
Args:
|
|
532
567
|
boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
|
|
@@ -536,7 +571,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
536
571
|
on_plot (Callable, optional): Function to call after plot is saved.
|
|
537
572
|
"""
|
|
538
573
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
|
539
|
-
import
|
|
574
|
+
import polars
|
|
540
575
|
from matplotlib.colors import LinearSegmentedColormap
|
|
541
576
|
|
|
542
577
|
# Filter matplotlib>=3.7.2 warning
|
|
@@ -547,16 +582,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
547
582
|
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
|
548
583
|
nc = int(cls.max() + 1) # number of classes
|
|
549
584
|
boxes = boxes[:1000000] # limit to 1M boxes
|
|
550
|
-
x =
|
|
551
|
-
|
|
552
|
-
try: # Seaborn correlogram
|
|
553
|
-
import seaborn
|
|
554
|
-
|
|
555
|
-
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
|
556
|
-
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
|
|
557
|
-
plt.close()
|
|
558
|
-
except ImportError:
|
|
559
|
-
pass # Skip if seaborn is not installed
|
|
585
|
+
x = polars.DataFrame(boxes, schema=["x", "y", "width", "height"])
|
|
560
586
|
|
|
561
587
|
# Matplotlib labels
|
|
562
588
|
subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"])
|
|
@@ -568,12 +594,13 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
568
594
|
if 0 < len(names) < 30:
|
|
569
595
|
ax[0].set_xticks(range(len(names)))
|
|
570
596
|
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
|
597
|
+
ax[0].bar_label(y[2])
|
|
571
598
|
else:
|
|
572
599
|
ax[0].set_xlabel("classes")
|
|
573
600
|
boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
|
|
574
601
|
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
|
|
575
602
|
for cls, box in zip(cls[:500], boxes[:500]):
|
|
576
|
-
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
|
603
|
+
ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(cls)) # plot
|
|
577
604
|
ax[1].imshow(img)
|
|
578
605
|
ax[1].axis("off")
|
|
579
606
|
|
|
@@ -583,8 +610,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
583
610
|
ax[3].hist2d(x["width"], x["height"], bins=50, cmap=subplot_3_4_color)
|
|
584
611
|
ax[3].set_xlabel("width")
|
|
585
612
|
ax[3].set_ylabel("height")
|
|
586
|
-
for a in
|
|
587
|
-
for s in
|
|
613
|
+
for a in {0, 1, 2, 3}:
|
|
614
|
+
for s in {"top", "right", "left", "bottom"}:
|
|
588
615
|
ax[a].spines[s].set_visible(False)
|
|
589
616
|
|
|
590
617
|
fname = save_dir / "labels.jpg"
|
|
@@ -594,13 +621,21 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
594
621
|
on_plot(fname)
|
|
595
622
|
|
|
596
623
|
|
|
597
|
-
def save_one_box(
|
|
598
|
-
|
|
599
|
-
|
|
624
|
+
def save_one_box(
|
|
625
|
+
xyxy,
|
|
626
|
+
im,
|
|
627
|
+
file: Path = Path("im.jpg"),
|
|
628
|
+
gain: float = 1.02,
|
|
629
|
+
pad: int = 10,
|
|
630
|
+
square: bool = False,
|
|
631
|
+
BGR: bool = False,
|
|
632
|
+
save: bool = True,
|
|
633
|
+
):
|
|
634
|
+
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
600
635
|
|
|
601
|
-
This function takes a bounding box and an image, and then saves a cropped portion of the image according
|
|
602
|
-
|
|
603
|
-
|
|
636
|
+
This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
|
|
637
|
+
bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
|
|
638
|
+
bounding box.
|
|
604
639
|
|
|
605
640
|
Args:
|
|
606
641
|
xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
|
|
@@ -609,7 +644,7 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
|
|
|
609
644
|
gain (float, optional): A multiplicative factor to increase the size of the bounding box.
|
|
610
645
|
pad (int, optional): The number of pixels to add to the width and height of the bounding box.
|
|
611
646
|
square (bool, optional): If True, the bounding box will be transformed into a square.
|
|
612
|
-
BGR (bool, optional): If True, the image will be
|
|
647
|
+
BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.
|
|
613
648
|
save (bool, optional): If True, the cropped image will be saved to disk.
|
|
614
649
|
|
|
615
650
|
Returns:
|
|
@@ -629,73 +664,83 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
|
|
|
629
664
|
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
|
630
665
|
xyxy = ops.xywh2xyxy(b).long()
|
|
631
666
|
xyxy = ops.clip_boxes(xyxy, im.shape)
|
|
632
|
-
|
|
667
|
+
grayscale = im.shape[2] == 1 # grayscale image
|
|
668
|
+
crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)]
|
|
633
669
|
if save:
|
|
634
670
|
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
|
635
671
|
f = str(increment_path(file).with_suffix(".jpg"))
|
|
636
672
|
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
|
637
|
-
|
|
673
|
+
crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop
|
|
674
|
+
Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB
|
|
638
675
|
return crop
|
|
639
676
|
|
|
640
677
|
|
|
641
678
|
@threaded
|
|
642
679
|
def plot_images(
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
|
|
647
|
-
confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
|
|
648
|
-
masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
|
|
649
|
-
kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
|
|
650
|
-
paths: Optional[List[str]] = None,
|
|
680
|
+
labels: dict[str, Any],
|
|
681
|
+
images: torch.Tensor | np.ndarray = np.zeros((0, 3, 640, 640), dtype=np.float32),
|
|
682
|
+
paths: list[str] | None = None,
|
|
651
683
|
fname: str = "images.jpg",
|
|
652
|
-
names:
|
|
653
|
-
on_plot:
|
|
684
|
+
names: dict[int, str] | None = None,
|
|
685
|
+
on_plot: Callable | None = None,
|
|
654
686
|
max_size: int = 1920,
|
|
655
687
|
max_subplots: int = 16,
|
|
656
688
|
save: bool = True,
|
|
657
689
|
conf_thres: float = 0.25,
|
|
658
|
-
) ->
|
|
659
|
-
"""
|
|
660
|
-
Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
690
|
+
) -> np.ndarray | None:
|
|
691
|
+
"""Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
661
692
|
|
|
662
693
|
Args:
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
max_size: Maximum size of the output image grid.
|
|
675
|
-
max_subplots: Maximum number of subplots in the image grid.
|
|
676
|
-
save: Whether to save the plotted image grid to a file.
|
|
677
|
-
conf_thres: Confidence threshold for displaying detections.
|
|
694
|
+
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
695
|
+
'keypoints', 'batch_idx', 'img'.
|
|
696
|
+
images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
|
|
697
|
+
paths (Optional[list[str]]): List of file paths for each image in the batch.
|
|
698
|
+
fname (str): Output filename for the plotted image grid.
|
|
699
|
+
names (Optional[dict[int, str]]): Dictionary mapping class indices to class names.
|
|
700
|
+
on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.
|
|
701
|
+
max_size (int): Maximum size of the output image grid.
|
|
702
|
+
max_subplots (int): Maximum number of subplots in the image grid.
|
|
703
|
+
save (bool): Whether to save the plotted image grid to a file.
|
|
704
|
+
conf_thres (float): Confidence threshold for displaying detections.
|
|
678
705
|
|
|
679
706
|
Returns:
|
|
680
707
|
(np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
|
|
681
708
|
|
|
682
|
-
|
|
709
|
+
Notes:
|
|
683
710
|
This function supports both tensor and numpy array inputs. It will automatically
|
|
684
711
|
convert tensor inputs to numpy arrays for processing.
|
|
712
|
+
|
|
713
|
+
Channel Support:
|
|
714
|
+
- 1 channel: Grayscale
|
|
715
|
+
- 2 channels: Third channel added as zeros
|
|
716
|
+
- 3 channels: Used as-is (standard RGB)
|
|
717
|
+
- 4+ channels: Cropped to first 3 channels
|
|
685
718
|
"""
|
|
686
|
-
|
|
719
|
+
for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}:
|
|
720
|
+
if k not in labels:
|
|
721
|
+
continue
|
|
722
|
+
if k == "cls" and labels[k].ndim == 2:
|
|
723
|
+
labels[k] = labels[k].squeeze(1) # squeeze if shape is (n, 1)
|
|
724
|
+
if isinstance(labels[k], torch.Tensor):
|
|
725
|
+
labels[k] = labels[k].cpu().numpy()
|
|
726
|
+
|
|
727
|
+
cls = labels.get("cls", np.zeros(0, dtype=np.int64))
|
|
728
|
+
batch_idx = labels.get("batch_idx", np.zeros(cls.shape, dtype=np.int64))
|
|
729
|
+
bboxes = labels.get("bboxes", np.zeros(0, dtype=np.float32))
|
|
730
|
+
confs = labels.get("conf", None)
|
|
731
|
+
masks = labels.get("masks", np.zeros(0, dtype=np.uint8))
|
|
732
|
+
kpts = labels.get("keypoints", np.zeros(0, dtype=np.float32))
|
|
733
|
+
images = labels.get("img", images) # default to input images
|
|
734
|
+
|
|
735
|
+
if len(images) and isinstance(images, torch.Tensor):
|
|
687
736
|
images = images.cpu().float().numpy()
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
kpts = kpts.cpu().numpy()
|
|
696
|
-
if isinstance(batch_idx, torch.Tensor):
|
|
697
|
-
batch_idx = batch_idx.cpu().numpy()
|
|
698
|
-
if images.shape[1] > 3:
|
|
737
|
+
|
|
738
|
+
# Handle 2-ch and n-ch images
|
|
739
|
+
c = images.shape[1]
|
|
740
|
+
if c == 2:
|
|
741
|
+
zero = np.zeros_like(images[:, :1])
|
|
742
|
+
images = np.concatenate((images, zero), axis=1) # pad 2-ch with a black channel
|
|
743
|
+
elif c > 3:
|
|
699
744
|
images = images[:, :3] # crop multispectral images to first 3 channels
|
|
700
745
|
|
|
701
746
|
bs, _, h, w = images.shape # batch size, _, height, width
|
|
@@ -730,10 +775,10 @@ def plot_images(
|
|
|
730
775
|
idx = batch_idx == i
|
|
731
776
|
classes = cls[idx].astype("int")
|
|
732
777
|
labels = confs is None
|
|
778
|
+
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
|
733
779
|
|
|
734
780
|
if len(bboxes):
|
|
735
781
|
boxes = bboxes[idx]
|
|
736
|
-
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
|
737
782
|
if len(boxes):
|
|
738
783
|
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
|
739
784
|
boxes[..., [0, 2]] *= w # scale to pixels
|
|
@@ -743,6 +788,7 @@ def plot_images(
|
|
|
743
788
|
boxes[..., 0] += x
|
|
744
789
|
boxes[..., 1] += y
|
|
745
790
|
is_obb = boxes.shape[-1] == 5 # xywhr
|
|
791
|
+
# TODO: this transformation might be unnecessary
|
|
746
792
|
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
|
747
793
|
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
|
748
794
|
c = classes[j]
|
|
@@ -750,13 +796,14 @@ def plot_images(
|
|
|
750
796
|
c = names.get(c, c) if names else c
|
|
751
797
|
if labels or conf[j] > conf_thres:
|
|
752
798
|
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
|
753
|
-
annotator.box_label(box, label, color=color
|
|
799
|
+
annotator.box_label(box, label, color=color)
|
|
754
800
|
|
|
755
801
|
elif len(classes):
|
|
756
802
|
for c in classes:
|
|
757
803
|
color = colors(c)
|
|
758
804
|
c = names.get(c, c) if names else c
|
|
759
|
-
|
|
805
|
+
label = f"{c}" if labels else f"{c} {conf[0]:.1f}"
|
|
806
|
+
annotator.text([x, y], label, txt_color=color, box_color=(64, 64, 64, 128))
|
|
760
807
|
|
|
761
808
|
# Plot keypoints
|
|
762
809
|
if len(kpts):
|
|
@@ -775,14 +822,13 @@ def plot_images(
|
|
|
775
822
|
|
|
776
823
|
# Plot masks
|
|
777
824
|
if len(masks):
|
|
778
|
-
if idx.shape[0] == masks.shape[0]: #
|
|
825
|
+
if idx.shape[0] == masks.shape[0] and masks.max() <= 1: # overlap_mask=False
|
|
779
826
|
image_masks = masks[idx]
|
|
780
|
-
else: #
|
|
827
|
+
else: # overlap_mask=True
|
|
781
828
|
image_masks = masks[[i]] # (1, 640, 640)
|
|
782
829
|
nl = idx.sum()
|
|
783
|
-
index = np.arange(nl).reshape((nl, 1, 1))
|
|
784
|
-
image_masks =
|
|
785
|
-
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
|
830
|
+
index = np.arange(1, nl + 1).reshape((nl, 1, 1))
|
|
831
|
+
image_masks = (image_masks == index).astype(np.float32)
|
|
786
832
|
|
|
787
833
|
im = np.asarray(annotator.im).copy()
|
|
788
834
|
for j in range(len(image_masks)):
|
|
@@ -810,17 +856,14 @@ def plot_images(
|
|
|
810
856
|
|
|
811
857
|
|
|
812
858
|
@plt_settings()
|
|
813
|
-
def plot_results(file="path/to/results.csv", dir="",
|
|
814
|
-
"""
|
|
815
|
-
|
|
816
|
-
|
|
859
|
+
def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
|
|
860
|
+
"""Plot training results from a results CSV file. The function supports various types of data including
|
|
861
|
+
segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
|
|
862
|
+
CSV is located.
|
|
817
863
|
|
|
818
864
|
Args:
|
|
819
865
|
file (str, optional): Path to the CSV file containing the training results.
|
|
820
866
|
dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
|
|
821
|
-
segment (bool, optional): Flag to indicate if the data is for segmentation.
|
|
822
|
-
pose (bool, optional): Flag to indicate if the data is for pose estimation.
|
|
823
|
-
classify (bool, optional): Flag to indicate if the data is for classification.
|
|
824
867
|
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
|
|
825
868
|
|
|
826
869
|
Examples:
|
|
@@ -828,38 +871,35 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|
|
828
871
|
>>> plot_results("path/to/results.csv", segment=True)
|
|
829
872
|
"""
|
|
830
873
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
|
831
|
-
import
|
|
874
|
+
import polars as pl
|
|
832
875
|
from scipy.ndimage import gaussian_filter1d
|
|
833
876
|
|
|
834
877
|
save_dir = Path(file).parent if file else Path(dir)
|
|
835
|
-
if classify:
|
|
836
|
-
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
|
|
837
|
-
index = [2, 5, 3, 4]
|
|
838
|
-
elif segment:
|
|
839
|
-
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
|
840
|
-
index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
|
|
841
|
-
elif pose:
|
|
842
|
-
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
|
|
843
|
-
index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
|
|
844
|
-
else:
|
|
845
|
-
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
|
846
|
-
index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
|
|
847
|
-
ax = ax.ravel()
|
|
848
878
|
files = list(save_dir.glob("results*.csv"))
|
|
849
879
|
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
|
850
|
-
|
|
880
|
+
|
|
881
|
+
loss_keys, metric_keys = [], []
|
|
882
|
+
for i, f in enumerate(files):
|
|
851
883
|
try:
|
|
852
|
-
data =
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
884
|
+
data = pl.read_csv(f, infer_schema_length=None)
|
|
885
|
+
if i == 0:
|
|
886
|
+
for c in data.columns:
|
|
887
|
+
if "loss" in c:
|
|
888
|
+
loss_keys.append(c)
|
|
889
|
+
elif "metric" in c:
|
|
890
|
+
metric_keys.append(c)
|
|
891
|
+
loss_mid, metric_mid = len(loss_keys) // 2, len(metric_keys) // 2
|
|
892
|
+
columns = (
|
|
893
|
+
loss_keys[:loss_mid] + metric_keys[:metric_mid] + loss_keys[loss_mid:] + metric_keys[metric_mid:]
|
|
894
|
+
)
|
|
895
|
+
fig, ax = plt.subplots(2, len(columns) // 2, figsize=(len(columns) + 2, 6), tight_layout=True)
|
|
896
|
+
ax = ax.ravel()
|
|
897
|
+
x = data.select(data.columns[0]).to_numpy().flatten()
|
|
898
|
+
for i, j in enumerate(columns):
|
|
899
|
+
y = data.select(j).to_numpy().flatten().astype("float")
|
|
858
900
|
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
|
859
901
|
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
|
860
|
-
ax[i].set_title(
|
|
861
|
-
# if j in {8, 9, 10}: # share train and val loss y axes
|
|
862
|
-
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
|
902
|
+
ax[i].set_title(j, fontsize=12)
|
|
863
903
|
except Exception as e:
|
|
864
904
|
LOGGER.error(f"Plotting error for {f}: {e}")
|
|
865
905
|
ax[1].legend()
|
|
@@ -870,9 +910,8 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|
|
870
910
|
on_plot(fname)
|
|
871
911
|
|
|
872
912
|
|
|
873
|
-
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
|
|
874
|
-
"""
|
|
875
|
-
Plot a scatter plot with points colored based on a 2D histogram.
|
|
913
|
+
def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
|
|
914
|
+
"""Plot a scatter plot with points colored based on a 2D histogram.
|
|
876
915
|
|
|
877
916
|
Args:
|
|
878
917
|
v (array-like): Values for the x-axis.
|
|
@@ -903,19 +942,21 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
|
|
|
903
942
|
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
|
|
904
943
|
|
|
905
944
|
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
|
|
909
|
-
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
|
|
945
|
+
@plt_settings()
|
|
946
|
+
def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
|
|
947
|
+
"""Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
|
|
948
|
+
key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
|
|
949
|
+
the plots.
|
|
910
950
|
|
|
911
951
|
Args:
|
|
912
952
|
csv_file (str, optional): Path to the CSV file containing the tuning results.
|
|
953
|
+
exclude_zero_fitness_points (bool, optional): Don't include points with zero fitness in tuning plots.
|
|
913
954
|
|
|
914
955
|
Examples:
|
|
915
956
|
>>> plot_tune_results("path/to/tune_results.csv")
|
|
916
957
|
"""
|
|
917
958
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
|
918
|
-
import
|
|
959
|
+
import polars as pl
|
|
919
960
|
from scipy.ndimage import gaussian_filter1d
|
|
920
961
|
|
|
921
962
|
def _save_one_file(file):
|
|
@@ -926,11 +967,22 @@ def plot_tune_results(csv_file="tune_results.csv"):
|
|
|
926
967
|
|
|
927
968
|
# Scatter plots for each hyperparameter
|
|
928
969
|
csv_file = Path(csv_file)
|
|
929
|
-
data =
|
|
970
|
+
data = pl.read_csv(csv_file, infer_schema_length=None)
|
|
930
971
|
num_metrics_columns = 1
|
|
931
972
|
keys = [x.strip() for x in data.columns][num_metrics_columns:]
|
|
932
|
-
x = data.
|
|
973
|
+
x = data.to_numpy()
|
|
933
974
|
fitness = x[:, 0] # fitness
|
|
975
|
+
if exclude_zero_fitness_points:
|
|
976
|
+
mask = fitness > 0 # exclude zero-fitness points
|
|
977
|
+
x, fitness = x[mask], fitness[mask]
|
|
978
|
+
# Iterative sigma rejection on lower bound only
|
|
979
|
+
for _ in range(3): # max 3 iterations
|
|
980
|
+
mean, std = fitness.mean(), fitness.std()
|
|
981
|
+
lower_bound = mean - 3 * std
|
|
982
|
+
mask = fitness >= lower_bound
|
|
983
|
+
if mask.all(): # no more outliers
|
|
984
|
+
break
|
|
985
|
+
x, fitness = x[mask], fitness[mask]
|
|
934
986
|
j = np.argmax(fitness) # max fitness index
|
|
935
987
|
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
|
|
936
988
|
plt.figure(figsize=(10, 10), tight_layout=True)
|
|
@@ -959,31 +1011,9 @@ def plot_tune_results(csv_file="tune_results.csv"):
|
|
|
959
1011
|
_save_one_file(csv_file.with_name("tune_fitness.png"))
|
|
960
1012
|
|
|
961
1013
|
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
for i, o in enumerate(output):
|
|
966
|
-
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
|
967
|
-
j = torch.full((conf.shape[0], 1), i)
|
|
968
|
-
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
|
|
969
|
-
targets = torch.cat(targets, 0).numpy()
|
|
970
|
-
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
def output_to_rotated_target(output, max_det=300):
|
|
974
|
-
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
|
975
|
-
targets = []
|
|
976
|
-
for i, o in enumerate(output):
|
|
977
|
-
box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
|
|
978
|
-
j = torch.full((conf.shape[0], 1), i)
|
|
979
|
-
targets.append(torch.cat((j, cls, box, angle, conf), 1))
|
|
980
|
-
targets = torch.cat(targets, 0).numpy()
|
|
981
|
-
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
|
|
985
|
-
"""
|
|
986
|
-
Visualize feature maps of a given model module during inference.
|
|
1014
|
+
@plt_settings()
|
|
1015
|
+
def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
|
|
1016
|
+
"""Visualize feature maps of a given model module during inference.
|
|
987
1017
|
|
|
988
1018
|
Args:
|
|
989
1019
|
x (torch.Tensor): Features to be visualized.
|
|
@@ -1000,7 +1030,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detec
|
|
|
1000
1030
|
if isinstance(x, torch.Tensor):
|
|
1001
1031
|
_, channels, height, width = x.shape # batch, channels, height, width
|
|
1002
1032
|
if height > 1 and width > 1:
|
|
1003
|
-
f = save_dir / f"stage{stage}_{module_type.
|
|
1033
|
+
f = save_dir / f"stage{stage}_{module_type.rsplit('.', 1)[-1]}_features.png" # filename
|
|
1004
1034
|
|
|
1005
1035
|
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
|
1006
1036
|
n = min(n, channels) # number of plots
|