dgenerate-ultralytics-headless 8.3.141__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
- dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +12 -12
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +22 -19
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -158
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +13 -11
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +18 -12
- ultralytics/solutions/object_cropper.py +12 -5
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +215 -85
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +84 -42
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- dgenerate_ultralytics_headless-8.3.141.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
ultralytics/utils/patches.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
|
4
4
|
import time
|
5
5
|
from pathlib import Path
|
6
|
+
from typing import List, Optional
|
6
7
|
|
7
8
|
import cv2
|
8
9
|
import numpy as np
|
@@ -12,16 +13,16 @@ import torch
|
|
12
13
|
_imshow = cv2.imshow # copy to avoid recursion errors
|
13
14
|
|
14
15
|
|
15
|
-
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
16
|
+
def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> Optional[np.ndarray]:
|
16
17
|
"""
|
17
|
-
Read an image from a file.
|
18
|
+
Read an image from a file with multilanguage filename support.
|
18
19
|
|
19
20
|
Args:
|
20
21
|
filename (str): Path to the file to read.
|
21
|
-
flags (int): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
|
22
|
+
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
|
22
23
|
|
23
24
|
Returns:
|
24
|
-
(np.ndarray): The read image.
|
25
|
+
(np.ndarray | None): The read image array, or None if reading fails.
|
25
26
|
|
26
27
|
Examples:
|
27
28
|
>>> img = imread("path/to/image.jpg")
|
@@ -31,17 +32,17 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
|
31
32
|
if filename.endswith((".tiff", ".tif")):
|
32
33
|
success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
|
33
34
|
if success:
|
34
|
-
#
|
35
|
+
# Handle RGB images in tif/tiff format
|
35
36
|
return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
|
36
37
|
return None
|
37
38
|
else:
|
38
39
|
im = cv2.imdecode(file_bytes, flags)
|
39
|
-
return im[..., None] if im.ndim == 2 else im #
|
40
|
+
return im[..., None] if im.ndim == 2 else im # Always ensure 3 dimensions
|
40
41
|
|
41
42
|
|
42
|
-
def imwrite(filename: str, img: np.ndarray, params=None):
|
43
|
+
def imwrite(filename: str, img: np.ndarray, params: Optional[List[int]] = None) -> bool:
|
43
44
|
"""
|
44
|
-
Write an image to a file.
|
45
|
+
Write an image to a file with multilanguage filename support.
|
45
46
|
|
46
47
|
Args:
|
47
48
|
filename (str): Path to the file to write.
|
@@ -65,12 +66,12 @@ def imwrite(filename: str, img: np.ndarray, params=None):
|
|
65
66
|
return False
|
66
67
|
|
67
68
|
|
68
|
-
def imshow(winname: str, mat: np.ndarray):
|
69
|
+
def imshow(winname: str, mat: np.ndarray) -> None:
|
69
70
|
"""
|
70
|
-
Display an image in the specified window.
|
71
|
+
Display an image in the specified window with multilanguage window name support.
|
71
72
|
|
72
|
-
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It
|
73
|
-
|
73
|
+
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
|
74
|
+
multilanguage window names by encoding them properly for OpenCV compatibility.
|
74
75
|
|
75
76
|
Args:
|
76
77
|
winname (str): Name of the window where the image will be displayed. If a window with this name already
|
@@ -127,9 +128,6 @@ def torch_save(*args, **kwargs):
|
|
127
128
|
*args (Any): Positional arguments to pass to torch.save.
|
128
129
|
**kwargs (Any): Keyword arguments to pass to torch.save.
|
129
130
|
|
130
|
-
Returns:
|
131
|
-
(Any): Result of torch.save operation if successful, None otherwise.
|
132
|
-
|
133
131
|
Examples:
|
134
132
|
>>> model = torch.nn.Linear(10, 1)
|
135
133
|
>>> torch_save(model.state_dict(), "model.pt")
|
@@ -137,7 +135,7 @@ def torch_save(*args, **kwargs):
|
|
137
135
|
for i in range(4): # 3 retries
|
138
136
|
try:
|
139
137
|
return _torch_save(*args, **kwargs)
|
140
|
-
except RuntimeError as e: #
|
138
|
+
except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan
|
141
139
|
if i == 3:
|
142
140
|
raise e
|
143
|
-
time.sleep((2**i) / 2) #
|
141
|
+
time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s
|
ultralytics/utils/plotting.py
CHANGED
@@ -18,20 +18,21 @@ from ultralytics.utils.files import increment_path
|
|
18
18
|
|
19
19
|
class Colors:
|
20
20
|
"""
|
21
|
-
Ultralytics color palette
|
21
|
+
Ultralytics color palette for visualization and plotting.
|
22
22
|
|
23
23
|
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
|
24
|
-
RGB values.
|
24
|
+
RGB values and accessing predefined color schemes for object detection and pose estimation.
|
25
25
|
|
26
26
|
Attributes:
|
27
|
-
palette (List[
|
27
|
+
palette (List[tuple]): List of RGB color tuples for general use.
|
28
28
|
n (int): The number of colors in the palette.
|
29
29
|
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
30
30
|
|
31
31
|
Examples:
|
32
32
|
>>> from ultralytics.utils.plotting import Colors
|
33
33
|
>>> colors = Colors()
|
34
|
-
>>> colors(5, True) #
|
34
|
+
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
35
|
+
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
35
36
|
|
36
37
|
## Ultralytics Color Palette
|
37
38
|
|
@@ -85,7 +86,8 @@ class Colors:
|
|
85
86
|
|
86
87
|
!!! note "Ultralytics Brand Colors"
|
87
88
|
|
88
|
-
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
89
|
+
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
90
|
+
Please use the official Ultralytics colors for all marketing materials.
|
89
91
|
"""
|
90
92
|
|
91
93
|
def __init__(self):
|
@@ -140,13 +142,22 @@ class Colors:
|
|
140
142
|
dtype=np.uint8,
|
141
143
|
)
|
142
144
|
|
143
|
-
def __call__(self, i, bgr=False):
|
144
|
-
"""
|
145
|
+
def __call__(self, i: int, bgr: bool = False) -> tuple:
|
146
|
+
"""
|
147
|
+
Convert hex color codes to RGB values.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
i (int): Color index.
|
151
|
+
bgr (bool, optional): Whether to return BGR format instead of RGB.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
(tuple): RGB or BGR color tuple.
|
155
|
+
"""
|
145
156
|
c = self.palette[int(i) % self.n]
|
146
157
|
return (c[2], c[1], c[0]) if bgr else c
|
147
158
|
|
148
159
|
@staticmethod
|
149
|
-
def hex2rgb(h):
|
160
|
+
def hex2rgb(h: str) -> tuple:
|
150
161
|
"""Convert hex color codes to RGB values (i.e. default PIL order)."""
|
151
162
|
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
|
152
163
|
|
@@ -159,9 +170,9 @@ class Annotator:
|
|
159
170
|
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
160
171
|
|
161
172
|
Attributes:
|
162
|
-
im (Image.Image
|
173
|
+
im (Image.Image | np.ndarray): The image to annotate.
|
163
174
|
pil (bool): Whether to use PIL or cv2 for drawing annotations.
|
164
|
-
font (ImageFont.truetype
|
175
|
+
font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.
|
165
176
|
lw (float): Line width for drawing.
|
166
177
|
skeleton (List[List[int]]): Skeleton structure for keypoints.
|
167
178
|
limb_color (List[int]): Color palette for limbs.
|
@@ -173,9 +184,18 @@ class Annotator:
|
|
173
184
|
>>> from ultralytics.utils.plotting import Annotator
|
174
185
|
>>> im0 = cv2.imread("test.png")
|
175
186
|
>>> annotator = Annotator(im0, line_width=10)
|
187
|
+
>>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0))
|
176
188
|
"""
|
177
189
|
|
178
|
-
def __init__(
|
190
|
+
def __init__(
|
191
|
+
self,
|
192
|
+
im,
|
193
|
+
line_width: Optional[int] = None,
|
194
|
+
font_size: Optional[int] = None,
|
195
|
+
font: str = "Arial.ttf",
|
196
|
+
pil: bool = False,
|
197
|
+
example: str = "abc",
|
198
|
+
):
|
179
199
|
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
180
200
|
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
181
201
|
input_is_pil = isinstance(im, Image.Image)
|
@@ -254,7 +274,7 @@ class Annotator:
|
|
254
274
|
(104, 31, 17),
|
255
275
|
}
|
256
276
|
|
257
|
-
def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
|
277
|
+
def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
|
258
278
|
"""
|
259
279
|
Assign text color based on background color.
|
260
280
|
|
@@ -278,7 +298,7 @@ class Annotator:
|
|
278
298
|
else:
|
279
299
|
return txt_color
|
280
300
|
|
281
|
-
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)
|
301
|
+
def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
|
282
302
|
"""
|
283
303
|
Draw a bounding box on an image with a given label.
|
284
304
|
|
@@ -287,7 +307,6 @@ class Annotator:
|
|
287
307
|
label (str, optional): The text label to be displayed.
|
288
308
|
color (tuple, optional): The background color of the rectangle (B, G, R).
|
289
309
|
txt_color (tuple, optional): The color of the text (R, G, B).
|
290
|
-
rotated (bool, optional): Whether the task is oriented bounding box detection.
|
291
310
|
|
292
311
|
Examples:
|
293
312
|
>>> from ultralytics.utils.plotting import Annotator
|
@@ -298,13 +317,13 @@ class Annotator:
|
|
298
317
|
txt_color = self.get_txt_color(color, txt_color)
|
299
318
|
if isinstance(box, torch.Tensor):
|
300
319
|
box = box.tolist()
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
320
|
+
|
321
|
+
multi_points = isinstance(box[0], list) # multiple points with shape (n, 2)
|
322
|
+
p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1]))
|
323
|
+
if self.pil:
|
324
|
+
self.draw.polygon(
|
325
|
+
[tuple(b) for b in box], width=self.lw, outline=color
|
326
|
+
) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color)
|
308
327
|
if label:
|
309
328
|
w, h = self.font.getsize(label) # text width, height
|
310
329
|
outside = p1[1] >= h # label fits outside box
|
@@ -317,12 +336,11 @@ class Annotator:
|
|
317
336
|
# self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
318
337
|
self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
|
319
338
|
else: # cv2
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
339
|
+
cv2.polylines(
|
340
|
+
self.im, [np.asarray(box, dtype=int)], True, color, self.lw
|
341
|
+
) if multi_points else cv2.rectangle(
|
342
|
+
self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA
|
343
|
+
)
|
326
344
|
if label:
|
327
345
|
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
|
328
346
|
h += 3 # add pixels to pad text
|
@@ -342,7 +360,7 @@ class Annotator:
|
|
342
360
|
lineType=cv2.LINE_AA,
|
343
361
|
)
|
344
362
|
|
345
|
-
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
363
|
+
def masks(self, masks, colors, im_gpu, alpha: float = 0.5, retina_masks: bool = False):
|
346
364
|
"""
|
347
365
|
Plot masks on image.
|
348
366
|
|
@@ -378,7 +396,15 @@ class Annotator:
|
|
378
396
|
# Convert im back to PIL and update draw
|
379
397
|
self.fromarray(self.im)
|
380
398
|
|
381
|
-
def kpts(
|
399
|
+
def kpts(
|
400
|
+
self,
|
401
|
+
kpts,
|
402
|
+
shape: tuple = (640, 640),
|
403
|
+
radius: Optional[int] = None,
|
404
|
+
kpt_line: bool = True,
|
405
|
+
conf_thres: float = 0.25,
|
406
|
+
kpt_color: Optional[tuple] = None,
|
407
|
+
):
|
382
408
|
"""
|
383
409
|
Plot keypoints on the image.
|
384
410
|
|
@@ -438,11 +464,11 @@ class Annotator:
|
|
438
464
|
# Convert im back to PIL and update draw
|
439
465
|
self.fromarray(self.im)
|
440
466
|
|
441
|
-
def rectangle(self, xy, fill=None, outline=None, width=1):
|
467
|
+
def rectangle(self, xy, fill=None, outline=None, width: int = 1):
|
442
468
|
"""Add rectangle to image (PIL-only)."""
|
443
469
|
self.draw.rectangle(xy, fill, outline, width)
|
444
470
|
|
445
|
-
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_color=()):
|
471
|
+
def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
|
446
472
|
"""
|
447
473
|
Add text to an image using PIL or cv2.
|
448
474
|
|
@@ -482,7 +508,7 @@ class Annotator:
|
|
482
508
|
"""Return annotated image as array."""
|
483
509
|
return np.asarray(self.im)
|
484
510
|
|
485
|
-
def show(self, title=None):
|
511
|
+
def show(self, title: Optional[str] = None):
|
486
512
|
"""Show the annotated image."""
|
487
513
|
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
|
488
514
|
if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
|
@@ -493,12 +519,12 @@ class Annotator:
|
|
493
519
|
else:
|
494
520
|
im.show(title=title)
|
495
521
|
|
496
|
-
def save(self, filename="image.jpg"):
|
522
|
+
def save(self, filename: str = "image.jpg"):
|
497
523
|
"""Save the annotated image to 'filename'."""
|
498
524
|
cv2.imwrite(filename, np.asarray(self.im))
|
499
525
|
|
500
526
|
@staticmethod
|
501
|
-
def get_bbox_dimension(bbox=None):
|
527
|
+
def get_bbox_dimension(bbox: Optional[tuple] = None):
|
502
528
|
"""
|
503
529
|
Calculate the dimensions and area of a bounding box.
|
504
530
|
|
@@ -594,7 +620,16 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
594
620
|
on_plot(fname)
|
595
621
|
|
596
622
|
|
597
|
-
def save_one_box(
|
623
|
+
def save_one_box(
|
624
|
+
xyxy,
|
625
|
+
im,
|
626
|
+
file: Path = Path("im.jpg"),
|
627
|
+
gain: float = 1.02,
|
628
|
+
pad: int = 10,
|
629
|
+
square: bool = False,
|
630
|
+
BGR: bool = False,
|
631
|
+
save: bool = True,
|
632
|
+
):
|
598
633
|
"""
|
599
634
|
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
600
635
|
|
@@ -750,7 +785,7 @@ def plot_images(
|
|
750
785
|
c = names.get(c, c) if names else c
|
751
786
|
if labels or conf[j] > conf_thres:
|
752
787
|
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
753
|
-
annotator.box_label(box, label, color=color
|
788
|
+
annotator.box_label(box, label, color=color)
|
754
789
|
|
755
790
|
elif len(classes):
|
756
791
|
for c in classes:
|
@@ -810,7 +845,14 @@ def plot_images(
|
|
810
845
|
|
811
846
|
|
812
847
|
@plt_settings()
|
813
|
-
def plot_results(
|
848
|
+
def plot_results(
|
849
|
+
file: str = "path/to/results.csv",
|
850
|
+
dir: str = "",
|
851
|
+
segment: bool = False,
|
852
|
+
pose: bool = False,
|
853
|
+
classify: bool = False,
|
854
|
+
on_plot: Optional[Callable] = None,
|
855
|
+
):
|
814
856
|
"""
|
815
857
|
Plot training results from a results CSV file. The function supports various types of data including segmentation,
|
816
858
|
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
|
@@ -870,7 +912,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|
870
912
|
on_plot(fname)
|
871
913
|
|
872
914
|
|
873
|
-
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
|
915
|
+
def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
|
874
916
|
"""
|
875
917
|
Plot a scatter plot with points colored based on a 2D histogram.
|
876
918
|
|
@@ -903,7 +945,7 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
|
|
903
945
|
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
|
904
946
|
|
905
947
|
|
906
|
-
def plot_tune_results(csv_file="tune_results.csv"):
|
948
|
+
def plot_tune_results(csv_file: str = "tune_results.csv"):
|
907
949
|
"""
|
908
950
|
Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
|
909
951
|
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
|
@@ -959,7 +1001,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
|
|
959
1001
|
_save_one_file(csv_file.with_name("tune_fitness.png"))
|
960
1002
|
|
961
1003
|
|
962
|
-
def output_to_target(output, max_det=300):
|
1004
|
+
def output_to_target(output, max_det: int = 300):
|
963
1005
|
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
964
1006
|
targets = []
|
965
1007
|
for i, o in enumerate(output):
|
@@ -970,7 +1012,7 @@ def output_to_target(output, max_det=300):
|
|
970
1012
|
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
971
1013
|
|
972
1014
|
|
973
|
-
def output_to_rotated_target(output, max_det=300):
|
1015
|
+
def output_to_rotated_target(output, max_det: int = 300):
|
974
1016
|
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
975
1017
|
targets = []
|
976
1018
|
for i, o in enumerate(output):
|
@@ -981,7 +1023,7 @@ def output_to_rotated_target(output, max_det=300):
|
|
981
1023
|
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
982
1024
|
|
983
1025
|
|
984
|
-
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
|
1026
|
+
def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
|
985
1027
|
"""
|
986
1028
|
Visualize feature maps of a given model module during inference.
|
987
1029
|
|
ultralytics/utils/tal.py
CHANGED
@@ -26,8 +26,17 @@ class TaskAlignedAssigner(nn.Module):
|
|
26
26
|
eps (float): A small value to prevent division by zero.
|
27
27
|
"""
|
28
28
|
|
29
|
-
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
30
|
-
"""
|
29
|
+
def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
|
30
|
+
"""
|
31
|
+
Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
topk (int, optional): The number of top candidates to consider.
|
35
|
+
num_classes (int, optional): The number of object classes.
|
36
|
+
alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
|
37
|
+
beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
|
38
|
+
eps (float, optional): A small value to prevent division by zero.
|
39
|
+
"""
|
31
40
|
super().__init__()
|
32
41
|
self.topk = topk
|
33
42
|
self.num_classes = num_classes
|
@@ -196,12 +205,11 @@ class TaskAlignedAssigner(nn.Module):
|
|
196
205
|
Select the top-k candidates based on the given metrics.
|
197
206
|
|
198
207
|
Args:
|
199
|
-
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
the top-k values are automatically computed based on the given metrics.
|
208
|
+
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
|
209
|
+
the maximum number of objects, and h*w represents the total number of anchor points.
|
210
|
+
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
211
|
+
topk is the number of top candidates to consider. If not provided, the top-k values are automatically
|
212
|
+
computed based on the given metrics.
|
205
213
|
|
206
214
|
Returns:
|
207
215
|
(torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
@@ -239,11 +247,9 @@ class TaskAlignedAssigner(nn.Module):
|
|
239
247
|
(foreground) anchor points.
|
240
248
|
|
241
249
|
Returns:
|
242
|
-
target_labels (torch.Tensor):
|
243
|
-
target_bboxes (torch.Tensor):
|
244
|
-
|
245
|
-
target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
|
246
|
-
anchor points.
|
250
|
+
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
|
251
|
+
target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
|
252
|
+
target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
|
247
253
|
"""
|
248
254
|
# Assigned target labels, (b, 1)
|
249
255
|
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
@@ -277,7 +283,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
277
283
|
Args:
|
278
284
|
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
279
285
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
|
280
|
-
eps (float, optional): Small value for numerical stability.
|
286
|
+
eps (float, optional): Small value for numerical stability.
|
281
287
|
|
282
288
|
Returns:
|
283
289
|
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
@@ -399,7 +405,7 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
399
405
|
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|
400
406
|
pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
|
401
407
|
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
|
402
|
-
dim (int, optional): Dimension along which to split.
|
408
|
+
dim (int, optional): Dimension along which to split.
|
403
409
|
|
404
410
|
Returns:
|
405
411
|
(torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
|