ultralytics 8.3.189__py3-none-any.whl → 8.3.191__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_cuda.py +6 -5
- tests/test_exports.py +1 -6
- tests/test_python.py +1 -4
- tests/test_solutions.py +1 -1
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -14
- ultralytics/cfg/datasets/VisDrone.yaml +4 -4
- ultralytics/data/annotator.py +6 -6
- ultralytics/data/augment.py +53 -51
- ultralytics/data/base.py +15 -13
- ultralytics/data/build.py +7 -4
- ultralytics/data/converter.py +9 -10
- ultralytics/data/dataset.py +24 -22
- ultralytics/data/loaders.py +13 -11
- ultralytics/data/split.py +4 -3
- ultralytics/data/split_dota.py +14 -12
- ultralytics/data/utils.py +31 -25
- ultralytics/engine/exporter.py +7 -4
- ultralytics/engine/model.py +16 -14
- ultralytics/engine/predictor.py +9 -7
- ultralytics/engine/results.py +59 -57
- ultralytics/engine/trainer.py +7 -0
- ultralytics/engine/tuner.py +4 -3
- ultralytics/engine/validator.py +3 -1
- ultralytics/hub/__init__.py +6 -2
- ultralytics/hub/auth.py +2 -2
- ultralytics/hub/google/__init__.py +9 -8
- ultralytics/hub/session.py +11 -11
- ultralytics/hub/utils.py +8 -9
- ultralytics/models/fastsam/model.py +8 -6
- ultralytics/models/nas/model.py +5 -3
- ultralytics/models/rtdetr/train.py +4 -3
- ultralytics/models/rtdetr/val.py +6 -4
- ultralytics/models/sam/amg.py +13 -10
- ultralytics/models/sam/model.py +3 -2
- ultralytics/models/sam/modules/blocks.py +21 -21
- ultralytics/models/sam/modules/decoders.py +11 -11
- ultralytics/models/sam/modules/encoders.py +25 -25
- ultralytics/models/sam/modules/memory_attention.py +9 -8
- ultralytics/models/sam/modules/sam.py +8 -10
- ultralytics/models/sam/modules/tiny_encoder.py +21 -20
- ultralytics/models/sam/modules/transformer.py +6 -5
- ultralytics/models/sam/modules/utils.py +7 -5
- ultralytics/models/sam/predict.py +32 -31
- ultralytics/models/utils/loss.py +29 -27
- ultralytics/models/utils/ops.py +10 -8
- ultralytics/models/yolo/classify/train.py +7 -5
- ultralytics/models/yolo/classify/val.py +10 -8
- ultralytics/models/yolo/detect/predict.py +3 -3
- ultralytics/models/yolo/detect/train.py +8 -6
- ultralytics/models/yolo/detect/val.py +23 -21
- ultralytics/models/yolo/model.py +14 -14
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +13 -10
- ultralytics/models/yolo/pose/train.py +7 -5
- ultralytics/models/yolo/pose/val.py +11 -9
- ultralytics/models/yolo/segment/train.py +4 -5
- ultralytics/models/yolo/segment/val.py +12 -10
- ultralytics/models/yolo/world/train.py +9 -7
- ultralytics/models/yolo/yoloe/train.py +7 -6
- ultralytics/models/yolo/yoloe/val.py +10 -8
- ultralytics/nn/autobackend.py +40 -52
- ultralytics/nn/modules/__init__.py +3 -3
- ultralytics/nn/modules/block.py +12 -12
- ultralytics/nn/modules/conv.py +4 -3
- ultralytics/nn/modules/head.py +46 -38
- ultralytics/nn/modules/transformer.py +22 -21
- ultralytics/nn/tasks.py +2 -2
- ultralytics/nn/text_model.py +6 -5
- ultralytics/solutions/analytics.py +7 -5
- ultralytics/solutions/config.py +12 -10
- ultralytics/solutions/distance_calculation.py +3 -3
- ultralytics/solutions/heatmap.py +4 -2
- ultralytics/solutions/object_counter.py +5 -3
- ultralytics/solutions/parking_management.py +4 -2
- ultralytics/solutions/region_counter.py +7 -5
- ultralytics/solutions/similarity_search.py +5 -3
- ultralytics/solutions/solutions.py +38 -36
- ultralytics/solutions/streamlit_inference.py +8 -7
- ultralytics/trackers/bot_sort.py +11 -9
- ultralytics/trackers/byte_tracker.py +17 -15
- ultralytics/trackers/utils/gmc.py +4 -3
- ultralytics/utils/__init__.py +27 -77
- ultralytics/utils/autobatch.py +3 -2
- ultralytics/utils/autodevice.py +10 -10
- ultralytics/utils/benchmarks.py +11 -10
- ultralytics/utils/callbacks/comet.py +9 -9
- ultralytics/utils/callbacks/platform.py +2 -1
- ultralytics/utils/checks.py +20 -29
- ultralytics/utils/downloads.py +2 -2
- ultralytics/utils/export.py +12 -11
- ultralytics/utils/files.py +8 -7
- ultralytics/utils/git.py +139 -0
- ultralytics/utils/instance.py +8 -7
- ultralytics/utils/logger.py +7 -6
- ultralytics/utils/loss.py +15 -13
- ultralytics/utils/metrics.py +62 -62
- ultralytics/utils/nms.py +346 -0
- ultralytics/utils/ops.py +83 -251
- ultralytics/utils/patches.py +6 -4
- ultralytics/utils/plotting.py +18 -16
- ultralytics/utils/tal.py +1 -1
- ultralytics/utils/torch_utils.py +4 -2
- ultralytics/utils/tqdm.py +47 -33
- ultralytics/utils/triton.py +3 -2
- {ultralytics-8.3.189.dist-info → ultralytics-8.3.191.dist-info}/METADATA +1 -1
- {ultralytics-8.3.189.dist-info → ultralytics-8.3.191.dist-info}/RECORD +111 -109
- {ultralytics-8.3.189.dist-info → ultralytics-8.3.191.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.189.dist-info → ultralytics-8.3.191.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.189.dist-info → ultralytics-8.3.191.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.189.dist-info → ultralytics-8.3.191.dist-info}/top_level.txt +0 -0
ultralytics/data/utils.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import json
|
4
6
|
import os
|
5
7
|
import random
|
@@ -9,7 +11,7 @@ import zipfile
|
|
9
11
|
from multiprocessing.pool import ThreadPool
|
10
12
|
from pathlib import Path
|
11
13
|
from tarfile import is_tarfile
|
12
|
-
from typing import Any
|
14
|
+
from typing import Any
|
13
15
|
|
14
16
|
import cv2
|
15
17
|
import numpy as np
|
@@ -39,14 +41,14 @@ VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "
|
|
39
41
|
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
40
42
|
|
41
43
|
|
42
|
-
def img2label_paths(img_paths:
|
44
|
+
def img2label_paths(img_paths: list[str]) -> list[str]:
|
43
45
|
"""Convert image paths to label paths by replacing 'images' with 'labels' and extension with '.txt'."""
|
44
46
|
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
45
47
|
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
46
48
|
|
47
49
|
|
48
50
|
def check_file_speeds(
|
49
|
-
files:
|
51
|
+
files: list[str], threshold_ms: float = 10, threshold_mb: float = 50, max_files: int = 5, prefix: str = ""
|
50
52
|
):
|
51
53
|
"""
|
52
54
|
Check dataset file access speed and provide performance feedback.
|
@@ -66,7 +68,7 @@ def check_file_speeds(
|
|
66
68
|
>>> image_files = list(Path("dataset/images").glob("*.jpg"))
|
67
69
|
>>> check_file_speeds(image_files, threshold_ms=15)
|
68
70
|
"""
|
69
|
-
if not files
|
71
|
+
if not files:
|
70
72
|
LOGGER.warning(f"{prefix}Image speed checks: No files to check")
|
71
73
|
return
|
72
74
|
|
@@ -123,7 +125,7 @@ def check_file_speeds(
|
|
123
125
|
)
|
124
126
|
|
125
127
|
|
126
|
-
def get_hash(paths:
|
128
|
+
def get_hash(paths: list[str]) -> str:
|
127
129
|
"""Return a single hash value of a list of paths (files or dirs)."""
|
128
130
|
size = 0
|
129
131
|
for p in paths:
|
@@ -136,7 +138,7 @@ def get_hash(paths: List[str]) -> str:
|
|
136
138
|
return h.hexdigest() # return hash
|
137
139
|
|
138
140
|
|
139
|
-
def exif_size(img: Image.Image) ->
|
141
|
+
def exif_size(img: Image.Image) -> tuple[int, int]:
|
140
142
|
"""Return exif-corrected PIL size."""
|
141
143
|
s = img.size # (width, height)
|
142
144
|
if img.format == "JPEG": # only support JPEG images
|
@@ -150,7 +152,7 @@ def exif_size(img: Image.Image) -> Tuple[int, int]:
|
|
150
152
|
return s
|
151
153
|
|
152
154
|
|
153
|
-
def verify_image(args:
|
155
|
+
def verify_image(args: tuple) -> tuple:
|
154
156
|
"""Verify one image."""
|
155
157
|
(im_file, cls), prefix = args
|
156
158
|
# Number (found, corrupt), message
|
@@ -175,7 +177,7 @@ def verify_image(args: Tuple) -> Tuple:
|
|
175
177
|
return (im_file, cls), nf, nc, msg
|
176
178
|
|
177
179
|
|
178
|
-
def verify_image_label(args:
|
180
|
+
def verify_image_label(args: tuple) -> list:
|
179
181
|
"""Verify one image-label pair."""
|
180
182
|
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim, single_cls = args
|
181
183
|
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
@@ -247,7 +249,7 @@ def verify_image_label(args: Tuple) -> List:
|
|
247
249
|
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
248
250
|
|
249
251
|
|
250
|
-
def visualize_image_annotations(image_path: str, txt_path: str, label_map:
|
252
|
+
def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[int, str]):
|
251
253
|
"""
|
252
254
|
Visualize YOLO annotations (bounding boxes and class labels) on an image.
|
253
255
|
|
@@ -292,7 +294,7 @@ def visualize_image_annotations(image_path: str, txt_path: str, label_map: Dict[
|
|
292
294
|
|
293
295
|
|
294
296
|
def polygon2mask(
|
295
|
-
imgsz:
|
297
|
+
imgsz: tuple[int, int], polygons: list[np.ndarray], color: int = 1, downsample_ratio: int = 1
|
296
298
|
) -> np.ndarray:
|
297
299
|
"""
|
298
300
|
Convert a list of polygons to a binary mask of the specified image size.
|
@@ -317,7 +319,7 @@ def polygon2mask(
|
|
317
319
|
|
318
320
|
|
319
321
|
def polygons2masks(
|
320
|
-
imgsz:
|
322
|
+
imgsz: tuple[int, int], polygons: list[np.ndarray], color: int, downsample_ratio: int = 1
|
321
323
|
) -> np.ndarray:
|
322
324
|
"""
|
323
325
|
Convert a list of polygons to a set of binary masks of the specified image size.
|
@@ -336,8 +338,8 @@ def polygons2masks(
|
|
336
338
|
|
337
339
|
|
338
340
|
def polygons2masks_overlap(
|
339
|
-
imgsz:
|
340
|
-
) ->
|
341
|
+
imgsz: tuple[int, int], segments: list[np.ndarray], downsample_ratio: int = 1
|
342
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
341
343
|
"""Return a (640, 640) overlap mask."""
|
342
344
|
masks = np.zeros(
|
343
345
|
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
@@ -345,8 +347,13 @@ def polygons2masks_overlap(
|
|
345
347
|
)
|
346
348
|
areas = []
|
347
349
|
ms = []
|
348
|
-
for
|
349
|
-
mask = polygon2mask(
|
350
|
+
for segment in segments:
|
351
|
+
mask = polygon2mask(
|
352
|
+
imgsz,
|
353
|
+
[segment.reshape(-1)],
|
354
|
+
downsample_ratio=downsample_ratio,
|
355
|
+
color=1,
|
356
|
+
)
|
350
357
|
ms.append(mask.astype(masks.dtype))
|
351
358
|
areas.append(mask.sum())
|
352
359
|
areas = np.asarray(areas)
|
@@ -380,7 +387,7 @@ def find_dataset_yaml(path: Path) -> Path:
|
|
380
387
|
return files[0]
|
381
388
|
|
382
389
|
|
383
|
-
def check_det_dataset(dataset: str, autodownload: bool = True) ->
|
390
|
+
def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]:
|
384
391
|
"""
|
385
392
|
Download, verify, and/or unzip a dataset if not found locally.
|
386
393
|
|
@@ -464,7 +471,7 @@ def check_det_dataset(dataset: str, autodownload: bool = True) -> Dict[str, Any]
|
|
464
471
|
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
465
472
|
elif s.startswith("bash "): # bash script
|
466
473
|
LOGGER.info(f"Running {s} ...")
|
467
|
-
|
474
|
+
subprocess.run(s.split(), check=True)
|
468
475
|
else: # python script
|
469
476
|
exec(s, {"yaml": data})
|
470
477
|
dt = f"({round(time.time() - t, 1)}s)"
|
@@ -475,7 +482,7 @@ def check_det_dataset(dataset: str, autodownload: bool = True) -> Dict[str, Any]
|
|
475
482
|
return data # dictionary
|
476
483
|
|
477
484
|
|
478
|
-
def check_cls_dataset(dataset:
|
485
|
+
def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
|
479
486
|
"""
|
480
487
|
Check a classification dataset such as Imagenet.
|
481
488
|
|
@@ -509,7 +516,7 @@ def check_cls_dataset(dataset: Union[str, Path], split: str = "") -> Dict[str, A
|
|
509
516
|
LOGGER.warning(f"Dataset not found, missing path {data_dir}, attempting download...")
|
510
517
|
t = time.time()
|
511
518
|
if str(dataset) == "imagenet":
|
512
|
-
subprocess.run(
|
519
|
+
subprocess.run(["bash", str(ROOT / "data/scripts/get_imagenet.sh")], check=True)
|
513
520
|
else:
|
514
521
|
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip"
|
515
522
|
download(url, dir=data_dir.parent)
|
@@ -517,8 +524,7 @@ def check_cls_dataset(dataset: Union[str, Path], split: str = "") -> Dict[str, A
|
|
517
524
|
train_set = data_dir / "train"
|
518
525
|
if not train_set.is_dir():
|
519
526
|
LOGGER.warning(f"Dataset 'split=train' not found at {train_set}")
|
520
|
-
image_files
|
521
|
-
if image_files:
|
527
|
+
if image_files := list(data_dir.rglob("*.jpg")) + list(data_dir.rglob("*.png")):
|
522
528
|
from ultralytics.data.split import split_classify_dataset
|
523
529
|
|
524
530
|
LOGGER.info(f"Found {len(image_files)} images in subdirectories. Attempting to split...")
|
@@ -632,7 +638,7 @@ class HUBDatasetStats:
|
|
632
638
|
self.data = data
|
633
639
|
|
634
640
|
@staticmethod
|
635
|
-
def _unzip(path: Path) ->
|
641
|
+
def _unzip(path: Path) -> tuple[bool, str, Path]:
|
636
642
|
"""Unzip data.zip."""
|
637
643
|
if not str(path).endswith(".zip"): # path is data.yaml
|
638
644
|
return False, None, path
|
@@ -646,7 +652,7 @@ class HUBDatasetStats:
|
|
646
652
|
"""Save a compressed image for HUB previews."""
|
647
653
|
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
|
648
654
|
|
649
|
-
def get_json(self, save: bool = False, verbose: bool = False) ->
|
655
|
+
def get_json(self, save: bool = False, verbose: bool = False) -> dict:
|
650
656
|
"""Return dataset JSON for Ultralytics HUB."""
|
651
657
|
|
652
658
|
def _round(labels):
|
@@ -773,7 +779,7 @@ def compress_one_image(f: str, f_new: str = None, max_dim: int = 1920, quality:
|
|
773
779
|
cv2.imwrite(str(f_new or f), im)
|
774
780
|
|
775
781
|
|
776
|
-
def load_dataset_cache_file(path: Path) ->
|
782
|
+
def load_dataset_cache_file(path: Path) -> dict:
|
777
783
|
"""Load an Ultralytics *.cache dictionary from path."""
|
778
784
|
import gc
|
779
785
|
|
@@ -783,7 +789,7 @@ def load_dataset_cache_file(path: Path) -> Dict:
|
|
783
789
|
return cache
|
784
790
|
|
785
791
|
|
786
|
-
def save_dataset_cache_file(prefix: str, path: Path, x:
|
792
|
+
def save_dataset_cache_file(prefix: str, path: Path, x: dict, version: str):
|
787
793
|
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
788
794
|
x["version"] = version # add cache version
|
789
795
|
if is_dir_writeable(path.parent):
|
ultralytics/engine/exporter.py
CHANGED
@@ -107,7 +107,9 @@ from ultralytics.utils.checks import (
|
|
107
107
|
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
|
108
108
|
from ultralytics.utils.export import export_engine, export_onnx
|
109
109
|
from ultralytics.utils.files import file_size, spaces_in_path
|
110
|
-
from ultralytics.utils.
|
110
|
+
from ultralytics.utils.metrics import batch_probiou
|
111
|
+
from ultralytics.utils.nms import TorchNMS
|
112
|
+
from ultralytics.utils.ops import Profile
|
111
113
|
from ultralytics.utils.patches import arange_patch
|
112
114
|
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
|
113
115
|
|
@@ -347,7 +349,7 @@ class Exporter:
|
|
347
349
|
assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
|
348
350
|
if self.args.nms:
|
349
351
|
assert not isinstance(model, ClassificationModel), "'nms=True' is not valid for classification models."
|
350
|
-
assert not
|
352
|
+
assert not tflite or not ARM64 or not LINUX, "TFLite export with NMS unsupported on ARM64 Linux"
|
351
353
|
if getattr(model, "end2end", False):
|
352
354
|
LOGGER.warning("'nms=True' is not available for end2end models. Forcing 'nms=False'.")
|
353
355
|
self.args.nms = False
|
@@ -434,7 +436,7 @@ class Exporter:
|
|
434
436
|
|
435
437
|
y = None
|
436
438
|
for _ in range(2): # dry runs
|
437
|
-
y = NMSModel(model, self.args)(im) if self.args.nms and not
|
439
|
+
y = NMSModel(model, self.args)(im) if self.args.nms and not coreml and not imx else model(im)
|
438
440
|
if self.args.half and onnx and self.device.type != "cpu":
|
439
441
|
im, model = im.half(), model.half() # to FP16
|
440
442
|
|
@@ -1562,12 +1564,13 @@ class NMSModel(torch.nn.Module):
|
|
1562
1564
|
nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
|
1563
1565
|
nms_fn = (
|
1564
1566
|
partial(
|
1565
|
-
|
1567
|
+
TorchNMS.fast_nms,
|
1566
1568
|
use_triu=not (
|
1567
1569
|
self.is_tf
|
1568
1570
|
or (self.args.opset or 14) < 14
|
1569
1571
|
or (self.args.format == "openvino" and self.args.int8) # OpenVINO int8 error with triu
|
1570
1572
|
),
|
1573
|
+
iou_func=batch_probiou,
|
1571
1574
|
)
|
1572
1575
|
if self.obb
|
1573
1576
|
else nms
|
ultralytics/engine/model.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import inspect
|
4
6
|
from pathlib import Path
|
5
|
-
from typing import Any
|
7
|
+
from typing import Any
|
6
8
|
|
7
9
|
import numpy as np
|
8
10
|
import torch
|
@@ -79,7 +81,7 @@ class Model(torch.nn.Module):
|
|
79
81
|
|
80
82
|
def __init__(
|
81
83
|
self,
|
82
|
-
model:
|
84
|
+
model: str | Path | Model = "yolo11n.pt",
|
83
85
|
task: str = None,
|
84
86
|
verbose: bool = False,
|
85
87
|
) -> None:
|
@@ -155,7 +157,7 @@ class Model(torch.nn.Module):
|
|
155
157
|
|
156
158
|
def __call__(
|
157
159
|
self,
|
158
|
-
source:
|
160
|
+
source: str | Path | int | Image.Image | list | tuple | np.ndarray | torch.Tensor = None,
|
159
161
|
stream: bool = False,
|
160
162
|
**kwargs: Any,
|
161
163
|
) -> list:
|
@@ -333,7 +335,7 @@ class Model(torch.nn.Module):
|
|
333
335
|
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
|
334
336
|
)
|
335
337
|
|
336
|
-
def reset_weights(self) ->
|
338
|
+
def reset_weights(self) -> Model:
|
337
339
|
"""
|
338
340
|
Reset the model's weights to their initial state.
|
339
341
|
|
@@ -359,7 +361,7 @@ class Model(torch.nn.Module):
|
|
359
361
|
p.requires_grad = True
|
360
362
|
return self
|
361
363
|
|
362
|
-
def load(self, weights:
|
364
|
+
def load(self, weights: str | Path = "yolo11n.pt") -> Model:
|
363
365
|
"""
|
364
366
|
Load parameters from the specified weights file into the model.
|
365
367
|
|
@@ -387,7 +389,7 @@ class Model(torch.nn.Module):
|
|
387
389
|
self.model.load(weights)
|
388
390
|
return self
|
389
391
|
|
390
|
-
def save(self, filename:
|
392
|
+
def save(self, filename: str | Path = "saved_model.pt") -> None:
|
391
393
|
"""
|
392
394
|
Save the current model state to a file.
|
393
395
|
|
@@ -464,7 +466,7 @@ class Model(torch.nn.Module):
|
|
464
466
|
|
465
467
|
def embed(
|
466
468
|
self,
|
467
|
-
source:
|
469
|
+
source: str | Path | int | list | tuple | np.ndarray | torch.Tensor = None,
|
468
470
|
stream: bool = False,
|
469
471
|
**kwargs: Any,
|
470
472
|
) -> list:
|
@@ -495,11 +497,11 @@ class Model(torch.nn.Module):
|
|
495
497
|
|
496
498
|
def predict(
|
497
499
|
self,
|
498
|
-
source:
|
500
|
+
source: str | Path | int | Image.Image | list | tuple | np.ndarray | torch.Tensor = None,
|
499
501
|
stream: bool = False,
|
500
502
|
predictor=None,
|
501
503
|
**kwargs: Any,
|
502
|
-
) ->
|
504
|
+
) -> list[Results]:
|
503
505
|
"""
|
504
506
|
Perform predictions on the given image source using the YOLO model.
|
505
507
|
|
@@ -556,11 +558,11 @@ class Model(torch.nn.Module):
|
|
556
558
|
|
557
559
|
def track(
|
558
560
|
self,
|
559
|
-
source:
|
561
|
+
source: str | Path | int | list | tuple | np.ndarray | torch.Tensor = None,
|
560
562
|
stream: bool = False,
|
561
563
|
persist: bool = False,
|
562
564
|
**kwargs: Any,
|
563
|
-
) ->
|
565
|
+
) -> list[Results]:
|
564
566
|
"""
|
565
567
|
Conduct object tracking on the specified input source using the registered trackers.
|
566
568
|
|
@@ -853,7 +855,7 @@ class Model(torch.nn.Module):
|
|
853
855
|
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
854
856
|
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
855
857
|
|
856
|
-
def _apply(self, fn) ->
|
858
|
+
def _apply(self, fn) -> Model:
|
857
859
|
"""
|
858
860
|
Apply a function to model tensors that are not parameters or registered buffers.
|
859
861
|
|
@@ -882,7 +884,7 @@ class Model(torch.nn.Module):
|
|
882
884
|
return self
|
883
885
|
|
884
886
|
@property
|
885
|
-
def names(self) ->
|
887
|
+
def names(self) -> dict[int, str]:
|
886
888
|
"""
|
887
889
|
Retrieve the class names associated with the loaded model.
|
888
890
|
|
@@ -1036,7 +1038,7 @@ class Model(torch.nn.Module):
|
|
1036
1038
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
1037
1039
|
|
1038
1040
|
@staticmethod
|
1039
|
-
def _reset_ckpt_args(args:
|
1041
|
+
def _reset_ckpt_args(args: dict[str, Any]) -> dict[str, Any]:
|
1040
1042
|
"""
|
1041
1043
|
Reset specific arguments when loading a PyTorch model checkpoint.
|
1042
1044
|
|
ultralytics/engine/predictor.py
CHANGED
@@ -32,11 +32,13 @@ Usage - formats:
|
|
32
32
|
yolo11n_rknn_model # Rockchip RKNN
|
33
33
|
"""
|
34
34
|
|
35
|
+
from __future__ import annotations
|
36
|
+
|
35
37
|
import platform
|
36
38
|
import re
|
37
39
|
import threading
|
38
40
|
from pathlib import Path
|
39
|
-
from typing import Any
|
41
|
+
from typing import Any
|
40
42
|
|
41
43
|
import cv2
|
42
44
|
import numpy as np
|
@@ -109,8 +111,8 @@ class BasePredictor:
|
|
109
111
|
def __init__(
|
110
112
|
self,
|
111
113
|
cfg=DEFAULT_CFG,
|
112
|
-
overrides:
|
113
|
-
_callbacks:
|
114
|
+
overrides: dict[str, Any] | None = None,
|
115
|
+
_callbacks: dict[str, list[callable]] | None = None,
|
114
116
|
):
|
115
117
|
"""
|
116
118
|
Initialize the BasePredictor class.
|
@@ -147,7 +149,7 @@ class BasePredictor:
|
|
147
149
|
self._lock = threading.Lock() # for automatic thread-safe inference
|
148
150
|
callbacks.add_integration_callbacks(self)
|
149
151
|
|
150
|
-
def preprocess(self, im:
|
152
|
+
def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor:
|
151
153
|
"""
|
152
154
|
Prepare input image before inference.
|
153
155
|
|
@@ -181,7 +183,7 @@ class BasePredictor:
|
|
181
183
|
)
|
182
184
|
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
183
185
|
|
184
|
-
def pre_transform(self, im:
|
186
|
+
def pre_transform(self, im: list[np.ndarray]) -> list[np.ndarray]:
|
185
187
|
"""
|
186
188
|
Pre-transform input image before inference.
|
187
189
|
|
@@ -389,7 +391,7 @@ class BasePredictor:
|
|
389
391
|
verbose (bool): Whether to print verbose output.
|
390
392
|
"""
|
391
393
|
self.model = AutoBackend(
|
392
|
-
|
394
|
+
model=model or self.args.model,
|
393
395
|
device=select_device(self.args.device, verbose=verbose),
|
394
396
|
dnn=self.args.dnn,
|
395
397
|
data=self.args.data,
|
@@ -404,7 +406,7 @@ class BasePredictor:
|
|
404
406
|
self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
|
405
407
|
self.model.eval()
|
406
408
|
|
407
|
-
def write_results(self, i: int, p: Path, im: torch.Tensor, s:
|
409
|
+
def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
|
408
410
|
"""
|
409
411
|
Write inference results to a file or directory.
|
410
412
|
|