ultralytics 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +52 -51
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +191 -161
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +4 -6
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +2 -2
- ultralytics/solutions/instance_segmentation.py +7 -4
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -11
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +189 -79
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +45 -29
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/METADATA +2 -2
- ultralytics-8.3.145.dist-info/RECORD +272 -0
- ultralytics-8.3.143.dist-info/RECORD +0 -272
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/top_level.txt +0 -0
ultralytics/utils/patches.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
|
4
4
|
import time
|
5
5
|
from pathlib import Path
|
6
|
+
from typing import List, Optional
|
6
7
|
|
7
8
|
import cv2
|
8
9
|
import numpy as np
|
@@ -12,16 +13,16 @@ import torch
|
|
12
13
|
_imshow = cv2.imshow # copy to avoid recursion errors
|
13
14
|
|
14
15
|
|
15
|
-
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
16
|
+
def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> Optional[np.ndarray]:
|
16
17
|
"""
|
17
|
-
Read an image from a file.
|
18
|
+
Read an image from a file with multilanguage filename support.
|
18
19
|
|
19
20
|
Args:
|
20
21
|
filename (str): Path to the file to read.
|
21
|
-
flags (int): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
|
22
|
+
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
|
22
23
|
|
23
24
|
Returns:
|
24
|
-
(np.ndarray): The read image.
|
25
|
+
(np.ndarray | None): The read image array, or None if reading fails.
|
25
26
|
|
26
27
|
Examples:
|
27
28
|
>>> img = imread("path/to/image.jpg")
|
@@ -31,17 +32,17 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
|
31
32
|
if filename.endswith((".tiff", ".tif")):
|
32
33
|
success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
|
33
34
|
if success:
|
34
|
-
#
|
35
|
+
# Handle RGB images in tif/tiff format
|
35
36
|
return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
|
36
37
|
return None
|
37
38
|
else:
|
38
39
|
im = cv2.imdecode(file_bytes, flags)
|
39
|
-
return im[..., None] if im.ndim == 2 else im #
|
40
|
+
return im[..., None] if im.ndim == 2 else im # Always ensure 3 dimensions
|
40
41
|
|
41
42
|
|
42
|
-
def imwrite(filename: str, img: np.ndarray, params=None):
|
43
|
+
def imwrite(filename: str, img: np.ndarray, params: Optional[List[int]] = None) -> bool:
|
43
44
|
"""
|
44
|
-
Write an image to a file.
|
45
|
+
Write an image to a file with multilanguage filename support.
|
45
46
|
|
46
47
|
Args:
|
47
48
|
filename (str): Path to the file to write.
|
@@ -65,12 +66,12 @@ def imwrite(filename: str, img: np.ndarray, params=None):
|
|
65
66
|
return False
|
66
67
|
|
67
68
|
|
68
|
-
def imshow(winname: str, mat: np.ndarray):
|
69
|
+
def imshow(winname: str, mat: np.ndarray) -> None:
|
69
70
|
"""
|
70
|
-
Display an image in the specified window.
|
71
|
+
Display an image in the specified window with multilanguage window name support.
|
71
72
|
|
72
|
-
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It
|
73
|
-
|
73
|
+
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
|
74
|
+
multilanguage window names by encoding them properly for OpenCV compatibility.
|
74
75
|
|
75
76
|
Args:
|
76
77
|
winname (str): Name of the window where the image will be displayed. If a window with this name already
|
@@ -127,9 +128,6 @@ def torch_save(*args, **kwargs):
|
|
127
128
|
*args (Any): Positional arguments to pass to torch.save.
|
128
129
|
**kwargs (Any): Keyword arguments to pass to torch.save.
|
129
130
|
|
130
|
-
Returns:
|
131
|
-
(Any): Result of torch.save operation if successful, None otherwise.
|
132
|
-
|
133
131
|
Examples:
|
134
132
|
>>> model = torch.nn.Linear(10, 1)
|
135
133
|
>>> torch_save(model.state_dict(), "model.pt")
|
@@ -137,7 +135,7 @@ def torch_save(*args, **kwargs):
|
|
137
135
|
for i in range(4): # 3 retries
|
138
136
|
try:
|
139
137
|
return _torch_save(*args, **kwargs)
|
140
|
-
except RuntimeError as e: #
|
138
|
+
except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan
|
141
139
|
if i == 3:
|
142
140
|
raise e
|
143
|
-
time.sleep((2**i) / 2) #
|
141
|
+
time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s
|
ultralytics/utils/plotting.py
CHANGED
@@ -18,20 +18,21 @@ from ultralytics.utils.files import increment_path
|
|
18
18
|
|
19
19
|
class Colors:
|
20
20
|
"""
|
21
|
-
Ultralytics color palette
|
21
|
+
Ultralytics color palette for visualization and plotting.
|
22
22
|
|
23
23
|
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
|
24
|
-
RGB values.
|
24
|
+
RGB values and accessing predefined color schemes for object detection and pose estimation.
|
25
25
|
|
26
26
|
Attributes:
|
27
|
-
palette (List[
|
27
|
+
palette (List[tuple]): List of RGB color tuples for general use.
|
28
28
|
n (int): The number of colors in the palette.
|
29
29
|
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
|
30
30
|
|
31
31
|
Examples:
|
32
32
|
>>> from ultralytics.utils.plotting import Colors
|
33
33
|
>>> colors = Colors()
|
34
|
-
>>> colors(5, True) #
|
34
|
+
>>> colors(5, True) # Returns BGR format: (221, 111, 255)
|
35
|
+
>>> colors(5, False) # Returns RGB format: (255, 111, 221)
|
35
36
|
|
36
37
|
## Ultralytics Color Palette
|
37
38
|
|
@@ -85,7 +86,8 @@ class Colors:
|
|
85
86
|
|
86
87
|
!!! note "Ultralytics Brand Colors"
|
87
88
|
|
88
|
-
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
89
|
+
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
|
90
|
+
Please use the official Ultralytics colors for all marketing materials.
|
89
91
|
"""
|
90
92
|
|
91
93
|
def __init__(self):
|
@@ -140,13 +142,22 @@ class Colors:
|
|
140
142
|
dtype=np.uint8,
|
141
143
|
)
|
142
144
|
|
143
|
-
def __call__(self, i, bgr=False):
|
144
|
-
"""
|
145
|
+
def __call__(self, i: int, bgr: bool = False) -> tuple:
|
146
|
+
"""
|
147
|
+
Convert hex color codes to RGB values.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
i (int): Color index.
|
151
|
+
bgr (bool, optional): Whether to return BGR format instead of RGB.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
(tuple): RGB or BGR color tuple.
|
155
|
+
"""
|
145
156
|
c = self.palette[int(i) % self.n]
|
146
157
|
return (c[2], c[1], c[0]) if bgr else c
|
147
158
|
|
148
159
|
@staticmethod
|
149
|
-
def hex2rgb(h):
|
160
|
+
def hex2rgb(h: str) -> tuple:
|
150
161
|
"""Convert hex color codes to RGB values (i.e. default PIL order)."""
|
151
162
|
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
|
152
163
|
|
@@ -159,9 +170,9 @@ class Annotator:
|
|
159
170
|
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
160
171
|
|
161
172
|
Attributes:
|
162
|
-
im (Image.Image
|
173
|
+
im (Image.Image | np.ndarray): The image to annotate.
|
163
174
|
pil (bool): Whether to use PIL or cv2 for drawing annotations.
|
164
|
-
font (ImageFont.truetype
|
175
|
+
font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.
|
165
176
|
lw (float): Line width for drawing.
|
166
177
|
skeleton (List[List[int]]): Skeleton structure for keypoints.
|
167
178
|
limb_color (List[int]): Color palette for limbs.
|
@@ -173,9 +184,18 @@ class Annotator:
|
|
173
184
|
>>> from ultralytics.utils.plotting import Annotator
|
174
185
|
>>> im0 = cv2.imread("test.png")
|
175
186
|
>>> annotator = Annotator(im0, line_width=10)
|
187
|
+
>>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0))
|
176
188
|
"""
|
177
189
|
|
178
|
-
def __init__(
|
190
|
+
def __init__(
|
191
|
+
self,
|
192
|
+
im,
|
193
|
+
line_width: Optional[int] = None,
|
194
|
+
font_size: Optional[int] = None,
|
195
|
+
font: str = "Arial.ttf",
|
196
|
+
pil: bool = False,
|
197
|
+
example: str = "abc",
|
198
|
+
):
|
179
199
|
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
180
200
|
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
181
201
|
input_is_pil = isinstance(im, Image.Image)
|
@@ -254,7 +274,7 @@ class Annotator:
|
|
254
274
|
(104, 31, 17),
|
255
275
|
}
|
256
276
|
|
257
|
-
def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
|
277
|
+
def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
|
258
278
|
"""
|
259
279
|
Assign text color based on background color.
|
260
280
|
|
@@ -278,7 +298,7 @@ class Annotator:
|
|
278
298
|
else:
|
279
299
|
return txt_color
|
280
300
|
|
281
|
-
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)):
|
301
|
+
def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
|
282
302
|
"""
|
283
303
|
Draw a bounding box on an image with a given label.
|
284
304
|
|
@@ -340,7 +360,7 @@ class Annotator:
|
|
340
360
|
lineType=cv2.LINE_AA,
|
341
361
|
)
|
342
362
|
|
343
|
-
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
363
|
+
def masks(self, masks, colors, im_gpu, alpha: float = 0.5, retina_masks: bool = False):
|
344
364
|
"""
|
345
365
|
Plot masks on image.
|
346
366
|
|
@@ -376,7 +396,15 @@ class Annotator:
|
|
376
396
|
# Convert im back to PIL and update draw
|
377
397
|
self.fromarray(self.im)
|
378
398
|
|
379
|
-
def kpts(
|
399
|
+
def kpts(
|
400
|
+
self,
|
401
|
+
kpts,
|
402
|
+
shape: tuple = (640, 640),
|
403
|
+
radius: Optional[int] = None,
|
404
|
+
kpt_line: bool = True,
|
405
|
+
conf_thres: float = 0.25,
|
406
|
+
kpt_color: Optional[tuple] = None,
|
407
|
+
):
|
380
408
|
"""
|
381
409
|
Plot keypoints on the image.
|
382
410
|
|
@@ -436,11 +464,11 @@ class Annotator:
|
|
436
464
|
# Convert im back to PIL and update draw
|
437
465
|
self.fromarray(self.im)
|
438
466
|
|
439
|
-
def rectangle(self, xy, fill=None, outline=None, width=1):
|
467
|
+
def rectangle(self, xy, fill=None, outline=None, width: int = 1):
|
440
468
|
"""Add rectangle to image (PIL-only)."""
|
441
469
|
self.draw.rectangle(xy, fill, outline, width)
|
442
470
|
|
443
|
-
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_color=()):
|
471
|
+
def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
|
444
472
|
"""
|
445
473
|
Add text to an image using PIL or cv2.
|
446
474
|
|
@@ -480,7 +508,7 @@ class Annotator:
|
|
480
508
|
"""Return annotated image as array."""
|
481
509
|
return np.asarray(self.im)
|
482
510
|
|
483
|
-
def show(self, title=None):
|
511
|
+
def show(self, title: Optional[str] = None):
|
484
512
|
"""Show the annotated image."""
|
485
513
|
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
|
486
514
|
if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
|
@@ -491,12 +519,12 @@ class Annotator:
|
|
491
519
|
else:
|
492
520
|
im.show(title=title)
|
493
521
|
|
494
|
-
def save(self, filename="image.jpg"):
|
522
|
+
def save(self, filename: str = "image.jpg"):
|
495
523
|
"""Save the annotated image to 'filename'."""
|
496
524
|
cv2.imwrite(filename, np.asarray(self.im))
|
497
525
|
|
498
526
|
@staticmethod
|
499
|
-
def get_bbox_dimension(bbox=None):
|
527
|
+
def get_bbox_dimension(bbox: Optional[tuple] = None):
|
500
528
|
"""
|
501
529
|
Calculate the dimensions and area of a bounding box.
|
502
530
|
|
@@ -592,7 +620,16 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
592
620
|
on_plot(fname)
|
593
621
|
|
594
622
|
|
595
|
-
def save_one_box(
|
623
|
+
def save_one_box(
|
624
|
+
xyxy,
|
625
|
+
im,
|
626
|
+
file: Path = Path("im.jpg"),
|
627
|
+
gain: float = 1.02,
|
628
|
+
pad: int = 10,
|
629
|
+
square: bool = False,
|
630
|
+
BGR: bool = False,
|
631
|
+
save: bool = True,
|
632
|
+
):
|
596
633
|
"""
|
597
634
|
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
598
635
|
|
@@ -808,7 +845,14 @@ def plot_images(
|
|
808
845
|
|
809
846
|
|
810
847
|
@plt_settings()
|
811
|
-
def plot_results(
|
848
|
+
def plot_results(
|
849
|
+
file: str = "path/to/results.csv",
|
850
|
+
dir: str = "",
|
851
|
+
segment: bool = False,
|
852
|
+
pose: bool = False,
|
853
|
+
classify: bool = False,
|
854
|
+
on_plot: Optional[Callable] = None,
|
855
|
+
):
|
812
856
|
"""
|
813
857
|
Plot training results from a results CSV file. The function supports various types of data including segmentation,
|
814
858
|
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
|
@@ -868,7 +912,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|
868
912
|
on_plot(fname)
|
869
913
|
|
870
914
|
|
871
|
-
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
|
915
|
+
def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
|
872
916
|
"""
|
873
917
|
Plot a scatter plot with points colored based on a 2D histogram.
|
874
918
|
|
@@ -901,7 +945,7 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
|
|
901
945
|
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
|
902
946
|
|
903
947
|
|
904
|
-
def plot_tune_results(csv_file="tune_results.csv"):
|
948
|
+
def plot_tune_results(csv_file: str = "tune_results.csv"):
|
905
949
|
"""
|
906
950
|
Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
|
907
951
|
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
|
@@ -957,7 +1001,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
|
|
957
1001
|
_save_one_file(csv_file.with_name("tune_fitness.png"))
|
958
1002
|
|
959
1003
|
|
960
|
-
def output_to_target(output, max_det=300):
|
1004
|
+
def output_to_target(output, max_det: int = 300):
|
961
1005
|
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
962
1006
|
targets = []
|
963
1007
|
for i, o in enumerate(output):
|
@@ -968,7 +1012,7 @@ def output_to_target(output, max_det=300):
|
|
968
1012
|
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
969
1013
|
|
970
1014
|
|
971
|
-
def output_to_rotated_target(output, max_det=300):
|
1015
|
+
def output_to_rotated_target(output, max_det: int = 300):
|
972
1016
|
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
973
1017
|
targets = []
|
974
1018
|
for i, o in enumerate(output):
|
@@ -979,7 +1023,7 @@ def output_to_rotated_target(output, max_det=300):
|
|
979
1023
|
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
|
980
1024
|
|
981
1025
|
|
982
|
-
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
|
1026
|
+
def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
|
983
1027
|
"""
|
984
1028
|
Visualize feature maps of a given model module during inference.
|
985
1029
|
|
ultralytics/utils/tal.py
CHANGED
@@ -26,8 +26,17 @@ class TaskAlignedAssigner(nn.Module):
|
|
26
26
|
eps (float): A small value to prevent division by zero.
|
27
27
|
"""
|
28
28
|
|
29
|
-
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
30
|
-
"""
|
29
|
+
def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
|
30
|
+
"""
|
31
|
+
Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
topk (int, optional): The number of top candidates to consider.
|
35
|
+
num_classes (int, optional): The number of object classes.
|
36
|
+
alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
|
37
|
+
beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
|
38
|
+
eps (float, optional): A small value to prevent division by zero.
|
39
|
+
"""
|
31
40
|
super().__init__()
|
32
41
|
self.topk = topk
|
33
42
|
self.num_classes = num_classes
|
@@ -196,12 +205,11 @@ class TaskAlignedAssigner(nn.Module):
|
|
196
205
|
Select the top-k candidates based on the given metrics.
|
197
206
|
|
198
207
|
Args:
|
199
|
-
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
the top-k values are automatically computed based on the given metrics.
|
208
|
+
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
|
209
|
+
the maximum number of objects, and h*w represents the total number of anchor points.
|
210
|
+
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
211
|
+
topk is the number of top candidates to consider. If not provided, the top-k values are automatically
|
212
|
+
computed based on the given metrics.
|
205
213
|
|
206
214
|
Returns:
|
207
215
|
(torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
@@ -239,11 +247,9 @@ class TaskAlignedAssigner(nn.Module):
|
|
239
247
|
(foreground) anchor points.
|
240
248
|
|
241
249
|
Returns:
|
242
|
-
target_labels (torch.Tensor):
|
243
|
-
target_bboxes (torch.Tensor):
|
244
|
-
|
245
|
-
target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
|
246
|
-
anchor points.
|
250
|
+
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
|
251
|
+
target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
|
252
|
+
target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
|
247
253
|
"""
|
248
254
|
# Assigned target labels, (b, 1)
|
249
255
|
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
@@ -277,7 +283,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
277
283
|
Args:
|
278
284
|
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
279
285
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
|
280
|
-
eps (float, optional): Small value for numerical stability.
|
286
|
+
eps (float, optional): Small value for numerical stability.
|
281
287
|
|
282
288
|
Returns:
|
283
289
|
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
@@ -399,7 +405,7 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
399
405
|
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|
400
406
|
pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
|
401
407
|
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
|
402
|
-
dim (int, optional): Dimension along which to split.
|
408
|
+
dim (int, optional): Dimension along which to split.
|
403
409
|
|
404
410
|
Returns:
|
405
411
|
(torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
|