dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/patches.py
CHANGED
|
@@ -18,8 +18,7 @@ _imshow = cv2.imshow # copy to avoid recursion errors
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
21
|
-
"""
|
|
22
|
-
Read an image from a file with multilanguage filename support.
|
|
21
|
+
"""Read an image from a file with multilanguage filename support.
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
24
|
filename (str): Path to the file to read.
|
|
@@ -36,17 +35,58 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
|
36
35
|
if filename.endswith((".tiff", ".tif")):
|
|
37
36
|
success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
|
|
38
37
|
if success:
|
|
39
|
-
# Handle
|
|
38
|
+
# Handle multi-frame TIFFs and color images
|
|
40
39
|
return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
|
|
41
40
|
return None
|
|
42
41
|
else:
|
|
43
42
|
im = cv2.imdecode(file_bytes, flags)
|
|
43
|
+
# Fallback for formats OpenCV imdecode may not support (AVIF, HEIC)
|
|
44
|
+
if im is None and filename.lower().endswith((".avif", ".heic")):
|
|
45
|
+
im = _imread_pil(filename, flags)
|
|
44
46
|
return im[..., None] if im is not None and im.ndim == 2 else im # Always ensure 3 dimensions
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
|
|
49
|
+
_pil_plugins_registered = False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _imread_pil(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
|
53
|
+
"""Read an image using PIL as fallback for formats not supported by OpenCV.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
filename (str): Path to the file to read.
|
|
57
|
+
flags (int, optional): OpenCV imread flags (used to determine grayscale conversion).
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
(np.ndarray | None): The read image array in BGR format, or None if reading fails.
|
|
48
61
|
"""
|
|
49
|
-
|
|
62
|
+
global _pil_plugins_registered
|
|
63
|
+
try:
|
|
64
|
+
from PIL import Image
|
|
65
|
+
|
|
66
|
+
# Register HEIF/AVIF plugins once
|
|
67
|
+
if not _pil_plugins_registered:
|
|
68
|
+
try:
|
|
69
|
+
import pillow_heif
|
|
70
|
+
|
|
71
|
+
pillow_heif.register_heif_opener()
|
|
72
|
+
except ImportError:
|
|
73
|
+
pass
|
|
74
|
+
try:
|
|
75
|
+
import pillow_avif # noqa: F401
|
|
76
|
+
except ImportError:
|
|
77
|
+
pass
|
|
78
|
+
_pil_plugins_registered = True
|
|
79
|
+
|
|
80
|
+
with Image.open(filename) as img:
|
|
81
|
+
if flags == cv2.IMREAD_GRAYSCALE:
|
|
82
|
+
return np.asarray(img.convert("L"))
|
|
83
|
+
return cv2.cvtColor(np.asarray(img.convert("RGB")), cv2.COLOR_RGB2BGR)
|
|
84
|
+
except Exception:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
|
|
89
|
+
"""Write an image to a file with multilanguage filename support.
|
|
50
90
|
|
|
51
91
|
Args:
|
|
52
92
|
filename (str): Path to the file to write.
|
|
@@ -71,15 +111,14 @@ def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) ->
|
|
|
71
111
|
|
|
72
112
|
|
|
73
113
|
def imshow(winname: str, mat: np.ndarray) -> None:
|
|
74
|
-
"""
|
|
75
|
-
Display an image in the specified window with multilanguage window name support.
|
|
114
|
+
"""Display an image in the specified window with multilanguage window name support.
|
|
76
115
|
|
|
77
116
|
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
|
|
78
117
|
multilanguage window names by encoding them properly for OpenCV compatibility.
|
|
79
118
|
|
|
80
119
|
Args:
|
|
81
|
-
winname (str): Name of the window where the image will be displayed. If a window with this name already
|
|
82
|
-
|
|
120
|
+
winname (str): Name of the window where the image will be displayed. If a window with this name already exists,
|
|
121
|
+
the image will be displayed in that window.
|
|
83
122
|
mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
|
|
84
123
|
|
|
85
124
|
Examples:
|
|
@@ -96,8 +135,7 @@ _torch_save = torch.save
|
|
|
96
135
|
|
|
97
136
|
|
|
98
137
|
def torch_load(*args, **kwargs):
|
|
99
|
-
"""
|
|
100
|
-
Load a PyTorch model with updated arguments to avoid warnings.
|
|
138
|
+
"""Load a PyTorch model with updated arguments to avoid warnings.
|
|
101
139
|
|
|
102
140
|
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
|
|
103
141
|
|
|
@@ -109,8 +147,8 @@ def torch_load(*args, **kwargs):
|
|
|
109
147
|
(Any): The loaded PyTorch object.
|
|
110
148
|
|
|
111
149
|
Notes:
|
|
112
|
-
For PyTorch versions
|
|
113
|
-
|
|
150
|
+
For PyTorch versions 1.13 and above, this function automatically sets `weights_only=False` if the argument is
|
|
151
|
+
not provided, to avoid deprecation warnings.
|
|
114
152
|
"""
|
|
115
153
|
from ultralytics.utils.torch_utils import TORCH_1_13
|
|
116
154
|
|
|
@@ -121,11 +159,10 @@ def torch_load(*args, **kwargs):
|
|
|
121
159
|
|
|
122
160
|
|
|
123
161
|
def torch_save(*args, **kwargs):
|
|
124
|
-
"""
|
|
125
|
-
Save PyTorch objects with retry mechanism for robustness.
|
|
162
|
+
"""Save PyTorch objects with retry mechanism for robustness.
|
|
126
163
|
|
|
127
|
-
This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
|
|
128
|
-
|
|
164
|
+
This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur due
|
|
165
|
+
to device flushing delays or antivirus scanning.
|
|
129
166
|
|
|
130
167
|
Args:
|
|
131
168
|
*args (Any): Positional arguments to pass to torch.save.
|
|
@@ -146,8 +183,7 @@ def torch_save(*args, **kwargs):
|
|
|
146
183
|
|
|
147
184
|
@contextmanager
|
|
148
185
|
def arange_patch(args):
|
|
149
|
-
"""
|
|
150
|
-
Workaround for ONNX torch.arange incompatibility with FP16.
|
|
186
|
+
"""Workaround for ONNX torch.arange incompatibility with FP16.
|
|
151
187
|
|
|
152
188
|
https://github.com/pytorch/pytorch/issues/148041.
|
|
153
189
|
"""
|
|
@@ -165,10 +201,28 @@ def arange_patch(args):
|
|
|
165
201
|
yield
|
|
166
202
|
|
|
167
203
|
|
|
204
|
+
@contextmanager
|
|
205
|
+
def onnx_export_patch():
|
|
206
|
+
"""Workaround for ONNX export issues in PyTorch 2.9+ with Dynamo enabled."""
|
|
207
|
+
from ultralytics.utils.torch_utils import TORCH_2_9
|
|
208
|
+
|
|
209
|
+
if TORCH_2_9:
|
|
210
|
+
func = torch.onnx.export
|
|
211
|
+
|
|
212
|
+
def torch_export(*args, **kwargs):
|
|
213
|
+
"""Return a 1-D tensor of size with values from the interval and common difference."""
|
|
214
|
+
return func(*args, **kwargs, dynamo=False) # cast to dtype instead of passing dtype
|
|
215
|
+
|
|
216
|
+
torch.onnx.export = torch_export # patch
|
|
217
|
+
yield
|
|
218
|
+
torch.onnx.export = func # unpatch
|
|
219
|
+
else:
|
|
220
|
+
yield
|
|
221
|
+
|
|
222
|
+
|
|
168
223
|
@contextmanager
|
|
169
224
|
def override_configs(args, overrides: dict[str, Any] | None = None):
|
|
170
|
-
"""
|
|
171
|
-
Context manager to temporarily override configurations in args.
|
|
225
|
+
"""Context manager to temporarily override configurations in args.
|
|
172
226
|
|
|
173
227
|
Args:
|
|
174
228
|
args (IterableSimpleNamespace): Original configuration arguments.
|
ultralytics/utils/plotting.py
CHANGED
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import math
|
|
6
|
-
import
|
|
6
|
+
from collections.abc import Callable
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import cv2
|
|
11
11
|
import numpy as np
|
|
@@ -19,22 +19,10 @@ from ultralytics.utils.files import increment_path
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class Colors:
|
|
22
|
-
"""
|
|
23
|
-
Ultralytics color palette for visualization and plotting.
|
|
24
|
-
|
|
25
|
-
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
|
|
26
|
-
RGB values and accessing predefined color schemes for object detection and pose estimation.
|
|
22
|
+
"""Ultralytics color palette for visualization and plotting.
|
|
27
23
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
n (int): The number of colors in the palette.
|
|
31
|
-
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
|
32
|
-
|
|
33
|
-
Examples:
|
|
34
|
-
>>> from ultralytics.utils.plotting import Colors
|
|
35
|
-
>>> colors = Colors()
|
|
36
|
-
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
|
37
|
-
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
|
24
|
+
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
|
|
25
|
+
values and accessing predefined color schemes for object detection and pose estimation.
|
|
38
26
|
|
|
39
27
|
## Ultralytics Color Palette
|
|
40
28
|
|
|
@@ -90,6 +78,17 @@ class Colors:
|
|
|
90
78
|
|
|
91
79
|
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
|
92
80
|
Please use the official Ultralytics colors for all marketing materials.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
palette (list[tuple]): List of RGB color tuples for general use.
|
|
84
|
+
n (int): The number of colors in the palette.
|
|
85
|
+
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
|
86
|
+
|
|
87
|
+
Examples:
|
|
88
|
+
>>> from ultralytics.utils.plotting import Colors
|
|
89
|
+
>>> colors = Colors()
|
|
90
|
+
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
|
91
|
+
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
|
93
92
|
"""
|
|
94
93
|
|
|
95
94
|
def __init__(self):
|
|
@@ -145,8 +144,7 @@ class Colors:
|
|
|
145
144
|
)
|
|
146
145
|
|
|
147
146
|
def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
|
|
148
|
-
"""
|
|
149
|
-
Convert hex color codes to RGB values.
|
|
147
|
+
"""Convert hex color codes to RGB values.
|
|
150
148
|
|
|
151
149
|
Args:
|
|
152
150
|
i (int | torch.Tensor): Color index.
|
|
@@ -168,8 +166,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
|
|
168
166
|
|
|
169
167
|
|
|
170
168
|
class Annotator:
|
|
171
|
-
"""
|
|
172
|
-
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
169
|
+
"""Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
173
170
|
|
|
174
171
|
Attributes:
|
|
175
172
|
im (Image.Image | np.ndarray): The image to annotate.
|
|
@@ -206,10 +203,12 @@ class Annotator:
|
|
|
206
203
|
if not input_is_pil:
|
|
207
204
|
if im.shape[2] == 1: # handle grayscale
|
|
208
205
|
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
|
|
206
|
+
elif im.shape[2] == 2: # handle 2-channel images
|
|
207
|
+
im = np.ascontiguousarray(np.dstack((im, np.zeros_like(im[..., :1]))))
|
|
209
208
|
elif im.shape[2] > 3: # multispectral
|
|
210
209
|
im = np.ascontiguousarray(im[..., :3])
|
|
211
210
|
if self.pil: # use PIL
|
|
212
|
-
self.im = im if input_is_pil else Image.fromarray(im)
|
|
211
|
+
self.im = im if input_is_pil else Image.fromarray(im) # stay in BGR since color palette is in BGR
|
|
213
212
|
if self.im.mode not in {"RGB", "RGBA"}: # multispectral
|
|
214
213
|
self.im = self.im.convert("RGB")
|
|
215
214
|
self.draw = ImageDraw.Draw(self.im, "RGBA")
|
|
@@ -278,8 +277,7 @@ class Annotator:
|
|
|
278
277
|
}
|
|
279
278
|
|
|
280
279
|
def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
|
|
281
|
-
"""
|
|
282
|
-
Assign text color based on background color.
|
|
280
|
+
"""Assign text color based on background color.
|
|
283
281
|
|
|
284
282
|
Args:
|
|
285
283
|
color (tuple, optional): The background color of the rectangle for text (B, G, R).
|
|
@@ -302,8 +300,7 @@ class Annotator:
|
|
|
302
300
|
return txt_color
|
|
303
301
|
|
|
304
302
|
def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
|
|
305
|
-
"""
|
|
306
|
-
Draw a bounding box on an image with a given label.
|
|
303
|
+
"""Draw a bounding box on an image with a given label.
|
|
307
304
|
|
|
308
305
|
Args:
|
|
309
306
|
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
|
|
@@ -364,8 +361,7 @@ class Annotator:
|
|
|
364
361
|
)
|
|
365
362
|
|
|
366
363
|
def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
|
|
367
|
-
"""
|
|
368
|
-
Plot masks on image.
|
|
364
|
+
"""Plot masks on image.
|
|
369
365
|
|
|
370
366
|
Args:
|
|
371
367
|
masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
|
|
@@ -384,25 +380,32 @@ class Annotator:
|
|
|
384
380
|
overlay[mask.astype(bool)] = colors[i]
|
|
385
381
|
self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
|
|
386
382
|
else:
|
|
387
|
-
assert isinstance(masks, torch.Tensor), "
|
|
383
|
+
assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
|
|
388
384
|
if len(masks) == 0:
|
|
389
385
|
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
|
386
|
+
return
|
|
390
387
|
if im_gpu.device != masks.device:
|
|
391
388
|
im_gpu = im_gpu.to(masks.device)
|
|
389
|
+
|
|
390
|
+
ih, iw = self.im.shape[:2]
|
|
391
|
+
if not retina_masks:
|
|
392
|
+
# Use scale_masks to properly remove padding and upsample, convert bool to float first
|
|
393
|
+
masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
|
|
394
|
+
# Convert original BGR image to RGB tensor
|
|
395
|
+
im_gpu = (
|
|
396
|
+
torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
|
|
397
|
+
)
|
|
398
|
+
|
|
392
399
|
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
|
|
393
400
|
colors = colors[:, None, None] # shape(n,1,1,3)
|
|
394
401
|
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
|
395
402
|
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
|
396
|
-
|
|
397
403
|
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
|
398
|
-
mcs = masks_color.max(dim=0).values # shape(
|
|
404
|
+
mcs = masks_color.max(dim=0).values # shape(h,w,3)
|
|
399
405
|
|
|
400
|
-
im_gpu = im_gpu.flip(dims=[0]) #
|
|
401
|
-
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
|
406
|
+
im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
|
|
402
407
|
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
|
|
403
|
-
|
|
404
|
-
im_mask_np = im_mask.byte().cpu().numpy()
|
|
405
|
-
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
|
|
408
|
+
self.im[:] = (im_gpu * 255).byte().cpu().numpy()
|
|
406
409
|
if self.pil:
|
|
407
410
|
# Convert im back to PIL and update draw
|
|
408
411
|
self.fromarray(self.im)
|
|
@@ -416,8 +419,7 @@ class Annotator:
|
|
|
416
419
|
conf_thres: float = 0.25,
|
|
417
420
|
kpt_color: tuple | None = None,
|
|
418
421
|
):
|
|
419
|
-
"""
|
|
420
|
-
Plot keypoints on the image.
|
|
422
|
+
"""Plot keypoints on the image.
|
|
421
423
|
|
|
422
424
|
Args:
|
|
423
425
|
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
|
|
@@ -427,7 +429,7 @@ class Annotator:
|
|
|
427
429
|
conf_thres (float, optional): Confidence threshold.
|
|
428
430
|
kpt_color (tuple, optional): Keypoint color (B, G, R).
|
|
429
431
|
|
|
430
|
-
|
|
432
|
+
Notes:
|
|
431
433
|
- `kpt_line=True` currently only supports human pose plotting.
|
|
432
434
|
- Modifies self.im in-place.
|
|
433
435
|
- If self.pil is True, converts image to numpy array and back to PIL.
|
|
@@ -480,8 +482,7 @@ class Annotator:
|
|
|
480
482
|
self.draw.rectangle(xy, fill, outline, width)
|
|
481
483
|
|
|
482
484
|
def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
|
|
483
|
-
"""
|
|
484
|
-
Add text to an image using PIL or cv2.
|
|
485
|
+
"""Add text to an image using PIL or cv2.
|
|
485
486
|
|
|
486
487
|
Args:
|
|
487
488
|
xy (list[int]): Top-left coordinates for text placement.
|
|
@@ -511,18 +512,19 @@ class Annotator:
|
|
|
511
512
|
cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
|
|
512
513
|
|
|
513
514
|
def fromarray(self, im):
|
|
514
|
-
"""Update self.im from a
|
|
515
|
+
"""Update `self.im` from a NumPy array or PIL image."""
|
|
515
516
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
|
516
517
|
self.draw = ImageDraw.Draw(self.im)
|
|
517
518
|
|
|
518
|
-
def result(self):
|
|
519
|
-
"""Return annotated image as array."""
|
|
520
|
-
|
|
519
|
+
def result(self, pil=False):
|
|
520
|
+
"""Return annotated image as array or PIL image."""
|
|
521
|
+
im = np.asarray(self.im) # self.im is in BGR
|
|
522
|
+
return Image.fromarray(im[..., ::-1]) if pil else im
|
|
521
523
|
|
|
522
524
|
def show(self, title: str | None = None):
|
|
523
525
|
"""Show the annotated image."""
|
|
524
|
-
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert
|
|
525
|
-
if IS_COLAB or IS_KAGGLE: #
|
|
526
|
+
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert BGR NumPy array to RGB PIL Image
|
|
527
|
+
if IS_COLAB or IS_KAGGLE: # cannot use IS_JUPYTER as it runs for all IPython environments
|
|
526
528
|
try:
|
|
527
529
|
display(im) # noqa - display() function only available in ipython environments
|
|
528
530
|
except ImportError as e:
|
|
@@ -535,12 +537,11 @@ class Annotator:
|
|
|
535
537
|
cv2.imwrite(filename, np.asarray(self.im))
|
|
536
538
|
|
|
537
539
|
@staticmethod
|
|
538
|
-
def get_bbox_dimension(bbox: tuple |
|
|
539
|
-
"""
|
|
540
|
-
Calculate the dimensions and area of a bounding box.
|
|
540
|
+
def get_bbox_dimension(bbox: tuple | list):
|
|
541
|
+
"""Calculate the dimensions and area of a bounding box.
|
|
541
542
|
|
|
542
543
|
Args:
|
|
543
|
-
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
544
|
+
bbox (tuple | list): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
544
545
|
|
|
545
546
|
Returns:
|
|
546
547
|
width (float): Width of the bounding box.
|
|
@@ -562,8 +563,7 @@ class Annotator:
|
|
|
562
563
|
@TryExcept()
|
|
563
564
|
@plt_settings()
|
|
564
565
|
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
565
|
-
"""
|
|
566
|
-
Plot training labels including class histograms and box statistics.
|
|
566
|
+
"""Plot training labels including class histograms and box statistics.
|
|
567
567
|
|
|
568
568
|
Args:
|
|
569
569
|
boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
|
|
@@ -576,10 +576,6 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
576
576
|
import polars
|
|
577
577
|
from matplotlib.colors import LinearSegmentedColormap
|
|
578
578
|
|
|
579
|
-
# Filter matplotlib>=3.7.2 warning
|
|
580
|
-
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
|
|
581
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
582
|
-
|
|
583
579
|
# Plot dataset labels
|
|
584
580
|
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
|
585
581
|
nc = int(cls.max() + 1) # number of classes
|
|
@@ -601,8 +597,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
|
601
597
|
ax[0].set_xlabel("classes")
|
|
602
598
|
boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
|
|
603
599
|
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
|
|
604
|
-
for
|
|
605
|
-
ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(
|
|
600
|
+
for class_id, box in zip(cls[:500], boxes[:500]):
|
|
601
|
+
ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(class_id)) # plot
|
|
606
602
|
ax[1].imshow(img)
|
|
607
603
|
ax[1].axis("off")
|
|
608
604
|
|
|
@@ -633,12 +629,11 @@ def save_one_box(
|
|
|
633
629
|
BGR: bool = False,
|
|
634
630
|
save: bool = True,
|
|
635
631
|
):
|
|
636
|
-
"""
|
|
637
|
-
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
632
|
+
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
638
633
|
|
|
639
|
-
This function takes a bounding box and an image, and then saves a cropped portion of the image according
|
|
640
|
-
|
|
641
|
-
|
|
634
|
+
This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
|
|
635
|
+
bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
|
|
636
|
+
bounding box.
|
|
642
637
|
|
|
643
638
|
Args:
|
|
644
639
|
xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
|
|
@@ -691,11 +686,11 @@ def plot_images(
|
|
|
691
686
|
save: bool = True,
|
|
692
687
|
conf_thres: float = 0.25,
|
|
693
688
|
) -> np.ndarray | None:
|
|
694
|
-
"""
|
|
695
|
-
Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
689
|
+
"""Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
696
690
|
|
|
697
691
|
Args:
|
|
698
|
-
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
692
|
+
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
693
|
+
'keypoints', 'batch_idx', 'img'.
|
|
699
694
|
images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
|
|
700
695
|
paths (Optional[list[str]]): List of file paths for each image in the batch.
|
|
701
696
|
fname (str): Output filename for the plotted image grid.
|
|
@@ -709,12 +704,12 @@ def plot_images(
|
|
|
709
704
|
Returns:
|
|
710
705
|
(np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
|
|
711
706
|
|
|
712
|
-
|
|
707
|
+
Notes:
|
|
713
708
|
This function supports both tensor and numpy array inputs. It will automatically
|
|
714
709
|
convert tensor inputs to numpy arrays for processing.
|
|
715
710
|
|
|
716
711
|
Channel Support:
|
|
717
|
-
- 1 channel:
|
|
712
|
+
- 1 channel: Grayscale
|
|
718
713
|
- 2 channels: Third channel added as zeros
|
|
719
714
|
- 3 channels: Used as-is (standard RGB)
|
|
720
715
|
- 4+ channels: Cropped to first 3 channels
|
|
@@ -791,7 +786,6 @@ def plot_images(
|
|
|
791
786
|
boxes[..., 0] += x
|
|
792
787
|
boxes[..., 1] += y
|
|
793
788
|
is_obb = boxes.shape[-1] == 5 # xywhr
|
|
794
|
-
# TODO: this transformation might be unnecessary
|
|
795
789
|
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
|
796
790
|
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
|
797
791
|
c = classes[j]
|
|
@@ -860,9 +854,9 @@ def plot_images(
|
|
|
860
854
|
|
|
861
855
|
@plt_settings()
|
|
862
856
|
def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
|
|
863
|
-
"""
|
|
864
|
-
|
|
865
|
-
|
|
857
|
+
"""Plot training results from a results CSV file. The function supports various types of data including
|
|
858
|
+
segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
|
|
859
|
+
CSV is located.
|
|
866
860
|
|
|
867
861
|
Args:
|
|
868
862
|
file (str, optional): Path to the CSV file containing the training results.
|
|
@@ -914,8 +908,7 @@ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Call
|
|
|
914
908
|
|
|
915
909
|
|
|
916
910
|
def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
|
|
917
|
-
"""
|
|
918
|
-
Plot a scatter plot with points colored based on a 2D histogram.
|
|
911
|
+
"""Plot a scatter plot with points colored based on a 2D histogram.
|
|
919
912
|
|
|
920
913
|
Args:
|
|
921
914
|
v (array-like): Values for the x-axis.
|
|
@@ -948,9 +941,9 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
|
|
|
948
941
|
|
|
949
942
|
@plt_settings()
|
|
950
943
|
def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
|
|
951
|
-
"""
|
|
952
|
-
|
|
953
|
-
|
|
944
|
+
"""Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
|
|
945
|
+
key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
|
|
946
|
+
the plots.
|
|
954
947
|
|
|
955
948
|
Args:
|
|
956
949
|
csv_file (str, optional): Path to the CSV file containing the tuning results.
|
|
@@ -979,6 +972,9 @@ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_p
|
|
|
979
972
|
if exclude_zero_fitness_points:
|
|
980
973
|
mask = fitness > 0 # exclude zero-fitness points
|
|
981
974
|
x, fitness = x[mask], fitness[mask]
|
|
975
|
+
if len(fitness) == 0:
|
|
976
|
+
LOGGER.warning("No valid fitness values to plot (all iterations may have failed)")
|
|
977
|
+
return
|
|
982
978
|
# Iterative sigma rejection on lower bound only
|
|
983
979
|
for _ in range(3): # max 3 iterations
|
|
984
980
|
mean, std = fitness.mean(), fitness.std()
|
|
@@ -1017,8 +1013,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_p
|
|
|
1017
1013
|
|
|
1018
1014
|
@plt_settings()
|
|
1019
1015
|
def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
|
|
1020
|
-
"""
|
|
1021
|
-
Visualize feature maps of a given model module during inference.
|
|
1016
|
+
"""Visualize feature maps of a given model module during inference.
|
|
1022
1017
|
|
|
1023
1018
|
Args:
|
|
1024
1019
|
x (torch.Tensor): Features to be visualized.
|