dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,9 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from .fastsam import FastSAM
|
4
|
+
from .nas import NAS
|
5
|
+
from .rtdetr import RTDETR
|
6
|
+
from .sam import SAM
|
7
|
+
from .yolo import YOLO, YOLOE, YOLOWorld
|
8
|
+
|
9
|
+
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "YOLOE" # allow simpler import
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
from ultralytics.engine.model import Model
|
6
|
+
|
7
|
+
from .predict import FastSAMPredictor
|
8
|
+
from .val import FastSAMValidator
|
9
|
+
|
10
|
+
|
11
|
+
class FastSAM(Model):
|
12
|
+
"""
|
13
|
+
FastSAM model interface for segment anything tasks.
|
14
|
+
|
15
|
+
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything Model)
|
16
|
+
implementation, allowing for efficient and accurate image segmentation.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
model (str): Path to the pre-trained FastSAM model file.
|
20
|
+
task (str): The task type, set to "segment" for FastSAM models.
|
21
|
+
|
22
|
+
Examples:
|
23
|
+
>>> from ultralytics import FastSAM
|
24
|
+
>>> model = FastSAM("last.pt")
|
25
|
+
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, model="FastSAM-x.pt"):
|
29
|
+
"""Initialize the FastSAM model with the specified pre-trained weights."""
|
30
|
+
if str(model) == "FastSAM.pt":
|
31
|
+
model = "FastSAM-x.pt"
|
32
|
+
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
33
|
+
super().__init__(model=model, task="segment")
|
34
|
+
|
35
|
+
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
|
36
|
+
"""
|
37
|
+
Perform segmentation prediction on image or video source.
|
38
|
+
|
39
|
+
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
|
40
|
+
prompts and passes them to the parent class predict method.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
source (str | PIL.Image | numpy.ndarray): Input source for prediction, can be a file path, URL, PIL image,
|
44
|
+
or numpy array.
|
45
|
+
stream (bool): Whether to enable real-time streaming mode for video inputs.
|
46
|
+
bboxes (list): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2], ...].
|
47
|
+
points (list): Point coordinates for prompted segmentation in format [[x, y], ...].
|
48
|
+
labels (list): Class labels for prompted segmentation.
|
49
|
+
texts (list): Text prompts for segmentation guidance.
|
50
|
+
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
(list): List of Results objects containing the prediction results.
|
54
|
+
"""
|
55
|
+
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
56
|
+
return super().predict(source, stream, prompts=prompts, **kwargs)
|
57
|
+
|
58
|
+
@property
|
59
|
+
def task_map(self):
|
60
|
+
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
61
|
+
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
|
@@ -0,0 +1,181 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from PIL import Image
|
5
|
+
|
6
|
+
from ultralytics.models.yolo.segment import SegmentationPredictor
|
7
|
+
from ultralytics.utils import DEFAULT_CFG, checks
|
8
|
+
from ultralytics.utils.metrics import box_iou
|
9
|
+
from ultralytics.utils.ops import scale_masks
|
10
|
+
|
11
|
+
from .utils import adjust_bboxes_to_image_border
|
12
|
+
|
13
|
+
|
14
|
+
class FastSAMPredictor(SegmentationPredictor):
|
15
|
+
"""
|
16
|
+
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
|
17
|
+
|
18
|
+
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
|
19
|
+
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
|
20
|
+
single-class segmentation.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
|
24
|
+
device (torch.device): Device on which model and tensors are processed.
|
25
|
+
clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
|
26
|
+
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
|
27
|
+
|
28
|
+
Methods:
|
29
|
+
postprocess: Applies box postprocessing for FastSAM predictions.
|
30
|
+
prompt: Performs image segmentation inference based on various prompt types.
|
31
|
+
_clip_inference: Performs CLIP inference to calculate similarity between images and text prompts.
|
32
|
+
set_prompts: Sets prompts to be used during inference.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
36
|
+
"""
|
37
|
+
Initialize the FastSAMPredictor with configuration and callbacks.
|
38
|
+
|
39
|
+
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
|
40
|
+
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
|
41
|
+
optimized for single-class segmentation.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
|
45
|
+
overrides (dict, optional): Configuration overrides.
|
46
|
+
_callbacks (list, optional): List of callback functions.
|
47
|
+
"""
|
48
|
+
super().__init__(cfg, overrides, _callbacks)
|
49
|
+
self.prompts = {}
|
50
|
+
|
51
|
+
def postprocess(self, preds, img, orig_imgs):
|
52
|
+
"""
|
53
|
+
Apply postprocessing to FastSAM predictions and handle prompts.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
preds (List[torch.Tensor]): Raw predictions from the model.
|
57
|
+
img (torch.Tensor): Input image tensor that was fed to the model.
|
58
|
+
orig_imgs (List[numpy.ndarray]): Original images before preprocessing.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
(List[Results]): Processed results with prompts applied.
|
62
|
+
"""
|
63
|
+
bboxes = self.prompts.pop("bboxes", None)
|
64
|
+
points = self.prompts.pop("points", None)
|
65
|
+
labels = self.prompts.pop("labels", None)
|
66
|
+
texts = self.prompts.pop("texts", None)
|
67
|
+
results = super().postprocess(preds, img, orig_imgs)
|
68
|
+
for result in results:
|
69
|
+
full_box = torch.tensor(
|
70
|
+
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
|
71
|
+
)
|
72
|
+
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
|
73
|
+
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
|
74
|
+
if idx.numel() != 0:
|
75
|
+
result.boxes.xyxy[idx] = full_box
|
76
|
+
|
77
|
+
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
|
78
|
+
|
79
|
+
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
80
|
+
"""
|
81
|
+
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
results (Results | List[Results]): Original inference results from FastSAM models without any prompts.
|
85
|
+
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
86
|
+
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
87
|
+
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
88
|
+
texts (str | List[str], optional): Textual prompts, a list containing string objects.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
(List[Results]): Output results filtered and determined by the provided prompts.
|
92
|
+
"""
|
93
|
+
if bboxes is None and points is None and texts is None:
|
94
|
+
return results
|
95
|
+
prompt_results = []
|
96
|
+
if not isinstance(results, list):
|
97
|
+
results = [results]
|
98
|
+
for result in results:
|
99
|
+
if len(result) == 0:
|
100
|
+
prompt_results.append(result)
|
101
|
+
continue
|
102
|
+
masks = result.masks.data
|
103
|
+
if masks.shape[1:] != result.orig_shape:
|
104
|
+
masks = scale_masks(masks[None], result.orig_shape)[0]
|
105
|
+
# bboxes prompt
|
106
|
+
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
107
|
+
if bboxes is not None:
|
108
|
+
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
|
109
|
+
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
110
|
+
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
111
|
+
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
|
112
|
+
full_mask_areas = torch.sum(masks, dim=(1, 2))
|
113
|
+
|
114
|
+
union = bbox_areas[:, None] + full_mask_areas - mask_areas
|
115
|
+
idx[torch.argmax(mask_areas / union, dim=1)] = True
|
116
|
+
if points is not None:
|
117
|
+
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
|
118
|
+
points = points[None] if points.ndim == 1 else points
|
119
|
+
if labels is None:
|
120
|
+
labels = torch.ones(points.shape[0])
|
121
|
+
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
122
|
+
assert len(labels) == len(points), (
|
123
|
+
f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
|
124
|
+
)
|
125
|
+
point_idx = (
|
126
|
+
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
127
|
+
if labels.sum() == 0 # all negative points
|
128
|
+
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
129
|
+
)
|
130
|
+
for point, label in zip(points, labels):
|
131
|
+
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
|
132
|
+
idx |= point_idx
|
133
|
+
if texts is not None:
|
134
|
+
if isinstance(texts, str):
|
135
|
+
texts = [texts]
|
136
|
+
crop_ims, filter_idx = [], []
|
137
|
+
for i, b in enumerate(result.boxes.xyxy.tolist()):
|
138
|
+
x1, y1, x2, y2 = (int(x) for x in b)
|
139
|
+
if masks[i].sum() <= 100:
|
140
|
+
filter_idx.append(i)
|
141
|
+
continue
|
142
|
+
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
|
143
|
+
similarity = self._clip_inference(crop_ims, texts)
|
144
|
+
text_idx = torch.argmax(similarity, dim=-1) # (M, )
|
145
|
+
if len(filter_idx):
|
146
|
+
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
|
147
|
+
idx[text_idx] = True
|
148
|
+
|
149
|
+
prompt_results.append(result[idx])
|
150
|
+
|
151
|
+
return prompt_results
|
152
|
+
|
153
|
+
def _clip_inference(self, images, texts):
|
154
|
+
"""
|
155
|
+
Perform CLIP inference to calculate similarity between images and text prompts.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
images (List[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
|
159
|
+
texts (List[str]): List of prompt texts, each should be a string object.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
(torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
|
163
|
+
"""
|
164
|
+
try:
|
165
|
+
import clip
|
166
|
+
except ImportError:
|
167
|
+
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
168
|
+
import clip
|
169
|
+
if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
|
170
|
+
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
|
171
|
+
images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
|
172
|
+
tokenized_text = clip.tokenize(texts).to(self.device)
|
173
|
+
image_features = self.clip_model.encode_image(images)
|
174
|
+
text_features = self.clip_model.encode_text(tokenized_text)
|
175
|
+
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
|
176
|
+
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
|
177
|
+
return (image_features * text_features[:, None]).sum(-1) # (M, N)
|
178
|
+
|
179
|
+
def set_prompts(self, prompts):
|
180
|
+
"""Set prompts to be used during inference."""
|
181
|
+
self.prompts = prompts
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
|
4
|
+
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
5
|
+
"""
|
6
|
+
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
7
|
+
|
8
|
+
Args:
|
9
|
+
boxes (torch.Tensor): Bounding boxes with shape (n, 4) in xyxy format.
|
10
|
+
image_shape (Tuple[int, int]): Image dimensions as (height, width).
|
11
|
+
threshold (int): Pixel threshold for considering a box close to the border.
|
12
|
+
|
13
|
+
Returns:
|
14
|
+
boxes (torch.Tensor): Adjusted bounding boxes with shape (n, 4).
|
15
|
+
"""
|
16
|
+
# Image dimensions
|
17
|
+
h, w = image_shape
|
18
|
+
|
19
|
+
# Adjust boxes that are close to image borders
|
20
|
+
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
|
21
|
+
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
|
22
|
+
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
|
23
|
+
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
|
24
|
+
return boxes
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from ultralytics.models.yolo.segment import SegmentationValidator
|
4
|
+
from ultralytics.utils.metrics import SegmentMetrics
|
5
|
+
|
6
|
+
|
7
|
+
class FastSAMValidator(SegmentationValidator):
|
8
|
+
"""
|
9
|
+
Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
|
10
|
+
|
11
|
+
Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class
|
12
|
+
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
|
13
|
+
to avoid errors during validation.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
dataloader (torch.utils.data.DataLoader): The data loader object used for validation.
|
17
|
+
save_dir (Path): The directory where validation results will be saved.
|
18
|
+
pbar (tqdm.tqdm): A progress bar object for displaying validation progress.
|
19
|
+
args (SimpleNamespace): Additional arguments for customization of the validation process.
|
20
|
+
_callbacks (list): List of callback functions to be invoked during validation.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
24
|
+
"""
|
25
|
+
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
29
|
+
save_dir (Path, optional): Directory to save results.
|
30
|
+
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
31
|
+
args (SimpleNamespace): Configuration for the validator.
|
32
|
+
_callbacks (list): List of callback functions to be invoked during validation.
|
33
|
+
|
34
|
+
Notes:
|
35
|
+
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
36
|
+
"""
|
37
|
+
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
38
|
+
self.args.task = "segment"
|
39
|
+
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
|
40
|
+
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""
|
3
|
+
YOLO-NAS model interface.
|
4
|
+
|
5
|
+
Examples:
|
6
|
+
>>> from ultralytics import NAS
|
7
|
+
>>> model = NAS("yolo_nas_s")
|
8
|
+
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
9
|
+
"""
|
10
|
+
|
11
|
+
from pathlib import Path
|
12
|
+
|
13
|
+
import torch
|
14
|
+
|
15
|
+
from ultralytics.engine.model import Model
|
16
|
+
from ultralytics.utils import DEFAULT_CFG_DICT
|
17
|
+
from ultralytics.utils.downloads import attempt_download_asset
|
18
|
+
from ultralytics.utils.torch_utils import model_info
|
19
|
+
|
20
|
+
from .predict import NASPredictor
|
21
|
+
from .val import NASValidator
|
22
|
+
|
23
|
+
|
24
|
+
class NAS(Model):
|
25
|
+
"""
|
26
|
+
YOLO NAS model for object detection.
|
27
|
+
|
28
|
+
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
|
29
|
+
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
|
30
|
+
|
31
|
+
Attributes:
|
32
|
+
model (torch.nn.Module): The loaded YOLO-NAS model.
|
33
|
+
task (str): The task type for the model, defaults to 'detect'.
|
34
|
+
predictor (NASPredictor): The predictor instance for making predictions.
|
35
|
+
validator (NASValidator): The validator instance for model validation.
|
36
|
+
|
37
|
+
Examples:
|
38
|
+
>>> from ultralytics import NAS
|
39
|
+
>>> model = NAS("yolo_nas_s")
|
40
|
+
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
41
|
+
|
42
|
+
Notes:
|
43
|
+
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
|
44
|
+
"""
|
45
|
+
|
46
|
+
def __init__(self, model: str = "yolo_nas_s.pt") -> None:
|
47
|
+
"""Initialize the NAS model with the provided or default model."""
|
48
|
+
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
|
49
|
+
super().__init__(model, task="detect")
|
50
|
+
|
51
|
+
def _load(self, weights: str, task=None) -> None:
|
52
|
+
"""
|
53
|
+
Load an existing NAS model weights or create a new NAS model with pretrained weights.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
weights (str): Path to the model weights file or model name.
|
57
|
+
task (str, optional): Task type for the model.
|
58
|
+
"""
|
59
|
+
import super_gradients
|
60
|
+
|
61
|
+
suffix = Path(weights).suffix
|
62
|
+
if suffix == ".pt":
|
63
|
+
self.model = torch.load(attempt_download_asset(weights))
|
64
|
+
elif suffix == "":
|
65
|
+
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
66
|
+
|
67
|
+
# Override the forward method to ignore additional arguments
|
68
|
+
def new_forward(x, *args, **kwargs):
|
69
|
+
"""Ignore additional __call__ arguments."""
|
70
|
+
return self.model._original_forward(x)
|
71
|
+
|
72
|
+
self.model._original_forward = self.model.forward
|
73
|
+
self.model.forward = new_forward
|
74
|
+
|
75
|
+
# Standardize model
|
76
|
+
self.model.fuse = lambda verbose=True: self.model
|
77
|
+
self.model.stride = torch.tensor([32])
|
78
|
+
self.model.names = dict(enumerate(self.model._class_names))
|
79
|
+
self.model.is_fused = lambda: False # for info()
|
80
|
+
self.model.yaml = {} # for info()
|
81
|
+
self.model.pt_path = weights # for export()
|
82
|
+
self.model.task = "detect" # for export()
|
83
|
+
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
|
84
|
+
self.model.eval()
|
85
|
+
|
86
|
+
def info(self, detailed: bool = False, verbose: bool = True):
|
87
|
+
"""
|
88
|
+
Log model information.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
detailed (bool): Show detailed information about model.
|
92
|
+
verbose (bool): Controls verbosity.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
(dict): Model information dictionary.
|
96
|
+
"""
|
97
|
+
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
98
|
+
|
99
|
+
@property
|
100
|
+
def task_map(self):
|
101
|
+
"""Return a dictionary mapping tasks to respective predictor and validator classes."""
|
102
|
+
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|
@@ -0,0 +1,58 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
6
|
+
from ultralytics.utils import ops
|
7
|
+
|
8
|
+
|
9
|
+
class NASPredictor(DetectionPredictor):
|
10
|
+
"""
|
11
|
+
Ultralytics YOLO NAS Predictor for object detection.
|
12
|
+
|
13
|
+
This class extends the `DetectionPredictor` from Ultralytics engine and is responsible for post-processing the
|
14
|
+
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
|
15
|
+
scaling the bounding boxes to fit the original image dimensions.
|
16
|
+
|
17
|
+
Attributes:
|
18
|
+
args (Namespace): Namespace containing various configurations for post-processing including confidence threshold,
|
19
|
+
IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
|
20
|
+
model (torch.nn.Module): The YOLO NAS model used for inference.
|
21
|
+
batch (list): Batch of inputs for processing.
|
22
|
+
|
23
|
+
Examples:
|
24
|
+
>>> from ultralytics import NAS
|
25
|
+
>>> model = NAS("yolo_nas_s")
|
26
|
+
>>> predictor = model.predictor
|
27
|
+
|
28
|
+
Assume that raw_preds, img, orig_imgs are available
|
29
|
+
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
|
30
|
+
|
31
|
+
Notes:
|
32
|
+
Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def postprocess(self, preds_in, img, orig_imgs):
|
36
|
+
"""
|
37
|
+
Postprocess NAS model predictions to generate final detection results.
|
38
|
+
|
39
|
+
This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
|
40
|
+
post-processing operations to generate the final detection results compatible with Ultralytics
|
41
|
+
result visualization and analysis tools.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.
|
45
|
+
img (torch.Tensor): Input image tensor that was fed to the model, with shape (B, C, H, W).
|
46
|
+
orig_imgs (list | torch.Tensor | np.ndarray): Original images before preprocessing, used for scaling
|
47
|
+
coordinates back to original dimensions.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
(list): List of Results objects containing the processed predictions for each image in the batch.
|
51
|
+
|
52
|
+
Examples:
|
53
|
+
>>> predictor = NAS("yolo_nas_s").predictor
|
54
|
+
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
|
55
|
+
"""
|
56
|
+
boxes = ops.xyxy2xywh(preds_in[0][0])
|
57
|
+
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # concatenate with class scores
|
58
|
+
return super().postprocess(preds, img, orig_imgs)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ultralytics.models.yolo.detect import DetectionValidator
|
6
|
+
from ultralytics.utils import ops
|
7
|
+
|
8
|
+
__all__ = ["NASValidator"]
|
9
|
+
|
10
|
+
|
11
|
+
class NASValidator(DetectionValidator):
|
12
|
+
"""
|
13
|
+
Ultralytics YOLO NAS Validator for object detection.
|
14
|
+
|
15
|
+
Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
|
16
|
+
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
|
17
|
+
ultimately producing the final detections.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU
|
21
|
+
thresholds.
|
22
|
+
lb (torch.Tensor): Optional tensor for multilabel NMS.
|
23
|
+
|
24
|
+
Examples:
|
25
|
+
>>> from ultralytics import NAS
|
26
|
+
>>> model = NAS("yolo_nas_s")
|
27
|
+
>>> validator = model.validator
|
28
|
+
Assumes that raw_preds are available
|
29
|
+
>>> final_preds = validator.postprocess(raw_preds)
|
30
|
+
|
31
|
+
Notes:
|
32
|
+
This class is generally not instantiated directly but is used internally within the `NAS` class.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def postprocess(self, preds_in):
|
36
|
+
"""Apply Non-maximum suppression to prediction outputs."""
|
37
|
+
boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh
|
38
|
+
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute
|
39
|
+
return super().postprocess(preds)
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""
|
3
|
+
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector.
|
4
|
+
|
5
|
+
RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT.
|
6
|
+
It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
|
7
|
+
|
8
|
+
References:
|
9
|
+
https://arxiv.org/pdf/2304.08069.pdf
|
10
|
+
"""
|
11
|
+
|
12
|
+
from ultralytics.engine.model import Model
|
13
|
+
from ultralytics.nn.tasks import RTDETRDetectionModel
|
14
|
+
|
15
|
+
from .predict import RTDETRPredictor
|
16
|
+
from .train import RTDETRTrainer
|
17
|
+
from .val import RTDETRValidator
|
18
|
+
|
19
|
+
|
20
|
+
class RTDETR(Model):
|
21
|
+
"""
|
22
|
+
Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
|
23
|
+
|
24
|
+
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
|
25
|
+
selection, and adaptable inference speed.
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
model (str): Path to the pre-trained model.
|
29
|
+
|
30
|
+
Examples:
|
31
|
+
>>> from ultralytics import RTDETR
|
32
|
+
>>> model = RTDETR("rtdetr-l.pt")
|
33
|
+
>>> results = model("image.jpg")
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, model: str = "rtdetr-l.pt") -> None:
|
37
|
+
"""
|
38
|
+
Initialize the RT-DETR model with the given pre-trained model file.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
|
42
|
+
|
43
|
+
Raises:
|
44
|
+
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
45
|
+
"""
|
46
|
+
super().__init__(model=model, task="detect")
|
47
|
+
|
48
|
+
@property
|
49
|
+
def task_map(self) -> dict:
|
50
|
+
"""
|
51
|
+
Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
(dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
55
|
+
"""
|
56
|
+
return {
|
57
|
+
"detect": {
|
58
|
+
"predictor": RTDETRPredictor,
|
59
|
+
"validator": RTDETRValidator,
|
60
|
+
"trainer": RTDETRTrainer,
|
61
|
+
"model": RTDETRDetectionModel,
|
62
|
+
}
|
63
|
+
}
|