ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_exports.py +2 -2
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +15 -19
- ultralytics/engine/exporter.py +24 -23
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +42 -24
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +65 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
- ultralytics-8.3.91.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,14 @@ from .val import FastSAMValidator
|
|
10
10
|
|
11
11
|
class FastSAM(Model):
|
12
12
|
"""
|
13
|
-
FastSAM model interface.
|
13
|
+
FastSAM model interface for segment anything tasks.
|
14
|
+
|
15
|
+
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything Model)
|
16
|
+
implementation, allowing for efficient and accurate image segmentation.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
model (str): Path to the pre-trained FastSAM model file.
|
20
|
+
task (str): The task type, set to "segment" for FastSAM models.
|
14
21
|
|
15
22
|
Examples:
|
16
23
|
>>> from ultralytics import FastSAM
|
@@ -19,7 +26,7 @@ class FastSAM(Model):
|
|
19
26
|
"""
|
20
27
|
|
21
28
|
def __init__(self, model="FastSAM-x.pt"):
|
22
|
-
"""
|
29
|
+
"""Initialize the FastSAM model with the specified pre-trained weights."""
|
23
30
|
if str(model) == "FastSAM.pt":
|
24
31
|
model = "FastSAM-x.pt"
|
25
32
|
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
@@ -29,19 +36,21 @@ class FastSAM(Model):
|
|
29
36
|
"""
|
30
37
|
Perform segmentation prediction on image or video source.
|
31
38
|
|
32
|
-
Supports prompted segmentation with bounding boxes, points, labels, and texts.
|
39
|
+
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
|
40
|
+
prompts and passes them to the parent class predict method.
|
33
41
|
|
34
42
|
Args:
|
35
|
-
source (str | PIL.Image | numpy.ndarray): Input source
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
43
|
+
source (str | PIL.Image | numpy.ndarray): Input source for prediction, can be a file path, URL, PIL image,
|
44
|
+
or numpy array.
|
45
|
+
stream (bool): Whether to enable real-time streaming mode for video inputs.
|
46
|
+
bboxes (List): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2], ...].
|
47
|
+
points (List): Point coordinates for prompted segmentation in format [[x, y], ...].
|
48
|
+
labels (List): Class labels for prompted segmentation.
|
49
|
+
texts (List): Text prompts for segmentation guidance.
|
50
|
+
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
42
51
|
|
43
52
|
Returns:
|
44
|
-
(
|
53
|
+
(List): List of Results objects containing the prediction results.
|
45
54
|
"""
|
46
55
|
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
47
56
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
@@ -13,21 +13,42 @@ from .utils import adjust_bboxes_to_image_border
|
|
13
13
|
|
14
14
|
class FastSAMPredictor(SegmentationPredictor):
|
15
15
|
"""
|
16
|
-
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks
|
17
|
-
YOLO framework.
|
16
|
+
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
|
18
17
|
|
19
18
|
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
|
20
|
-
adjusts post-processing steps to incorporate mask prediction and non-
|
21
|
-
class segmentation.
|
19
|
+
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
|
20
|
+
single-class segmentation.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
prompts (Dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
|
24
|
+
device (torch.device): Device on which model and tensors are processed.
|
25
|
+
clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
|
26
|
+
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
|
27
|
+
|
28
|
+
Methods:
|
29
|
+
postprocess: Applies box postprocessing for FastSAM predictions.
|
30
|
+
prompt: Performs image segmentation inference based on various prompt types.
|
31
|
+
_clip_inference: Performs CLIP inference to calculate similarity between images and text prompts.
|
32
|
+
set_prompts: Sets prompts to be used during inference.
|
22
33
|
"""
|
23
34
|
|
24
35
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
25
|
-
"""
|
36
|
+
"""Initialize the FastSAMPredictor with configuration and callbacks."""
|
26
37
|
super().__init__(cfg, overrides, _callbacks)
|
27
38
|
self.prompts = {}
|
28
39
|
|
29
40
|
def postprocess(self, preds, img, orig_imgs):
|
30
|
-
"""
|
41
|
+
"""
|
42
|
+
Apply postprocessing to FastSAM predictions and handle prompts.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
preds (List[torch.Tensor]): Raw predictions from the model.
|
46
|
+
img (torch.Tensor): Input image tensor that was fed to the model.
|
47
|
+
orig_imgs (List[numpy.ndarray]): Original images before preprocessing.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
(List[Results]): Processed results with prompts applied.
|
51
|
+
"""
|
31
52
|
bboxes = self.prompts.pop("bboxes", None)
|
32
53
|
points = self.prompts.pop("points", None)
|
33
54
|
labels = self.prompts.pop("labels", None)
|
@@ -46,18 +67,17 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
46
67
|
|
47
68
|
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
48
69
|
"""
|
49
|
-
|
50
|
-
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
70
|
+
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
|
51
71
|
|
52
72
|
Args:
|
53
|
-
results (Results | List[Results]):
|
73
|
+
results (Results | List[Results]): Original inference results from FastSAM models without any prompts.
|
54
74
|
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
55
75
|
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
56
76
|
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
57
|
-
texts (str | List[str], optional): Textual prompts, a list
|
77
|
+
texts (str | List[str], optional): Textual prompts, a list containing string objects.
|
58
78
|
|
59
79
|
Returns:
|
60
|
-
(List[Results]):
|
80
|
+
(List[Results]): Output results filtered and determined by the provided prompts.
|
61
81
|
"""
|
62
82
|
if bboxes is None and points is None and texts is None:
|
63
83
|
return results
|
@@ -121,14 +141,14 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
121
141
|
|
122
142
|
def _clip_inference(self, images, texts):
|
123
143
|
"""
|
124
|
-
CLIP
|
144
|
+
Perform CLIP inference to calculate similarity between images and text prompts.
|
125
145
|
|
126
146
|
Args:
|
127
|
-
images (List[PIL.Image]):
|
128
|
-
texts (List[str]):
|
147
|
+
images (List[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
|
148
|
+
texts (List[str]): List of prompt texts, each should be a string object.
|
129
149
|
|
130
150
|
Returns:
|
131
|
-
(torch.Tensor):
|
151
|
+
(torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
|
132
152
|
"""
|
133
153
|
try:
|
134
154
|
import clip
|
@@ -146,5 +166,5 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
146
166
|
return (image_features * text_features[:, None]).sum(-1) # (M, N)
|
147
167
|
|
148
168
|
def set_prompts(self, prompts):
|
149
|
-
"""Set prompts
|
169
|
+
"""Set prompts to be used during inference."""
|
150
170
|
self.prompts = prompts
|
@@ -6,17 +6,17 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
|
6
6
|
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
7
7
|
|
8
8
|
Args:
|
9
|
-
boxes (torch.Tensor): (n, 4)
|
10
|
-
image_shape (
|
11
|
-
threshold (int):
|
9
|
+
boxes (torch.Tensor): Bounding boxes with shape (n, 4) in xyxy format.
|
10
|
+
image_shape (Tuple[int, int]): Image dimensions as (height, width).
|
11
|
+
threshold (int): Pixel threshold for considering a box close to the border.
|
12
12
|
|
13
13
|
Returns:
|
14
|
-
|
14
|
+
boxes (torch.Tensor): Adjusted bounding boxes with shape (n, 4).
|
15
15
|
"""
|
16
16
|
# Image dimensions
|
17
17
|
h, w = image_shape
|
18
18
|
|
19
|
-
# Adjust boxes
|
19
|
+
# Adjust boxes that are close to image borders
|
20
20
|
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
|
21
21
|
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
|
22
22
|
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
|
@@ -13,11 +13,11 @@ class FastSAMValidator(SegmentationValidator):
|
|
13
13
|
to avoid errors during validation.
|
14
14
|
|
15
15
|
Attributes:
|
16
|
-
dataloader: The data loader object used for validation.
|
17
|
-
save_dir (
|
18
|
-
pbar: A progress bar object.
|
19
|
-
args: Additional arguments for customization.
|
20
|
-
_callbacks: List of callback functions to be invoked during validation.
|
16
|
+
dataloader (torch.utils.data.DataLoader): The data loader object used for validation.
|
17
|
+
save_dir (Path): The directory where validation results will be saved.
|
18
|
+
pbar (tqdm.tqdm): A progress bar object for displaying validation progress.
|
19
|
+
args (SimpleNamespace): Additional arguments for customization of the validation process.
|
20
|
+
_callbacks (List): List of callback functions to be invoked during validation.
|
21
21
|
"""
|
22
22
|
|
23
23
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
@@ -29,7 +29,7 @@ class FastSAMValidator(SegmentationValidator):
|
|
29
29
|
save_dir (Path, optional): Directory to save results.
|
30
30
|
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
31
31
|
args (SimpleNamespace): Configuration for the validator.
|
32
|
-
_callbacks (
|
32
|
+
_callbacks (List): List of callback functions to be invoked during validation.
|
33
33
|
|
34
34
|
Notes:
|
35
35
|
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
ultralytics/models/nas/model.py
CHANGED
@@ -28,31 +28,39 @@ class NAS(Model):
|
|
28
28
|
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
|
29
29
|
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
|
30
30
|
|
31
|
+
Attributes:
|
32
|
+
model (torch.nn.Module): The loaded YOLO-NAS model.
|
33
|
+
task (str): The task type for the model, defaults to 'detect'.
|
34
|
+
predictor (NASPredictor): The predictor instance for making predictions.
|
35
|
+
validator (NASValidator): The validator instance for model validation.
|
36
|
+
|
31
37
|
Examples:
|
32
38
|
>>> from ultralytics import NAS
|
33
39
|
>>> model = NAS("yolo_nas_s")
|
34
40
|
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
35
41
|
|
36
|
-
|
37
|
-
model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
|
38
|
-
|
39
|
-
Note:
|
42
|
+
Notes:
|
40
43
|
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
|
41
44
|
"""
|
42
45
|
|
43
|
-
def __init__(self, model="yolo_nas_s.pt") -> None:
|
44
|
-
"""
|
46
|
+
def __init__(self, model: str = "yolo_nas_s.pt") -> None:
|
47
|
+
"""Initialize the NAS model with the provided or default model."""
|
45
48
|
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
|
46
49
|
super().__init__(model, task="detect")
|
47
50
|
|
48
51
|
def _load(self, weights: str, task=None) -> None:
|
49
|
-
"""
|
52
|
+
"""
|
53
|
+
Load an existing NAS model weights or create a new NAS model with pretrained weights.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
weights (str): Path to the model weights file or model name.
|
57
|
+
task (str, optional): Task type for the model.
|
58
|
+
"""
|
50
59
|
import super_gradients
|
51
60
|
|
52
61
|
suffix = Path(weights).suffix
|
53
62
|
if suffix == ".pt":
|
54
63
|
self.model = torch.load(attempt_download_asset(weights))
|
55
|
-
|
56
64
|
elif suffix == "":
|
57
65
|
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
58
66
|
|
@@ -74,17 +82,20 @@ class NAS(Model):
|
|
74
82
|
self.model.task = "detect" # for export()
|
75
83
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
|
76
84
|
|
77
|
-
def info(self, detailed=False, verbose=True):
|
85
|
+
def info(self, detailed: bool = False, verbose: bool = True):
|
78
86
|
"""
|
79
|
-
|
87
|
+
Log model information.
|
80
88
|
|
81
89
|
Args:
|
82
90
|
detailed (bool): Show detailed information about model.
|
83
91
|
verbose (bool): Controls verbosity.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
(dict): Model information dictionary.
|
84
95
|
"""
|
85
96
|
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
86
97
|
|
87
98
|
@property
|
88
99
|
def task_map(self):
|
89
|
-
"""
|
100
|
+
"""Return a dictionary mapping tasks to respective predictor and validator classes."""
|
90
101
|
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|
@@ -16,26 +16,30 @@ class NASPredictor(BasePredictor):
|
|
16
16
|
scaling the bounding boxes to fit the original image dimensions.
|
17
17
|
|
18
18
|
Attributes:
|
19
|
-
args (Namespace): Namespace containing various configurations for post-processing
|
19
|
+
args (Namespace): Namespace containing various configurations for post-processing including confidence threshold,
|
20
|
+
IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
|
21
|
+
model (torch.nn.Module): The YOLO NAS model used for inference.
|
22
|
+
batch (List): Batch of inputs for processing.
|
20
23
|
|
21
24
|
Examples:
|
22
25
|
>>> from ultralytics import NAS
|
23
26
|
>>> model = NAS("yolo_nas_s")
|
24
27
|
>>> predictor = model.predictor
|
25
28
|
|
26
|
-
|
29
|
+
Assume that raw_preds, img, orig_imgs are available
|
27
30
|
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
|
28
31
|
|
29
|
-
|
32
|
+
Notes:
|
30
33
|
Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
|
31
34
|
"""
|
32
35
|
|
33
36
|
def postprocess(self, preds_in, img, orig_imgs):
|
34
37
|
"""Postprocess predictions and returns a list of Results objects."""
|
35
|
-
#
|
38
|
+
# Convert boxes from xyxy to xywh format and concatenate with class scores
|
36
39
|
boxes = ops.xyxy2xywh(preds_in[0][0])
|
37
40
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
38
41
|
|
42
|
+
# Apply non-maximum suppression to filter overlapping detections
|
39
43
|
preds = ops.non_max_suppression(
|
40
44
|
preds,
|
41
45
|
self.args.conf,
|
@@ -50,6 +54,7 @@ class NASPredictor(BasePredictor):
|
|
50
54
|
|
51
55
|
results = []
|
52
56
|
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
|
57
|
+
# Scale bounding boxes to match original image dimensions
|
53
58
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
54
59
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
55
60
|
return results
|
ultralytics/models/nas/val.py
CHANGED
@@ -17,25 +17,25 @@ class NASValidator(DetectionValidator):
|
|
17
17
|
ultimately producing the final detections.
|
18
18
|
|
19
19
|
Attributes:
|
20
|
-
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU
|
20
|
+
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU
|
21
|
+
thresholds.
|
21
22
|
lb (torch.Tensor): Optional tensor for multilabel NMS.
|
22
23
|
|
23
24
|
Examples:
|
24
25
|
>>> from ultralytics import NAS
|
25
26
|
>>> model = NAS("yolo_nas_s")
|
26
27
|
>>> validator = model.validator
|
27
|
-
|
28
28
|
Assumes that raw_preds are available
|
29
29
|
>>> final_preds = validator.postprocess(raw_preds)
|
30
30
|
|
31
|
-
|
31
|
+
Notes:
|
32
32
|
This class is generally not instantiated directly but is used internally within the `NAS` class.
|
33
33
|
"""
|
34
34
|
|
35
35
|
def postprocess(self, preds_in):
|
36
36
|
"""Apply Non-maximum suppression to prediction outputs."""
|
37
|
-
boxes = ops.xyxy2xywh(preds_in[0][0])
|
38
|
-
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
37
|
+
boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh
|
38
|
+
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute
|
39
39
|
return super().postprocess(
|
40
40
|
preds,
|
41
41
|
max_time_img=0.5,
|
@@ -1,10 +1,12 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""
|
3
|
-
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector.
|
4
|
-
performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
|
5
|
-
hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
|
3
|
+
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector.
|
6
4
|
|
7
|
-
|
5
|
+
RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT.
|
6
|
+
It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
|
7
|
+
|
8
|
+
References:
|
9
|
+
https://arxiv.org/pdf/2304.08069.pdf
|
8
10
|
"""
|
9
11
|
|
10
12
|
from ultralytics.engine.model import Model
|
@@ -17,19 +19,26 @@ from .val import RTDETRValidator
|
|
17
19
|
|
18
20
|
class RTDETR(Model):
|
19
21
|
"""
|
20
|
-
Interface for Baidu's RT-DETR model
|
21
|
-
|
22
|
+
Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
|
23
|
+
|
24
|
+
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
|
25
|
+
selection, and adaptable inference speed.
|
22
26
|
|
23
27
|
Attributes:
|
24
|
-
model (str): Path to the pre-trained model.
|
28
|
+
model (str): Path to the pre-trained model.
|
29
|
+
|
30
|
+
Examples:
|
31
|
+
>>> from ultralytics import RTDETR
|
32
|
+
>>> model = RTDETR("rtdetr-l.pt")
|
33
|
+
>>> results = model("image.jpg")
|
25
34
|
"""
|
26
35
|
|
27
|
-
def __init__(self, model="rtdetr-l.pt") -> None:
|
36
|
+
def __init__(self, model: str = "rtdetr-l.pt") -> None:
|
28
37
|
"""
|
29
|
-
|
38
|
+
Initialize the RT-DETR model with the given pre-trained model file.
|
30
39
|
|
31
40
|
Args:
|
32
|
-
model (str): Path to the pre-trained model.
|
41
|
+
model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
|
33
42
|
|
34
43
|
Raises:
|
35
44
|
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
@@ -42,7 +51,7 @@ class RTDETR(Model):
|
|
42
51
|
Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
43
52
|
|
44
53
|
Returns:
|
45
|
-
|
54
|
+
(Dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
46
55
|
"""
|
47
56
|
return {
|
48
57
|
"detect": {
|
@@ -10,11 +10,16 @@ from ultralytics.utils import ops
|
|
10
10
|
|
11
11
|
class RTDETRPredictor(BasePredictor):
|
12
12
|
"""
|
13
|
-
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions
|
14
|
-
Baidu's RT-DETR model.
|
13
|
+
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
|
15
14
|
|
16
|
-
This class leverages
|
17
|
-
|
15
|
+
This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.
|
16
|
+
It supports key features like efficient hybrid encoding and IoU-aware query selection.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
imgsz (int): Image size for inference (must be square and scale-filled).
|
20
|
+
args (dict): Argument overrides for the predictor.
|
21
|
+
model (torch.nn.Module): The loaded RT-DETR model.
|
22
|
+
batch (List): Current batch of processed inputs.
|
18
23
|
|
19
24
|
Examples:
|
20
25
|
>>> from ultralytics.utils import ASSETS
|
@@ -22,25 +27,23 @@ class RTDETRPredictor(BasePredictor):
|
|
22
27
|
>>> args = dict(model="rtdetr-l.pt", source=ASSETS)
|
23
28
|
>>> predictor = RTDETRPredictor(overrides=args)
|
24
29
|
>>> predictor.predict_cli()
|
25
|
-
|
26
|
-
Attributes:
|
27
|
-
imgsz (int): Image size for inference (must be square and scale-filled).
|
28
|
-
args (dict): Argument overrides for the predictor.
|
29
30
|
"""
|
30
31
|
|
31
32
|
def postprocess(self, preds, img, orig_imgs):
|
32
33
|
"""
|
33
34
|
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
34
35
|
|
35
|
-
The method filters detections based on confidence and class if specified in `self.args`.
|
36
|
+
The method filters detections based on confidence and class if specified in `self.args`. It converts
|
37
|
+
model predictions to Results objects containing properly scaled bounding boxes.
|
36
38
|
|
37
39
|
Args:
|
38
|
-
preds (
|
39
|
-
|
40
|
-
|
40
|
+
preds (List | Tuple): List of [predictions, extra] from the model, where predictions contain
|
41
|
+
bounding boxes and scores.
|
42
|
+
img (torch.Tensor): Processed input images with shape (N, 3, H, W).
|
43
|
+
orig_imgs (List | torch.Tensor): Original, unprocessed images.
|
41
44
|
|
42
45
|
Returns:
|
43
|
-
(
|
46
|
+
(List[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
|
44
47
|
and class labels.
|
45
48
|
"""
|
46
49
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
@@ -61,8 +64,8 @@ class RTDETRPredictor(BasePredictor):
|
|
61
64
|
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
62
65
|
pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
|
63
66
|
oh, ow = orig_img.shape[:2]
|
64
|
-
pred[..., [0, 2]] *= ow
|
65
|
-
pred[..., [1, 3]] *= oh
|
67
|
+
pred[..., [0, 2]] *= ow # scale x coordinates to original width
|
68
|
+
pred[..., [1, 3]] *= oh # scale y coordinates to original height
|
66
69
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
67
70
|
return results
|
68
71
|
|
@@ -13,9 +13,18 @@ from .val import RTDETRDataset, RTDETRValidator
|
|
13
13
|
|
14
14
|
class RTDETRTrainer(DetectionTrainer):
|
15
15
|
"""
|
16
|
-
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
17
|
-
|
18
|
-
|
16
|
+
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
17
|
+
|
18
|
+
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.
|
19
|
+
The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference
|
20
|
+
speed.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
loss_names (Tuple[str]): Names of the loss components used for training.
|
24
|
+
data (Dict): Dataset configuration containing class count and other parameters.
|
25
|
+
args (Dict): Training arguments and hyperparameters.
|
26
|
+
save_dir (Path): Directory to save training results.
|
27
|
+
test_loader (DataLoader): DataLoader for validation/testing data.
|
19
28
|
|
20
29
|
Notes:
|
21
30
|
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
@@ -33,9 +42,9 @@ class RTDETRTrainer(DetectionTrainer):
|
|
33
42
|
Initialize and return an RT-DETR model for object detection tasks.
|
34
43
|
|
35
44
|
Args:
|
36
|
-
cfg (
|
37
|
-
weights (str, optional): Path to pre-trained model weights.
|
38
|
-
verbose (bool): Verbose logging if True.
|
45
|
+
cfg (Dict, optional): Model configuration.
|
46
|
+
weights (str, optional): Path to pre-trained model weights.
|
47
|
+
verbose (bool): Verbose logging if True.
|
39
48
|
|
40
49
|
Returns:
|
41
50
|
(RTDETRDetectionModel): Initialized model.
|
@@ -52,7 +61,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
52
61
|
Args:
|
53
62
|
img_path (str): Path to the folder containing images.
|
54
63
|
mode (str): Dataset mode, either 'train' or 'val'.
|
55
|
-
batch (int, optional): Batch size for rectangle training.
|
64
|
+
batch (int, optional): Batch size for rectangle training.
|
56
65
|
|
57
66
|
Returns:
|
58
67
|
(RTDETRDataset): Dataset object for the specific mode.
|
@@ -73,24 +82,19 @@ class RTDETRTrainer(DetectionTrainer):
|
|
73
82
|
)
|
74
83
|
|
75
84
|
def get_validator(self):
|
76
|
-
"""
|
77
|
-
Returns a DetectionValidator suitable for RT-DETR model validation.
|
78
|
-
|
79
|
-
Returns:
|
80
|
-
(RTDETRValidator): Validator object for model validation.
|
81
|
-
"""
|
85
|
+
"""Returns a DetectionValidator suitable for RT-DETR model validation."""
|
82
86
|
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
|
83
87
|
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
84
88
|
|
85
89
|
def preprocess_batch(self, batch):
|
86
90
|
"""
|
87
|
-
Preprocess a batch of images
|
91
|
+
Preprocess a batch of images by scaling and converting to float format.
|
88
92
|
|
89
93
|
Args:
|
90
|
-
batch (
|
94
|
+
batch (Dict): Dictionary containing a batch of images, bboxes, and labels.
|
91
95
|
|
92
96
|
Returns:
|
93
|
-
(
|
97
|
+
(Dict): Preprocessed batch with ground truth bounding boxes and classes separated by batch index.
|
94
98
|
"""
|
95
99
|
batch = super().preprocess_batch(batch)
|
96
100
|
bs = len(batch["img"])
|
ultralytics/models/rtdetr/val.py
CHANGED
@@ -22,13 +22,20 @@ class RTDETRDataset(YOLODataset):
|
|
22
22
|
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
|
23
23
|
super().__init__(*args, data=data, **kwargs)
|
24
24
|
|
25
|
-
# NOTE: add stretch version load_image for RTDETR mosaic
|
26
25
|
def load_image(self, i, rect_mode=False):
|
27
26
|
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
28
27
|
return super().load_image(i=i, rect_mode=rect_mode)
|
29
28
|
|
30
29
|
def build_transforms(self, hyp=None):
|
31
|
-
"""
|
30
|
+
"""
|
31
|
+
Build transformation pipeline for the dataset.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
hyp (Dict, optional): Hyperparameters for transformations.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
(Compose): Composition of transformation functions.
|
38
|
+
"""
|
32
39
|
if self.augment:
|
33
40
|
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
34
41
|
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
@@ -75,7 +82,10 @@ class RTDETRValidator(DetectionValidator):
|
|
75
82
|
Args:
|
76
83
|
img_path (str): Path to the folder containing images.
|
77
84
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
78
|
-
batch (int, optional): Size of batches, this is for `rect`.
|
85
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
(RTDETRDataset): Dataset configured for RT-DETR validation.
|
79
89
|
"""
|
80
90
|
return RTDETRDataset(
|
81
91
|
img_path=img_path,
|
@@ -90,7 +100,15 @@ class RTDETRValidator(DetectionValidator):
|
|
90
100
|
)
|
91
101
|
|
92
102
|
def postprocess(self, preds):
|
93
|
-
"""
|
103
|
+
"""
|
104
|
+
Apply Non-maximum suppression to prediction outputs.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
preds (List | Tuple | torch.Tensor): Raw predictions from the model.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
(List[torch.Tensor]): List of processed predictions for each image in batch.
|
111
|
+
"""
|
94
112
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
95
113
|
preds = [preds, None]
|
96
114
|
|
@@ -111,7 +129,16 @@ class RTDETRValidator(DetectionValidator):
|
|
111
129
|
return outputs
|
112
130
|
|
113
131
|
def _prepare_batch(self, si, batch):
|
114
|
-
"""
|
132
|
+
"""
|
133
|
+
Prepares a batch for validation by applying necessary transformations.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
si (int): Batch index.
|
137
|
+
batch (Dict): Batch data containing images and annotations.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
(Dict): Prepared batch with transformed annotations.
|
141
|
+
"""
|
115
142
|
idx = batch["batch_idx"] == si
|
116
143
|
cls = batch["cls"][idx].squeeze(-1)
|
117
144
|
bbox = batch["bboxes"][idx]
|
@@ -125,7 +152,16 @@ class RTDETRValidator(DetectionValidator):
|
|
125
152
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
126
153
|
|
127
154
|
def _prepare_pred(self, pred, pbatch):
|
128
|
-
"""
|
155
|
+
"""
|
156
|
+
Prepares predictions by scaling bounding boxes to original image dimensions.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
pred (torch.Tensor): Raw predictions.
|
160
|
+
pbatch (Dict): Prepared batch information.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
(torch.Tensor): Predictions scaled to original image dimensions.
|
164
|
+
"""
|
129
165
|
predn = pred.clone()
|
130
166
|
predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
131
167
|
predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
@@ -3,4 +3,4 @@
|
|
3
3
|
from .model import SAM
|
4
4
|
from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
|
5
5
|
|
6
|
-
__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list
|
6
|
+
__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list of exportable items
|