kyolo 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. kyolo/__init__.py +55 -0
  2. kyolo/conversion/__init__.py +46 -0
  3. kyolo/conversion/cli.py +117 -0
  4. kyolo/conversion/convert.py +429 -0
  5. kyolo/conversion/exceptions.py +51 -0
  6. kyolo/conversion/mappings.py +89 -0
  7. kyolo/heads/__init__.py +5 -0
  8. kyolo/heads/detect.py +91 -0
  9. kyolo/layers/__init__.py +79 -0
  10. kyolo/layers/attention.py +264 -0
  11. kyolo/layers/blocks.py +253 -0
  12. kyolo/layers/common.py +173 -0
  13. kyolo/layers/dfl.py +60 -0
  14. kyolo/layers/letterbox.py +151 -0
  15. kyolo/layers/reparam.py +256 -0
  16. kyolo/losses/__init__.py +5 -0
  17. kyolo/losses/detection_loss.py +155 -0
  18. kyolo/models/__init__.py +213 -0
  19. kyolo/models/base.py +112 -0
  20. kyolo/models/yolo11/__init__.py +93 -0
  21. kyolo/models/yolo11/config.py +14 -0
  22. kyolo/models/yolo11/convert_yolo11_torch_to_keras.py +57 -0
  23. kyolo/models/yolo11/yolo11_model.py +90 -0
  24. kyolo/models/yolo12/__init__.py +93 -0
  25. kyolo/models/yolo12/config.py +14 -0
  26. kyolo/models/yolo12/convert_yolo12_torch_to_keras.py +57 -0
  27. kyolo/models/yolo12/yolo12_model.py +93 -0
  28. kyolo/models/yolo26/__init__.py +93 -0
  29. kyolo/models/yolo26/config.py +14 -0
  30. kyolo/models/yolo26/convert_yolo26_torch_to_keras.py +57 -0
  31. kyolo/models/yolo26/yolo26_model.py +104 -0
  32. kyolo/models/yolov10/__init__.py +109 -0
  33. kyolo/models/yolov10/config.py +15 -0
  34. kyolo/models/yolov10/convert_yolov10_torch_to_keras.py +57 -0
  35. kyolo/models/yolov10/yolov10_model.py +100 -0
  36. kyolo/models/yolov5/__init__.py +93 -0
  37. kyolo/models/yolov5/config.py +14 -0
  38. kyolo/models/yolov5/convert_yolov5_torch_to_keras.py +57 -0
  39. kyolo/models/yolov5/yolov5_model.py +95 -0
  40. kyolo/models/yolov6/__init__.py +77 -0
  41. kyolo/models/yolov6/config.py +13 -0
  42. kyolo/models/yolov6/convert_yolov6_torch_to_keras.py +57 -0
  43. kyolo/models/yolov6/yolov6_model.py +101 -0
  44. kyolo/models/yolov7/__init__.py +63 -0
  45. kyolo/models/yolov7/config.py +13 -0
  46. kyolo/models/yolov7/convert_yolov7_torch_to_keras.py +57 -0
  47. kyolo/models/yolov7/yolov7_model.py +106 -0
  48. kyolo/models/yolov8/__init__.py +93 -0
  49. kyolo/models/yolov8/config.py +14 -0
  50. kyolo/models/yolov8/convert_yolov8_torch_to_keras.py +57 -0
  51. kyolo/models/yolov8/yolov8_model.py +88 -0
  52. kyolo/models/yolov9/__init__.py +93 -0
  53. kyolo/models/yolov9/config.py +14 -0
  54. kyolo/models/yolov9/convert_yolov9_torch_to_keras.py +57 -0
  55. kyolo/models/yolov9/yolov9_model.py +95 -0
  56. kyolo/ops/__init__.py +24 -0
  57. kyolo/ops/anchors.py +123 -0
  58. kyolo/ops/boxes.py +114 -0
  59. kyolo/ops/tal.py +152 -0
  60. kyolo/postprocessing/__init__.py +17 -0
  61. kyolo/postprocessing/nms.py +171 -0
  62. kyolo/postprocessing/postprocessor.py +90 -0
  63. kyolo/preprocessing/__init__.py +5 -0
  64. kyolo/preprocessing/preprocessor.py +131 -0
  65. kyolo/training/__init__.py +5 -0
  66. kyolo/training/detector.py +120 -0
  67. kyolo/utils/__init__.py +20 -0
  68. kyolo/utils/coco.py +127 -0
  69. kyolo/utils/visualization.py +209 -0
  70. kyolo/version.py +5 -0
  71. kyolo-0.1.0.dist-info/METADATA +254 -0
  72. kyolo-0.1.0.dist-info/RECORD +76 -0
  73. kyolo-0.1.0.dist-info/WHEEL +5 -0
  74. kyolo-0.1.0.dist-info/entry_points.txt +2 -0
  75. kyolo-0.1.0.dist-info/licenses/LICENSE +201 -0
  76. kyolo-0.1.0.dist-info/top_level.txt +1 -0
