dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/hub/session.py
CHANGED
|
@@ -19,8 +19,7 @@ AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version_
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class HUBTrainingSession:
|
|
22
|
-
"""
|
|
23
|
-
HUB training session for Ultralytics HUB YOLO models.
|
|
22
|
+
"""HUB training session for Ultralytics HUB YOLO models.
|
|
24
23
|
|
|
25
24
|
This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including
|
|
26
25
|
model creation, metrics tracking, and checkpoint uploading.
|
|
@@ -45,12 +44,11 @@ class HUBTrainingSession:
|
|
|
45
44
|
"""
|
|
46
45
|
|
|
47
46
|
def __init__(self, identifier: str):
|
|
48
|
-
"""
|
|
49
|
-
Initialize the HUBTrainingSession with the provided model identifier.
|
|
47
|
+
"""Initialize the HUBTrainingSession with the provided model identifier.
|
|
50
48
|
|
|
51
49
|
Args:
|
|
52
|
-
identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string
|
|
53
|
-
|
|
50
|
+
identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string or a
|
|
51
|
+
model key with specific format.
|
|
54
52
|
|
|
55
53
|
Raises:
|
|
56
54
|
ValueError: If the provided model identifier is invalid.
|
|
@@ -93,8 +91,7 @@ class HUBTrainingSession:
|
|
|
93
91
|
|
|
94
92
|
@classmethod
|
|
95
93
|
def create_session(cls, identifier: str, args: dict[str, Any] | None = None):
|
|
96
|
-
"""
|
|
97
|
-
Create an authenticated HUBTrainingSession or return None.
|
|
94
|
+
"""Create an authenticated HUBTrainingSession or return None.
|
|
98
95
|
|
|
99
96
|
Args:
|
|
100
97
|
identifier (str): Model identifier used to initialize the HUB training session.
|
|
@@ -114,8 +111,7 @@ class HUBTrainingSession:
|
|
|
114
111
|
return None
|
|
115
112
|
|
|
116
113
|
def load_model(self, model_id: str):
|
|
117
|
-
"""
|
|
118
|
-
Load an existing model from Ultralytics HUB using the provided model identifier.
|
|
114
|
+
"""Load an existing model from Ultralytics HUB using the provided model identifier.
|
|
119
115
|
|
|
120
116
|
Args:
|
|
121
117
|
model_id (str): The identifier of the model to load.
|
|
@@ -140,8 +136,7 @@ class HUBTrainingSession:
|
|
|
140
136
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
|
141
137
|
|
|
142
138
|
def create_model(self, model_args: dict[str, Any]):
|
|
143
|
-
"""
|
|
144
|
-
Initialize a HUB training session with the specified model arguments.
|
|
139
|
+
"""Initialize a HUB training session with the specified model arguments.
|
|
145
140
|
|
|
146
141
|
Args:
|
|
147
142
|
model_args (dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,
|
|
@@ -186,8 +181,7 @@ class HUBTrainingSession:
|
|
|
186
181
|
|
|
187
182
|
@staticmethod
|
|
188
183
|
def _parse_identifier(identifier: str):
|
|
189
|
-
"""
|
|
190
|
-
Parse the given identifier to determine the type and extract relevant components.
|
|
184
|
+
"""Parse the given identifier to determine the type and extract relevant components.
|
|
191
185
|
|
|
192
186
|
The method supports different identifier formats:
|
|
193
187
|
- A HUB model URL https://hub.ultralytics.com/models/MODEL
|
|
@@ -218,12 +212,11 @@ class HUBTrainingSession:
|
|
|
218
212
|
return api_key, model_id, filename
|
|
219
213
|
|
|
220
214
|
def _set_train_args(self):
|
|
221
|
-
"""
|
|
222
|
-
Initialize training arguments and create a model entry on the Ultralytics HUB.
|
|
215
|
+
"""Initialize training arguments and create a model entry on the Ultralytics HUB.
|
|
223
216
|
|
|
224
|
-
This method sets up training arguments based on the model's state and updates them with any additional
|
|
225
|
-
|
|
226
|
-
|
|
217
|
+
This method sets up training arguments based on the model's state and updates them with any additional arguments
|
|
218
|
+
provided. It handles different states of the model, such as whether it's resumable, pretrained, or requires
|
|
219
|
+
specific file setup.
|
|
227
220
|
|
|
228
221
|
Raises:
|
|
229
222
|
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
|
@@ -261,8 +254,7 @@ class HUBTrainingSession:
|
|
|
261
254
|
*args,
|
|
262
255
|
**kwargs,
|
|
263
256
|
):
|
|
264
|
-
"""
|
|
265
|
-
Execute request_func with retries, timeout handling, optional threading, and progress tracking.
|
|
257
|
+
"""Execute request_func with retries, timeout handling, optional threading, and progress tracking.
|
|
266
258
|
|
|
267
259
|
Args:
|
|
268
260
|
request_func (callable): The function to execute.
|
|
@@ -342,8 +334,7 @@ class HUBTrainingSession:
|
|
|
342
334
|
return status_code in retry_codes
|
|
343
335
|
|
|
344
336
|
def _get_failure_message(self, response, retry: int, timeout: int) -> str:
|
|
345
|
-
"""
|
|
346
|
-
Generate a retry message based on the response status code.
|
|
337
|
+
"""Generate a retry message based on the response status code.
|
|
347
338
|
|
|
348
339
|
Args:
|
|
349
340
|
response (requests.Response): The HTTP response object.
|
|
@@ -379,8 +370,7 @@ class HUBTrainingSession:
|
|
|
379
370
|
map: float = 0.0,
|
|
380
371
|
final: bool = False,
|
|
381
372
|
) -> None:
|
|
382
|
-
"""
|
|
383
|
-
Upload a model checkpoint to Ultralytics HUB.
|
|
373
|
+
"""Upload a model checkpoint to Ultralytics HUB.
|
|
384
374
|
|
|
385
375
|
Args:
|
|
386
376
|
epoch (int): The current training epoch.
|
ultralytics/hub/utils.py
CHANGED
|
@@ -21,8 +21,7 @@ HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/h
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def request_with_credentials(url: str) -> Any:
|
|
24
|
-
"""
|
|
25
|
-
Make an AJAX request with cookies attached in a Google Colab environment.
|
|
24
|
+
"""Make an AJAX request with cookies attached in a Google Colab environment.
|
|
26
25
|
|
|
27
26
|
Args:
|
|
28
27
|
url (str): The URL to make the request to.
|
|
@@ -35,8 +34,8 @@ def request_with_credentials(url: str) -> Any:
|
|
|
35
34
|
"""
|
|
36
35
|
if not IS_COLAB:
|
|
37
36
|
raise OSError("request_with_credentials() must run in a Colab environment")
|
|
38
|
-
from google.colab import output
|
|
39
|
-
from IPython import display
|
|
37
|
+
from google.colab import output
|
|
38
|
+
from IPython import display
|
|
40
39
|
|
|
41
40
|
display.display(
|
|
42
41
|
display.Javascript(
|
|
@@ -62,8 +61,7 @@ def request_with_credentials(url: str) -> Any:
|
|
|
62
61
|
|
|
63
62
|
|
|
64
63
|
def requests_with_progress(method: str, url: str, **kwargs):
|
|
65
|
-
"""
|
|
66
|
-
Make an HTTP request using the specified method and URL, with an optional progress bar.
|
|
64
|
+
"""Make an HTTP request using the specified method and URL, with an optional progress bar.
|
|
67
65
|
|
|
68
66
|
Args:
|
|
69
67
|
method (str): The HTTP method to use (e.g. 'GET', 'POST').
|
|
@@ -106,8 +104,7 @@ def smart_request(
|
|
|
106
104
|
progress: bool = False,
|
|
107
105
|
**kwargs,
|
|
108
106
|
):
|
|
109
|
-
"""
|
|
110
|
-
Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
|
107
|
+
"""Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
|
111
108
|
|
|
112
109
|
Args:
|
|
113
110
|
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
|
ultralytics/models/__init__.py
CHANGED
|
@@ -6,4 +6,4 @@ from .rtdetr import RTDETR
|
|
|
6
6
|
from .sam import SAM
|
|
7
7
|
from .yolo import YOLO, YOLOE, YOLOWorld
|
|
8
8
|
|
|
9
|
-
__all__ = "
|
|
9
|
+
__all__ = "NAS", "RTDETR", "SAM", "YOLO", "YOLOE", "FastSAM", "YOLOWorld" # allow simpler import
|
|
@@ -12,8 +12,7 @@ from .val import FastSAMValidator
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class FastSAM(Model):
|
|
15
|
-
"""
|
|
16
|
-
FastSAM model interface for segment anything tasks.
|
|
15
|
+
"""FastSAM model interface for Segment Anything tasks.
|
|
17
16
|
|
|
18
17
|
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
|
|
19
18
|
Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
|
|
@@ -36,11 +35,11 @@ class FastSAM(Model):
|
|
|
36
35
|
>>> results = model.predict("image.jpg", bboxes=[[100, 100, 200, 200]])
|
|
37
36
|
"""
|
|
38
37
|
|
|
39
|
-
def __init__(self, model: str = "FastSAM-x.pt"):
|
|
38
|
+
def __init__(self, model: str | Path = "FastSAM-x.pt"):
|
|
40
39
|
"""Initialize the FastSAM model with the specified pre-trained weights."""
|
|
41
40
|
if str(model) == "FastSAM.pt":
|
|
42
41
|
model = "FastSAM-x.pt"
|
|
43
|
-
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM
|
|
42
|
+
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM only supports pre-trained weights."
|
|
44
43
|
super().__init__(model=model, task="segment")
|
|
45
44
|
|
|
46
45
|
def predict(
|
|
@@ -53,15 +52,14 @@ class FastSAM(Model):
|
|
|
53
52
|
texts: list | None = None,
|
|
54
53
|
**kwargs: Any,
|
|
55
54
|
):
|
|
56
|
-
"""
|
|
57
|
-
Perform segmentation prediction on image or video source.
|
|
55
|
+
"""Perform segmentation prediction on image or video source.
|
|
58
56
|
|
|
59
|
-
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
|
|
60
|
-
|
|
57
|
+
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these prompts
|
|
58
|
+
and passes them to the parent class predict method for processing.
|
|
61
59
|
|
|
62
60
|
Args:
|
|
63
|
-
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,
|
|
64
|
-
|
|
61
|
+
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image, or
|
|
62
|
+
numpy array.
|
|
65
63
|
stream (bool): Whether to enable real-time streaming mode for video inputs.
|
|
66
64
|
bboxes (list, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
|
|
67
65
|
points (list, optional): Point coordinates for prompted segmentation in format [[x, y]].
|
|
@@ -4,16 +4,16 @@ import torch
|
|
|
4
4
|
from PIL import Image
|
|
5
5
|
|
|
6
6
|
from ultralytics.models.yolo.segment import SegmentationPredictor
|
|
7
|
-
from ultralytics.utils import DEFAULT_CFG
|
|
7
|
+
from ultralytics.utils import DEFAULT_CFG
|
|
8
8
|
from ultralytics.utils.metrics import box_iou
|
|
9
9
|
from ultralytics.utils.ops import scale_masks
|
|
10
|
+
from ultralytics.utils.torch_utils import TORCH_1_10
|
|
10
11
|
|
|
11
12
|
from .utils import adjust_bboxes_to_image_border
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class FastSAMPredictor(SegmentationPredictor):
|
|
15
|
-
"""
|
|
16
|
-
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
|
|
16
|
+
"""FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
|
|
17
17
|
|
|
18
18
|
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
|
|
19
19
|
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
|
|
@@ -22,8 +22,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
22
22
|
Attributes:
|
|
23
23
|
prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
|
|
24
24
|
device (torch.device): Device on which model and tensors are processed.
|
|
25
|
-
|
|
26
|
-
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
|
|
25
|
+
clip (Any, optional): CLIP model used for text-based prompting, loaded on demand.
|
|
27
26
|
|
|
28
27
|
Methods:
|
|
29
28
|
postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
|
|
@@ -32,8 +31,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
32
31
|
"""
|
|
33
32
|
|
|
34
33
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
35
|
-
"""
|
|
36
|
-
Initialize the FastSAMPredictor with configuration and callbacks.
|
|
34
|
+
"""Initialize the FastSAMPredictor with configuration and callbacks.
|
|
37
35
|
|
|
38
36
|
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
|
|
39
37
|
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
|
|
@@ -48,8 +46,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
48
46
|
self.prompts = {}
|
|
49
47
|
|
|
50
48
|
def postprocess(self, preds, img, orig_imgs):
|
|
51
|
-
"""
|
|
52
|
-
Apply postprocessing to FastSAM predictions and handle prompts.
|
|
49
|
+
"""Apply postprocessing to FastSAM predictions and handle prompts.
|
|
53
50
|
|
|
54
51
|
Args:
|
|
55
52
|
preds (list[torch.Tensor]): Raw predictions from the model.
|
|
@@ -76,8 +73,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
76
73
|
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
|
|
77
74
|
|
|
78
75
|
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
|
79
|
-
"""
|
|
80
|
-
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
|
|
76
|
+
"""Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
|
|
81
77
|
|
|
82
78
|
Args:
|
|
83
79
|
results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
|
|
@@ -100,7 +96,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
100
96
|
continue
|
|
101
97
|
masks = result.masks.data
|
|
102
98
|
if masks.shape[1:] != result.orig_shape:
|
|
103
|
-
masks = scale_masks(masks[None], result.orig_shape)[0]
|
|
99
|
+
masks = (scale_masks(masks[None].float(), result.orig_shape)[0] > 0.5).byte()
|
|
104
100
|
# bboxes prompt
|
|
105
101
|
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
|
106
102
|
if bboxes is not None:
|
|
@@ -119,7 +115,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
119
115
|
labels = torch.ones(points.shape[0])
|
|
120
116
|
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
|
121
117
|
assert len(labels) == len(points), (
|
|
122
|
-
f"Expected `labels`
|
|
118
|
+
f"Expected `labels` to have the same length as `points`, but got {len(labels)} and {len(points)}."
|
|
123
119
|
)
|
|
124
120
|
point_idx = (
|
|
125
121
|
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
|
@@ -135,7 +131,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
135
131
|
crop_ims, filter_idx = [], []
|
|
136
132
|
for i, b in enumerate(result.boxes.xyxy.tolist()):
|
|
137
133
|
x1, y1, x2, y2 = (int(x) for x in b)
|
|
138
|
-
if masks[i].sum() <= 100:
|
|
134
|
+
if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100: # torch 1.9 bug workaround
|
|
139
135
|
filter_idx.append(i)
|
|
140
136
|
continue
|
|
141
137
|
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
|
|
@@ -150,8 +146,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
150
146
|
return prompt_results
|
|
151
147
|
|
|
152
148
|
def _clip_inference(self, images, texts):
|
|
153
|
-
"""
|
|
154
|
-
Perform CLIP inference to calculate similarity between images and text prompts.
|
|
149
|
+
"""Perform CLIP inference to calculate similarity between images and text prompts.
|
|
155
150
|
|
|
156
151
|
Args:
|
|
157
152
|
images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
|
|
@@ -160,20 +155,14 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
160
155
|
Returns:
|
|
161
156
|
(torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
|
|
162
157
|
"""
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
tokenized_text = clip.tokenize(texts).to(self.device)
|
|
172
|
-
image_features = self.clip_model.encode_image(images)
|
|
173
|
-
text_features = self.clip_model.encode_text(tokenized_text)
|
|
174
|
-
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
|
|
175
|
-
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
|
|
176
|
-
return (image_features * text_features[:, None]).sum(-1) # (M, N)
|
|
158
|
+
from ultralytics.nn.text_model import CLIP
|
|
159
|
+
|
|
160
|
+
if not hasattr(self, "clip"):
|
|
161
|
+
self.clip = CLIP("ViT-B/32", device=self.device)
|
|
162
|
+
images = torch.stack([self.clip.image_preprocess(image).to(self.device) for image in images])
|
|
163
|
+
image_features = self.clip.encode_image(images)
|
|
164
|
+
text_features = self.clip.encode_text(self.clip.tokenize(texts))
|
|
165
|
+
return text_features @ image_features.T # (M, N)
|
|
177
166
|
|
|
178
167
|
def set_prompts(self, prompts):
|
|
179
168
|
"""Set prompts to be used during inference."""
|
|
@@ -2,8 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
|
5
|
-
"""
|
|
6
|
-
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
|
5
|
+
"""Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
|
7
6
|
|
|
8
7
|
Args:
|
|
9
8
|
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
|
@@ -4,10 +4,9 @@ from ultralytics.models.yolo.segment import SegmentationValidator
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class FastSAMValidator(SegmentationValidator):
|
|
7
|
-
"""
|
|
8
|
-
Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
|
|
7
|
+
"""Custom validation class for FastSAM (Segment Anything Model) segmentation in the Ultralytics YOLO framework.
|
|
9
8
|
|
|
10
|
-
Extends the SegmentationValidator class, customizing the validation process specifically for
|
|
9
|
+
Extends the SegmentationValidator class, customizing the validation process specifically for FastSAM. This class
|
|
11
10
|
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
|
|
12
11
|
to avoid errors during validation.
|
|
13
12
|
|
|
@@ -19,15 +18,14 @@ class FastSAMValidator(SegmentationValidator):
|
|
|
19
18
|
metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
|
|
20
19
|
|
|
21
20
|
Methods:
|
|
22
|
-
__init__: Initialize the FastSAMValidator with custom settings for
|
|
21
|
+
__init__: Initialize the FastSAMValidator with custom settings for FastSAM.
|
|
23
22
|
"""
|
|
24
23
|
|
|
25
24
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
|
26
|
-
"""
|
|
27
|
-
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
|
25
|
+
"""Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
|
28
26
|
|
|
29
27
|
Args:
|
|
30
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
28
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
31
29
|
save_dir (Path, optional): Directory to save results.
|
|
32
30
|
args (SimpleNamespace, optional): Configuration for the validator.
|
|
33
31
|
_callbacks (list, optional): List of callback functions to be invoked during validation.
|
ultralytics/models/nas/model.py
CHANGED
|
@@ -18,11 +18,10 @@ from .val import NASValidator
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class NAS(Model):
|
|
21
|
-
"""
|
|
22
|
-
YOLO-NAS model for object detection.
|
|
21
|
+
"""YOLO-NAS model for object detection.
|
|
23
22
|
|
|
24
|
-
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
|
|
25
|
-
|
|
23
|
+
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. It
|
|
24
|
+
is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
|
|
26
25
|
|
|
27
26
|
Attributes:
|
|
28
27
|
model (torch.nn.Module): The loaded YOLO-NAS model.
|
|
@@ -48,8 +47,7 @@ class NAS(Model):
|
|
|
48
47
|
super().__init__(model, task="detect")
|
|
49
48
|
|
|
50
49
|
def _load(self, weights: str, task=None) -> None:
|
|
51
|
-
"""
|
|
52
|
-
Load an existing NAS model weights or create a new NAS model with pretrained weights.
|
|
50
|
+
"""Load an existing NAS model weights or create a new NAS model with pretrained weights.
|
|
53
51
|
|
|
54
52
|
Args:
|
|
55
53
|
weights (str): Path to the model weights file or model name.
|
|
@@ -83,8 +81,7 @@ class NAS(Model):
|
|
|
83
81
|
self.model.eval()
|
|
84
82
|
|
|
85
83
|
def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
|
|
86
|
-
"""
|
|
87
|
-
Log model information.
|
|
84
|
+
"""Log model information.
|
|
88
85
|
|
|
89
86
|
Args:
|
|
90
87
|
detailed (bool): Show detailed information about model.
|
|
@@ -7,12 +7,11 @@ from ultralytics.utils import ops
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class NASPredictor(DetectionPredictor):
|
|
10
|
-
"""
|
|
11
|
-
Ultralytics YOLO NAS Predictor for object detection.
|
|
10
|
+
"""Ultralytics YOLO NAS Predictor for object detection.
|
|
12
11
|
|
|
13
|
-
This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the
|
|
14
|
-
|
|
15
|
-
|
|
12
|
+
This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw
|
|
13
|
+
predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the
|
|
14
|
+
bounding boxes to fit the original image dimensions.
|
|
16
15
|
|
|
17
16
|
Attributes:
|
|
18
17
|
args (Namespace): Namespace containing various configurations for post-processing including confidence
|
|
@@ -33,12 +32,11 @@ class NASPredictor(DetectionPredictor):
|
|
|
33
32
|
"""
|
|
34
33
|
|
|
35
34
|
def postprocess(self, preds_in, img, orig_imgs):
|
|
36
|
-
"""
|
|
37
|
-
Postprocess NAS model predictions to generate final detection results.
|
|
35
|
+
"""Postprocess NAS model predictions to generate final detection results.
|
|
38
36
|
|
|
39
37
|
This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
|
|
40
|
-
post-processing operations to generate the final detection results compatible with Ultralytics
|
|
41
|
-
|
|
38
|
+
post-processing operations to generate the final detection results compatible with Ultralytics result
|
|
39
|
+
visualization and analysis tools.
|
|
42
40
|
|
|
43
41
|
Args:
|
|
44
42
|
preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.
|
ultralytics/models/nas/val.py
CHANGED
|
@@ -9,8 +9,7 @@ __all__ = ["NASValidator"]
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class NASValidator(DetectionValidator):
|
|
12
|
-
"""
|
|
13
|
-
Ultralytics YOLO NAS Validator for object detection.
|
|
12
|
+
"""Ultralytics YOLO NAS Validator for object detection.
|
|
14
13
|
|
|
15
14
|
Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
|
|
16
15
|
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
|
|
@@ -11,6 +11,7 @@ References:
|
|
|
11
11
|
|
|
12
12
|
from ultralytics.engine.model import Model
|
|
13
13
|
from ultralytics.nn.tasks import RTDETRDetectionModel
|
|
14
|
+
from ultralytics.utils.torch_utils import TORCH_1_11
|
|
14
15
|
|
|
15
16
|
from .predict import RTDETRPredictor
|
|
16
17
|
from .train import RTDETRTrainer
|
|
@@ -18,11 +19,10 @@ from .val import RTDETRValidator
|
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class RTDETR(Model):
|
|
21
|
-
"""
|
|
22
|
-
Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
|
|
22
|
+
"""Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
|
|
23
23
|
|
|
24
|
-
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware
|
|
25
|
-
|
|
24
|
+
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
|
|
25
|
+
selection, and adaptable inference speed.
|
|
26
26
|
|
|
27
27
|
Attributes:
|
|
28
28
|
model (str): Path to the pre-trained model.
|
|
@@ -38,18 +38,17 @@ class RTDETR(Model):
|
|
|
38
38
|
"""
|
|
39
39
|
|
|
40
40
|
def __init__(self, model: str = "rtdetr-l.pt") -> None:
|
|
41
|
-
"""
|
|
42
|
-
Initialize the RT-DETR model with the given pre-trained model file.
|
|
41
|
+
"""Initialize the RT-DETR model with the given pre-trained model file.
|
|
43
42
|
|
|
44
43
|
Args:
|
|
45
44
|
model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
|
|
46
45
|
"""
|
|
46
|
+
assert TORCH_1_11, "RTDETR requires torch>=1.11"
|
|
47
47
|
super().__init__(model=model, task="detect")
|
|
48
48
|
|
|
49
49
|
@property
|
|
50
50
|
def task_map(self) -> dict:
|
|
51
|
-
"""
|
|
52
|
-
Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
|
51
|
+
"""Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
|
53
52
|
|
|
54
53
|
Returns:
|
|
55
54
|
(dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
|
@@ -9,11 +9,10 @@ from ultralytics.utils import ops
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class RTDETRPredictor(BasePredictor):
|
|
12
|
-
"""
|
|
13
|
-
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
|
|
12
|
+
"""RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
|
|
14
13
|
|
|
15
|
-
This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.
|
|
16
|
-
|
|
14
|
+
This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It
|
|
15
|
+
supports key features like efficient hybrid encoding and IoU-aware query selection.
|
|
17
16
|
|
|
18
17
|
Attributes:
|
|
19
18
|
imgsz (int): Image size for inference (must be square and scale-filled).
|
|
@@ -34,21 +33,20 @@ class RTDETRPredictor(BasePredictor):
|
|
|
34
33
|
"""
|
|
35
34
|
|
|
36
35
|
def postprocess(self, preds, img, orig_imgs):
|
|
37
|
-
"""
|
|
38
|
-
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
|
36
|
+
"""Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
|
39
37
|
|
|
40
|
-
The method filters detections based on confidence and class if specified in `self.args`. It converts
|
|
41
|
-
|
|
38
|
+
The method filters detections based on confidence and class if specified in `self.args`. It converts model
|
|
39
|
+
predictions to Results objects containing properly scaled bounding boxes.
|
|
42
40
|
|
|
43
41
|
Args:
|
|
44
|
-
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
|
|
45
|
-
|
|
42
|
+
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes
|
|
43
|
+
and scores.
|
|
46
44
|
img (torch.Tensor): Processed input images with shape (N, 3, H, W).
|
|
47
45
|
orig_imgs (list | torch.Tensor): Original, unprocessed images.
|
|
48
46
|
|
|
49
47
|
Returns:
|
|
50
|
-
results (list[Results]): A list of Results objects containing the post-processed bounding boxes,
|
|
51
|
-
|
|
48
|
+
results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence
|
|
49
|
+
scores, and class labels.
|
|
52
50
|
"""
|
|
53
51
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
|
54
52
|
preds = [preds, None]
|
|
@@ -57,7 +55,7 @@ class RTDETRPredictor(BasePredictor):
|
|
|
57
55
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
|
58
56
|
|
|
59
57
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
60
|
-
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
58
|
+
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
|
|
61
59
|
|
|
62
60
|
results = []
|
|
63
61
|
for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
|
|
@@ -75,15 +73,13 @@ class RTDETRPredictor(BasePredictor):
|
|
|
75
73
|
return results
|
|
76
74
|
|
|
77
75
|
def pre_transform(self, im):
|
|
78
|
-
"""
|
|
79
|
-
Pre-transform input images before feeding them into the model for inference.
|
|
76
|
+
"""Pre-transform input images before feeding them into the model for inference.
|
|
80
77
|
|
|
81
|
-
The input images are letterboxed to ensure a square aspect ratio and scale-filled.
|
|
82
|
-
(640) and scale_filled.
|
|
78
|
+
The input images are letterboxed to ensure a square aspect ratio and scale-filled.
|
|
83
79
|
|
|
84
80
|
Args:
|
|
85
|
-
im (list[np.ndarray]
|
|
86
|
-
|
|
81
|
+
im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
|
|
82
|
+
list.
|
|
87
83
|
|
|
88
84
|
Returns:
|
|
89
85
|
(list): List of pre-transformed images ready for model inference.
|
|
@@ -12,12 +12,11 @@ from .val import RTDETRDataset, RTDETRValidator
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class RTDETRTrainer(DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
|
15
|
+
"""Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
|
17
16
|
|
|
18
|
-
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
|
|
19
|
-
The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
|
|
20
|
-
speed.
|
|
17
|
+
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
|
|
18
|
+
RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
|
|
19
|
+
inference speed.
|
|
21
20
|
|
|
22
21
|
Attributes:
|
|
23
22
|
loss_names (tuple): Names of the loss components used for training.
|
|
@@ -31,20 +30,19 @@ class RTDETRTrainer(DetectionTrainer):
|
|
|
31
30
|
build_dataset: Build and return an RT-DETR dataset for training or validation.
|
|
32
31
|
get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
|
|
33
32
|
|
|
34
|
-
Notes:
|
|
35
|
-
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
|
36
|
-
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
|
37
|
-
|
|
38
33
|
Examples:
|
|
39
34
|
>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
|
|
40
35
|
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
|
|
41
36
|
>>> trainer = RTDETRTrainer(overrides=args)
|
|
42
37
|
>>> trainer.train()
|
|
38
|
+
|
|
39
|
+
Notes:
|
|
40
|
+
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
|
41
|
+
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
|
43
42
|
"""
|
|
44
43
|
|
|
45
44
|
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
|
|
46
|
-
"""
|
|
47
|
-
Initialize and return an RT-DETR model for object detection tasks.
|
|
45
|
+
"""Initialize and return an RT-DETR model for object detection tasks.
|
|
48
46
|
|
|
49
47
|
Args:
|
|
50
48
|
cfg (dict, optional): Model configuration.
|
|
@@ -60,8 +58,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
|
60
58
|
return model
|
|
61
59
|
|
|
62
60
|
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
|
|
63
|
-
"""
|
|
64
|
-
Build and return an RT-DETR dataset for training or validation.
|
|
61
|
+
"""Build and return an RT-DETR dataset for training or validation.
|
|
65
62
|
|
|
66
63
|
Args:
|
|
67
64
|
img_path (str): Path to the folder containing images.
|