dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/utils/patches.py
CHANGED
|
@@ -18,8 +18,7 @@ _imshow = cv2.imshow # copy to avoid recursion errors
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
21
|
-
"""
|
|
22
|
-
Read an image from a file with multilanguage filename support.
|
|
21
|
+
"""Read an image from a file with multilanguage filename support.
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
24
|
filename (str): Path to the file to read.
|
|
@@ -36,7 +35,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
|
36
35
|
if filename.endswith((".tiff", ".tif")):
|
|
37
36
|
success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
|
|
38
37
|
if success:
|
|
39
|
-
# Handle
|
|
38
|
+
# Handle multi-frame TIFFs and color images
|
|
40
39
|
return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
|
|
41
40
|
return None
|
|
42
41
|
else:
|
|
@@ -45,8 +44,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
|
|
48
|
-
"""
|
|
49
|
-
Write an image to a file with multilanguage filename support.
|
|
47
|
+
"""Write an image to a file with multilanguage filename support.
|
|
50
48
|
|
|
51
49
|
Args:
|
|
52
50
|
filename (str): Path to the file to write.
|
|
@@ -71,15 +69,14 @@ def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) ->
|
|
|
71
69
|
|
|
72
70
|
|
|
73
71
|
def imshow(winname: str, mat: np.ndarray) -> None:
|
|
74
|
-
"""
|
|
75
|
-
Display an image in the specified window with multilanguage window name support.
|
|
72
|
+
"""Display an image in the specified window with multilanguage window name support.
|
|
76
73
|
|
|
77
74
|
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
|
|
78
75
|
multilanguage window names by encoding them properly for OpenCV compatibility.
|
|
79
76
|
|
|
80
77
|
Args:
|
|
81
|
-
winname (str): Name of the window where the image will be displayed. If a window with this name already
|
|
82
|
-
|
|
78
|
+
winname (str): Name of the window where the image will be displayed. If a window with this name already exists,
|
|
79
|
+
the image will be displayed in that window.
|
|
83
80
|
mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
|
|
84
81
|
|
|
85
82
|
Examples:
|
|
@@ -96,8 +93,7 @@ _torch_save = torch.save
|
|
|
96
93
|
|
|
97
94
|
|
|
98
95
|
def torch_load(*args, **kwargs):
|
|
99
|
-
"""
|
|
100
|
-
Load a PyTorch model with updated arguments to avoid warnings.
|
|
96
|
+
"""Load a PyTorch model with updated arguments to avoid warnings.
|
|
101
97
|
|
|
102
98
|
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
|
|
103
99
|
|
|
@@ -109,8 +105,8 @@ def torch_load(*args, **kwargs):
|
|
|
109
105
|
(Any): The loaded PyTorch object.
|
|
110
106
|
|
|
111
107
|
Notes:
|
|
112
|
-
For PyTorch versions
|
|
113
|
-
|
|
108
|
+
For PyTorch versions 1.13 and above, this function automatically sets `weights_only=False` if the argument is
|
|
109
|
+
not provided, to avoid deprecation warnings.
|
|
114
110
|
"""
|
|
115
111
|
from ultralytics.utils.torch_utils import TORCH_1_13
|
|
116
112
|
|
|
@@ -121,11 +117,10 @@ def torch_load(*args, **kwargs):
|
|
|
121
117
|
|
|
122
118
|
|
|
123
119
|
def torch_save(*args, **kwargs):
|
|
124
|
-
"""
|
|
125
|
-
Save PyTorch objects with retry mechanism for robustness.
|
|
120
|
+
"""Save PyTorch objects with retry mechanism for robustness.
|
|
126
121
|
|
|
127
|
-
This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
|
|
128
|
-
|
|
122
|
+
This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur due
|
|
123
|
+
to device flushing delays or antivirus scanning.
|
|
129
124
|
|
|
130
125
|
Args:
|
|
131
126
|
*args (Any): Positional arguments to pass to torch.save.
|
|
@@ -146,8 +141,7 @@ def torch_save(*args, **kwargs):
|
|
|
146
141
|
|
|
147
142
|
@contextmanager
|
|
148
143
|
def arange_patch(args):
|
|
149
|
-
"""
|
|
150
|
-
Workaround for ONNX torch.arange incompatibility with FP16.
|
|
144
|
+
"""Workaround for ONNX torch.arange incompatibility with FP16.
|
|
151
145
|
|
|
152
146
|
https://github.com/pytorch/pytorch/issues/148041.
|
|
153
147
|
"""
|
|
@@ -165,10 +159,28 @@ def arange_patch(args):
|
|
|
165
159
|
yield
|
|
166
160
|
|
|
167
161
|
|
|
162
|
+
@contextmanager
|
|
163
|
+
def onnx_export_patch():
|
|
164
|
+
"""Workaround for ONNX export issues in PyTorch 2.9+ with Dynamo enabled."""
|
|
165
|
+
from ultralytics.utils.torch_utils import TORCH_2_9
|
|
166
|
+
|
|
167
|
+
if TORCH_2_9:
|
|
168
|
+
func = torch.onnx.export
|
|
169
|
+
|
|
170
|
+
def torch_export(*args, **kwargs):
|
|
171
|
+
"""Return a 1-D tensor of size with values from the interval and common difference."""
|
|
172
|
+
return func(*args, **kwargs, dynamo=False) # cast to dtype instead of passing dtype
|
|
173
|
+
|
|
174
|
+
torch.onnx.export = torch_export # patch
|
|
175
|
+
yield
|
|
176
|
+
torch.onnx.export = func # unpatch
|
|
177
|
+
else:
|
|
178
|
+
yield
|
|
179
|
+
|
|
180
|
+
|
|
168
181
|
@contextmanager
|
|
169
182
|
def override_configs(args, overrides: dict[str, Any] | None = None):
|
|
170
|
-
"""
|
|
171
|
-
Context manager to temporarily override configurations in args.
|
|
183
|
+
"""Context manager to temporarily override configurations in args.
|
|
172
184
|
|
|
173
185
|
Args:
|
|
174
186
|
args (IterableSimpleNamespace): Original configuration arguments.
|
ultralytics/utils/plotting.py
CHANGED
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import math
|
|
6
|
-
import
|
|
6
|
+
from collections.abc import Callable
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import cv2
|
|
11
11
|
import numpy as np
|
|
@@ -19,22 +19,10 @@ from ultralytics.utils.files import increment_path
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class Colors:
|
|
22
|
-
"""
|
|
23
|
-
Ultralytics color palette for visualization and plotting.
|
|
24
|
-
|
|
25
|
-
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
|
|
26
|
-
RGB values and accessing predefined color schemes for object detection and pose estimation.
|
|
22
|
+
"""Ultralytics color palette for visualization and plotting.
|
|
27
23
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
n (int): The number of colors in the palette.
|
|
31
|
-
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
|
32
|
-
|
|
33
|
-
Examples:
|
|
34
|
-
>>> from ultralytics.utils.plotting import Colors
|
|
35
|
-
>>> colors = Colors()
|
|
36
|
-
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
|
37
|
-
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
|
24
|
+
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
|
|
25
|
+
values and accessing predefined color schemes for object detection and pose estimation.
|
|
38
26
|
|
|
39
27
|
## Ultralytics Color Palette
|
|
40
28
|
|
|
@@ -90,6 +78,17 @@ class Colors:
|
|
|
90
78
|
|
|
91
79
|
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
|
92
80
|
Please use the official Ultralytics colors for all marketing materials.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
palette (list[tuple]): List of RGB color tuples for general use.
|
|
84
|
+
n (int): The number of colors in the palette.
|
|
85
|
+
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
|
86
|
+
|
|
87
|
+
Examples:
|
|
88
|
+
>>> from ultralytics.utils.plotting import Colors
|
|
89
|
+
>>> colors = Colors()
|
|
90
|
+
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
|
91
|
+
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
|
93
92
|
"""
|
|
94
93
|
|
|
95
94
|
def __init__(self):
|
|
@@ -145,8 +144,7 @@ class Colors:
|
|
|
145
144
|
)
|
|
146
145
|
|
|
147
146
|
def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
|
|
148
|
-
"""
|
|
149
|
-
Convert hex color codes to RGB values.
|
|
147
|
+
"""Convert hex color codes to RGB values.
|
|
150
148
|
|
|
151
149
|
Args:
|
|
152
150
|
i (int | torch.Tensor): Color index.
|
|
@@ -168,8 +166,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
|
|
168
166
|
|
|
169
167
|
|
|
170
168
|
class Annotator:
|
|
171
|
-
"""
|
|
172
|
-
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
169
|
+
"""Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
173
170
|
|
|
174
171
|
Attributes:
|
|
175
172
|
im (Image.Image | np.ndarray): The image to annotate.
|
|
@@ -206,10 +203,12 @@ class Annotator:
|
|
|
206
203
|
if not input_is_pil:
|
|
207
204
|
if im.shape[2] == 1: # handle grayscale
|
|
208
205
|
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
|
|
206
|
+
elif im.shape[2] == 2: # handle 2-channel images
|
|
207
|
+
im = np.ascontiguousarray(np.dstack((im, np.zeros_like(im[..., :1]))))
|
|
209
208
|
elif im.shape[2] > 3: # multispectral
|
|
210
209
|
im = np.ascontiguousarray(im[..., :3])
|
|
211
210
|
if self.pil: # use PIL
|
|
212
|
-
self.im = im if input_is_pil else Image.fromarray(im)
|
|
211
|
+
self.im = im if input_is_pil else Image.fromarray(im) # stay in BGR since color palette is in BGR
|
|
213
212
|
if self.im.mode not in {"RGB", "RGBA"}: # multispectral
|
|
214
213
|
self.im = self.im.convert("RGB")
|
|
215
214
|
self.draw = ImageDraw.Draw(self.im, "RGBA")
|
|
@@ -278,8 +277,7 @@ class Annotator:
|
|
|
278
277
|
}
|
|
279
278
|
|
|
280
279
|
def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
|
|
281
|
-
"""
|
|
282
|
-
Assign text color based on background color.
|
|
280
|
+
"""Assign text color based on background color.
|
|
283
281
|
|
|
284
282
|
Args:
|
|
285
283
|
color (tuple, optional): The background color of the rectangle for text (B, G, R).
|
|
@@ -302,8 +300,7 @@ class Annotator:
|
|
|
302
300
|
return txt_color
|
|
303
301
|
|
|
304
302
|
def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
|
|
305
|
-
"""
|
|
306
|
-
Draw a bounding box on an image with a given label.
|
|
303
|
+
"""Draw a bounding box on an image with a given label.
|
|
307
304
|
|
|
308
305
|
Args:
|
|
309
306
|
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
|
|
@@ -364,8 +361,7 @@ class Annotator:
|
|
|
364
361
|
)
|
|
365
362
|
|
|
366
363
|
def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
|
|
367
|
-
"""
|
|
368
|
-
Plot masks on image.
|
|
364
|
+
"""Plot masks on image.
|
|
369
365
|
|
|
370
366
|
Args:
|
|
371
367
|
masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
|
|
@@ -384,25 +380,32 @@ class Annotator:
|
|
|
384
380
|
overlay[mask.astype(bool)] = colors[i]
|
|
385
381
|
self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
|
|
386
382
|
else:
|
|
387
|
-
assert isinstance(masks, torch.Tensor), "
|
|
383
|
+
assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
|
|
388
384
|
if len(masks) == 0:
|
|
389
385
|
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
|
386
|
+
return
|
|
390
387
|
if im_gpu.device != masks.device:
|
|
391
388
|
im_gpu = im_gpu.to(masks.device)
|
|
389
|
+
|
|
390
|
+
ih, iw = self.im.shape[:2]
|
|
391
|
+
if not retina_masks:
|
|
392
|
+
# Use scale_masks to properly remove padding and upsample, convert bool to float first
|
|
393
|
+
masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
|
|
394
|
+
# Convert original BGR image to RGB tensor
|
|
395
|
+
im_gpu = (
|
|
396
|
+
torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
|
|
397
|
+
)
|
|
398
|
+
|
|
392
399
|
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
|
|
393
400
|
colors = colors[:, None, None] # shape(n,1,1,3)
|
|
394
401
|
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
|
395
402
|
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
|
396
|
-
|
|
397
403
|
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
|
398
|
-
mcs = masks_color.max(dim=0).values # shape(
|
|
404
|
+
mcs = masks_color.max(dim=0).values # shape(h,w,3)
|
|
399
405
|
|
|
400
|
-
im_gpu = im_gpu.flip(dims=[0]) #
|
|
401
|
-
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
|
406
|
+
im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
|
|
402
407
|
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
|
|
403
|
-
|
|
404
|
-
im_mask_np = im_mask.byte().cpu().numpy()
|
|
405
|
-
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
|
|
408
|
+
self.im[:] = (im_gpu * 255).byte().cpu().numpy()
|
|
406
409
|
if self.pil:
|
|
407
410
|
# Convert im back to PIL and update draw
|
|
408
411
|
self.fromarray(self.im)
|
|
@@ -416,8 +419,7 @@ class Annotator:
|
|
|
416
419
|
conf_thres: float = 0.25,
|
|
417
420
|
kpt_color: tuple | None = None,
|
|
418
421
|
):
|
|
419
|
-
"""
|
|
420
|
-
Plot keypoints on the image.
|
|
422
|
+
"""Plot keypoints on the image.
|
|
421
423
|
|
|
422
424
|
Args:
|
|
423
425
|
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
|
|
@@ -427,7 +429,7 @@ class Annotator:
|
|
|
427
429
|
conf_thres (float, optional): Confidence threshold.
|
|
428
430
|
kpt_color (tuple, optional): Keypoint color (B, G, R).
|
|
429
431
|
|
|
430
|
-
|
|
432
|
+
Notes:
|
|
431
433
|
- `kpt_line=True` currently only supports human pose plotting.
|
|
432
434
|
- Modifies self.im in-place.
|
|
433
435
|
- If self.pil is True, converts image to numpy array and back to PIL.
|
|
@@ -480,8 +482,7 @@ class Annotator:
|
|
|
480
482
|
self.draw.rectangle(xy, fill, outline, width)
|
|
481
483
|
|
|
482
484
|
def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
|
|
483
|
-
"""
|
|
484
|
-
Add text to an image using PIL or cv2.
|
|
485
|
+
"""Add text to an image using PIL or cv2.
|
|
485
486
|
|
|
486
487
|
Args:
|
|
487
488
|
xy (list[int]): Top-left coordinates for text placement.
|
|
@@ -511,18 +512,19 @@ class Annotator:
|
|
|
511
512
|
cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
|
|
512
513
|
|
|
513
514
|
def fromarray(self, im):
|
|
514
|
-
"""Update self.im from a
|
|
515
|
+
"""Update `self.im` from a NumPy array or PIL image."""
|
|
515
516
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
|
516
517
|
self.draw = ImageDraw.Draw(self.im)
|
|
517
518
|
|
|
518
|
-
def result(self):
|
|
519
|
-
"""Return annotated image as array."""
|
|
520
|
-
|
|
519
|
+
def result(self, pil=False):
|
|
520
|
+
"""Return annotated image as array or PIL image."""
|
|
521
|
+
im = np.asarray(self.im) # self.im is in BGR
|
|
522
|
+
return Image.fromarray(im[..., ::-1]) if pil else im
|
|
521
523
|
|
|
522
524
|
def show(self, title: str | None = None):
|
|
523
525
|
"""Show the annotated image."""
|
|
524
|
-
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert
|
|
525
|
-
if IS_COLAB or IS_KAGGLE: #
|
|
526
|
+
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert BGR NumPy array to RGB PIL Image
|
|
527
|
+
if IS_COLAB or IS_KAGGLE: # cannot use IS_JUPYTER as it runs for all IPython environments
|
|
526
528
|
try:
|
|
527
529
|
display(im) # noqa - display() function only available in ipython environments
|
|
528
530
|
except ImportError as e:
|
|
@@ -535,12 +537,11 @@ class Annotator:
|
|
|
535
537
|
cv2.imwrite(filename, np.asarray(self.im))
|
|
536
538
|
|
|
537
539
|
@staticmethod
|
|
538
|
-
def get_bbox_dimension(bbox: tuple |
|
|
539
|
-
"""
|
|
540
|
-
Calculate the dimensions and area of a bounding box.
|
|
540
|
+
def get_bbox_dimension(bbox: tuple | list):
|
|
541
|
+
"""Calculate the dimensions and area of a bounding box.
|
|
541
542
|
|
|
542
543
|
Args:
|
|
543
|
-
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
544
|
+
bbox (tuple | list): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
544
545
|
|
|
545
546
|
Returns:
|
|
546
547
|
width (float): Width of the bounding box.
|
|
@@ -562,8 +563,7 @@ class Annotator:
|
|
|
562
563
|
@TryExcept()
|
|
563
564
|
@plt_settings()
|
|
564
565
|
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
565
|
-
"""
|
|
566
|
-
Plot training labels including class histograms and box statistics.
|
|
566
|
+
"""Plot training labels including class histograms and box statistics.
|
|
567
567
|
|
|
568
568
|
Args:
|
|
569
569
|
boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
|
|
@@ -576,10 +576,6 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
576
576
|
import polars
|
|
577
577
|
from matplotlib.colors import LinearSegmentedColormap
|
|
578
578
|
|
|
579
|
-
# Filter matplotlib>=3.7.2 warning
|
|
580
|
-
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
|
581
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
582
|
-
|
|
583
579
|
# Plot dataset labels
|
|
584
580
|
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
|
585
581
|
nc = int(cls.max() + 1) # number of classes
|
|
@@ -601,8 +597,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
601
597
|
ax[0].set_xlabel("classes")
|
|
602
598
|
boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
|
|
603
599
|
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
|
|
604
|
-
for
|
|
605
|
-
ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(
|
|
600
|
+
for class_id, box in zip(cls[:500], boxes[:500]):
|
|
601
|
+
ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(class_id)) # plot
|
|
606
602
|
ax[1].imshow(img)
|
|
607
603
|
ax[1].axis("off")
|
|
608
604
|
|
|
@@ -633,12 +629,11 @@ def save_one_box(
|
|
|
633
629
|
BGR: bool = False,
|
|
634
630
|
save: bool = True,
|
|
635
631
|
):
|
|
636
|
-
"""
|
|
637
|
-
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
632
|
+
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
638
633
|
|
|
639
|
-
This function takes a bounding box and an image, and then saves a cropped portion of the image according
|
|
640
|
-
|
|
641
|
-
|
|
634
|
+
This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
|
|
635
|
+
bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
|
|
636
|
+
bounding box.
|
|
642
637
|
|
|
643
638
|
Args:
|
|
644
639
|
xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
|
|
@@ -691,11 +686,11 @@ def plot_images(
|
|
|
691
686
|
save: bool = True,
|
|
692
687
|
conf_thres: float = 0.25,
|
|
693
688
|
) -> np.ndarray | None:
|
|
694
|
-
"""
|
|
695
|
-
Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
689
|
+
"""Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
696
690
|
|
|
697
691
|
Args:
|
|
698
|
-
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
692
|
+
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
693
|
+
'keypoints', 'batch_idx', 'img'.
|
|
699
694
|
images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
|
|
700
695
|
paths (Optional[list[str]]): List of file paths for each image in the batch.
|
|
701
696
|
fname (str): Output filename for the plotted image grid.
|
|
@@ -709,12 +704,12 @@ def plot_images(
|
|
|
709
704
|
Returns:
|
|
710
705
|
(np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
|
|
711
706
|
|
|
712
|
-
|
|
707
|
+
Notes:
|
|
713
708
|
This function supports both tensor and numpy array inputs. It will automatically
|
|
714
709
|
convert tensor inputs to numpy arrays for processing.
|
|
715
710
|
|
|
716
711
|
Channel Support:
|
|
717
|
-
- 1 channel:
|
|
712
|
+
- 1 channel: Grayscale
|
|
718
713
|
- 2 channels: Third channel added as zeros
|
|
719
714
|
- 3 channels: Used as-is (standard RGB)
|
|
720
715
|
- 4+ channels: Cropped to first 3 channels
|
|
@@ -791,7 +786,6 @@ def plot_images(
|
|
|
791
786
|
boxes[..., 0] += x
|
|
792
787
|
boxes[..., 1] += y
|
|
793
788
|
is_obb = boxes.shape[-1] == 5 # xywhr
|
|
794
|
-
# TODO: this transformation might be unnecessary
|
|
795
789
|
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
|
796
790
|
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
|
797
791
|
c = classes[j]
|
|
@@ -860,9 +854,9 @@ def plot_images(
|
|
|
860
854
|
|
|
861
855
|
@plt_settings()
|
|
862
856
|
def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
|
|
863
|
-
"""
|
|
864
|
-
|
|
865
|
-
|
|
857
|
+
"""Plot training results from a results CSV file. The function supports various types of data including
|
|
858
|
+
segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
|
|
859
|
+
CSV is located.
|
|
866
860
|
|
|
867
861
|
Args:
|
|
868
862
|
file (str, optional): Path to the CSV file containing the training results.
|
|
@@ -914,8 +908,7 @@ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Call
|
|
|
914
908
|
|
|
915
909
|
|
|
916
910
|
def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
|
|
917
|
-
"""
|
|
918
|
-
Plot a scatter plot with points colored based on a 2D histogram.
|
|
911
|
+
"""Plot a scatter plot with points colored based on a 2D histogram.
|
|
919
912
|
|
|
920
913
|
Args:
|
|
921
914
|
v (array-like): Values for the x-axis.
|
|
@@ -948,9 +941,9 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
|
|
|
948
941
|
|
|
949
942
|
@plt_settings()
|
|
950
943
|
def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
|
|
951
|
-
"""
|
|
952
|
-
|
|
953
|
-
|
|
944
|
+
"""Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
|
|
945
|
+
key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
|
|
946
|
+
the plots.
|
|
954
947
|
|
|
955
948
|
Args:
|
|
956
949
|
csv_file (str, optional): Path to the CSV file containing the tuning results.
|
|
@@ -1017,8 +1010,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_p
|
|
|
1017
1010
|
|
|
1018
1011
|
@plt_settings()
|
|
1019
1012
|
def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
|
|
1020
|
-
"""
|
|
1021
|
-
Visualize feature maps of a given model module during inference.
|
|
1013
|
+
"""Visualize feature maps of a given model module during inference.
|
|
1022
1014
|
|
|
1023
1015
|
Args:
|
|
1024
1016
|
x (torch.Tensor): Features to be visualized.
|
ultralytics/utils/tal.py
CHANGED
|
@@ -10,8 +10,7 @@ from .torch_utils import TORCH_1_11
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class TaskAlignedAssigner(nn.Module):
|
|
13
|
-
"""
|
|
14
|
-
A task-aligned assigner for object detection.
|
|
13
|
+
"""A task-aligned assigner for object detection.
|
|
15
14
|
|
|
16
15
|
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
|
17
16
|
classification and localization information.
|
|
@@ -25,8 +24,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
25
24
|
"""
|
|
26
25
|
|
|
27
26
|
def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
|
|
28
|
-
"""
|
|
29
|
-
Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
27
|
+
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
30
28
|
|
|
31
29
|
Args:
|
|
32
30
|
topk (int, optional): The number of top candidates to consider.
|
|
@@ -44,8 +42,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
44
42
|
|
|
45
43
|
@torch.no_grad()
|
|
46
44
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
47
|
-
"""
|
|
48
|
-
Compute the task-aligned assignment.
|
|
45
|
+
"""Compute the task-aligned assignment.
|
|
49
46
|
|
|
50
47
|
Args:
|
|
51
48
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -88,8 +85,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
88
85
|
return tuple(t.to(device) for t in result)
|
|
89
86
|
|
|
90
87
|
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
91
|
-
"""
|
|
92
|
-
Compute the task-aligned assignment.
|
|
88
|
+
"""Compute the task-aligned assignment.
|
|
93
89
|
|
|
94
90
|
Args:
|
|
95
91
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -125,8 +121,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
125
121
|
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
|
126
122
|
|
|
127
123
|
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
|
128
|
-
"""
|
|
129
|
-
Get positive mask for each ground truth box.
|
|
124
|
+
"""Get positive mask for each ground truth box.
|
|
130
125
|
|
|
131
126
|
Args:
|
|
132
127
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -139,7 +134,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
139
134
|
Returns:
|
|
140
135
|
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
|
141
136
|
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
|
142
|
-
overlaps (torch.Tensor): Overlaps between predicted
|
|
137
|
+
overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
|
|
143
138
|
"""
|
|
144
139
|
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
|
145
140
|
# Get anchor_align metric, (b, max_num_obj, h*w)
|
|
@@ -152,8 +147,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
152
147
|
return mask_pos, align_metric, overlaps
|
|
153
148
|
|
|
154
149
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
|
155
|
-
"""
|
|
156
|
-
Compute alignment metric given predicted and ground truth bounding boxes.
|
|
150
|
+
"""Compute alignment metric given predicted and ground truth bounding boxes.
|
|
157
151
|
|
|
158
152
|
Args:
|
|
159
153
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -186,8 +180,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
186
180
|
return align_metric, overlaps
|
|
187
181
|
|
|
188
182
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
|
189
|
-
"""
|
|
190
|
-
Calculate IoU for horizontal bounding boxes.
|
|
183
|
+
"""Calculate IoU for horizontal bounding boxes.
|
|
191
184
|
|
|
192
185
|
Args:
|
|
193
186
|
gt_bboxes (torch.Tensor): Ground truth boxes.
|
|
@@ -199,14 +192,13 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
199
192
|
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
|
200
193
|
|
|
201
194
|
def select_topk_candidates(self, metrics, topk_mask=None):
|
|
202
|
-
"""
|
|
203
|
-
Select the top-k candidates based on the given metrics.
|
|
195
|
+
"""Select the top-k candidates based on the given metrics.
|
|
204
196
|
|
|
205
197
|
Args:
|
|
206
198
|
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
|
|
207
199
|
the maximum number of objects, and h*w represents the total number of anchor points.
|
|
208
|
-
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
|
209
|
-
|
|
200
|
+
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
|
|
201
|
+
is the number of top candidates to consider. If not provided, the top-k values are automatically
|
|
210
202
|
computed based on the given metrics.
|
|
211
203
|
|
|
212
204
|
Returns:
|
|
@@ -231,18 +223,16 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
231
223
|
return count_tensor.to(metrics.dtype)
|
|
232
224
|
|
|
233
225
|
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
|
234
|
-
"""
|
|
235
|
-
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
226
|
+
"""Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
236
227
|
|
|
237
228
|
Args:
|
|
238
|
-
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
|
239
|
-
|
|
229
|
+
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
|
|
230
|
+
max_num_obj is the maximum number of objects.
|
|
240
231
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
|
241
|
-
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
(foreground) anchor points.
|
|
232
|
+
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
|
|
233
|
+
shape (b, h*w), where h*w is the total number of anchor points.
|
|
234
|
+
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
|
|
235
|
+
points.
|
|
246
236
|
|
|
247
237
|
Returns:
|
|
248
238
|
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
|
|
@@ -275,8 +265,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
275
265
|
|
|
276
266
|
@staticmethod
|
|
277
267
|
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
|
278
|
-
"""
|
|
279
|
-
Select positive anchor centers within ground truth bounding boxes.
|
|
268
|
+
"""Select positive anchor centers within ground truth bounding boxes.
|
|
280
269
|
|
|
281
270
|
Args:
|
|
282
271
|
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
|
@@ -286,9 +275,9 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
286
275
|
Returns:
|
|
287
276
|
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
|
288
277
|
|
|
289
|
-
|
|
290
|
-
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
291
|
-
Bounding box format: [x_min, y_min, x_max, y_max].
|
|
278
|
+
Notes:
|
|
279
|
+
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
280
|
+
- Bounding box format: [x_min, y_min, x_max, y_max].
|
|
292
281
|
"""
|
|
293
282
|
n_anchors = xy_centers.shape[0]
|
|
294
283
|
bs, n_boxes, _ = gt_bboxes.shape
|
|
@@ -298,8 +287,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
298
287
|
|
|
299
288
|
@staticmethod
|
|
300
289
|
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
|
301
|
-
"""
|
|
302
|
-
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
290
|
+
"""Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
303
291
|
|
|
304
292
|
Args:
|
|
305
293
|
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
|
@@ -336,8 +324,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
|
336
324
|
|
|
337
325
|
@staticmethod
|
|
338
326
|
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
|
339
|
-
"""
|
|
340
|
-
Select the positive anchor center in gt for rotated bounding boxes.
|
|
327
|
+
"""Select the positive anchor center in gt for rotated bounding boxes.
|
|
341
328
|
|
|
342
329
|
Args:
|
|
343
330
|
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
|
@@ -396,8 +383,7 @@ def bbox2dist(anchor_points, bbox, reg_max):
|
|
|
396
383
|
|
|
397
384
|
|
|
398
385
|
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
399
|
-
"""
|
|
400
|
-
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
386
|
+
"""Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
401
387
|
|
|
402
388
|
Args:
|
|
403
389
|
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|