dgenerate-ultralytics-headless 8.3.189__py3-none-any.whl → 8.3.191__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/RECORD +111 -109
- tests/test_cuda.py +6 -5
- tests/test_exports.py +1 -6
- tests/test_python.py +1 -4
- tests/test_solutions.py +1 -1
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -14
- ultralytics/cfg/datasets/VisDrone.yaml +4 -4
- ultralytics/data/annotator.py +6 -6
- ultralytics/data/augment.py +53 -51
- ultralytics/data/base.py +15 -13
- ultralytics/data/build.py +7 -4
- ultralytics/data/converter.py +9 -10
- ultralytics/data/dataset.py +24 -22
- ultralytics/data/loaders.py +13 -11
- ultralytics/data/split.py +4 -3
- ultralytics/data/split_dota.py +14 -12
- ultralytics/data/utils.py +31 -25
- ultralytics/engine/exporter.py +7 -4
- ultralytics/engine/model.py +16 -14
- ultralytics/engine/predictor.py +9 -7
- ultralytics/engine/results.py +59 -57
- ultralytics/engine/trainer.py +7 -0
- ultralytics/engine/tuner.py +4 -3
- ultralytics/engine/validator.py +3 -1
- ultralytics/hub/__init__.py +6 -2
- ultralytics/hub/auth.py +2 -2
- ultralytics/hub/google/__init__.py +9 -8
- ultralytics/hub/session.py +11 -11
- ultralytics/hub/utils.py +8 -9
- ultralytics/models/fastsam/model.py +8 -6
- ultralytics/models/nas/model.py +5 -3
- ultralytics/models/rtdetr/train.py +4 -3
- ultralytics/models/rtdetr/val.py +6 -4
- ultralytics/models/sam/amg.py +13 -10
- ultralytics/models/sam/model.py +3 -2
- ultralytics/models/sam/modules/blocks.py +21 -21
- ultralytics/models/sam/modules/decoders.py +11 -11
- ultralytics/models/sam/modules/encoders.py +25 -25
- ultralytics/models/sam/modules/memory_attention.py +9 -8
- ultralytics/models/sam/modules/sam.py +8 -10
- ultralytics/models/sam/modules/tiny_encoder.py +21 -20
- ultralytics/models/sam/modules/transformer.py +6 -5
- ultralytics/models/sam/modules/utils.py +7 -5
- ultralytics/models/sam/predict.py +32 -31
- ultralytics/models/utils/loss.py +29 -27
- ultralytics/models/utils/ops.py +10 -8
- ultralytics/models/yolo/classify/train.py +7 -5
- ultralytics/models/yolo/classify/val.py +10 -8
- ultralytics/models/yolo/detect/predict.py +3 -3
- ultralytics/models/yolo/detect/train.py +8 -6
- ultralytics/models/yolo/detect/val.py +23 -21
- ultralytics/models/yolo/model.py +14 -14
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +13 -10
- ultralytics/models/yolo/pose/train.py +7 -5
- ultralytics/models/yolo/pose/val.py +11 -9
- ultralytics/models/yolo/segment/train.py +4 -5
- ultralytics/models/yolo/segment/val.py +12 -10
- ultralytics/models/yolo/world/train.py +9 -7
- ultralytics/models/yolo/yoloe/train.py +7 -6
- ultralytics/models/yolo/yoloe/val.py +10 -8
- ultralytics/nn/autobackend.py +40 -52
- ultralytics/nn/modules/__init__.py +3 -3
- ultralytics/nn/modules/block.py +12 -12
- ultralytics/nn/modules/conv.py +4 -3
- ultralytics/nn/modules/head.py +46 -38
- ultralytics/nn/modules/transformer.py +22 -21
- ultralytics/nn/tasks.py +2 -2
- ultralytics/nn/text_model.py +6 -5
- ultralytics/solutions/analytics.py +7 -5
- ultralytics/solutions/config.py +12 -10
- ultralytics/solutions/distance_calculation.py +3 -3
- ultralytics/solutions/heatmap.py +4 -2
- ultralytics/solutions/object_counter.py +5 -3
- ultralytics/solutions/parking_management.py +4 -2
- ultralytics/solutions/region_counter.py +7 -5
- ultralytics/solutions/similarity_search.py +5 -3
- ultralytics/solutions/solutions.py +38 -36
- ultralytics/solutions/streamlit_inference.py +8 -7
- ultralytics/trackers/bot_sort.py +11 -9
- ultralytics/trackers/byte_tracker.py +17 -15
- ultralytics/trackers/utils/gmc.py +4 -3
- ultralytics/utils/__init__.py +27 -77
- ultralytics/utils/autobatch.py +3 -2
- ultralytics/utils/autodevice.py +10 -10
- ultralytics/utils/benchmarks.py +11 -10
- ultralytics/utils/callbacks/comet.py +9 -9
- ultralytics/utils/callbacks/platform.py +2 -1
- ultralytics/utils/checks.py +20 -29
- ultralytics/utils/downloads.py +2 -2
- ultralytics/utils/export.py +12 -11
- ultralytics/utils/files.py +8 -7
- ultralytics/utils/git.py +139 -0
- ultralytics/utils/instance.py +8 -7
- ultralytics/utils/logger.py +7 -6
- ultralytics/utils/loss.py +15 -13
- ultralytics/utils/metrics.py +62 -62
- ultralytics/utils/nms.py +346 -0
- ultralytics/utils/ops.py +83 -251
- ultralytics/utils/patches.py +6 -4
- ultralytics/utils/plotting.py +18 -16
- ultralytics/utils/tal.py +1 -1
- ultralytics/utils/torch_utils.py +4 -2
- ultralytics/utils/tqdm.py +47 -33
- ultralytics/utils/triton.py +3 -2
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/top_level.txt +0 -0
ultralytics/hub/auth.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
import requests
|
4
|
-
|
5
3
|
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
|
6
4
|
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis
|
7
5
|
|
@@ -110,6 +108,8 @@ class Auth:
|
|
110
108
|
Returns:
|
111
109
|
(bool): True if authentication is successful, False otherwise.
|
112
110
|
"""
|
111
|
+
import requests # scoped as slow import
|
112
|
+
|
113
113
|
try:
|
114
114
|
if header := self.get_auth_header():
|
115
115
|
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
|
@@ -1,11 +1,10 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import concurrent.futures
|
4
6
|
import statistics
|
5
7
|
import time
|
6
|
-
from typing import List, Optional, Tuple
|
7
|
-
|
8
|
-
import requests
|
9
8
|
|
10
9
|
|
11
10
|
class GCPRegions:
|
@@ -73,16 +72,16 @@ class GCPRegions:
|
|
73
72
|
"us-west4": (2, "Las Vegas", "United States"),
|
74
73
|
}
|
75
74
|
|
76
|
-
def tier1(self) ->
|
75
|
+
def tier1(self) -> list[str]:
|
77
76
|
"""Return a list of GCP regions classified as tier 1 based on predefined criteria."""
|
78
77
|
return [region for region, info in self.regions.items() if info[0] == 1]
|
79
78
|
|
80
|
-
def tier2(self) ->
|
79
|
+
def tier2(self) -> list[str]:
|
81
80
|
"""Return a list of GCP regions classified as tier 2 based on predefined criteria."""
|
82
81
|
return [region for region, info in self.regions.items() if info[0] == 2]
|
83
82
|
|
84
83
|
@staticmethod
|
85
|
-
def _ping_region(region: str, attempts: int = 1) ->
|
84
|
+
def _ping_region(region: str, attempts: int = 1) -> tuple[str, float, float, float, float]:
|
86
85
|
"""
|
87
86
|
Ping a specified GCP region and measure network latency statistics.
|
88
87
|
|
@@ -101,6 +100,8 @@ class GCPRegions:
|
|
101
100
|
>>> region, mean, std, min_lat, max_lat = GCPRegions._ping_region("us-central1", attempts=3)
|
102
101
|
>>> print(f"Region {region} has mean latency: {mean:.2f}ms")
|
103
102
|
"""
|
103
|
+
import requests # scoped as slow import
|
104
|
+
|
104
105
|
url = f"https://{region}-docker.pkg.dev"
|
105
106
|
latencies = []
|
106
107
|
for _ in range(attempts):
|
@@ -122,9 +123,9 @@ class GCPRegions:
|
|
122
123
|
self,
|
123
124
|
top: int = 1,
|
124
125
|
verbose: bool = False,
|
125
|
-
tier:
|
126
|
+
tier: int | None = None,
|
126
127
|
attempts: int = 1,
|
127
|
-
) ->
|
128
|
+
) -> list[tuple[str, float, float, float, float]]:
|
128
129
|
"""
|
129
130
|
Determine the GCP regions with the lowest latency based on ping tests.
|
130
131
|
|
ultralytics/hub/session.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import shutil
|
4
6
|
import threading
|
5
7
|
import time
|
6
8
|
from http import HTTPStatus
|
7
9
|
from pathlib import Path
|
8
|
-
from typing import Any
|
10
|
+
from typing import Any
|
9
11
|
from urllib.parse import parse_qs, urlparse
|
10
12
|
|
11
|
-
import requests
|
12
|
-
|
13
13
|
from ultralytics import __version__
|
14
14
|
from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX
|
15
15
|
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, TQDM, checks, emojis
|
@@ -92,7 +92,7 @@ class HUBTrainingSession:
|
|
92
92
|
)
|
93
93
|
|
94
94
|
@classmethod
|
95
|
-
def create_session(cls, identifier: str, args:
|
95
|
+
def create_session(cls, identifier: str, args: dict[str, Any] | None = None):
|
96
96
|
"""
|
97
97
|
Create an authenticated HUBTrainingSession or return None.
|
98
98
|
|
@@ -139,7 +139,7 @@ class HUBTrainingSession:
|
|
139
139
|
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
140
140
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
141
141
|
|
142
|
-
def create_model(self, model_args:
|
142
|
+
def create_model(self, model_args: dict[str, Any]):
|
143
143
|
"""
|
144
144
|
Initialize a HUB training session with the specified model arguments.
|
145
145
|
|
@@ -206,7 +206,7 @@ class HUBTrainingSession:
|
|
206
206
|
HUBModelError: If the identifier format is not recognized.
|
207
207
|
"""
|
208
208
|
api_key, model_id, filename = None, None, None
|
209
|
-
if
|
209
|
+
if identifier.endswith((".pt", ".yaml")):
|
210
210
|
filename = identifier
|
211
211
|
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
212
212
|
parsed_url = urlparse(identifier)
|
@@ -256,8 +256,8 @@ class HUBTrainingSession:
|
|
256
256
|
timeout: int = 30,
|
257
257
|
thread: bool = True,
|
258
258
|
verbose: bool = True,
|
259
|
-
progress_total:
|
260
|
-
stream_response:
|
259
|
+
progress_total: int | None = None,
|
260
|
+
stream_response: bool | None = None,
|
261
261
|
*args,
|
262
262
|
**kwargs,
|
263
263
|
):
|
@@ -341,7 +341,7 @@ class HUBTrainingSession:
|
|
341
341
|
}
|
342
342
|
return status_code in retry_codes
|
343
343
|
|
344
|
-
def _get_failure_message(self, response
|
344
|
+
def _get_failure_message(self, response, retry: int, timeout: int) -> str:
|
345
345
|
"""
|
346
346
|
Generate a retry message based on the response status code.
|
347
347
|
|
@@ -419,14 +419,14 @@ class HUBTrainingSession:
|
|
419
419
|
)
|
420
420
|
|
421
421
|
@staticmethod
|
422
|
-
def _show_upload_progress(content_length: int, response
|
422
|
+
def _show_upload_progress(content_length: int, response) -> None:
|
423
423
|
"""Display a progress bar to track the upload progress of a file download."""
|
424
424
|
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
425
425
|
for data in response.iter_content(chunk_size=1024):
|
426
426
|
pbar.update(len(data))
|
427
427
|
|
428
428
|
@staticmethod
|
429
|
-
def _iterate_content(response
|
429
|
+
def _iterate_content(response) -> None:
|
430
430
|
"""Process the streamed HTTP response data."""
|
431
431
|
for _ in response.iter_content(chunk_size=1024):
|
432
432
|
pass # Do nothing with data chunks
|
ultralytics/hub/utils.py
CHANGED
@@ -5,16 +5,14 @@ import random
|
|
5
5
|
import threading
|
6
6
|
import time
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import Any
|
9
|
-
|
10
|
-
import requests
|
8
|
+
from typing import Any
|
11
9
|
|
12
10
|
from ultralytics import __version__
|
13
11
|
from ultralytics.utils import (
|
14
12
|
ARGV,
|
15
13
|
ENVIRONMENT,
|
14
|
+
GIT,
|
16
15
|
IS_COLAB,
|
17
|
-
IS_GIT_DIR,
|
18
16
|
IS_PIP_PACKAGE,
|
19
17
|
LOGGER,
|
20
18
|
ONLINE,
|
@@ -25,7 +23,6 @@ from ultralytics.utils import (
|
|
25
23
|
TQDM,
|
26
24
|
TryExcept,
|
27
25
|
colorstr,
|
28
|
-
get_git_origin_url,
|
29
26
|
)
|
30
27
|
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
31
28
|
from ultralytics.utils.torch_utils import get_cpu_info
|
@@ -78,7 +75,7 @@ def request_with_credentials(url: str) -> Any:
|
|
78
75
|
return output.eval_js("_hub_tmp")
|
79
76
|
|
80
77
|
|
81
|
-
def requests_with_progress(method: str, url: str, **kwargs)
|
78
|
+
def requests_with_progress(method: str, url: str, **kwargs):
|
82
79
|
"""
|
83
80
|
Make an HTTP request using the specified method and URL, with an optional progress bar.
|
84
81
|
|
@@ -95,6 +92,8 @@ def requests_with_progress(method: str, url: str, **kwargs) -> requests.Response
|
|
95
92
|
content length.
|
96
93
|
- If 'progress' is a number then progress bar will display assuming content length = progress.
|
97
94
|
"""
|
95
|
+
import requests # scoped as slow import
|
96
|
+
|
98
97
|
progress = kwargs.pop("progress", False)
|
99
98
|
if not progress:
|
100
99
|
return requests.request(method, url, **kwargs)
|
@@ -120,7 +119,7 @@ def smart_request(
|
|
120
119
|
verbose: bool = True,
|
121
120
|
progress: bool = False,
|
122
121
|
**kwargs,
|
123
|
-
)
|
122
|
+
):
|
124
123
|
"""
|
125
124
|
Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
126
125
|
|
@@ -205,7 +204,7 @@ class Events:
|
|
205
204
|
self.t = 0.0 # rate limit timer (seconds)
|
206
205
|
self.metadata = {
|
207
206
|
"cli": Path(ARGV[0]).name == "yolo",
|
208
|
-
"install": "git" if
|
207
|
+
"install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
|
209
208
|
"python": PYTHON_VERSION.rsplit(".", 1)[0], # i.e. 3.13
|
210
209
|
"CPU": get_cpu_info(),
|
211
210
|
# "GPU": get_gpu_info(index=0) if cuda else None,
|
@@ -219,7 +218,7 @@ class Events:
|
|
219
218
|
and RANK in {-1, 0}
|
220
219
|
and not TESTS_RUNNING
|
221
220
|
and ONLINE
|
222
|
-
and (IS_PIP_PACKAGE or
|
221
|
+
and (IS_PIP_PACKAGE or GIT.origin == "https://github.com/ultralytics/ultralytics.git")
|
223
222
|
)
|
224
223
|
|
225
224
|
def __call__(self, cfg, device=None):
|
@@ -1,7 +1,9 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
from pathlib import Path
|
4
|
-
from typing import Any
|
6
|
+
from typing import Any
|
5
7
|
|
6
8
|
from ultralytics.engine.model import Model
|
7
9
|
|
@@ -45,10 +47,10 @@ class FastSAM(Model):
|
|
45
47
|
self,
|
46
48
|
source,
|
47
49
|
stream: bool = False,
|
48
|
-
bboxes:
|
49
|
-
points:
|
50
|
-
labels:
|
51
|
-
texts:
|
50
|
+
bboxes: list | None = None,
|
51
|
+
points: list | None = None,
|
52
|
+
labels: list | None = None,
|
53
|
+
texts: list | None = None,
|
52
54
|
**kwargs: Any,
|
53
55
|
):
|
54
56
|
"""
|
@@ -74,6 +76,6 @@ class FastSAM(Model):
|
|
74
76
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
75
77
|
|
76
78
|
@property
|
77
|
-
def task_map(self) ->
|
79
|
+
def task_map(self) -> dict[str, dict[str, Any]]:
|
78
80
|
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
79
81
|
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
|
ultralytics/models/nas/model.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
from pathlib import Path
|
4
|
-
from typing import Any
|
6
|
+
from typing import Any
|
5
7
|
|
6
8
|
import torch
|
7
9
|
|
@@ -80,7 +82,7 @@ class NAS(Model):
|
|
80
82
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
|
81
83
|
self.model.eval()
|
82
84
|
|
83
|
-
def info(self, detailed: bool = False, verbose: bool = True) ->
|
85
|
+
def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
|
84
86
|
"""
|
85
87
|
Log model information.
|
86
88
|
|
@@ -94,6 +96,6 @@ class NAS(Model):
|
|
94
96
|
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
95
97
|
|
96
98
|
@property
|
97
|
-
def task_map(self) ->
|
99
|
+
def task_map(self) -> dict[str, dict[str, Any]]:
|
98
100
|
"""Return a dictionary mapping tasks to respective predictor and validator classes."""
|
99
101
|
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|
@@ -1,7 +1,8 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
from copy import copy
|
4
|
-
from typing import Optional
|
5
6
|
|
6
7
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
7
8
|
from ultralytics.nn.tasks import RTDETRDetectionModel
|
@@ -41,7 +42,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
41
42
|
>>> trainer.train()
|
42
43
|
"""
|
43
44
|
|
44
|
-
def get_model(self, cfg:
|
45
|
+
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
|
45
46
|
"""
|
46
47
|
Initialize and return an RT-DETR model for object detection tasks.
|
47
48
|
|
@@ -58,7 +59,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
58
59
|
model.load(weights)
|
59
60
|
return model
|
60
61
|
|
61
|
-
def build_dataset(self, img_path: str, mode: str = "val", batch:
|
62
|
+
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
|
62
63
|
"""
|
63
64
|
Build and return an RT-DETR dataset for training or validation.
|
64
65
|
|
ultralytics/models/rtdetr/val.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
from pathlib import Path
|
4
|
-
from typing import Any
|
6
|
+
from typing import Any
|
5
7
|
|
6
8
|
import torch
|
7
9
|
|
@@ -155,8 +157,8 @@ class RTDETRValidator(DetectionValidator):
|
|
155
157
|
)
|
156
158
|
|
157
159
|
def postprocess(
|
158
|
-
self, preds:
|
159
|
-
) ->
|
160
|
+
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
|
161
|
+
) -> list[dict[str, torch.Tensor]]:
|
160
162
|
"""
|
161
163
|
Apply Non-maximum suppression to prediction outputs.
|
162
164
|
|
@@ -187,7 +189,7 @@ class RTDETRValidator(DetectionValidator):
|
|
187
189
|
|
188
190
|
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
|
189
191
|
|
190
|
-
def pred_to_json(self, predn:
|
192
|
+
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
191
193
|
"""
|
192
194
|
Serialize YOLO predictions to COCO json format.
|
193
195
|
|
ultralytics/models/sam/amg.py
CHANGED
@@ -1,15 +1,18 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import math
|
6
|
+
from collections.abc import Generator
|
4
7
|
from itertools import product
|
5
|
-
from typing import Any
|
8
|
+
from typing import Any
|
6
9
|
|
7
10
|
import numpy as np
|
8
11
|
import torch
|
9
12
|
|
10
13
|
|
11
14
|
def is_box_near_crop_edge(
|
12
|
-
boxes: torch.Tensor, crop_box:
|
15
|
+
boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
|
13
16
|
) -> torch.Tensor:
|
14
17
|
"""
|
15
18
|
Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
@@ -38,7 +41,7 @@ def is_box_near_crop_edge(
|
|
38
41
|
return torch.any(near_crop_edge, dim=1)
|
39
42
|
|
40
43
|
|
41
|
-
def batch_iterator(batch_size: int, *args) -> Generator[
|
44
|
+
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
|
42
45
|
"""
|
43
46
|
Yield batches of data from input arguments with specified batch size for efficient processing.
|
44
47
|
|
@@ -106,14 +109,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
|
106
109
|
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
107
110
|
|
108
111
|
|
109
|
-
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) ->
|
112
|
+
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> list[np.ndarray]:
|
110
113
|
"""Generate point grids for multiple crop layers with varying scales and densities."""
|
111
114
|
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
112
115
|
|
113
116
|
|
114
117
|
def generate_crop_boxes(
|
115
|
-
im_size:
|
116
|
-
) ->
|
118
|
+
im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
|
119
|
+
) -> tuple[list[list[int]], list[int]]:
|
117
120
|
"""
|
118
121
|
Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
119
122
|
|
@@ -163,7 +166,7 @@ def generate_crop_boxes(
|
|
163
166
|
return crop_boxes, layer_idxs
|
164
167
|
|
165
168
|
|
166
|
-
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box:
|
169
|
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
|
167
170
|
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
|
168
171
|
x0, y0, _, _ = crop_box
|
169
172
|
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
@@ -173,7 +176,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|
173
176
|
return boxes + offset
|
174
177
|
|
175
178
|
|
176
|
-
def uncrop_points(points: torch.Tensor, crop_box:
|
179
|
+
def uncrop_points(points: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
|
177
180
|
"""Uncrop points by adding the crop box offset to their coordinates."""
|
178
181
|
x0, y0, _, _ = crop_box
|
179
182
|
offset = torch.tensor([[x0, y0]], device=points.device)
|
@@ -183,7 +186,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|
183
186
|
return points + offset
|
184
187
|
|
185
188
|
|
186
|
-
def uncrop_masks(masks: torch.Tensor, crop_box:
|
189
|
+
def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
187
190
|
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
|
188
191
|
x0, y0, x1, y1 = crop_box
|
189
192
|
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
@@ -194,7 +197,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
|
|
194
197
|
return torch.nn.functional.pad(masks, pad, value=0)
|
195
198
|
|
196
199
|
|
197
|
-
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) ->
|
200
|
+
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
|
198
201
|
"""
|
199
202
|
Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
200
203
|
|
ultralytics/models/sam/model.py
CHANGED
@@ -14,8 +14,9 @@ Key Features:
|
|
14
14
|
- Trained on SA-1B dataset
|
15
15
|
"""
|
16
16
|
|
17
|
+
from __future__ import annotations
|
18
|
+
|
17
19
|
from pathlib import Path
|
18
|
-
from typing import Dict, Type
|
19
20
|
|
20
21
|
from ultralytics.engine.model import Model
|
21
22
|
from ultralytics.utils.torch_utils import model_info
|
@@ -154,7 +155,7 @@ class SAM(Model):
|
|
154
155
|
return model_info(self.model, detailed=detailed, verbose=verbose)
|
155
156
|
|
156
157
|
@property
|
157
|
-
def task_map(self) ->
|
158
|
+
def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
|
158
159
|
"""
|
159
160
|
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
160
161
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
from __future__ import annotations
|
2
3
|
|
3
4
|
import copy
|
4
5
|
import math
|
5
6
|
from functools import partial
|
6
|
-
from typing import Optional, Tuple, Type, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
@@ -81,7 +81,7 @@ class MaskDownSampler(nn.Module):
|
|
81
81
|
stride: int = 4,
|
82
82
|
padding: int = 0,
|
83
83
|
total_stride: int = 16,
|
84
|
-
activation:
|
84
|
+
activation: type[nn.Module] = nn.GELU,
|
85
85
|
):
|
86
86
|
"""Initialize a mask downsampler module for progressive downsampling and channel expansion."""
|
87
87
|
super().__init__()
|
@@ -227,7 +227,7 @@ class Fuser(nn.Module):
|
|
227
227
|
torch.Size([1, 256, 32, 32])
|
228
228
|
"""
|
229
229
|
|
230
|
-
def __init__(self, layer: nn.Module, num_layers: int, dim:
|
230
|
+
def __init__(self, layer: nn.Module, num_layers: int, dim: int | None = None, input_projection: bool = False):
|
231
231
|
"""
|
232
232
|
Initialize the Fuser module for feature fusion through multiple layers.
|
233
233
|
|
@@ -295,7 +295,7 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|
295
295
|
embedding_dim: int,
|
296
296
|
num_heads: int,
|
297
297
|
mlp_dim: int = 2048,
|
298
|
-
activation:
|
298
|
+
activation: type[nn.Module] = nn.ReLU,
|
299
299
|
attention_downsample_rate: int = 2,
|
300
300
|
skip_first_layer_pe: bool = False,
|
301
301
|
) -> None:
|
@@ -359,7 +359,7 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
|
|
359
359
|
embedding_dim: int,
|
360
360
|
num_heads: int,
|
361
361
|
mlp_dim: int,
|
362
|
-
activation:
|
362
|
+
activation: type[nn.Module] = nn.ReLU,
|
363
363
|
attention_downsample_rate: int = 2,
|
364
364
|
) -> None:
|
365
365
|
"""
|
@@ -432,7 +432,7 @@ class RoPEAttention(Attention):
|
|
432
432
|
*args,
|
433
433
|
rope_theta: float = 10000.0,
|
434
434
|
rope_k_repeat: bool = False,
|
435
|
-
feat_sizes:
|
435
|
+
feat_sizes: tuple[int, int] = (32, 32), # [w, h] for stride 16 feats at 512 resolution
|
436
436
|
**kwargs,
|
437
437
|
):
|
438
438
|
"""Initialize RoPEAttention with rotary position encoding for enhanced positional awareness."""
|
@@ -618,9 +618,9 @@ class MultiScaleBlock(nn.Module):
|
|
618
618
|
num_heads: int,
|
619
619
|
mlp_ratio: float = 4.0,
|
620
620
|
drop_path: float = 0.0,
|
621
|
-
norm_layer:
|
622
|
-
q_stride:
|
623
|
-
act_layer:
|
621
|
+
norm_layer: nn.Module | str = "LayerNorm",
|
622
|
+
q_stride: tuple[int, int] = None,
|
623
|
+
act_layer: type[nn.Module] = nn.GELU,
|
624
624
|
window_size: int = 0,
|
625
625
|
):
|
626
626
|
"""Initialize a multiscale attention block with window partitioning and optional query pooling."""
|
@@ -728,7 +728,7 @@ class PositionEmbeddingSine(nn.Module):
|
|
728
728
|
num_pos_feats: int,
|
729
729
|
temperature: int = 10000,
|
730
730
|
normalize: bool = True,
|
731
|
-
scale:
|
731
|
+
scale: float | None = None,
|
732
732
|
):
|
733
733
|
"""Initialize sinusoidal position embeddings for 2D image inputs."""
|
734
734
|
super().__init__()
|
@@ -744,7 +744,7 @@ class PositionEmbeddingSine(nn.Module):
|
|
744
744
|
|
745
745
|
self.cache = {}
|
746
746
|
|
747
|
-
def _encode_xy(self, x: torch.Tensor, y: torch.Tensor) ->
|
747
|
+
def _encode_xy(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
748
748
|
"""Encode 2D positions using sine/cosine functions for transformer positional embeddings."""
|
749
749
|
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
750
750
|
x_embed = x * self.scale
|
@@ -833,7 +833,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
833
833
|
torch.Size([128, 32, 32])
|
834
834
|
"""
|
835
835
|
|
836
|
-
def __init__(self, num_pos_feats: int = 64, scale:
|
836
|
+
def __init__(self, num_pos_feats: int = 64, scale: float | None = None) -> None:
|
837
837
|
"""Initialize random spatial frequency position embedding for transformers."""
|
838
838
|
super().__init__()
|
839
839
|
if scale is None or scale <= 0.0:
|
@@ -853,7 +853,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
853
853
|
# Outputs d_1 x ... x d_n x C shape
|
854
854
|
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
855
855
|
|
856
|
-
def forward(self, size:
|
856
|
+
def forward(self, size: tuple[int, int]) -> torch.Tensor:
|
857
857
|
"""Generate positional encoding for a grid using random spatial frequencies."""
|
858
858
|
h, w = size
|
859
859
|
grid = torch.ones(
|
@@ -869,7 +869,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
869
869
|
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
870
870
|
return pe.permute(2, 0, 1) # C x H x W
|
871
871
|
|
872
|
-
def forward_with_coords(self, coords_input: torch.Tensor, image_size:
|
872
|
+
def forward_with_coords(self, coords_input: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
873
873
|
"""Positionally encode input coordinates, normalizing them to [0,1] based on the given image size."""
|
874
874
|
coords = coords_input.clone()
|
875
875
|
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
@@ -910,12 +910,12 @@ class Block(nn.Module):
|
|
910
910
|
num_heads: int,
|
911
911
|
mlp_ratio: float = 4.0,
|
912
912
|
qkv_bias: bool = True,
|
913
|
-
norm_layer:
|
914
|
-
act_layer:
|
913
|
+
norm_layer: type[nn.Module] = nn.LayerNorm,
|
914
|
+
act_layer: type[nn.Module] = nn.GELU,
|
915
915
|
use_rel_pos: bool = False,
|
916
916
|
rel_pos_zero_init: bool = True,
|
917
917
|
window_size: int = 0,
|
918
|
-
input_size:
|
918
|
+
input_size: tuple[int, int] | None = None,
|
919
919
|
) -> None:
|
920
920
|
"""
|
921
921
|
Initialize a transformer block with optional window attention and relative positional embeddings.
|
@@ -1012,7 +1012,7 @@ class REAttention(nn.Module):
|
|
1012
1012
|
qkv_bias: bool = True,
|
1013
1013
|
use_rel_pos: bool = False,
|
1014
1014
|
rel_pos_zero_init: bool = True,
|
1015
|
-
input_size:
|
1015
|
+
input_size: tuple[int, int] | None = None,
|
1016
1016
|
) -> None:
|
1017
1017
|
"""
|
1018
1018
|
Initialize a Relative Position Attention module for transformer-based architectures.
|
@@ -1093,9 +1093,9 @@ class PatchEmbed(nn.Module):
|
|
1093
1093
|
|
1094
1094
|
def __init__(
|
1095
1095
|
self,
|
1096
|
-
kernel_size:
|
1097
|
-
stride:
|
1098
|
-
padding:
|
1096
|
+
kernel_size: tuple[int, int] = (16, 16),
|
1097
|
+
stride: tuple[int, int] = (16, 16),
|
1098
|
+
padding: tuple[int, int] = (0, 0),
|
1099
1099
|
in_chans: int = 3,
|
1100
1100
|
embed_dim: int = 768,
|
1101
1101
|
) -> None:
|