jax-image-models 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ Metadata-Version: 2.4
2
+ Name: jax-image-models
3
+ Version: 0.3.3
4
+ Summary: Jax Image Modeling of Models
5
+ License-File: LICENSE
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: flax>=0.10.6
8
+ Requires-Dist: jax>=0.6.2
9
+ Requires-Dist: jaxtyping>=0.3.2
10
+ Requires-Dist: safetensors>=0.5.3
11
+ Requires-Dist: tokamax>=0.0.12
12
+ Description-Content-Type: text/markdown
13
+
14
+ # Jax Image Modeling of Models (jimm)
15
+ Docs are at: [https://pythoncrazy.github.io/jimm](https://pythoncrazy.github.io/jimm)
16
+ - This aims to be the jax counterpart to timm, with the exception that for image-text models (CLIP, SigLIP, etc), we support the text model entirely.
17
+ - Made with flax nnx, supports weight loading from pytorch_model.bin and safetensors (as well as both methods from huggingface).
18
+
19
+ Models Supported:
20
+ - Vision Transformers
21
+ - Both with a classification linear layer, or not
22
+ - Using a CLS Token for pooling, or using Multihead Attention Pooling
23
+ - Can load any standard variant of Vision Transformers of any size/resolution(e.g. "google/vit-base-patch16-224" or "google/vit-large-patch16-384")
24
+ - CLIP
25
+ - Can load from any checkpoints of the clip model on github (such as "openai/clip-vit-base-patch32" or "geolocal/StreetCLIP")
26
+ - SigLIP
27
+ - Can load any non-naflex version of the SigLIP model, from both siglipv1 and siglipv2 (eg "google/siglip-base-patch16-256" or "google/siglip2-large-patch16-512" from huggingface or locally)
28
+ ## Installation
29
+ ### Using pixi.sh:
30
+ `pixi add jimm@https://github.com/pythoncrazy/jimm.git --pypi`
31
+ ### Using uv
32
+ `uv add --dev git+https://github.com/pythoncrazy/jimm.git`
33
+ or if you prefer to not add as a direct dependency:
34
+ `uv pip install git+https://github.com/pythoncrazy/jimm.git`
35
+ ### Using pip/conda
36
+ `pip install git+https://github.com/pythoncrazy/jimm.git`
@@ -0,0 +1,25 @@
1
+ jimm/__init__.py,sha256=7rsfJf4ma5htZG7IaXCklJ52m9fELLZiLNyxqdtdkBk,849
2
+ jimm/common/autotuning.py,sha256=8E-S2gtHwSpJKZgR-BQlsosebJH_zn195HVeDhJClIE,4998
3
+ jimm/common/loading_utils.py,sha256=zxwytlhLD26dpWrnCjd4LAQmLtWeD9Z2GSDH-6qt4Rc,9398
4
+ jimm/common/sharding.py,sha256=_hXhuX4C6lneNHo28gK7CkBDxlvPujzlLlG65MzgVjs,4161
5
+ jimm/common/tokamax_attention.py,sha256=CSJs_s2XKniuq8__fjTUoPoprMsqweNCpUQmwr_whQE,4160
6
+ jimm/common/transformer.py,sha256=h-cpzg8Os-nU39izuh6bt3UQBwlsVK21DTh0BOie5m4,12194
7
+ jimm/common/utils.py,sha256=r6sbZVmeZbwj6opK6MTf61Y7w4d_3zmuXVFkWQoSXcE,6706
8
+ jimm/common/vit.py,sha256=_pZv4hXk67Xg0hwjkQnoQ_Q5RXwyGO1Hw7hokdfWLm4,14940
9
+ jimm/models/__init__.py,sha256=dkC9Ujsrr0qEjPNl2vaoTJEOOADy766ZOUEKVsCbrbQ,311
10
+ jimm/models/clip/__init__.py,sha256=Ayy0m9sNHurvQy3l-OlhWoc_NsGeCUN8Oxp_Jw7NHbg,117
11
+ jimm/models/clip/clip_model.py,sha256=K3oyN1yrisrOf1JpspiLPOD52j-ieeSTPub0sEX4EQs,27822
12
+ jimm/models/clip/params.py,sha256=mZOGjL0at_uv3dljiGbIumkSkXjf-J0zf958vTuwUKo,24504
13
+ jimm/models/clip/sharding.py,sha256=KIfUU_27dJTk75r1QZtfaBzmvd2E8udS-4gzp4w2jxc,3115
14
+ jimm/models/siglip/__init__.py,sha256=2-hoGa2kAYhsNC-n52m_5PTE35XYNGXuFAQSnMtQ2QE,131
15
+ jimm/models/siglip/params.py,sha256=sc4_x8332boIXNg_FpnfIQKkFnBUpmPMuIcycZudOJs,29485
16
+ jimm/models/siglip/sharding.py,sha256=Ck0HRAYz4gYN5XJws05lKxad04GfpylszaSM2KMN0zc,3029
17
+ jimm/models/siglip/siglip_model.py,sha256=k-yf92I18o28rkBBO3EhGdW-zNiGY39B7hVCQbhdQbQ,26495
18
+ jimm/models/vit/__init__.py,sha256=Hw1rwR3mN13X-0TYmq-H2cZsXUmK95326sYADF3VjWM,74
19
+ jimm/models/vit/params.py,sha256=IaP0Z4nzdXIXWm28P-x8ERJ6lNNrd8HbL7CXFxigYd0,12890
20
+ jimm/models/vit/sharding.py,sha256=8LLOvAOfY0770jF3GZth4JiWzuqMZfERfMqdJ7a-vsg,2963
21
+ jimm/models/vit/vit_model.py,sha256=WenSypHga0G_4kCyBUR5Gru5wDFsVNcEqst3rl6B1qo,9057
22
+ jax_image_models-0.3.3.dist-info/METADATA,sha256=PfaDFPWkrf1BZtIEJGDWd0Gtx2TnSNnGH7Vk5w_teoA,1762
23
+ jax_image_models-0.3.3.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
24
+ jax_image_models-0.3.3.dist-info/licenses/LICENSE,sha256=wYMBl9KN_WLSFwJ2lMHMchvv7IhGvIPl4alCEzjFw6I,1070
25
+ jax_image_models-0.3.3.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.29.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Pinak Paliwal
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
jimm/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ from .common.autotuning import AutotuningResult
2
+ from .common.autotuning import autotune
3
+ from .common.autotuning import autotuned_fn
4
+ from .common.autotuning import cached_autotune
5
+ from .common.autotuning import load_autotune_result
6
+ from .common.tokamax_attention import make_tokamax_attention
7
+ from .common.tokamax_attention import tokamax_attention_fn as tokamax_attention
8
+ from .models import (
9
+ CLIP,
10
+ CLIPTextModel,
11
+ CLIPVisionModel,
12
+ SigLIP,
13
+ SigLIPTextModel,
14
+ SigLIPVisionModel,
15
+ VisionTransformer,
16
+ )
17
+
18
+ __all__ = [
19
+ "tokamax_attention",
20
+ "make_tokamax_attention",
21
+ "autotune",
22
+ "cached_autotune",
23
+ "autotuned_fn",
24
+ "load_autotune_result",
25
+ "AutotuningResult",
26
+ "VisionTransformer",
27
+ "CLIP",
28
+ "CLIPTextModel",
29
+ "CLIPVisionModel",
30
+ "SigLIP",
31
+ "SigLIPTextModel",
32
+ "SigLIPVisionModel",
33
+ ]
@@ -0,0 +1,150 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import functools
5
+ import hashlib
6
+ import os
7
+ from collections.abc import Callable
8
+ from typing import Any
9
+
10
+ from jax.extend import backend as _jax_backend
11
+ from tokamax._src.autotuning import api as _autotune_api
12
+
13
+ AutotuningResult = _autotune_api.AutotuningResult
14
+
15
+
16
+ def _lower(jitted_fn: Any, *args: Any) -> Any:
17
+ """Unwrap a Flax ``nnx.jit`` ``Lowered`` to the raw ``jax.stages.Lowered``.
18
+
19
+ ``nnx.jit`` wraps the JAX lowered object; tokamax's HLO utils require the
20
+ raw type, accessible via the ``.lowered`` attribute.
21
+
22
+ Args:
23
+ jitted_fn (Any): An ``nnx.jit`` or ``jax.jit`` callable.
24
+ *args (Any): Sample arguments for lowering.
25
+
26
+ Returns:
27
+ Any: Raw ``jax.stages.Lowered`` computation.
28
+ """
29
+ lowered = jitted_fn.lower(*args)
30
+ return getattr(lowered, "lowered", lowered)
31
+
32
+
33
+ def autotune(
34
+ jitted_fn: Any,
35
+ *sample_args: Any,
36
+ save_path: str | os.PathLike | None = None,
37
+ **kwargs: Any,
38
+ ) -> AutotuningResult:
39
+ """Autotune all tokamax ops in a jitted function and optionally save results.
40
+
41
+ Lowers ``jitted_fn`` with ``sample_args`` to extract op shapes, then
42
+ microbenchmarks every kernel config. Use the result as a context manager::
43
+
44
+ result = jimm.autotune(forward, model, image, text)
45
+ with result:
46
+ out = forward(model, image, text)
47
+
48
+ Args:
49
+ jitted_fn (Any): Jitted callable containing tokamax ops.
50
+ *sample_args (Any): Representative inputs; shapes/dtypes must match
51
+ production inputs.
52
+ save_path (str | os.PathLike | None): Serialize result as JSON here.
53
+ **kwargs (Any): Forwarded to ``tokamax.autotune``
54
+ (e.g. ``all_implementations``, ``progress_bar``).
55
+
56
+ Returns:
57
+ AutotuningResult: Best config for each op found in ``jitted_fn``.
58
+ """
59
+ result = _autotune_api.autotune(_lower(jitted_fn, *sample_args), **kwargs)
60
+ if save_path is not None:
61
+ with open(save_path, "w") as f:
62
+ result.dump(f)
63
+ return result
64
+
65
+
66
+ def load_autotune_result(path: str | os.PathLike) -> AutotuningResult:
67
+ """Load an :class:`AutotuningResult` from a JSON file.
68
+
69
+ Args:
70
+ path (str | os.PathLike): Path written by :func:`autotune` or
71
+ :func:`cached_autotune`.
72
+
73
+ Returns:
74
+ AutotuningResult: Deserialized result.
75
+ """
76
+ with open(path) as f:
77
+ return AutotuningResult.load(f)
78
+
79
+
80
+ def cached_autotune(
81
+ jitted_fn: Any,
82
+ *sample_args: Any,
83
+ cache_dir: str | os.PathLike,
84
+ **kwargs: Any,
85
+ ) -> AutotuningResult | None:
86
+ """Load autotuning results from disk, or run autotune and cache them.
87
+
88
+ The cache key is an MD5 hash of op autotuning keys + device kind, so the
89
+ same entry is reused whenever model architecture and hardware are unchanged.
90
+
91
+ Args:
92
+ jitted_fn (Any): Jitted callable containing tokamax ops.
93
+ *sample_args (Any): Representative inputs used to extract op shapes.
94
+ cache_dir (str | os.PathLike): Directory for JSON cache files.
95
+ **kwargs (Any): Forwarded to ``tokamax.autotune``.
96
+
97
+ Returns:
98
+ AutotuningResult | None: Loaded or freshly computed result, or ``None``
99
+ if no tunable ops exist.
100
+ """
101
+ bound_args = _autotune_api.get_bound_args(_lower(jitted_fn, *sample_args))
102
+ if not bound_args:
103
+ return None
104
+
105
+ device_kind = _jax_backend.get_default_device().device_kind
106
+ key_parts = sorted(str(ba.autotuning_cache_key) for ba in bound_args)
107
+ cache_key = hashlib.md5((device_kind + ":" + "|".join(key_parts)).encode()).hexdigest()[:16]
108
+
109
+ os.makedirs(cache_dir, exist_ok=True)
110
+ cache_path = os.path.join(cache_dir, f"{cache_key}.json")
111
+
112
+ if os.path.exists(cache_path):
113
+ return load_autotune_result(cache_path)
114
+
115
+ result = _autotune_api.autotune(list(bound_args), **kwargs)
116
+ with open(cache_path, "w") as f:
117
+ result.dump(f)
118
+ return result
119
+
120
+
121
+ def autotuned_fn(
122
+ jitted_fn: Any,
123
+ *sample_args: Any,
124
+ cache_dir: str | os.PathLike,
125
+ **kwargs: Any,
126
+ ) -> Callable[..., Any]:
127
+ """Wrap a jitted function to always run with tuned tokamax configs.
128
+
129
+ Eagerly calls :func:`cached_autotune`, then returns a wrapper that injects
130
+ the tuned configs via :class:`AutotuningResult` on every call.
131
+
132
+ Args:
133
+ jitted_fn (Any): Jitted callable containing tokamax ops.
134
+ *sample_args (Any): Representative inputs used for autotuning.
135
+ cache_dir (str | os.PathLike): Directory for the autotune cache.
136
+ **kwargs (Any): Forwarded to ``tokamax.autotune``.
137
+
138
+ Returns:
139
+ Callable[..., Any]: Wrapped callable that applies tuned configs on
140
+ every invocation.
141
+ """
142
+ result = cached_autotune(jitted_fn, *sample_args, cache_dir=cache_dir, **kwargs)
143
+ ctx = result if result is not None else contextlib.nullcontext()
144
+
145
+ @functools.wraps(jitted_fn)
146
+ def wrapper(*args: Any, **kw: Any) -> Any:
147
+ with ctx:
148
+ return jitted_fn(*args, **kw)
149
+
150
+ return wrapper
@@ -0,0 +1,248 @@
1
+ import json
2
+ import os
3
+ import re
4
+ from math import prod
5
+ from typing import Any, Dict, Tuple, TypeVar
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from flax import nnx
10
+ from huggingface_hub import hf_hub_download
11
+ from jaxtyping import Array, DTypeLike
12
+ from safetensors.flax import load_file as load_safetensors_flax_file
13
+
14
+ _M = TypeVar("_M", bound=nnx.Module)
15
+
16
+
17
+ def load_params_and_config(
18
+ model_name_or_path: str,
19
+ use_pytorch: bool = False,
20
+ default_config_filename: str = "config.json",
21
+ default_pytorch_filename: str = "pytorch_model.bin",
22
+ default_safetensors_filename: str = "model.safetensors",
23
+ ) -> Tuple[Dict[str, Array], Dict[str, Any]]:
24
+ """Load model parameters and configuration from local directory or HuggingFace Hub.
25
+
26
+ Args:
27
+ model_name_or_path (str): Local directory path or HuggingFace model ID.
28
+ use_pytorch (bool): Whether to load from PyTorch weights. Defaults to False.
29
+ default_config_filename (str): Config filename. Defaults to "config.json".
30
+ default_pytorch_filename (str): PyTorch weights filename. Defaults to "pytorch_model.bin".
31
+ default_safetensors_filename (str): Safetensors filename. Defaults to "model.safetensors".
32
+
33
+ Returns:
34
+ Tuple[Dict[str, Array], Dict[str, Any]]: Loaded parameters and configuration.
35
+ """
36
+ if os.path.isdir(model_name_or_path):
37
+ config_file_path = os.path.join(model_name_or_path, default_config_filename)
38
+ weights_filename = default_pytorch_filename if use_pytorch else default_safetensors_filename
39
+ weights_file_path = os.path.join(model_name_or_path, weights_filename)
40
+ else:
41
+ config_file_path = hf_hub_download(repo_id=model_name_or_path, filename=default_config_filename)
42
+ weights_filename = default_pytorch_filename if use_pytorch else default_safetensors_filename
43
+ weights_file_path = hf_hub_download(repo_id=model_name_or_path, filename=weights_filename)
44
+
45
+ with open(config_file_path, "r") as f:
46
+ config = json.load(f)
47
+
48
+ if use_pytorch:
49
+ import torch
50
+
51
+ state_dict = torch.load(weights_file_path, map_location="cpu")
52
+ params_fstate = {k: jnp.array(v.numpy()) for k, v in state_dict.items()}
53
+ else:
54
+ params_fstate = load_safetensors_flax_file(weights_file_path)
55
+
56
+ return params_fstate, config
57
+
58
+
59
+ def stoi(k: str) -> int | str:
60
+ """Convert a string to int if numeric, else return as-is.
61
+
62
+ Args:
63
+ k (str): Key to convert.
64
+
65
+ Returns:
66
+ int | str: Integer if numeric string, else original string.
67
+ """
68
+ try:
69
+ return int(k)
70
+ except ValueError:
71
+ return k
72
+
73
+
74
+ def map_to_bonsai_key(
75
+ mapping: dict[str, tuple[str, Any]],
76
+ key: str,
77
+ ) -> tuple[str | None, Any]:
78
+ """Match a flat HF key against regex patterns and return the flax target key.
79
+
80
+ Args:
81
+ mapping (dict[str, tuple[str, Any]]): Dict of {regex_pattern: (flax_key_template, transform)}.
82
+ key (str): HuggingFace parameter key to match.
83
+
84
+ Returns:
85
+ tuple[str | None, Any]: (flax_key, transform) if matched, else (None, None).
86
+ """
87
+ for pattern, (target, transform) in mapping.items():
88
+ if re.fullmatch(pattern, key):
89
+ return re.sub(pattern, target, key), transform
90
+ return None, None
91
+
92
+
93
+ def to_scan_batched_keys(keys: tuple) -> tuple[tuple | None, int | None]:
94
+ """Convert a per-layer flat state key to its scan-batched equivalent.
95
+
96
+ With nnx.scan/vmap, transformer layers are stored as a single batched module
97
+ under a "layers" key rather than separate "layers_N" keys. This converts
98
+ keys like ("encoder", "layers_3", "attn", "query", "kernel") to
99
+ ("encoder", "layers", "attn", "query", "kernel") and returns the layer index.
100
+
101
+ Args:
102
+ keys (tuple): Flat state key tuple potentially containing a "layers_N" component.
103
+
104
+ Returns:
105
+ tuple[tuple | None, int | None]: (batched_keys, layer_idx) if a layers_N component
106
+ is found, else (None, None).
107
+ """
108
+ new_keys = list(keys)
109
+ layer_idx = None
110
+ for i, k in enumerate(new_keys):
111
+ if isinstance(k, str) and k.startswith("layers_"):
112
+ try:
113
+ layer_idx = int(k[7:])
114
+ new_keys[i] = "layers"
115
+ break
116
+ except ValueError:
117
+ pass
118
+ if layer_idx is None:
119
+ return None, None
120
+ return tuple(new_keys), layer_idx
121
+
122
+
123
+ def _reshape_if_compatible(tensor: Array, target_shape: tuple[int, ...], hf_key: str, bonsai_keys: tuple[Any, ...]) -> Array:
124
+ """Reshape tensor only when element counts match, else raise a clear error."""
125
+ if tensor.shape == target_shape:
126
+ return tensor
127
+
128
+ if tensor.size != prod(target_shape):
129
+ bonsai_key = ".".join(str(key) for key in bonsai_keys)
130
+ raise ValueError(f"Shape mismatch for {hf_key} -> {bonsai_key}: got {tensor.shape}, expected {target_shape}")
131
+
132
+ return tensor.reshape(target_shape)
133
+
134
+
135
+ def apply_mapping(
136
+ model: _M,
137
+ params_fstate: dict[str, Any],
138
+ mapping: dict[str, tuple[str, Any]],
139
+ param_dtype: DTypeLike,
140
+ ) -> _M:
141
+ """Apply regex-based HF parameter mappings to a model in-place."""
142
+ flat_state = dict(nnx.to_flat_state(nnx.state(model, nnx.Param)))
143
+ layer_accum: dict[tuple, dict[int, Any]] = {}
144
+
145
+ for hf_key, tensor in params_fstate.items():
146
+ bonsai_key, transform = map_to_bonsai_key(mapping, hf_key)
147
+ if bonsai_key is None:
148
+ continue
149
+
150
+ keys = tuple(stoi(k) for k in bonsai_key.split("."))
151
+ permute_rule, _, _ = transform.value
152
+ transformed = tensor.astype(param_dtype)
153
+ if permute_rule is not None:
154
+ transformed = jnp.transpose(transformed, permute_rule)
155
+
156
+ if keys in flat_state:
157
+ var = flat_state[keys]
158
+ var[...] = _reshape_if_compatible(transformed, var[...].shape, hf_key, keys)
159
+ continue
160
+
161
+ batched_keys, layer_idx = to_scan_batched_keys(keys)
162
+ if batched_keys is not None and layer_idx is not None and batched_keys in flat_state:
163
+ layer_accum.setdefault(batched_keys, {})[layer_idx] = transformed
164
+
165
+ for batched_keys, layers_dict in layer_accum.items():
166
+ var = flat_state[batched_keys]
167
+ num_layers = var[...].shape[0]
168
+ missing = sorted(set(range(num_layers)) - set(layers_dict))
169
+ if missing:
170
+ bonsai_key = ".".join(str(key) for key in batched_keys)
171
+ raise ValueError(f"Missing scanned layers for {bonsai_key}: {missing}")
172
+
173
+ stacked = jnp.stack([layers_dict[i] for i in range(num_layers)], axis=0)
174
+ var[...] = _reshape_if_compatible(stacked, var[...].shape, ".".join(str(key) for key in batched_keys), batched_keys)
175
+
176
+ nnx.update(model, nnx.from_flat_state(flat_state))
177
+ return model
178
+
179
+
180
+ def _slice_layer(d: dict, idx: int) -> dict:
181
+ """Recursively extract index idx from batched arrays in a nested dict."""
182
+ result = {}
183
+ for key, value in d.items():
184
+ if isinstance(value, dict):
185
+ result[key] = _slice_layer(value, idx)
186
+ elif isinstance(value, jax.Array):
187
+ result[key] = value[idx]
188
+ else:
189
+ result[key] = value
190
+ return result
191
+
192
+
193
+ def _infer_num_layers(d: dict) -> int | None:
194
+ """Return the leading dimension of the first JAX array found in d."""
195
+ for value in d.values():
196
+ if isinstance(value, jax.Array):
197
+ return int(value.shape[0])
198
+ if isinstance(value, dict):
199
+ n = _infer_num_layers(value)
200
+ if n is not None:
201
+ return n
202
+ return None
203
+
204
+
205
+ def _is_numeric_key(key: Any) -> bool:
206
+ """Return True when a nested state-dict key is an integer layer index."""
207
+ return isinstance(key, int) or (isinstance(key, str) and key.isdigit())
208
+
209
+
210
+ def _should_expand_layers_dict(d: dict) -> bool:
211
+ """Return True only for scan-batched layer dicts.
212
+
213
+ `nnx.scan` stores stacked transformer blocks under a single ``layers`` key
214
+ whose children are named module fields like ``attn`` and ``mlp``. In
215
+ contrast, ``nnx.Sequential`` also uses a ``layers`` key, but its children
216
+ are numeric submodule indices. Those sequential containers must not be
217
+ expanded.
218
+ """
219
+ if not d or all(_is_numeric_key(key) for key in d):
220
+ return False
221
+ return _infer_num_layers(d) is not None
222
+
223
+
224
+ def expand_scanned_layers(state_dict: dict) -> dict:
225
+ """Expand scan-batched layer parameters into per-layer entries for HF saving.
226
+
227
+ With nnx.scan/vmap, all transformer layer parameters are stored as a single
228
+ batched module under a "layers" key with a leading layer dimension. This
229
+ expands them into separate "layers_N" sub-dicts compatible with the HF
230
+ format conversion utilities.
231
+
232
+ Args:
233
+ state_dict (dict): Model state dict potentially containing a "layers" key
234
+ with batched parameters.
235
+
236
+ Returns:
237
+ dict: State dict with "layers" expanded into "layers_0", ..., "layers_(N-1)".
238
+ """
239
+ result = {}
240
+ for key, value in state_dict.items():
241
+ if key == "layers" and isinstance(value, dict) and _should_expand_layers_dict(value):
242
+ num_layers = _infer_num_layers(value)
243
+ if num_layers is not None:
244
+ for i in range(num_layers):
245
+ result[f"layers_{i}"] = _slice_layer(value, i)
246
+ continue
247
+ result[key] = expand_scanned_layers(value) if isinstance(value, dict) else value
248
+ return result
@@ -0,0 +1,98 @@
1
+ import dataclasses
2
+ from typing import Any, Protocol
3
+
4
+ import jax
5
+ from jax.sharding import NamedSharding
6
+ from jax.sharding import PartitionSpec as P
7
+
8
+
9
+ class ShardingSpec(Protocol):
10
+ """Protocol defining the sharding specification for model parameters.
11
+
12
+ Specs represent per-layer (non-stacked) shapes. The Transformer stacks
13
+ layers via nnx.vmap and patches the stacked Variable metadata to prepend
14
+ None for the scan axis, so the optimizer sees the correct 4-D spec.
15
+
16
+ attn_qkv_kernel (in_features, num_heads, head_dim)
17
+ attn_qkv_bias (num_heads, head_dim)
18
+ attn_out_kernel (num_heads, head_dim, out_features)
19
+ attn_out_bias (out_features,)
20
+ mlp_up_kernel (in_features, intermediate_size)
21
+ mlp_up_bias (intermediate_size,)
22
+ mlp_down_kernel (intermediate_size, out_features)
23
+ mlp_down_bias (out_features,)
24
+ layernorm (hidden_size,)
25
+ """
26
+
27
+ attn_qkv_kernel: tuple[str | None, str | None, str | None]
28
+ attn_qkv_bias: tuple[str | None, str | None]
29
+ attn_out_kernel: tuple[str | None, str | None, str | None]
30
+ attn_out_bias: tuple[str | None]
31
+ mlp_up_kernel: tuple[str | None, str | None]
32
+ mlp_up_bias: tuple[str | None]
33
+ mlp_down_kernel: tuple[str | None, str | None]
34
+ mlp_down_bias: tuple[str | None]
35
+ layernorm: tuple[str | None]
36
+ patch_conv_kernel: tuple[str | None, str | None, str | None, str | None]
37
+ patch_conv_bias: tuple[str | None]
38
+ embed: tuple[str | None, str | None]
39
+ pos_embed_3d: tuple[str | None, str | None, str | None]
40
+ pos_embed_2d: tuple[str | None, str | None]
41
+ vision_pos_id: tuple[str | None, str | None]
42
+ text_pos_embed: tuple[str | None, str | None]
43
+ cls_token: tuple[str | None, str | None, str | None]
44
+ probe_token: tuple[str | None, str | None, str | None]
45
+ proj_kernel: tuple[str | None, str | None]
46
+ proj_bias: tuple[str | None]
47
+
48
+
49
+ @dataclasses.dataclass(frozen=True)
50
+ class NoSharding:
51
+ """No sharding - all parameters replicated."""
52
+
53
+ attn_qkv_kernel: tuple[str | None, str | None, str | None] = (None, None, None)
54
+ attn_qkv_bias: tuple[str | None, str | None] = (None, None)
55
+ attn_out_kernel: tuple[str | None, str | None, str | None] = (None, None, None)
56
+ attn_out_bias: tuple[str | None] = (None,)
57
+ mlp_up_kernel: tuple[str | None, str | None] = (None, None)
58
+ mlp_up_bias: tuple[str | None] = (None,)
59
+ mlp_down_kernel: tuple[str | None, str | None] = (None, None)
60
+ mlp_down_bias: tuple[str | None] = (None,)
61
+ layernorm: tuple[str | None] = (None,)
62
+ patch_conv_kernel: tuple[str | None, str | None, str | None, str | None] = (None, None, None, None)
63
+ patch_conv_bias: tuple[str | None] = (None,)
64
+ embed: tuple[str | None, str | None] = (None, None)
65
+ pos_embed_3d: tuple[str | None, str | None, str | None] = (None, None, None)
66
+ pos_embed_2d: tuple[str | None, str | None] = (None, None)
67
+ vision_pos_id: tuple[str | None, str | None] = (None, None)
68
+ text_pos_embed: tuple[str | None, str | None] = (None, None)
69
+ cls_token: tuple[str | None, str | None, str | None] = (None, None, None)
70
+ probe_token: tuple[str | None, str | None, str | None] = (None, None, None)
71
+ proj_kernel: tuple[str | None, str | None] = (None, None)
72
+ proj_bias: tuple[str | None] = (None,)
73
+
74
+
75
+ def sharding_of(value: Any) -> NamedSharding:
76
+ """Returns the traced NamedSharding for a value in explicit mode."""
77
+
78
+ return jax.typeof(value).sharding
79
+
80
+
81
+ def named_sharding_like(reference: Any, spec: P | tuple[str | None, ...]) -> NamedSharding:
82
+ """Builds a NamedSharding on the same mesh as the reference value."""
83
+
84
+ if not isinstance(spec, P):
85
+ spec = P(*spec)
86
+ return NamedSharding(sharding_of(reference).mesh, spec)
87
+
88
+
89
+ def replicated_sharding(reference: Any) -> NamedSharding:
90
+ """Builds a replicated sharding matching the reference mesh and rank."""
91
+
92
+ return named_sharding_like(reference, P(*([None] * reference.ndim)))
93
+
94
+
95
+ def reshard_like(value: Any, reference: Any) -> Any:
96
+ """Reshards a value to match the sharding of a reference value."""
97
+
98
+ return jax.sharding.reshard(value, sharding_of(reference))