kyolo/__init__.py ADDED
@@ -0,0 +1,55 @@
1
+ """kyolo — the YOLO object-detection family in pure Keras 3.
2
+
3
+ Backend-agnostic (TensorFlow / JAX / PyTorch) implementations of YOLOv5, v6,
4
+ v7, v8, v9, v10, YOLO11, YOLO12 and YOLO26, with preprocessing, postprocessing,
5
+ training/fine-tuning support and PyTorch->Keras weight-conversion utilities.
6
+
7
+ Quickstart
8
+ ----------
9
+ import keras
10
+ from kyolo import yolov8n, YOLOPreprocessor, YOLOPostprocessor
11
+
12
+ model = yolov8n(nc=80)
13
+ pre = YOLOPreprocessor(image_size=640)
14
+ post = YOLOPostprocessor(nc=80)
15
+
16
+ batch = pre(image) # {"images", "ratio", "pad"}
17
+ feats = model(batch["images"]) # list of 3 raw feature maps
18
+ detections = post(feats) # (B, max_det, 6): x1,y1,x2,y2,score,cls
19
+
20
+ Every model variant is a factory function (``yolov5n``, ``yolo11s``, ``yolov9c``,
21
+ ...). See ``kyolo.models.MODEL_NAMES`` for the full list.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from . import models
27
+ from .losses import YOLODetectionLoss
28
+ from .models import * # noqa: F401,F403 (per-variant factories + family classes)
29
+ from .models import MODEL_NAMES, list_models
30
+ from .postprocessing import (
31
+ NonMaxSuppression,
32
+ YOLOPostprocessor,
33
+ detections_to_list,
34
+ )
35
+ from .preprocessing import YOLOPreprocessor
36
+ from .training import YOLODetector
37
+ from .version import __version__, version
38
+
39
+ __all__ = [
40
+ "__version__",
41
+ "version",
42
+ "models",
43
+ "list_models",
44
+ "MODEL_NAMES",
45
+ # per-variant factories + family classes (re-exported from kyolo.models)
46
+ *models.__all__,
47
+ # pre / post
48
+ "YOLOPreprocessor",
49
+ "YOLOPostprocessor",
50
+ "NonMaxSuppression",
51
+ "detections_to_list",
52
+ # training / loss
53
+ "YOLODetector",
54
+ "YOLODetectionLoss",
55
+ ]
@@ -0,0 +1,46 @@
1
+ """PyTorch -> Keras 3 weight conversion for :mod:`kyolo`.
2
+
3
+ The official YOLO weights are AGPL-3.0 and are **not** redistributed by this
4
+ project. These utilities let a user convert a ``.pt`` checkpoint they have
5
+ obtained themselves into Keras ``.weights.h5`` format.
6
+
7
+ Public API
8
+ ----------
9
+ * :func:`load_torch_state_dict` -- read a ``.pt`` into ordered numpy arrays.
10
+ * :func:`transfer_by_order` -- robust positional transfer (recommended).
11
+ * :func:`transfer_torch_to_keras` -- best-effort name-based transfer.
12
+ * :func:`convert_weights` -- high-level: load, transfer, save.
13
+
14
+ See :mod:`kyolo.conversion.convert` for details and the important caveat that a
15
+ clean transfer must still be validated against the reference outputs.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from .convert import (
21
+ convert_weights,
22
+ load_torch_state_dict,
23
+ transfer_by_order,
24
+ transfer_torch_to_keras,
25
+ )
26
+ from .exceptions import (
27
+ WeightConversionError,
28
+ WeightCountMismatchError,
29
+ WeightMappingError,
30
+ WeightShapeMismatchError,
31
+ )
32
+ from .mappings import DEFAULT_MAPPING, NAME_MAPPINGS, get_mapping
33
+
34
+ __all__ = [
35
+ "transfer_torch_to_keras",
36
+ "transfer_by_order",
37
+ "load_torch_state_dict",
38
+ "convert_weights",
39
+ "DEFAULT_MAPPING",
40
+ "NAME_MAPPINGS",
41
+ "get_mapping",
42
+ "WeightConversionError",
43
+ "WeightMappingError",
44
+ "WeightShapeMismatchError",
45
+ "WeightCountMismatchError",
46
+ ]
@@ -0,0 +1,117 @@
1
+ """Command-line entry point for PyTorch -> Keras weight conversion.
2
+
3
+ Exposed as the ``kyolo-convert`` console script (see ``pyproject.toml``). Builds
4
+ the requested kyolo model in deploy (inference) form, then transfers weights
5
+ from a user-supplied ``.pt`` file. Each model also ships an equivalent
6
+ co-located converter at ``kyolo/models/<name>/convert_<name>_torch_to_keras.py``.
7
+
8
+ Example
9
+ -------
10
+ kyolo-convert --model yolov8n --weights yolov8n.pt \
11
+ --output yolov8n.weights.h5 --method order
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ from typing import Optional, Sequence
18
+
19
+ from .convert import convert_weights
20
+ from .mappings import get_mapping
21
+
22
+
23
+ def _build_parser() -> argparse.ArgumentParser:
24
+ parser = argparse.ArgumentParser(
25
+ prog="kyolo-convert",
26
+ description=(
27
+ "Convert an official YOLO PyTorch (.pt) checkpoint into kyolo Keras "
28
+ "weights (.weights.h5). The weights themselves are AGPL-3.0 and must "
29
+ "be obtained by you; this tool only converts a file you already have."
30
+ ),
31
+ )
32
+ parser.add_argument(
33
+ "--model",
34
+ required=True,
35
+ help="Model name, e.g. yolov8n, yolo11s, yolov9c, yolov5m.",
36
+ )
37
+ parser.add_argument(
38
+ "--weights",
39
+ required=True,
40
+ help="Path to the source PyTorch .pt checkpoint.",
41
+ )
42
+ parser.add_argument(
43
+ "--output",
44
+ default=None,
45
+ help="Output path for the Keras weights (default: <weights-stem>.weights.h5).",
46
+ )
47
+ parser.add_argument(
48
+ "--method",
49
+ choices=("order", "name"),
50
+ default="order",
51
+ help="Transfer strategy. 'order' (default) is the robust positional "
52
+ "transfer; 'name' is best-effort name matching.",
53
+ )
54
+ parser.add_argument(
55
+ "--nc",
56
+ type=int,
57
+ default=80,
58
+ help="Number of classes the model was trained on (default: 80 = COCO).",
59
+ )
60
+ parser.add_argument(
61
+ "--imgsz",
62
+ type=int,
63
+ default=640,
64
+ help="Square input image size used to build the model (default: 640).",
65
+ )
66
+ parser.add_argument(
67
+ "--quiet",
68
+ action="store_true",
69
+ help="Suppress progress output.",
70
+ )
71
+ return parser
72
+
73
+
74
+ def main(argv: Optional[Sequence[str]] = None) -> int:
75
+ """Parse arguments, build the model and run the conversion."""
76
+ args = _build_parser().parse_args(argv)
77
+
78
+ # Imported here (not at module load) so that `kyolo.conversion` stays
79
+ # importable even while `kyolo.models` is under development.
80
+ import kyolo.models as models
81
+
82
+ key = args.model.replace("-", "_")
83
+ factory = getattr(models, key, None)
84
+ if factory is None or key not in models.MODEL_NAMES:
85
+ raise SystemExit(
86
+ f"Unknown model {args.model!r}. Available: {', '.join(models.MODEL_NAMES)}"
87
+ )
88
+
89
+ verbose = not args.quiet
90
+ if verbose:
91
+ print(f"Building {args.model} (nc={args.nc}, imgsz={args.imgsz}, deploy=True) ...")
92
+ model = factory(
93
+ nc=args.nc,
94
+ input_shape=(args.imgsz, args.imgsz, 3),
95
+ deploy=True,
96
+ )
97
+
98
+ name_mapping = get_mapping(args.model) if args.method == "name" else None
99
+ report = convert_weights(
100
+ model,
101
+ args.weights,
102
+ output_path=args.output,
103
+ method=args.method,
104
+ name_mapping=name_mapping,
105
+ verbose=verbose,
106
+ )
107
+
108
+ transferred = report.get("transferred", 0)
109
+ total = report.get("total", 0)
110
+ if verbose:
111
+ print(f"Done: {transferred}/{total} variables transferred.")
112
+ # Non-zero exit if nothing (or not everything) transferred, to aid scripting.
113
+ return 0 if transferred == total and total > 0 else 1
114
+
115
+
116
+ if __name__ == "__main__":
117
+ raise SystemExit(main())
@@ -0,0 +1,429 @@
1
+ """PyTorch -> Keras 3 weight conversion for the YOLO family.
2
+
3
+ This module moves parameters out of an official (Ultralytics / WongKinYiu)
4
+ ``.pt`` checkpoint and into an equivalent :mod:`kyolo` Keras 3 model. Nothing
5
+ here depends on a particular Keras backend: reading/writing variables goes
6
+ through :meth:`keras.Variable.assign`, and tensors are shuttled around as plain
7
+ NumPy arrays.
8
+
9
+ Two transfer strategies are provided:
10
+
11
+ ``transfer_by_order`` (method ``"order"``, RECOMMENDED)
12
+ Walks the Keras variables and the Torch ``state_dict`` **in build order**
13
+ and zips them together positionally. Because the :mod:`kyolo` builders
14
+ create their sub-layers in the exact same order the PyTorch modules are
15
+ registered, this lines up without relying on fragile string matching. It
16
+ only needs the *counts* and *shapes* to agree, which it validates loudly.
17
+
18
+ ``transfer_torch_to_keras`` (method ``"name"``, best-effort)
19
+ Derives a candidate Torch key for every Keras variable from its dotted
20
+ ``path`` (e.g. ``model.0.conv/kernel`` -> ``model.0.conv.weight``) and looks
21
+ it up directly. This is convenient when the two graphs are structurally
22
+ identical, but the derived names can drift from a specific checkpoint's
23
+ layout (the detection head in particular), so it may need per-model mapping
24
+ tuning. Prefer ``"order"`` unless you have a reason not to.
25
+
26
+ Layout conventions handled here
27
+ -------------------------------
28
+ * Conv2D kernel: Keras is ``HWIO`` ``(kh, kw, in/groups, out)`` while PyTorch is
29
+ ``OIHW`` ``(out, in/groups, kh, kw)`` -> ``np.transpose(w, (2, 3, 1, 0))``.
30
+ Grouped / depth-wise convolutions use the same transpose.
31
+ * Dense/Linear kernel: Keras ``(in, out)`` vs Torch ``(out, in)`` -> transpose.
32
+ * BatchNormalization: Keras ``gamma, beta, moving_mean, moving_variance`` map to
33
+ Torch ``weight, bias, running_mean, running_var``.
34
+ * Conv/Dense bias: copied verbatim.
35
+
36
+ .. important::
37
+ A successful, no-error transfer does **not** by itself guarantee a
38
+ bit-exact model. Padding conventions, activation choices and the exact
39
+ assembly of the detection head must be validated by running the converted
40
+ model against the PyTorch reference on the same input and comparing outputs.
41
+ Treat the reports returned here as a first, necessary check -- not proof of
42
+ correctness.
43
+ """
44
+
45
+ from __future__ import annotations
46
+
47
+ import os
48
+ from collections import OrderedDict
49
+ from typing import Dict, List, Optional, Tuple
50
+
51
+ import numpy as np
52
+
53
+ from .mappings import DEFAULT_MAPPING
54
+
55
+ __all__ = [
56
+ "load_torch_state_dict",
57
+ "transfer_by_order",
58
+ "transfer_torch_to_keras",
59
+ "convert_weights",
60
+ ]
61
+
62
+
63
+ # Maps a Keras variable leaf name -> the corresponding PyTorch parameter suffix.
64
+ _SUFFIX_MAP = {
65
+ "kernel": "weight",
66
+ "gamma": "weight",
67
+ "beta": "bias",
68
+ "moving_mean": "running_mean",
69
+ "moving_variance": "running_var",
70
+ "bias": "bias",
71
+ }
72
+
73
+
74
+ # --------------------------------------------------------------------------- #
75
+ # Loading a torch checkpoint into plain numpy arrays
76
+ # --------------------------------------------------------------------------- #
77
+ def _require_torch():
78
+ """Import torch lazily, raising a friendly error if it is unavailable."""
79
+ try:
80
+ import torch # noqa: F401
81
+ except ImportError as exc: # pragma: no cover - trivial
82
+ raise ImportError(
83
+ "PyTorch is required to read '.pt' checkpoints but is not installed. "
84
+ "Install it with `pip install torch` or `pip install kyolo[conversion]`. "
85
+ "torch is only needed for weight conversion, not for running kyolo."
86
+ ) from exc
87
+ return torch
88
+
89
+
90
+ def _looks_like_state_dict(obj, torch) -> bool:
91
+ """True if ``obj`` is a mapping whose values are (mostly) tensors."""
92
+ if not isinstance(obj, dict) or not obj:
93
+ return False
94
+ return any(torch.is_tensor(v) for v in obj.values())
95
+
96
+
97
+ def _to_state_dict(obj, torch) -> Dict[str, "object"]:
98
+ """Reduce an arbitrary loaded checkpoint object to a ``name -> tensor`` map.
99
+
100
+ Handles, in order of preference: a raw ``state_dict``; an ``nn.Module``; an
101
+ Ultralytics-style checkpoint ``{"model": nn.Module, "ema": ..., ...}``; and
102
+ the generic ``{"state_dict": ...}`` container.
103
+ """
104
+ # A live nn.Module (Ultralytics stores the model object itself).
105
+ if not isinstance(obj, dict) and hasattr(obj, "state_dict"):
106
+ try:
107
+ return obj.float().state_dict()
108
+ except Exception: # pragma: no cover - some modules dislike .float()
109
+ return obj.state_dict()
110
+
111
+ if isinstance(obj, dict):
112
+ # Ultralytics checkpoint: prefer the trained model, fall back to EMA.
113
+ for key in ("model", "ema"):
114
+ if obj.get(key) is not None:
115
+ return _to_state_dict(obj[key], torch)
116
+ if obj.get("state_dict") is not None:
117
+ return _to_state_dict(obj["state_dict"], torch)
118
+ if _looks_like_state_dict(obj, torch):
119
+ return obj
120
+
121
+ raise ValueError(
122
+ "Could not interpret the checkpoint as a state_dict. Expected a raw "
123
+ "state_dict, an nn.Module, or a dict containing a 'model'/'ema'/"
124
+ f"'state_dict' entry, but got: {type(obj)!r}."
125
+ )
126
+
127
+
128
+ def load_torch_state_dict(path: str) -> "OrderedDict[str, np.ndarray]":
129
+ """Load a ``.pt`` file into an ordered ``name -> numpy.ndarray`` mapping.
130
+
131
+ Insertion order is preserved (this is what the ``"order"`` transfer relies
132
+ on). ``num_batches_tracked`` buffers and any non-tensor entries are dropped,
133
+ and every tensor is moved to CPU and upcast to ``float32``.
134
+
135
+ Args:
136
+ path: Path to the ``.pt`` checkpoint.
137
+
138
+ Returns:
139
+ An ``OrderedDict`` of NumPy arrays in the checkpoint's parameter order.
140
+ """
141
+ torch = _require_torch()
142
+
143
+ # weights_only=False is required to unpickle Ultralytics' nn.Module
144
+ # checkpoints; older torch releases lack the argument entirely.
145
+ try:
146
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
147
+ except TypeError:
148
+ ckpt = torch.load(path, map_location="cpu")
149
+
150
+ state = _to_state_dict(ckpt, torch)
151
+
152
+ out: "OrderedDict[str, np.ndarray]" = OrderedDict()
153
+ for name, value in state.items():
154
+ if name.endswith("num_batches_tracked"):
155
+ continue
156
+ if not torch.is_tensor(value):
157
+ continue
158
+ out[name] = value.detach().cpu().float().numpy()
159
+ if not out:
160
+ raise ValueError(f"No tensor parameters were found in checkpoint: {path!r}")
161
+ return out
162
+
163
+
164
+ # --------------------------------------------------------------------------- #
165
+ # Shared helpers
166
+ # --------------------------------------------------------------------------- #
167
+ def _leaf(path: str) -> str:
168
+ """Return the trailing variable name of a Keras ``variable.path``."""
169
+ return path.rsplit("/", 1)[-1]
170
+
171
+
172
+ def _var_kind(var) -> str:
173
+ """Classify a Keras variable so we know how (if at all) to transpose it."""
174
+ name = _leaf(var.path)
175
+ ndim = len(var.shape)
176
+ if name.endswith("kernel"):
177
+ if ndim == 4:
178
+ return "conv_kernel"
179
+ if ndim == 2:
180
+ return "dense_kernel"
181
+ return "other"
182
+ if name == "bias":
183
+ return "bias"
184
+ if name == "gamma":
185
+ return "bn_gamma"
186
+ if name == "beta":
187
+ return "bn_beta"
188
+ if name == "moving_mean":
189
+ return "bn_mean"
190
+ if name == "moving_variance":
191
+ return "bn_var"
192
+ return "other"
193
+
194
+
195
+ def _transpose_for_kind(arr: np.ndarray, kind: str) -> np.ndarray:
196
+ """Apply the PyTorch->Keras axis permutation implied by ``kind``."""
197
+ if kind == "conv_kernel":
198
+ # OIHW -> HWIO. Works for plain, grouped and depth-wise convolutions.
199
+ return np.transpose(arr, (2, 3, 1, 0))
200
+ if kind == "dense_kernel":
201
+ return np.transpose(arr, (1, 0))
202
+ return arr
203
+
204
+
205
+ def _assign(var, arr: np.ndarray) -> None:
206
+ """Assign ``arr`` into a Keras variable, matching its dtype."""
207
+ var.assign(arr.astype(var.dtype))
208
+
209
+
210
+ # --------------------------------------------------------------------------- #
211
+ # Order-based transfer (robust, recommended)
212
+ # --------------------------------------------------------------------------- #
213
+ def transfer_by_order(
214
+ keras_model,
215
+ torch_state: "Dict[str, np.ndarray]",
216
+ verbose: bool = True,
217
+ strict_shapes: bool = True,
218
+ ) -> Dict[str, int]:
219
+ """Positionally zip Keras variables against the Torch parameters.
220
+
221
+ Both sides are walked in build order: ``keras_model.weights`` yields
222
+ variables layer-by-layer (trainable then non-trainable within each layer),
223
+ which matches the ``weight, bias, running_mean, running_var`` grouping of a
224
+ PyTorch ``state_dict``. The correct transpose is chosen from each Keras
225
+ variable's kind; only element counts and post-transpose shapes need to line
226
+ up, both of which are validated.
227
+
228
+ Args:
229
+ keras_model: The target Keras model (already built).
230
+ torch_state: Ordered ``name -> ndarray`` mapping from
231
+ :func:`load_torch_state_dict`.
232
+ verbose: Print a short progress summary.
233
+ strict_shapes: Raise on any per-tensor shape mismatch. When ``False``,
234
+ mismatches are skipped and counted instead.
235
+
236
+ Returns:
237
+ ``{"transferred": int, "skipped": int, "total": int}``.
238
+
239
+ Raises:
240
+ ValueError: If the number of Keras variables and Torch tensors differ,
241
+ or (when ``strict_shapes``) on a shape mismatch.
242
+ """
243
+ keras_vars = [(v, _var_kind(v)) for v in keras_model.weights]
244
+ torch_items: List[Tuple[str, np.ndarray]] = list(torch_state.items())
245
+
246
+ if len(keras_vars) != len(torch_items):
247
+ raise ValueError(
248
+ "Parameter count mismatch between the Keras model and the Torch "
249
+ f"checkpoint: Keras has {len(keras_vars)} variables but the "
250
+ f"checkpoint has {len(torch_items)} tensors. This usually means the "
251
+ "architecture/config does not match the weights (wrong variant, "
252
+ "wrong `nc`, or a deploy/train mismatch)."
253
+ )
254
+
255
+ transferred = 0
256
+ skipped = 0
257
+ for idx, ((var, kind), (tk, arr)) in enumerate(zip(keras_vars, torch_items)):
258
+ converted = _transpose_for_kind(arr, kind)
259
+ if tuple(converted.shape) != tuple(var.shape):
260
+ message = (
261
+ f"[{idx}] shape mismatch: Keras '{var.path}' expects "
262
+ f"{tuple(var.shape)} (kind={kind}) but Torch '{tk}' is "
263
+ f"{tuple(arr.shape)} -> {tuple(converted.shape)} after transpose."
264
+ )
265
+ if strict_shapes:
266
+ raise ValueError(message)
267
+ if verbose:
268
+ print(f" skip: {message}")
269
+ skipped += 1
270
+ continue
271
+ _assign(var, converted)
272
+ transferred += 1
273
+
274
+ total = len(keras_vars)
275
+ if verbose:
276
+ print(
277
+ f"[order] transferred {transferred}/{total} variables"
278
+ + (f" ({skipped} skipped)" if skipped else "")
279
+ )
280
+ return {"transferred": transferred, "skipped": skipped, "total": total}
281
+
282
+
283
+ # --------------------------------------------------------------------------- #
284
+ # Name-based transfer (best-effort)
285
+ # --------------------------------------------------------------------------- #
286
+ def transfer_torch_to_keras(
287
+ keras_model,
288
+ torch_state: "Dict[str, np.ndarray]",
289
+ name_mapping: Optional[Dict[str, str]] = None,
290
+ verbose: bool = True,
291
+ strict: bool = False,
292
+ ) -> Dict[str, object]:
293
+ """Name-driven transfer: derive each Torch key from the Keras variable path.
294
+
295
+ For a variable whose path is ``model.0.conv/kernel`` this builds the
296
+ candidate key ``model.0.conv.weight`` (via the leaf-suffix rules
297
+ ``kernel->weight``, ``gamma->weight``, ``beta->bias``,
298
+ ``moving_mean->running_mean``, ``moving_variance->running_var``,
299
+ ``bias->bias``) and then applies every ``old -> new`` substring replacement
300
+ in ``name_mapping``, in order. If the resulting key exists in
301
+ ``torch_state`` the value is transposed to the Keras layout and assigned;
302
+ otherwise the variable is recorded as a miss.
303
+
304
+ This path is best-effort. The dotted :mod:`kyolo` naming mirrors the
305
+ reference modules, but individual checkpoints (especially detection heads)
306
+ can diverge and need a per-model ``name_mapping``. When in doubt use
307
+ :func:`transfer_by_order`.
308
+
309
+ Args:
310
+ keras_model: The target Keras model (already built).
311
+ torch_state: ``name -> ndarray`` mapping from
312
+ :func:`load_torch_state_dict`.
313
+ name_mapping: Ordered substring replacements applied to each derived
314
+ Torch key. Defaults to :data:`kyolo.conversion.mappings.DEFAULT_MAPPING`.
315
+ verbose: Print a short summary (and the first few misses).
316
+ strict: Raise on the first miss / shape mismatch instead of recording it.
317
+
318
+ Returns:
319
+ ``{"transferred", "skipped", "total", "misses"}`` where ``misses`` is a
320
+ list of ``(keras_path, torch_key, reason)`` tuples.
321
+ """
322
+ if name_mapping is None:
323
+ name_mapping = DEFAULT_MAPPING
324
+
325
+ transferred = 0
326
+ misses: List[Tuple[str, str, str]] = []
327
+ total = 0
328
+
329
+ for var in keras_model.weights:
330
+ total += 1
331
+ path = var.path
332
+ if "/" in path:
333
+ layer_path, leaf = path.rsplit("/", 1)
334
+ else:
335
+ layer_path, leaf = path, _leaf(path)
336
+ suffix = _SUFFIX_MAP.get(leaf, leaf)
337
+ torch_key = f"{layer_path}.{suffix}"
338
+ for old, new in name_mapping.items():
339
+ torch_key = torch_key.replace(old, new)
340
+
341
+ if torch_key not in torch_state:
342
+ if strict:
343
+ raise KeyError(
344
+ f"No Torch parameter matched Keras variable '{path}' "
345
+ f"(tried '{torch_key}'). Adjust the name_mapping or use "
346
+ "method='order'."
347
+ )
348
+ misses.append((path, torch_key, "missing"))
349
+ continue
350
+
351
+ arr = torch_state[torch_key]
352
+ converted = _transpose_for_kind(arr, _var_kind(var))
353
+ if tuple(converted.shape) != tuple(var.shape):
354
+ reason = f"shape {tuple(arr.shape)}->{tuple(converted.shape)} != {tuple(var.shape)}"
355
+ if strict:
356
+ raise ValueError(f"Shape mismatch for '{path}' <- '{torch_key}': {reason}.")
357
+ misses.append((path, torch_key, reason))
358
+ continue
359
+
360
+ _assign(var, converted)
361
+ transferred += 1
362
+
363
+ if verbose:
364
+ print(f"[name] transferred {transferred}/{total} variables ({len(misses)} misses)")
365
+ for path, torch_key, reason in misses[:10]:
366
+ print(f" miss: {path} <- {torch_key} ({reason})")
367
+ if len(misses) > 10:
368
+ print(f" ... and {len(misses) - 10} more")
369
+
370
+ return {
371
+ "transferred": transferred,
372
+ "skipped": len(misses),
373
+ "total": total,
374
+ "misses": misses,
375
+ }
376
+
377
+
378
+ # --------------------------------------------------------------------------- #
379
+ # High-level driver
380
+ # --------------------------------------------------------------------------- #
381
+ def convert_weights(
382
+ model,
383
+ torch_weights_path: str,
384
+ output_path: Optional[str] = None,
385
+ method: str = "order",
386
+ name_mapping: Optional[Dict[str, str]] = None,
387
+ verbose: bool = True,
388
+ ) -> Dict[str, object]:
389
+ """Load a ``.pt`` file, transfer it into ``model`` and save ``.weights.h5``.
390
+
391
+ Args:
392
+ model: A built Keras model to receive the weights.
393
+ torch_weights_path: Path to the source ``.pt`` checkpoint.
394
+ output_path: Where to write the Keras weights. Defaults to
395
+ ``<stem>.weights.h5`` next to the checkpoint's basename. Pass
396
+ ``False`` to skip saving.
397
+ method: ``"order"`` (recommended) or ``"name"``.
398
+ name_mapping: Optional substring mapping for the ``"name"`` method.
399
+ verbose: Print progress and a final summary.
400
+
401
+ Returns:
402
+ The transfer report from the chosen method.
403
+ """
404
+ torch_state = load_torch_state_dict(torch_weights_path)
405
+
406
+ if method == "order":
407
+ report = transfer_by_order(model, torch_state, verbose=verbose)
408
+ elif method == "name":
409
+ report = transfer_torch_to_keras(
410
+ model, torch_state, name_mapping=name_mapping, verbose=verbose
411
+ )
412
+ else:
413
+ raise ValueError(f"Unknown method {method!r}; expected 'order' or 'name'.")
414
+
415
+ if output_path is not False:
416
+ if output_path is None:
417
+ stem = os.path.splitext(os.path.basename(torch_weights_path))[0]
418
+ output_path = f"{stem}.weights.h5"
419
+ model.save_weights(output_path)
420
+ if verbose:
421
+ print(f"Saved Keras weights to: {output_path}")
422
+
423
+ if verbose:
424
+ print(
425
+ "Reminder: a clean transfer is necessary but not sufficient. "
426
+ "Validate the converted model's outputs against the PyTorch "
427
+ "reference before trusting it."
428
+ )
429
+ return report