dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
- dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -7
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +184 -75
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
ultralytics/hub/session.py
CHANGED
@@ -5,6 +5,7 @@ import threading
|
|
5
5
|
import time
|
6
6
|
from http import HTTPStatus
|
7
7
|
from pathlib import Path
|
8
|
+
from typing import Any, Dict, Optional
|
8
9
|
from urllib.parse import parse_qs, urlparse
|
9
10
|
|
10
11
|
import requests
|
@@ -19,7 +20,7 @@ AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version_
|
|
19
20
|
|
20
21
|
class HUBTrainingSession:
|
21
22
|
"""
|
22
|
-
HUB training session for Ultralytics HUB YOLO models.
|
23
|
+
HUB training session for Ultralytics HUB YOLO models.
|
23
24
|
|
24
25
|
This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including
|
25
26
|
model creation, metrics tracking, and checkpoint uploading.
|
@@ -27,28 +28,29 @@ class HUBTrainingSession:
|
|
27
28
|
Attributes:
|
28
29
|
model_id (str): Identifier for the YOLO model being trained.
|
29
30
|
model_url (str): URL for the model in Ultralytics HUB.
|
30
|
-
rate_limits (
|
31
|
-
timers (
|
32
|
-
metrics_queue (
|
33
|
-
metrics_upload_failed_queue (
|
34
|
-
model (
|
31
|
+
rate_limits (Dict[str, int]): Rate limits for different API calls in seconds.
|
32
|
+
timers (Dict[str, Any]): Timers for rate limiting.
|
33
|
+
metrics_queue (Dict[str, Any]): Queue for the model's metrics.
|
34
|
+
metrics_upload_failed_queue (Dict[str, Any]): Queue for metrics that failed to upload.
|
35
|
+
model (Any): Model data fetched from Ultralytics HUB.
|
35
36
|
model_file (str): Path to the model file.
|
36
|
-
train_args (
|
37
|
-
client (
|
37
|
+
train_args (Dict[str, Any]): Arguments for training the model.
|
38
|
+
client (Any): Client for interacting with Ultralytics HUB.
|
38
39
|
filename (str): Filename of the model.
|
39
40
|
|
40
41
|
Examples:
|
42
|
+
Create a training session with a model URL
|
41
43
|
>>> session = HUBTrainingSession("https://hub.ultralytics.com/models/example-model")
|
42
44
|
>>> session.upload_metrics()
|
43
45
|
"""
|
44
46
|
|
45
|
-
def __init__(self, identifier):
|
47
|
+
def __init__(self, identifier: str):
|
46
48
|
"""
|
47
49
|
Initialize the HUBTrainingSession with the provided model identifier.
|
48
50
|
|
49
51
|
Args:
|
50
|
-
identifier (str): Model identifier used to initialize the HUB training session.
|
51
|
-
|
52
|
+
identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string
|
53
|
+
or a model key with specific format.
|
52
54
|
|
53
55
|
Raises:
|
54
56
|
ValueError: If the provided model identifier is invalid.
|
@@ -90,16 +92,16 @@ class HUBTrainingSession:
|
|
90
92
|
)
|
91
93
|
|
92
94
|
@classmethod
|
93
|
-
def create_session(cls, identifier, args=None):
|
95
|
+
def create_session(cls, identifier: str, args: Optional[Dict[str, Any]] = None):
|
94
96
|
"""
|
95
97
|
Create an authenticated HUBTrainingSession or return None.
|
96
98
|
|
97
99
|
Args:
|
98
100
|
identifier (str): Model identifier used to initialize the HUB training session.
|
99
|
-
args (
|
101
|
+
args (Dict[str, Any], optional): Arguments for creating a new model if identifier is not a HUB model URL.
|
100
102
|
|
101
103
|
Returns:
|
102
|
-
(HUBTrainingSession | None): An authenticated session or None if creation fails.
|
104
|
+
session (HUBTrainingSession | None): An authenticated session or None if creation fails.
|
103
105
|
"""
|
104
106
|
try:
|
105
107
|
session = cls(identifier)
|
@@ -111,7 +113,7 @@ class HUBTrainingSession:
|
|
111
113
|
except (PermissionError, ModuleNotFoundError, AssertionError):
|
112
114
|
return None
|
113
115
|
|
114
|
-
def load_model(self, model_id):
|
116
|
+
def load_model(self, model_id: str):
|
115
117
|
"""
|
116
118
|
Load an existing model from Ultralytics HUB using the provided model identifier.
|
117
119
|
|
@@ -137,12 +139,13 @@ class HUBTrainingSession:
|
|
137
139
|
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
138
140
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
139
141
|
|
140
|
-
def create_model(self, model_args):
|
142
|
+
def create_model(self, model_args: Dict[str, Any]):
|
141
143
|
"""
|
142
144
|
Initialize a HUB training session with the specified model arguments.
|
143
145
|
|
144
146
|
Args:
|
145
|
-
model_args (
|
147
|
+
model_args (Dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,
|
148
|
+
etc.
|
146
149
|
|
147
150
|
Returns:
|
148
151
|
(None): If the model could not be created.
|
@@ -182,7 +185,7 @@ class HUBTrainingSession:
|
|
182
185
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
183
186
|
|
184
187
|
@staticmethod
|
185
|
-
def _parse_identifier(identifier):
|
188
|
+
def _parse_identifier(identifier: str):
|
186
189
|
"""
|
187
190
|
Parse the given identifier to determine the type and extract relevant components.
|
188
191
|
|
@@ -195,7 +198,9 @@ class HUBTrainingSession:
|
|
195
198
|
identifier (str): The identifier string to be parsed.
|
196
199
|
|
197
200
|
Returns:
|
198
|
-
(
|
201
|
+
api_key (str | None): Extracted API key if present.
|
202
|
+
model_id (str | None): Extracted model ID if present.
|
203
|
+
filename (str | None): Extracted filename if present.
|
199
204
|
|
200
205
|
Raises:
|
201
206
|
HUBModelError: If the identifier format is not recognized.
|
@@ -247,17 +252,17 @@ class HUBTrainingSession:
|
|
247
252
|
def request_queue(
|
248
253
|
self,
|
249
254
|
request_func,
|
250
|
-
retry=3,
|
251
|
-
timeout=30,
|
252
|
-
thread=True,
|
253
|
-
verbose=True,
|
254
|
-
progress_total=None,
|
255
|
-
stream_response=None,
|
255
|
+
retry: int = 3,
|
256
|
+
timeout: int = 30,
|
257
|
+
thread: bool = True,
|
258
|
+
verbose: bool = True,
|
259
|
+
progress_total: Optional[int] = None,
|
260
|
+
stream_response: Optional[bool] = None,
|
256
261
|
*args,
|
257
262
|
**kwargs,
|
258
263
|
):
|
259
264
|
"""
|
260
|
-
|
265
|
+
Execute request_func with retries, timeout handling, optional threading, and progress tracking.
|
261
266
|
|
262
267
|
Args:
|
263
268
|
request_func (callable): The function to execute.
|
@@ -275,7 +280,7 @@ class HUBTrainingSession:
|
|
275
280
|
"""
|
276
281
|
|
277
282
|
def retry_request():
|
278
|
-
"""Attempt to call
|
283
|
+
"""Attempt to call request_func with retries, timeout, and optional threading."""
|
279
284
|
t0 = time.time() # Record the start time for the timeout
|
280
285
|
response = None
|
281
286
|
for i in range(retry + 1):
|
@@ -327,16 +332,8 @@ class HUBTrainingSession:
|
|
327
332
|
return retry_request()
|
328
333
|
|
329
334
|
@staticmethod
|
330
|
-
def _should_retry(status_code):
|
331
|
-
"""
|
332
|
-
Determine if a request should be retried based on the HTTP status code.
|
333
|
-
|
334
|
-
Args:
|
335
|
-
status_code (int): The HTTP status code from the response.
|
336
|
-
|
337
|
-
Returns:
|
338
|
-
(bool): True if the request should be retried, False otherwise.
|
339
|
-
"""
|
335
|
+
def _should_retry(status_code: int) -> bool:
|
336
|
+
"""Determine if a request should be retried based on the HTTP status code."""
|
340
337
|
retry_codes = {
|
341
338
|
HTTPStatus.REQUEST_TIMEOUT,
|
342
339
|
HTTPStatus.BAD_GATEWAY,
|
@@ -344,7 +341,7 @@ class HUBTrainingSession:
|
|
344
341
|
}
|
345
342
|
return status_code in retry_codes
|
346
343
|
|
347
|
-
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
|
344
|
+
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int) -> str:
|
348
345
|
"""
|
349
346
|
Generate a retry message based on the response status code.
|
350
347
|
|
@@ -423,24 +420,13 @@ class HUBTrainingSession:
|
|
423
420
|
|
424
421
|
@staticmethod
|
425
422
|
def _show_upload_progress(content_length: int, response: requests.Response) -> None:
|
426
|
-
"""
|
427
|
-
Display a progress bar to track the upload progress of a file download.
|
428
|
-
|
429
|
-
Args:
|
430
|
-
content_length (int): The total size of the content to be downloaded in bytes.
|
431
|
-
response (requests.Response): The response object from the file download request.
|
432
|
-
"""
|
423
|
+
"""Display a progress bar to track the upload progress of a file download."""
|
433
424
|
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
434
425
|
for data in response.iter_content(chunk_size=1024):
|
435
426
|
pbar.update(len(data))
|
436
427
|
|
437
428
|
@staticmethod
|
438
429
|
def _iterate_content(response: requests.Response) -> None:
|
439
|
-
"""
|
440
|
-
Process the streamed HTTP response data.
|
441
|
-
|
442
|
-
Args:
|
443
|
-
response (requests.Response): The response object from the file download request.
|
444
|
-
"""
|
430
|
+
"""Process the streamed HTTP response data."""
|
445
431
|
for _ in response.iter_content(chunk_size=1024):
|
446
432
|
pass # Do nothing with data chunks
|
ultralytics/hub/utils.py
CHANGED
@@ -5,6 +5,7 @@ import random
|
|
5
5
|
import threading
|
6
6
|
import time
|
7
7
|
from pathlib import Path
|
8
|
+
from typing import Any, Optional
|
8
9
|
|
9
10
|
import requests
|
10
11
|
|
@@ -36,7 +37,7 @@ PREFIX = colorstr("Ultralytics HUB: ")
|
|
36
37
|
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
|
37
38
|
|
38
39
|
|
39
|
-
def request_with_credentials(url: str) ->
|
40
|
+
def request_with_credentials(url: str) -> Any:
|
40
41
|
"""
|
41
42
|
Make an AJAX request with cookies attached in a Google Colab environment.
|
42
43
|
|
@@ -77,7 +78,7 @@ def request_with_credentials(url: str) -> any:
|
|
77
78
|
return output.eval_js("_hub_tmp")
|
78
79
|
|
79
80
|
|
80
|
-
def requests_with_progress(method, url, **kwargs):
|
81
|
+
def requests_with_progress(method: str, url: str, **kwargs) -> requests.Response:
|
81
82
|
"""
|
82
83
|
Make an HTTP request using the specified method and URL, with an optional progress bar.
|
83
84
|
|
@@ -109,7 +110,17 @@ def requests_with_progress(method, url, **kwargs):
|
|
109
110
|
return response
|
110
111
|
|
111
112
|
|
112
|
-
def smart_request(
|
113
|
+
def smart_request(
|
114
|
+
method: str,
|
115
|
+
url: str,
|
116
|
+
retry: int = 3,
|
117
|
+
timeout: int = 30,
|
118
|
+
thread: bool = True,
|
119
|
+
code: int = -1,
|
120
|
+
verbose: bool = True,
|
121
|
+
progress: bool = False,
|
122
|
+
**kwargs,
|
123
|
+
) -> Optional[requests.Response]:
|
113
124
|
"""
|
114
125
|
Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
115
126
|
|
@@ -125,7 +136,8 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
|
|
125
136
|
**kwargs (Any): Keyword arguments to be passed to the requests function specified in method.
|
126
137
|
|
127
138
|
Returns:
|
128
|
-
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns
|
139
|
+
(requests.Response | None): The HTTP response object. If the request is executed in a separate thread, returns
|
140
|
+
None.
|
129
141
|
"""
|
130
142
|
retry_codes = (408, 500) # retry only these codes
|
131
143
|
|
@@ -177,7 +189,9 @@ class Events:
|
|
177
189
|
|
178
190
|
Attributes:
|
179
191
|
url (str): The URL to send anonymous events.
|
192
|
+
events (list): List of collected events to be sent.
|
180
193
|
rate_limit (float): The rate limit in seconds for sending events.
|
194
|
+
t (float): Rate limit timer in seconds.
|
181
195
|
metadata (dict): A dictionary containing metadata about the environment.
|
182
196
|
enabled (bool): A flag to enable or disable Events based on certain conditions.
|
183
197
|
"""
|
@@ -214,7 +228,7 @@ class Events:
|
|
214
228
|
|
215
229
|
Args:
|
216
230
|
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
|
217
|
-
device (torch.device | str): The device type (e.g., 'cpu', 'cuda').
|
231
|
+
device (torch.device | str, optional): The device type (e.g., 'cpu', 'cuda').
|
218
232
|
"""
|
219
233
|
if not self.enabled:
|
220
234
|
# Events disabled, do nothing
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
+
from typing import Any, Dict, List, Optional
|
4
5
|
|
5
6
|
from ultralytics.engine.model import Model
|
6
7
|
|
@@ -12,50 +13,67 @@ class FastSAM(Model):
|
|
12
13
|
"""
|
13
14
|
FastSAM model interface for segment anything tasks.
|
14
15
|
|
15
|
-
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
|
16
|
-
implementation, allowing for efficient and accurate image segmentation.
|
16
|
+
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
|
17
|
+
Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
|
17
18
|
|
18
19
|
Attributes:
|
19
20
|
model (str): Path to the pre-trained FastSAM model file.
|
20
21
|
task (str): The task type, set to "segment" for FastSAM models.
|
21
22
|
|
23
|
+
Methods:
|
24
|
+
predict: Perform segmentation prediction on image or video source with optional prompts.
|
25
|
+
task_map: Returns mapping of segment task to predictor and validator classes.
|
26
|
+
|
22
27
|
Examples:
|
28
|
+
Initialize FastSAM model and run prediction
|
23
29
|
>>> from ultralytics import FastSAM
|
24
|
-
>>> model = FastSAM("
|
30
|
+
>>> model = FastSAM("FastSAM-x.pt")
|
25
31
|
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
32
|
+
|
33
|
+
Run prediction with bounding box prompts
|
34
|
+
>>> results = model.predict("image.jpg", bboxes=[[100, 100, 200, 200]])
|
26
35
|
"""
|
27
36
|
|
28
|
-
def __init__(self, model="FastSAM-x.pt"):
|
37
|
+
def __init__(self, model: str = "FastSAM-x.pt"):
|
29
38
|
"""Initialize the FastSAM model with the specified pre-trained weights."""
|
30
39
|
if str(model) == "FastSAM.pt":
|
31
40
|
model = "FastSAM-x.pt"
|
32
41
|
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
33
42
|
super().__init__(model=model, task="segment")
|
34
43
|
|
35
|
-
def predict(
|
44
|
+
def predict(
|
45
|
+
self,
|
46
|
+
source,
|
47
|
+
stream: bool = False,
|
48
|
+
bboxes: Optional[List] = None,
|
49
|
+
points: Optional[List] = None,
|
50
|
+
labels: Optional[List] = None,
|
51
|
+
texts: Optional[List] = None,
|
52
|
+
**kwargs: Any,
|
53
|
+
):
|
36
54
|
"""
|
37
55
|
Perform segmentation prediction on image or video source.
|
38
56
|
|
39
57
|
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
|
40
|
-
prompts and passes them to the parent class predict method.
|
58
|
+
prompts and passes them to the parent class predict method for processing.
|
41
59
|
|
42
60
|
Args:
|
43
61
|
source (str | PIL.Image | numpy.ndarray): Input source for prediction, can be a file path, URL, PIL image,
|
44
62
|
or numpy array.
|
45
63
|
stream (bool): Whether to enable real-time streaming mode for video inputs.
|
46
|
-
bboxes (
|
47
|
-
points (
|
48
|
-
labels (
|
49
|
-
texts (
|
64
|
+
bboxes (List, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
|
65
|
+
points (List, optional): Point coordinates for prompted segmentation in format [[x, y]].
|
66
|
+
labels (List, optional): Class labels for prompted segmentation.
|
67
|
+
texts (List, optional): Text prompts for segmentation guidance.
|
50
68
|
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
51
69
|
|
52
70
|
Returns:
|
53
|
-
(
|
71
|
+
(List): List of Results objects containing the prediction results.
|
54
72
|
"""
|
55
73
|
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
56
74
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
57
75
|
|
58
76
|
@property
|
59
|
-
def task_map(self):
|
77
|
+
def task_map(self) -> Dict[str, Dict[str, Any]]:
|
60
78
|
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
61
79
|
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
|
@@ -26,10 +26,9 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
26
26
|
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
|
27
27
|
|
28
28
|
Methods:
|
29
|
-
postprocess:
|
30
|
-
prompt:
|
31
|
-
|
32
|
-
set_prompts: Sets prompts to be used during inference.
|
29
|
+
postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
|
30
|
+
prompt: Perform image segmentation inference based on various prompt types.
|
31
|
+
set_prompts: Set prompts to be used during inference.
|
33
32
|
"""
|
34
33
|
|
35
34
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
@@ -41,7 +40,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
41
40
|
optimized for single-class segmentation.
|
42
41
|
|
43
42
|
Args:
|
44
|
-
cfg (dict): Configuration for the predictor.
|
43
|
+
cfg (dict): Configuration for the predictor.
|
45
44
|
overrides (dict, optional): Configuration overrides.
|
46
45
|
_callbacks (list, optional): List of callback functions.
|
47
46
|
"""
|
@@ -120,7 +119,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
120
119
|
labels = torch.ones(points.shape[0])
|
121
120
|
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
122
121
|
assert len(labels) == len(points), (
|
123
|
-
f"
|
122
|
+
f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
|
124
123
|
)
|
125
124
|
point_idx = (
|
126
125
|
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
@@ -6,12 +6,12 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
|
6
6
|
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
7
7
|
|
8
8
|
Args:
|
9
|
-
boxes (torch.Tensor): Bounding boxes with shape (
|
10
|
-
image_shape (
|
9
|
+
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
10
|
+
image_shape (tuple): Image dimensions as (height, width).
|
11
11
|
threshold (int): Pixel threshold for considering a box close to the border.
|
12
12
|
|
13
13
|
Returns:
|
14
|
-
|
14
|
+
(torch.Tensor): Adjusted bounding boxes with shape (N, 4).
|
15
15
|
"""
|
16
16
|
# Image dimensions
|
17
17
|
h, w = image_shape
|
@@ -6,9 +6,9 @@ from ultralytics.utils.metrics import SegmentMetrics
|
|
6
6
|
|
7
7
|
class FastSAMValidator(SegmentationValidator):
|
8
8
|
"""
|
9
|
-
Custom validation class for
|
9
|
+
Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
|
10
10
|
|
11
|
-
Extends the SegmentationValidator class, customizing the validation process specifically for
|
11
|
+
Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
|
12
12
|
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
|
13
13
|
to avoid errors during validation.
|
14
14
|
|
@@ -18,6 +18,10 @@ class FastSAMValidator(SegmentationValidator):
|
|
18
18
|
pbar (tqdm.tqdm): A progress bar object for displaying validation progress.
|
19
19
|
args (SimpleNamespace): Additional arguments for customization of the validation process.
|
20
20
|
_callbacks (list): List of callback functions to be invoked during validation.
|
21
|
+
metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
|
22
|
+
|
23
|
+
Methods:
|
24
|
+
__init__: Initialize the FastSAMValidator with custom settings for Fast SAM.
|
21
25
|
"""
|
22
26
|
|
23
27
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
@@ -25,11 +29,11 @@ class FastSAMValidator(SegmentationValidator):
|
|
25
29
|
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
26
30
|
|
27
31
|
Args:
|
28
|
-
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
32
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
29
33
|
save_dir (Path, optional): Directory to save results.
|
30
|
-
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
31
|
-
args (SimpleNamespace): Configuration for the validator.
|
32
|
-
_callbacks (list): List of callback functions to be invoked during validation.
|
34
|
+
pbar (tqdm.tqdm, optional): Progress bar for displaying progress.
|
35
|
+
args (SimpleNamespace, optional): Configuration for the validator.
|
36
|
+
_callbacks (list, optional): List of callback functions to be invoked during validation.
|
33
37
|
|
34
38
|
Notes:
|
35
39
|
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
ultralytics/models/nas/model.py
CHANGED
@@ -9,6 +9,7 @@ Examples:
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
from pathlib import Path
|
12
|
+
from typing import Any, Dict
|
12
13
|
|
13
14
|
import torch
|
14
15
|
|
@@ -23,7 +24,7 @@ from .val import NASValidator
|
|
23
24
|
|
24
25
|
class NAS(Model):
|
25
26
|
"""
|
26
|
-
YOLO
|
27
|
+
YOLO-NAS model for object detection.
|
27
28
|
|
28
29
|
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
|
29
30
|
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
|
@@ -34,6 +35,9 @@ class NAS(Model):
|
|
34
35
|
predictor (NASPredictor): The predictor instance for making predictions.
|
35
36
|
validator (NASValidator): The validator instance for model validation.
|
36
37
|
|
38
|
+
Methods:
|
39
|
+
info: Log model information and return model details.
|
40
|
+
|
37
41
|
Examples:
|
38
42
|
>>> from ultralytics import NAS
|
39
43
|
>>> model = NAS("yolo_nas_s")
|
@@ -72,7 +76,7 @@ class NAS(Model):
|
|
72
76
|
self.model._original_forward = self.model.forward
|
73
77
|
self.model.forward = new_forward
|
74
78
|
|
75
|
-
# Standardize model
|
79
|
+
# Standardize model attributes for compatibility
|
76
80
|
self.model.fuse = lambda verbose=True: self.model
|
77
81
|
self.model.stride = torch.tensor([32])
|
78
82
|
self.model.names = dict(enumerate(self.model._class_names))
|
@@ -83,7 +87,7 @@ class NAS(Model):
|
|
83
87
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
|
84
88
|
self.model.eval()
|
85
89
|
|
86
|
-
def info(self, detailed: bool = False, verbose: bool = True):
|
90
|
+
def info(self, detailed: bool = False, verbose: bool = True) -> Dict[str, Any]:
|
87
91
|
"""
|
88
92
|
Log model information.
|
89
93
|
|
@@ -92,11 +96,11 @@ class NAS(Model):
|
|
92
96
|
verbose (bool): Controls verbosity.
|
93
97
|
|
94
98
|
Returns:
|
95
|
-
(
|
99
|
+
(Dict[str, Any]): Model information dictionary.
|
96
100
|
"""
|
97
101
|
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
98
102
|
|
99
103
|
@property
|
100
|
-
def task_map(self):
|
104
|
+
def task_map(self) -> Dict[str, Dict[str, Any]]:
|
101
105
|
"""Return a dictionary mapping tasks to respective predictor and validator classes."""
|
102
106
|
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|
@@ -10,13 +10,13 @@ class NASPredictor(DetectionPredictor):
|
|
10
10
|
"""
|
11
11
|
Ultralytics YOLO NAS Predictor for object detection.
|
12
12
|
|
13
|
-
This class extends the
|
13
|
+
This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the
|
14
14
|
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
|
15
15
|
scaling the bounding boxes to fit the original image dimensions.
|
16
16
|
|
17
17
|
Attributes:
|
18
|
-
args (Namespace): Namespace containing various configurations for post-processing including confidence
|
19
|
-
IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
|
18
|
+
args (Namespace): Namespace containing various configurations for post-processing including confidence
|
19
|
+
threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
|
20
20
|
model (torch.nn.Module): The YOLO NAS model used for inference.
|
21
21
|
batch (list): Batch of inputs for processing.
|
22
22
|
|
@@ -29,7 +29,7 @@ class NASPredictor(DetectionPredictor):
|
|
29
29
|
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
|
30
30
|
|
31
31
|
Notes:
|
32
|
-
Typically, this class is not instantiated directly. It is used internally within the
|
32
|
+
Typically, this class is not instantiated directly. It is used internally within the NAS class.
|
33
33
|
"""
|
34
34
|
|
35
35
|
def postprocess(self, preds_in, img, orig_imgs):
|
@@ -53,6 +53,6 @@ class NASPredictor(DetectionPredictor):
|
|
53
53
|
>>> predictor = NAS("yolo_nas_s").predictor
|
54
54
|
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
|
55
55
|
"""
|
56
|
-
boxes = ops.xyxy2xywh(preds_in[0][0])
|
57
|
-
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) #
|
56
|
+
boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format
|
57
|
+
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores
|
58
58
|
return super().postprocess(preds, img, orig_imgs)
|
ultralytics/models/nas/val.py
CHANGED
@@ -12,7 +12,7 @@ class NASValidator(DetectionValidator):
|
|
12
12
|
"""
|
13
13
|
Ultralytics YOLO NAS Validator for object detection.
|
14
14
|
|
15
|
-
Extends
|
15
|
+
Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
|
16
16
|
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
|
17
17
|
ultimately producing the final detections.
|
18
18
|
|
@@ -25,11 +25,11 @@ class NASValidator(DetectionValidator):
|
|
25
25
|
>>> from ultralytics import NAS
|
26
26
|
>>> model = NAS("yolo_nas_s")
|
27
27
|
>>> validator = model.validator
|
28
|
-
Assumes that raw_preds are available
|
28
|
+
>>> # Assumes that raw_preds are available
|
29
29
|
>>> final_preds = validator.postprocess(raw_preds)
|
30
30
|
|
31
31
|
Notes:
|
32
|
-
This class is generally not instantiated directly but is used internally within the
|
32
|
+
This class is generally not instantiated directly but is used internally within the NAS class.
|
33
33
|
"""
|
34
34
|
|
35
35
|
def postprocess(self, preds_in):
|
@@ -21,13 +21,17 @@ class RTDETR(Model):
|
|
21
21
|
"""
|
22
22
|
Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
|
23
23
|
|
24
|
-
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware
|
25
|
-
selection, and adaptable inference speed.
|
24
|
+
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware
|
25
|
+
query selection, and adaptable inference speed.
|
26
26
|
|
27
27
|
Attributes:
|
28
28
|
model (str): Path to the pre-trained model.
|
29
29
|
|
30
|
+
Methods:
|
31
|
+
task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
32
|
+
|
30
33
|
Examples:
|
34
|
+
Initialize RT-DETR with a pre-trained model
|
31
35
|
>>> from ultralytics import RTDETR
|
32
36
|
>>> model = RTDETR("rtdetr-l.pt")
|
33
37
|
>>> results = model("image.jpg")
|
@@ -39,16 +43,13 @@ class RTDETR(Model):
|
|
39
43
|
|
40
44
|
Args:
|
41
45
|
model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
|
42
|
-
|
43
|
-
Raises:
|
44
|
-
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
45
46
|
"""
|
46
47
|
super().__init__(model=model, task="detect")
|
47
48
|
|
48
49
|
@property
|
49
50
|
def task_map(self) -> dict:
|
50
51
|
"""
|
51
|
-
|
52
|
+
Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
52
53
|
|
53
54
|
Returns:
|
54
55
|
(dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
@@ -21,6 +21,10 @@ class RTDETRPredictor(BasePredictor):
|
|
21
21
|
model (torch.nn.Module): The loaded RT-DETR model.
|
22
22
|
batch (list): Current batch of processed inputs.
|
23
23
|
|
24
|
+
Methods:
|
25
|
+
postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.
|
26
|
+
pre_transform: Pre-transform input images before feeding them into the model for inference.
|
27
|
+
|
24
28
|
Examples:
|
25
29
|
>>> from ultralytics.utils import ASSETS
|
26
30
|
>>> from ultralytics.models.rtdetr import RTDETRPredictor
|
@@ -37,14 +41,14 @@ class RTDETRPredictor(BasePredictor):
|
|
37
41
|
model predictions to Results objects containing properly scaled bounding boxes.
|
38
42
|
|
39
43
|
Args:
|
40
|
-
preds (
|
44
|
+
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
|
41
45
|
bounding boxes and scores.
|
42
46
|
img (torch.Tensor): Processed input images with shape (N, 3, H, W).
|
43
|
-
orig_imgs (
|
47
|
+
orig_imgs (list | torch.Tensor): Original, unprocessed images.
|
44
48
|
|
45
49
|
Returns:
|
46
|
-
(List[Results]): A list of Results objects containing the post-processed bounding boxes,
|
47
|
-
and class labels.
|
50
|
+
results (List[Results]): A list of Results objects containing the post-processed bounding boxes,
|
51
|
+
confidence scores, and class labels.
|
48
52
|
"""
|
49
53
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
50
54
|
preds = [preds, None]
|
@@ -71,11 +75,14 @@ class RTDETRPredictor(BasePredictor):
|
|
71
75
|
|
72
76
|
def pre_transform(self, im):
|
73
77
|
"""
|
74
|
-
Pre-
|
75
|
-
|
78
|
+
Pre-transform input images before feeding them into the model for inference.
|
79
|
+
|
80
|
+
The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square
|
81
|
+
(640) and scale_filled.
|
76
82
|
|
77
83
|
Args:
|
78
|
-
im (
|
84
|
+
im (List[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
|
85
|
+
[(H, W, 3) x N] for list.
|
79
86
|
|
80
87
|
Returns:
|
81
88
|
(list): List of pre-transformed images ready for model inference.
|