ultralytics-opencv-headless 8.3.251__py3-none-any.whl → 8.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +2 -2
- tests/conftest.py +1 -1
- tests/test_cuda.py +8 -2
- tests/test_engine.py +8 -8
- tests/test_exports.py +13 -4
- tests/test_integrations.py +9 -9
- tests/test_python.py +14 -14
- tests/test_solutions.py +3 -3
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +6 -6
- ultralytics/cfg/default.yaml +3 -1
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/data/augment.py +7 -0
- ultralytics/data/dataset.py +1 -1
- ultralytics/engine/exporter.py +11 -4
- ultralytics/engine/model.py +1 -1
- ultralytics/engine/trainer.py +40 -15
- ultralytics/engine/tuner.py +15 -7
- ultralytics/models/fastsam/predict.py +1 -1
- ultralytics/models/yolo/detect/train.py +3 -2
- ultralytics/models/yolo/detect/val.py +6 -0
- ultralytics/models/yolo/model.py +1 -1
- ultralytics/models/yolo/obb/predict.py +1 -1
- ultralytics/models/yolo/obb/train.py +1 -1
- ultralytics/models/yolo/pose/train.py +1 -1
- ultralytics/models/yolo/segment/predict.py +1 -1
- ultralytics/models/yolo/segment/train.py +1 -1
- ultralytics/models/yolo/segment/val.py +3 -1
- ultralytics/models/yolo/yoloe/train.py +6 -1
- ultralytics/models/yolo/yoloe/train_seg.py +6 -1
- ultralytics/nn/autobackend.py +11 -5
- ultralytics/nn/modules/__init__.py +8 -0
- ultralytics/nn/modules/block.py +128 -8
- ultralytics/nn/modules/head.py +789 -204
- ultralytics/nn/tasks.py +74 -29
- ultralytics/nn/text_model.py +5 -2
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/utils/callbacks/platform.py +30 -11
- ultralytics/utils/downloads.py +3 -1
- ultralytics/utils/export/engine.py +19 -10
- ultralytics/utils/export/imx.py +23 -12
- ultralytics/utils/export/tensorflow.py +21 -21
- ultralytics/utils/loss.py +587 -203
- ultralytics/utils/metrics.py +1 -0
- ultralytics/utils/ops.py +11 -2
- ultralytics/utils/tal.py +100 -20
- ultralytics/utils/torch_utils.py +1 -1
- ultralytics/utils/tqdm.py +4 -1
- {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/METADATA +31 -39
- {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/RECORD +63 -52
- {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/WHEEL +0 -0
- {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/entry_points.txt +0 -0
- {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/licenses/LICENSE +0 -0
- {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import platform
|
|
5
|
+
import re
|
|
5
6
|
import socket
|
|
6
7
|
import sys
|
|
7
8
|
from concurrent.futures import ThreadPoolExecutor
|
|
@@ -12,6 +13,14 @@ from ultralytics.utils import ENVIRONMENT, GIT, LOGGER, PYTHON_VERSION, RANK, SE
|
|
|
12
13
|
|
|
13
14
|
PREFIX = colorstr("Platform: ")
|
|
14
15
|
|
|
16
|
+
|
|
17
|
+
def slugify(text):
|
|
18
|
+
"""Convert text to URL-safe slug (e.g., 'My Project 1' -> 'my-project-1')."""
|
|
19
|
+
if not text:
|
|
20
|
+
return text
|
|
21
|
+
return re.sub(r"-+", "-", re.sub(r"[^a-z0-9\s-]", "", str(text).lower()).replace(" ", "-")).strip("-")[:128]
|
|
22
|
+
|
|
23
|
+
|
|
15
24
|
try:
|
|
16
25
|
assert not TESTS_RUNNING # do not log pytest
|
|
17
26
|
assert SETTINGS.get("platform", False) is True or os.getenv("ULTRALYTICS_API_KEY") or SETTINGS.get("api_key")
|
|
@@ -57,9 +66,11 @@ def resolve_platform_uri(uri, hard=True):
|
|
|
57
66
|
|
|
58
67
|
api_key = os.getenv("ULTRALYTICS_API_KEY") or SETTINGS.get("api_key")
|
|
59
68
|
if not api_key:
|
|
60
|
-
raise ValueError(
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"ULTRALYTICS_API_KEY required for '{uri}'. Get key at https://platform.ultralytics.com/settings"
|
|
71
|
+
)
|
|
61
72
|
|
|
62
|
-
base = "https://
|
|
73
|
+
base = "https://platform.ultralytics.com/api/webhooks"
|
|
63
74
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
64
75
|
|
|
65
76
|
# ul://username/datasets/slug
|
|
@@ -141,7 +152,7 @@ def _send(event, data, project, name, model_id=None):
|
|
|
141
152
|
if model_id:
|
|
142
153
|
payload["modelId"] = model_id
|
|
143
154
|
r = requests.post(
|
|
144
|
-
"https://
|
|
155
|
+
"https://platform.ultralytics.com/api/webhooks/training/metrics",
|
|
145
156
|
json=payload,
|
|
146
157
|
headers={"Authorization": f"Bearer {_api_key}"},
|
|
147
158
|
timeout=10,
|
|
@@ -167,7 +178,7 @@ def _upload_model(model_path, project, name):
|
|
|
167
178
|
|
|
168
179
|
# Get signed upload URL
|
|
169
180
|
response = requests.post(
|
|
170
|
-
"https://
|
|
181
|
+
"https://platform.ultralytics.com/api/webhooks/models/upload",
|
|
171
182
|
json={"project": project, "name": name, "filename": model_path.name},
|
|
172
183
|
headers={"Authorization": f"Bearer {_api_key}"},
|
|
173
184
|
timeout=10,
|
|
@@ -184,7 +195,7 @@ def _upload_model(model_path, project, name):
|
|
|
184
195
|
timeout=600, # 10 min timeout for large models
|
|
185
196
|
).raise_for_status()
|
|
186
197
|
|
|
187
|
-
# url = f"https://
|
|
198
|
+
# url = f"https://platform.ultralytics.com/{project}/{name}"
|
|
188
199
|
# LOGGER.info(f"{PREFIX}Model uploaded to {url}")
|
|
189
200
|
return data.get("gcsPath")
|
|
190
201
|
|
|
@@ -249,6 +260,14 @@ def _get_environment_info():
|
|
|
249
260
|
return env
|
|
250
261
|
|
|
251
262
|
|
|
263
|
+
def _get_project_name(trainer):
|
|
264
|
+
"""Get slugified project and name from trainer args."""
|
|
265
|
+
raw = str(trainer.args.project)
|
|
266
|
+
parts = raw.split("/", 1)
|
|
267
|
+
project = f"{parts[0]}/{slugify(parts[1])}" if len(parts) == 2 else slugify(raw)
|
|
268
|
+
return project, slugify(str(trainer.args.name or "train"))
|
|
269
|
+
|
|
270
|
+
|
|
252
271
|
def on_pretrain_routine_start(trainer):
|
|
253
272
|
"""Initialize Platform logging at training start."""
|
|
254
273
|
if RANK not in {-1, 0} or not trainer.args.project:
|
|
@@ -258,8 +277,8 @@ def on_pretrain_routine_start(trainer):
|
|
|
258
277
|
trainer._platform_model_id = None
|
|
259
278
|
trainer._platform_last_upload = time()
|
|
260
279
|
|
|
261
|
-
project, name =
|
|
262
|
-
url = f"https://
|
|
280
|
+
project, name = _get_project_name(trainer)
|
|
281
|
+
url = f"https://platform.ultralytics.com/{project}/{name}"
|
|
263
282
|
LOGGER.info(f"{PREFIX}Streaming to {url}")
|
|
264
283
|
|
|
265
284
|
# Create callback to send console output to Platform
|
|
@@ -305,7 +324,7 @@ def on_fit_epoch_end(trainer):
|
|
|
305
324
|
if RANK not in {-1, 0} or not trainer.args.project:
|
|
306
325
|
return
|
|
307
326
|
|
|
308
|
-
project, name =
|
|
327
|
+
project, name = _get_project_name(trainer)
|
|
309
328
|
metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics}
|
|
310
329
|
|
|
311
330
|
if trainer.optimizer and trainer.optimizer.param_groups:
|
|
@@ -365,7 +384,7 @@ def on_model_save(trainer):
|
|
|
365
384
|
if not model_path:
|
|
366
385
|
return
|
|
367
386
|
|
|
368
|
-
project, name =
|
|
387
|
+
project, name = _get_project_name(trainer)
|
|
369
388
|
_upload_model_async(model_path, project, name)
|
|
370
389
|
trainer._platform_last_upload = time()
|
|
371
390
|
|
|
@@ -375,7 +394,7 @@ def on_train_end(trainer):
|
|
|
375
394
|
if RANK not in {-1, 0} or not trainer.args.project:
|
|
376
395
|
return
|
|
377
396
|
|
|
378
|
-
project, name =
|
|
397
|
+
project, name = _get_project_name(trainer)
|
|
379
398
|
|
|
380
399
|
# Stop console capture
|
|
381
400
|
if hasattr(trainer, "_platform_console_logger") and trainer._platform_console_logger:
|
|
@@ -420,7 +439,7 @@ def on_train_end(trainer):
|
|
|
420
439
|
name,
|
|
421
440
|
getattr(trainer, "_platform_model_id", None),
|
|
422
441
|
)
|
|
423
|
-
url = f"https://
|
|
442
|
+
url = f"https://platform.ultralytics.com/{project}/{name}"
|
|
424
443
|
LOGGER.info(f"{PREFIX}View results at {url}")
|
|
425
444
|
|
|
426
445
|
|
ultralytics/utils/downloads.py
CHANGED
|
@@ -18,12 +18,14 @@ GITHUB_ASSETS_NAMES = frozenset(
|
|
|
18
18
|
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")]
|
|
19
19
|
+ [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
|
|
20
20
|
+ [f"yolo12{k}{suffix}.pt" for k in "nsmlx" for suffix in ("",)] # detect models only currently
|
|
21
|
+
+ [f"yolo26{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
|
|
21
22
|
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
|
|
22
23
|
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
|
|
23
24
|
+ [f"yolov8{k}-world.pt" for k in "smlx"]
|
|
24
25
|
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
|
|
25
26
|
+ [f"yoloe-v8{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
|
|
26
27
|
+ [f"yoloe-11{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
|
|
28
|
+
+ [f"yoloe-26{k}{suffix}.pt" for k in "nsmlx" for suffix in ("-seg", "-seg-pf")]
|
|
27
29
|
+ [f"yolov9{k}.pt" for k in "tsmce"]
|
|
28
30
|
+ [f"yolov10{k}.pt" for k in "nsmblx"]
|
|
29
31
|
+ [f"yolo_nas_{k}.pt" for k in "sml"]
|
|
@@ -424,7 +426,7 @@ def get_github_assets(
|
|
|
424
426
|
def attempt_download_asset(
|
|
425
427
|
file: str | Path,
|
|
426
428
|
repo: str = "ultralytics/assets",
|
|
427
|
-
release: str = "v8.
|
|
429
|
+
release: str = "v8.4.0",
|
|
428
430
|
**kwargs,
|
|
429
431
|
) -> str:
|
|
430
432
|
"""Attempt to download a file from GitHub release assets if it is not found locally.
|
|
@@ -143,7 +143,7 @@ def onnx2engine(
|
|
|
143
143
|
for inp in inputs:
|
|
144
144
|
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
|
|
145
145
|
config.add_optimization_profile(profile)
|
|
146
|
-
if int8:
|
|
146
|
+
if int8 and not is_trt10: # deprecated in TensorRT 10, causes internal errors
|
|
147
147
|
config.set_calibration_profile(profile)
|
|
148
148
|
|
|
149
149
|
LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
|
|
@@ -226,12 +226,21 @@ def onnx2engine(
|
|
|
226
226
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
227
227
|
|
|
228
228
|
# Write file
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
if
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
229
|
+
if is_trt10:
|
|
230
|
+
# TensorRT 10+ returns bytes directly, not a context manager
|
|
231
|
+
engine = builder.build_serialized_network(network, config)
|
|
232
|
+
if engine is None:
|
|
233
|
+
raise RuntimeError("TensorRT engine build failed, check logs for errors")
|
|
234
|
+
with open(engine_file, "wb") as t:
|
|
235
|
+
if metadata is not None:
|
|
236
|
+
meta = json.dumps(metadata)
|
|
237
|
+
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
|
|
238
|
+
t.write(meta.encode())
|
|
239
|
+
t.write(engine)
|
|
240
|
+
else:
|
|
241
|
+
with builder.build_engine(network, config) as engine, open(engine_file, "wb") as t:
|
|
242
|
+
if metadata is not None:
|
|
243
|
+
meta = json.dumps(metadata)
|
|
244
|
+
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
|
|
245
|
+
t.write(meta.encode())
|
|
246
|
+
t.write(engine.serialize())
|
ultralytics/utils/export/imx.py
CHANGED
|
@@ -21,27 +21,27 @@ from ultralytics.utils.torch_utils import copy_attr
|
|
|
21
21
|
MCT_CONFIG = {
|
|
22
22
|
"YOLO11": {
|
|
23
23
|
"detect": {
|
|
24
|
-
"layer_names": ["sub", "mul_2", "add_14", "
|
|
24
|
+
"layer_names": ["sub", "mul_2", "add_14", "cat_19"],
|
|
25
25
|
"weights_memory": 2585350.2439,
|
|
26
26
|
"n_layers": 238,
|
|
27
27
|
},
|
|
28
28
|
"pose": {
|
|
29
|
-
"layer_names": ["sub", "mul_2", "add_14", "
|
|
29
|
+
"layer_names": ["sub", "mul_2", "add_14", "cat_21", "cat_22", "mul_4", "add_15"],
|
|
30
30
|
"weights_memory": 2437771.67,
|
|
31
31
|
"n_layers": 257,
|
|
32
32
|
},
|
|
33
33
|
"classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": 112},
|
|
34
|
-
"segment": {"layer_names": ["sub", "mul_2", "add_14", "
|
|
34
|
+
"segment": {"layer_names": ["sub", "mul_2", "add_14", "cat_21"], "weights_memory": 2466604.8, "n_layers": 265},
|
|
35
35
|
},
|
|
36
36
|
"YOLOv8": {
|
|
37
|
-
"detect": {"layer_names": ["sub", "mul", "add_6", "
|
|
37
|
+
"detect": {"layer_names": ["sub", "mul", "add_6", "cat_15"], "weights_memory": 2550540.8, "n_layers": 168},
|
|
38
38
|
"pose": {
|
|
39
|
-
"layer_names": ["add_7", "mul_2", "
|
|
39
|
+
"layer_names": ["add_7", "mul_2", "cat_17", "mul", "sub", "add_6", "cat_18"],
|
|
40
40
|
"weights_memory": 2482451.85,
|
|
41
41
|
"n_layers": 187,
|
|
42
42
|
},
|
|
43
43
|
"classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": 73},
|
|
44
|
-
"segment": {"layer_names": ["sub", "mul", "add_6", "
|
|
44
|
+
"segment": {"layer_names": ["sub", "mul", "add_6", "cat_17"], "weights_memory": 2580060.0, "n_layers": 195},
|
|
45
45
|
},
|
|
46
46
|
}
|
|
47
47
|
|
|
@@ -104,10 +104,13 @@ class FXModel(torch.nn.Module):
|
|
|
104
104
|
return x
|
|
105
105
|
|
|
106
106
|
|
|
107
|
-
def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
|
|
107
|
+
def _inference(self, x: list[torch.Tensor] | dict[str, torch.Tensor]) -> tuple[torch.Tensor]:
|
|
108
108
|
"""Decode boxes and cls scores for imx object detection."""
|
|
109
|
-
|
|
110
|
-
|
|
109
|
+
if isinstance(x, dict):
|
|
110
|
+
box, cls = x["boxes"], x["scores"]
|
|
111
|
+
else:
|
|
112
|
+
x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
|
|
113
|
+
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
|
111
114
|
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
|
112
115
|
return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
|
|
113
116
|
|
|
@@ -115,9 +118,17 @@ def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
|
|
|
115
118
|
def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
116
119
|
"""Forward pass for imx pose estimation, including keypoint decoding."""
|
|
117
120
|
bs = x[0].shape[0] # batch size
|
|
118
|
-
|
|
121
|
+
nk_out = getattr(self, "nk_output", self.nk)
|
|
122
|
+
kpt = torch.cat([self.cv4[i](x[i]).view(bs, nk_out, -1) for i in range(self.nl)], -1)
|
|
123
|
+
|
|
124
|
+
# If using Pose26 with 5 dims, convert to 3 dims for export
|
|
125
|
+
if hasattr(self, "nk_output") and self.nk_output != self.nk:
|
|
126
|
+
spatial = kpt.shape[-1]
|
|
127
|
+
kpt = kpt.view(bs, self.kpt_shape[0], self.kpt_shape[1] + 2, spatial)
|
|
128
|
+
kpt = kpt[:, :, :-2, :] # Remove sigma_x, sigma_y
|
|
129
|
+
kpt = kpt.view(bs, self.nk, spatial)
|
|
119
130
|
x = Detect.forward(self, x)
|
|
120
|
-
pred_kpt = self.kpts_decode(
|
|
131
|
+
pred_kpt = self.kpts_decode(kpt)
|
|
121
132
|
return *x, pred_kpt.permute(0, 2, 1)
|
|
122
133
|
|
|
123
134
|
|
|
@@ -219,7 +230,7 @@ def torch2imx(
|
|
|
219
230
|
Examples:
|
|
220
231
|
>>> from ultralytics import YOLO
|
|
221
232
|
>>> model = YOLO("yolo11n.pt")
|
|
222
|
-
>>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.
|
|
233
|
+
>>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.7, max_det=300)
|
|
223
234
|
|
|
224
235
|
Notes:
|
|
225
236
|
- Requires model_compression_toolkit, onnx, edgemdt_tpc, and edge-mdt-cl packages
|
|
@@ -2,12 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from functools import partial
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import torch
|
|
9
10
|
|
|
10
|
-
from ultralytics.nn.modules import Detect, Pose
|
|
11
|
+
from ultralytics.nn.modules import Detect, Pose, Pose26
|
|
11
12
|
from ultralytics.utils import LOGGER
|
|
12
13
|
from ultralytics.utils.downloads import attempt_download_asset
|
|
13
14
|
from ultralytics.utils.files import spaces_in_path
|
|
@@ -15,43 +16,42 @@ from ultralytics.utils.tal import make_anchors
|
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def tf_wrapper(model: torch.nn.Module) -> torch.nn.Module:
|
|
18
|
-
"""A wrapper
|
|
19
|
+
"""A wrapper for TensorFlow export compatibility (TF-specific handling is now in head modules)."""
|
|
19
20
|
for m in model.modules():
|
|
20
21
|
if not isinstance(m, Detect):
|
|
21
22
|
continue
|
|
22
23
|
import types
|
|
23
24
|
|
|
24
|
-
m.
|
|
25
|
-
if
|
|
26
|
-
m.kpts_decode = types.MethodType(
|
|
25
|
+
m._get_decode_boxes = types.MethodType(_tf_decode_boxes, m)
|
|
26
|
+
if isinstance(m, Pose):
|
|
27
|
+
m.kpts_decode = types.MethodType(partial(_tf_kpts_decode, is_pose26=type(m) is Pose26), m)
|
|
27
28
|
return model
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def
|
|
31
|
-
"""Decode boxes
|
|
32
|
-
shape = x[0].shape # BCHW
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
|
31
|
+
def _tf_decode_boxes(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
32
|
+
"""Decode bounding boxes for TensorFlow export."""
|
|
33
|
+
shape = x["feats"][0].shape # BCHW
|
|
34
|
+
boxes = x["boxes"]
|
|
35
|
+
if self.format != "imx" and (self.dynamic or self.shape != shape):
|
|
36
|
+
self.anchors, self.strides = (a.transpose(0, 1) for a in make_anchors(x["feats"], self.stride, 0.5))
|
|
37
37
|
self.shape = shape
|
|
38
|
-
grid_h, grid_w = shape[2]
|
|
39
|
-
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=
|
|
38
|
+
grid_h, grid_w = shape[2:4]
|
|
39
|
+
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=boxes.device).reshape(1, 4, 1)
|
|
40
40
|
norm = self.strides / (self.stride[0] * grid_size)
|
|
41
|
-
dbox = self.decode_bboxes(self.dfl(
|
|
42
|
-
return
|
|
41
|
+
dbox = self.decode_bboxes(self.dfl(boxes) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
|
42
|
+
return dbox
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
def
|
|
46
|
-
"""Decode keypoints for
|
|
45
|
+
def _tf_kpts_decode(self, kpts: torch.Tensor, is_pose26: bool = False) -> torch.Tensor:
|
|
46
|
+
"""Decode keypoints for TensorFlow export."""
|
|
47
47
|
ndim = self.kpt_shape[1]
|
|
48
|
-
|
|
48
|
+
bs = kpts.shape[0]
|
|
49
49
|
# Precompute normalization factor to increase numerical stability
|
|
50
50
|
y = kpts.view(bs, *self.kpt_shape, -1)
|
|
51
|
-
grid_h, grid_w = self.shape[2]
|
|
51
|
+
grid_h, grid_w = self.shape[2:4]
|
|
52
52
|
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
|
|
53
53
|
norm = self.strides / (self.stride[0] * grid_size)
|
|
54
|
-
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
|
|
54
|
+
a = ((y[:, :, :2] + self.anchors) if is_pose26 else (y[:, :, :2] * 2.0 + (self.anchors - 0.5))) * norm
|
|
55
55
|
if ndim == 3:
|
|
56
56
|
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
|
57
57
|
return a.view(bs, self.nk, -1)
|