ultralytics-opencv-headless 8.3.251__py3-none-any.whl → 8.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. tests/__init__.py +2 -2
  2. tests/conftest.py +1 -1
  3. tests/test_cuda.py +8 -2
  4. tests/test_engine.py +8 -8
  5. tests/test_exports.py +13 -4
  6. tests/test_integrations.py +9 -9
  7. tests/test_python.py +14 -14
  8. tests/test_solutions.py +3 -3
  9. ultralytics/__init__.py +1 -1
  10. ultralytics/cfg/__init__.py +6 -6
  11. ultralytics/cfg/default.yaml +3 -1
  12. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  13. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  14. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  15. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  16. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  17. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  18. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  19. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  20. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  21. ultralytics/data/augment.py +7 -0
  22. ultralytics/data/dataset.py +1 -1
  23. ultralytics/engine/exporter.py +11 -4
  24. ultralytics/engine/model.py +1 -1
  25. ultralytics/engine/trainer.py +40 -15
  26. ultralytics/engine/tuner.py +15 -7
  27. ultralytics/models/fastsam/predict.py +1 -1
  28. ultralytics/models/yolo/detect/train.py +3 -2
  29. ultralytics/models/yolo/detect/val.py +6 -0
  30. ultralytics/models/yolo/model.py +1 -1
  31. ultralytics/models/yolo/obb/predict.py +1 -1
  32. ultralytics/models/yolo/obb/train.py +1 -1
  33. ultralytics/models/yolo/pose/train.py +1 -1
  34. ultralytics/models/yolo/segment/predict.py +1 -1
  35. ultralytics/models/yolo/segment/train.py +1 -1
  36. ultralytics/models/yolo/segment/val.py +3 -1
  37. ultralytics/models/yolo/yoloe/train.py +6 -1
  38. ultralytics/models/yolo/yoloe/train_seg.py +6 -1
  39. ultralytics/nn/autobackend.py +11 -5
  40. ultralytics/nn/modules/__init__.py +8 -0
  41. ultralytics/nn/modules/block.py +128 -8
  42. ultralytics/nn/modules/head.py +789 -204
  43. ultralytics/nn/tasks.py +74 -29
  44. ultralytics/nn/text_model.py +5 -2
  45. ultralytics/optim/__init__.py +5 -0
  46. ultralytics/optim/muon.py +338 -0
  47. ultralytics/utils/callbacks/platform.py +30 -11
  48. ultralytics/utils/downloads.py +3 -1
  49. ultralytics/utils/export/engine.py +19 -10
  50. ultralytics/utils/export/imx.py +23 -12
  51. ultralytics/utils/export/tensorflow.py +21 -21
  52. ultralytics/utils/loss.py +587 -203
  53. ultralytics/utils/metrics.py +1 -0
  54. ultralytics/utils/ops.py +11 -2
  55. ultralytics/utils/tal.py +100 -20
  56. ultralytics/utils/torch_utils.py +1 -1
  57. ultralytics/utils/tqdm.py +4 -1
  58. {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/METADATA +31 -39
  59. {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/RECORD +63 -52
  60. {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/WHEEL +0 -0
  61. {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/entry_points.txt +0 -0
  62. {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/licenses/LICENSE +0 -0
  63. {ultralytics_opencv_headless-8.3.251.dist-info → ultralytics_opencv_headless-8.4.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  import os
4
4
  import platform
5
+ import re
5
6
  import socket
6
7
  import sys
7
8
  from concurrent.futures import ThreadPoolExecutor
@@ -12,6 +13,14 @@ from ultralytics.utils import ENVIRONMENT, GIT, LOGGER, PYTHON_VERSION, RANK, SE
12
13
 
13
14
  PREFIX = colorstr("Platform: ")
14
15
 
16
+
17
+ def slugify(text):
18
+ """Convert text to URL-safe slug (e.g., 'My Project 1' -> 'my-project-1')."""
19
+ if not text:
20
+ return text
21
+ return re.sub(r"-+", "-", re.sub(r"[^a-z0-9\s-]", "", str(text).lower()).replace(" ", "-")).strip("-")[:128]
22
+
23
+
15
24
  try:
16
25
  assert not TESTS_RUNNING # do not log pytest
17
26
  assert SETTINGS.get("platform", False) is True or os.getenv("ULTRALYTICS_API_KEY") or SETTINGS.get("api_key")
@@ -57,9 +66,11 @@ def resolve_platform_uri(uri, hard=True):
57
66
 
58
67
  api_key = os.getenv("ULTRALYTICS_API_KEY") or SETTINGS.get("api_key")
59
68
  if not api_key:
60
- raise ValueError(f"ULTRALYTICS_API_KEY required for '{uri}'. Get key at https://alpha.ultralytics.com/settings")
69
+ raise ValueError(
70
+ f"ULTRALYTICS_API_KEY required for '{uri}'. Get key at https://platform.ultralytics.com/settings"
71
+ )
61
72
 
62
- base = "https://alpha.ultralytics.com/api/webhooks"
73
+ base = "https://platform.ultralytics.com/api/webhooks"
63
74
  headers = {"Authorization": f"Bearer {api_key}"}
64
75
 
65
76
  # ul://username/datasets/slug
@@ -141,7 +152,7 @@ def _send(event, data, project, name, model_id=None):
141
152
  if model_id:
142
153
  payload["modelId"] = model_id
143
154
  r = requests.post(
144
- "https://alpha.ultralytics.com/api/webhooks/training/metrics",
155
+ "https://platform.ultralytics.com/api/webhooks/training/metrics",
145
156
  json=payload,
146
157
  headers={"Authorization": f"Bearer {_api_key}"},
147
158
  timeout=10,
@@ -167,7 +178,7 @@ def _upload_model(model_path, project, name):
167
178
 
168
179
  # Get signed upload URL
169
180
  response = requests.post(
170
- "https://alpha.ultralytics.com/api/webhooks/models/upload",
181
+ "https://platform.ultralytics.com/api/webhooks/models/upload",
171
182
  json={"project": project, "name": name, "filename": model_path.name},
172
183
  headers={"Authorization": f"Bearer {_api_key}"},
173
184
  timeout=10,
@@ -184,7 +195,7 @@ def _upload_model(model_path, project, name):
184
195
  timeout=600, # 10 min timeout for large models
185
196
  ).raise_for_status()
186
197
 
187
- # url = f"https://alpha.ultralytics.com/{project}/{name}"
198
+ # url = f"https://platform.ultralytics.com/{project}/{name}"
188
199
  # LOGGER.info(f"{PREFIX}Model uploaded to {url}")
189
200
  return data.get("gcsPath")
190
201
 
@@ -249,6 +260,14 @@ def _get_environment_info():
249
260
  return env
250
261
 
251
262
 
263
+ def _get_project_name(trainer):
264
+ """Get slugified project and name from trainer args."""
265
+ raw = str(trainer.args.project)
266
+ parts = raw.split("/", 1)
267
+ project = f"{parts[0]}/{slugify(parts[1])}" if len(parts) == 2 else slugify(raw)
268
+ return project, slugify(str(trainer.args.name or "train"))
269
+
270
+
252
271
  def on_pretrain_routine_start(trainer):
253
272
  """Initialize Platform logging at training start."""
254
273
  if RANK not in {-1, 0} or not trainer.args.project:
@@ -258,8 +277,8 @@ def on_pretrain_routine_start(trainer):
258
277
  trainer._platform_model_id = None
259
278
  trainer._platform_last_upload = time()
260
279
 
261
- project, name = str(trainer.args.project), str(trainer.args.name or "train")
262
- url = f"https://alpha.ultralytics.com/{project}/{name}"
280
+ project, name = _get_project_name(trainer)
281
+ url = f"https://platform.ultralytics.com/{project}/{name}"
263
282
  LOGGER.info(f"{PREFIX}Streaming to {url}")
264
283
 
265
284
  # Create callback to send console output to Platform
@@ -305,7 +324,7 @@ def on_fit_epoch_end(trainer):
305
324
  if RANK not in {-1, 0} or not trainer.args.project:
306
325
  return
307
326
 
308
- project, name = str(trainer.args.project), str(trainer.args.name or "train")
327
+ project, name = _get_project_name(trainer)
309
328
  metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics}
310
329
 
311
330
  if trainer.optimizer and trainer.optimizer.param_groups:
@@ -365,7 +384,7 @@ def on_model_save(trainer):
365
384
  if not model_path:
366
385
  return
367
386
 
368
- project, name = str(trainer.args.project), str(trainer.args.name or "train")
387
+ project, name = _get_project_name(trainer)
369
388
  _upload_model_async(model_path, project, name)
370
389
  trainer._platform_last_upload = time()
371
390
 
@@ -375,7 +394,7 @@ def on_train_end(trainer):
375
394
  if RANK not in {-1, 0} or not trainer.args.project:
376
395
  return
377
396
 
378
- project, name = str(trainer.args.project), str(trainer.args.name or "train")
397
+ project, name = _get_project_name(trainer)
379
398
 
380
399
  # Stop console capture
381
400
  if hasattr(trainer, "_platform_console_logger") and trainer._platform_console_logger:
@@ -420,7 +439,7 @@ def on_train_end(trainer):
420
439
  name,
421
440
  getattr(trainer, "_platform_model_id", None),
422
441
  )
423
- url = f"https://alpha.ultralytics.com/{project}/{name}"
442
+ url = f"https://platform.ultralytics.com/{project}/{name}"
424
443
  LOGGER.info(f"{PREFIX}View results at {url}")
425
444
 
426
445
 
@@ -18,12 +18,14 @@ GITHUB_ASSETS_NAMES = frozenset(
18
18
  [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")]
19
19
  + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
20
20
  + [f"yolo12{k}{suffix}.pt" for k in "nsmlx" for suffix in ("",)] # detect models only currently
21
+ + [f"yolo26{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
21
22
  + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
22
23
  + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
23
24
  + [f"yolov8{k}-world.pt" for k in "smlx"]
24
25
  + [f"yolov8{k}-worldv2.pt" for k in "smlx"]
25
26
  + [f"yoloe-v8{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
26
27
  + [f"yoloe-11{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
28
+ + [f"yoloe-26{k}{suffix}.pt" for k in "nsmlx" for suffix in ("-seg", "-seg-pf")]
27
29
  + [f"yolov9{k}.pt" for k in "tsmce"]
28
30
  + [f"yolov10{k}.pt" for k in "nsmblx"]
29
31
  + [f"yolo_nas_{k}.pt" for k in "sml"]
@@ -424,7 +426,7 @@ def get_github_assets(
424
426
  def attempt_download_asset(
425
427
  file: str | Path,
426
428
  repo: str = "ultralytics/assets",
427
- release: str = "v8.3.0",
429
+ release: str = "v8.4.0",
428
430
  **kwargs,
429
431
  ) -> str:
430
432
  """Attempt to download a file from GitHub release assets if it is not found locally.
@@ -143,7 +143,7 @@ def onnx2engine(
143
143
  for inp in inputs:
144
144
  profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
145
145
  config.add_optimization_profile(profile)
146
- if int8:
146
+ if int8 and not is_trt10: # deprecated in TensorRT 10, causes internal errors
147
147
  config.set_calibration_profile(profile)
148
148
 
149
149
  LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
@@ -226,12 +226,21 @@ def onnx2engine(
226
226
  config.set_flag(trt.BuilderFlag.FP16)
227
227
 
228
228
  # Write file
229
- build = builder.build_serialized_network if is_trt10 else builder.build_engine
230
- with build(network, config) as engine, open(engine_file, "wb") as t:
231
- # Metadata
232
- if metadata is not None:
233
- meta = json.dumps(metadata)
234
- t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
235
- t.write(meta.encode())
236
- # Model
237
- t.write(engine if is_trt10 else engine.serialize())
229
+ if is_trt10:
230
+ # TensorRT 10+ returns bytes directly, not a context manager
231
+ engine = builder.build_serialized_network(network, config)
232
+ if engine is None:
233
+ raise RuntimeError("TensorRT engine build failed, check logs for errors")
234
+ with open(engine_file, "wb") as t:
235
+ if metadata is not None:
236
+ meta = json.dumps(metadata)
237
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
238
+ t.write(meta.encode())
239
+ t.write(engine)
240
+ else:
241
+ with builder.build_engine(network, config) as engine, open(engine_file, "wb") as t:
242
+ if metadata is not None:
243
+ meta = json.dumps(metadata)
244
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
245
+ t.write(meta.encode())
246
+ t.write(engine.serialize())
@@ -21,27 +21,27 @@ from ultralytics.utils.torch_utils import copy_attr
21
21
  MCT_CONFIG = {
22
22
  "YOLO11": {
23
23
  "detect": {
24
- "layer_names": ["sub", "mul_2", "add_14", "cat_21"],
24
+ "layer_names": ["sub", "mul_2", "add_14", "cat_19"],
25
25
  "weights_memory": 2585350.2439,
26
26
  "n_layers": 238,
27
27
  },
28
28
  "pose": {
29
- "layer_names": ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"],
29
+ "layer_names": ["sub", "mul_2", "add_14", "cat_21", "cat_22", "mul_4", "add_15"],
30
30
  "weights_memory": 2437771.67,
31
31
  "n_layers": 257,
32
32
  },
33
33
  "classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": 112},
34
- "segment": {"layer_names": ["sub", "mul_2", "add_14", "cat_22"], "weights_memory": 2466604.8, "n_layers": 265},
34
+ "segment": {"layer_names": ["sub", "mul_2", "add_14", "cat_21"], "weights_memory": 2466604.8, "n_layers": 265},
35
35
  },
36
36
  "YOLOv8": {
37
- "detect": {"layer_names": ["sub", "mul", "add_6", "cat_17"], "weights_memory": 2550540.8, "n_layers": 168},
37
+ "detect": {"layer_names": ["sub", "mul", "add_6", "cat_15"], "weights_memory": 2550540.8, "n_layers": 168},
38
38
  "pose": {
39
- "layer_names": ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"],
39
+ "layer_names": ["add_7", "mul_2", "cat_17", "mul", "sub", "add_6", "cat_18"],
40
40
  "weights_memory": 2482451.85,
41
41
  "n_layers": 187,
42
42
  },
43
43
  "classify": {"layer_names": [], "weights_memory": np.inf, "n_layers": 73},
44
- "segment": {"layer_names": ["sub", "mul", "add_6", "cat_18"], "weights_memory": 2580060.0, "n_layers": 195},
44
+ "segment": {"layer_names": ["sub", "mul", "add_6", "cat_17"], "weights_memory": 2580060.0, "n_layers": 195},
45
45
  },
46
46
  }
47
47
 
@@ -104,10 +104,13 @@ class FXModel(torch.nn.Module):
104
104
  return x
105
105
 
106
106
 
107
- def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
107
+ def _inference(self, x: list[torch.Tensor] | dict[str, torch.Tensor]) -> tuple[torch.Tensor]:
108
108
  """Decode boxes and cls scores for imx object detection."""
109
- x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
110
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
109
+ if isinstance(x, dict):
110
+ box, cls = x["boxes"], x["scores"]
111
+ else:
112
+ x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
113
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
111
114
  dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
112
115
  return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
113
116
 
@@ -115,9 +118,17 @@ def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
115
118
  def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
116
119
  """Forward pass for imx pose estimation, including keypoint decoding."""
117
120
  bs = x[0].shape[0] # batch size
118
- kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
121
+ nk_out = getattr(self, "nk_output", self.nk)
122
+ kpt = torch.cat([self.cv4[i](x[i]).view(bs, nk_out, -1) for i in range(self.nl)], -1)
123
+
124
+ # If using Pose26 with 5 dims, convert to 3 dims for export
125
+ if hasattr(self, "nk_output") and self.nk_output != self.nk:
126
+ spatial = kpt.shape[-1]
127
+ kpt = kpt.view(bs, self.kpt_shape[0], self.kpt_shape[1] + 2, spatial)
128
+ kpt = kpt[:, :, :-2, :] # Remove sigma_x, sigma_y
129
+ kpt = kpt.view(bs, self.nk, spatial)
119
130
  x = Detect.forward(self, x)
120
- pred_kpt = self.kpts_decode(bs, kpt)
131
+ pred_kpt = self.kpts_decode(kpt)
121
132
  return *x, pred_kpt.permute(0, 2, 1)
122
133
 
123
134
 
@@ -219,7 +230,7 @@ def torch2imx(
219
230
  Examples:
220
231
  >>> from ultralytics import YOLO
221
232
  >>> model = YOLO("yolo11n.pt")
222
- >>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.45, max_det=300)
233
+ >>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.7, max_det=300)
223
234
 
224
235
  Notes:
225
236
  - Requires model_compression_toolkit, onnx, edgemdt_tpc, and edge-mdt-cl packages
@@ -2,12 +2,13 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from functools import partial
5
6
  from pathlib import Path
6
7
 
7
8
  import numpy as np
8
9
  import torch
9
10
 
10
- from ultralytics.nn.modules import Detect, Pose
11
+ from ultralytics.nn.modules import Detect, Pose, Pose26
11
12
  from ultralytics.utils import LOGGER
12
13
  from ultralytics.utils.downloads import attempt_download_asset
13
14
  from ultralytics.utils.files import spaces_in_path
@@ -15,43 +16,42 @@ from ultralytics.utils.tal import make_anchors
15
16
 
16
17
 
17
18
  def tf_wrapper(model: torch.nn.Module) -> torch.nn.Module:
18
- """A wrapper to add TensorFlow compatible inference methods to Detect and Pose layers."""
19
+ """A wrapper for TensorFlow export compatibility (TF-specific handling is now in head modules)."""
19
20
  for m in model.modules():
20
21
  if not isinstance(m, Detect):
21
22
  continue
22
23
  import types
23
24
 
24
- m._inference = types.MethodType(_tf_inference, m)
25
- if type(m) is Pose:
26
- m.kpts_decode = types.MethodType(tf_kpts_decode, m)
25
+ m._get_decode_boxes = types.MethodType(_tf_decode_boxes, m)
26
+ if isinstance(m, Pose):
27
+ m.kpts_decode = types.MethodType(partial(_tf_kpts_decode, is_pose26=type(m) is Pose26), m)
27
28
  return model
28
29
 
29
30
 
30
- def _tf_inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
31
- """Decode boxes and cls scores for tf object detection."""
32
- shape = x[0].shape # BCHW
33
- x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
34
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
35
- if self.dynamic or self.shape != shape:
36
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
31
+ def _tf_decode_boxes(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
32
+ """Decode bounding boxes for TensorFlow export."""
33
+ shape = x["feats"][0].shape # BCHW
34
+ boxes = x["boxes"]
35
+ if self.format != "imx" and (self.dynamic or self.shape != shape):
36
+ self.anchors, self.strides = (a.transpose(0, 1) for a in make_anchors(x["feats"], self.stride, 0.5))
37
37
  self.shape = shape
38
- grid_h, grid_w = shape[2], shape[3]
39
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
38
+ grid_h, grid_w = shape[2:4]
39
+ grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=boxes.device).reshape(1, 4, 1)
40
40
  norm = self.strides / (self.stride[0] * grid_size)
41
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
42
- return torch.cat((dbox, cls.sigmoid()), 1)
41
+ dbox = self.decode_bboxes(self.dfl(boxes) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
42
+ return dbox
43
43
 
44
44
 
45
- def tf_kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:
46
- """Decode keypoints for tf pose estimation."""
45
+ def _tf_kpts_decode(self, kpts: torch.Tensor, is_pose26: bool = False) -> torch.Tensor:
46
+ """Decode keypoints for TensorFlow export."""
47
47
  ndim = self.kpt_shape[1]
48
- # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
48
+ bs = kpts.shape[0]
49
49
  # Precompute normalization factor to increase numerical stability
50
50
  y = kpts.view(bs, *self.kpt_shape, -1)
51
- grid_h, grid_w = self.shape[2], self.shape[3]
51
+ grid_h, grid_w = self.shape[2:4]
52
52
  grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
53
53
  norm = self.strides / (self.stride[0] * grid_size)
54
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
54
+ a = ((y[:, :, :2] + self.anchors) if is_pose26 else (y[:, :, :2] * 2.0 + (self.anchors - 0.5))) * norm
55
55
  if ndim == 3:
56
56
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
57
57
  return a.view(bs, self.nk, -1)