jax-image-models 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_image_models-0.3.3.dist-info/METADATA +36 -0
- jax_image_models-0.3.3.dist-info/RECORD +25 -0
- jax_image_models-0.3.3.dist-info/WHEEL +4 -0
- jax_image_models-0.3.3.dist-info/licenses/LICENSE +21 -0
- jimm/__init__.py +33 -0
- jimm/common/autotuning.py +150 -0
- jimm/common/loading_utils.py +248 -0
- jimm/common/sharding.py +98 -0
- jimm/common/tokamax_attention.py +108 -0
- jimm/common/transformer.py +299 -0
- jimm/common/utils.py +157 -0
- jimm/common/vit.py +340 -0
- jimm/models/__init__.py +13 -0
- jimm/models/clip/__init__.py +3 -0
- jimm/models/clip/clip_model.py +635 -0
- jimm/models/clip/params.py +520 -0
- jimm/models/clip/sharding.py +57 -0
- jimm/models/siglip/__init__.py +3 -0
- jimm/models/siglip/params.py +608 -0
- jimm/models/siglip/sharding.py +56 -0
- jimm/models/siglip/siglip_model.py +614 -0
- jimm/models/vit/__init__.py +3 -0
- jimm/models/vit/params.py +272 -0
- jimm/models/vit/sharding.py +55 -0
- jimm/models/vit/vit_model.py +209 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jax-image-models
|
|
3
|
+
Version: 0.3.3
|
|
4
|
+
Summary: Jax Image Modeling of Models
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
|
+
Requires-Dist: flax>=0.10.6
|
|
8
|
+
Requires-Dist: jax>=0.6.2
|
|
9
|
+
Requires-Dist: jaxtyping>=0.3.2
|
|
10
|
+
Requires-Dist: safetensors>=0.5.3
|
|
11
|
+
Requires-Dist: tokamax>=0.0.12
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
|
|
14
|
+
# Jax Image Modeling of Models (jimm)
|
|
15
|
+
Docs are at: [https://pythoncrazy.github.io/jimm](https://pythoncrazy.github.io/jimm)
|
|
16
|
+
- This aims to be the jax counterpart to timm, with the exception that for image-text models (CLIP, SigLIP, etc), we support the text model entirely.
|
|
17
|
+
- Made with flax nnx, supports weight loading from pytorch_model.bin and safetensors (as well as both methods from huggingface).
|
|
18
|
+
|
|
19
|
+
Models Supported:
|
|
20
|
+
- Vision Transformers
|
|
21
|
+
- Both with a classification linear layer, or not
|
|
22
|
+
- Using a CLS Token for pooling, or using Multihead Attention Pooling
|
|
23
|
+
- Can load any standard variant of Vision Transformers of any size/resolution(e.g. "google/vit-base-patch16-224" or "google/vit-large-patch16-384")
|
|
24
|
+
- CLIP
|
|
25
|
+
- Can load from any checkpoints of the clip model on github (such as "openai/clip-vit-base-patch32" or "geolocal/StreetCLIP")
|
|
26
|
+
- SigLIP
|
|
27
|
+
- Can load any non-naflex version of the SigLIP model, from both siglipv1 and siglipv2 (eg "google/siglip-base-patch16-256" or "google/siglip2-large-patch16-512" from huggingface or locally)
|
|
28
|
+
## Installation
|
|
29
|
+
### Using pixi.sh:
|
|
30
|
+
`pixi add jimm@https://github.com/pythoncrazy/jimm.git --pypi`
|
|
31
|
+
### Using uv
|
|
32
|
+
`uv add --dev git+https://github.com/pythoncrazy/jimm.git`
|
|
33
|
+
or if you prefer to not add as a direct dependency:
|
|
34
|
+
`uv pip install git+https://github.com/pythoncrazy/jimm.git`
|
|
35
|
+
### Using pip/conda
|
|
36
|
+
`pip install git+https://github.com/pythoncrazy/jimm.git`
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
jimm/__init__.py,sha256=7rsfJf4ma5htZG7IaXCklJ52m9fELLZiLNyxqdtdkBk,849
|
|
2
|
+
jimm/common/autotuning.py,sha256=8E-S2gtHwSpJKZgR-BQlsosebJH_zn195HVeDhJClIE,4998
|
|
3
|
+
jimm/common/loading_utils.py,sha256=zxwytlhLD26dpWrnCjd4LAQmLtWeD9Z2GSDH-6qt4Rc,9398
|
|
4
|
+
jimm/common/sharding.py,sha256=_hXhuX4C6lneNHo28gK7CkBDxlvPujzlLlG65MzgVjs,4161
|
|
5
|
+
jimm/common/tokamax_attention.py,sha256=CSJs_s2XKniuq8__fjTUoPoprMsqweNCpUQmwr_whQE,4160
|
|
6
|
+
jimm/common/transformer.py,sha256=h-cpzg8Os-nU39izuh6bt3UQBwlsVK21DTh0BOie5m4,12194
|
|
7
|
+
jimm/common/utils.py,sha256=r6sbZVmeZbwj6opK6MTf61Y7w4d_3zmuXVFkWQoSXcE,6706
|
|
8
|
+
jimm/common/vit.py,sha256=_pZv4hXk67Xg0hwjkQnoQ_Q5RXwyGO1Hw7hokdfWLm4,14940
|
|
9
|
+
jimm/models/__init__.py,sha256=dkC9Ujsrr0qEjPNl2vaoTJEOOADy766ZOUEKVsCbrbQ,311
|
|
10
|
+
jimm/models/clip/__init__.py,sha256=Ayy0m9sNHurvQy3l-OlhWoc_NsGeCUN8Oxp_Jw7NHbg,117
|
|
11
|
+
jimm/models/clip/clip_model.py,sha256=K3oyN1yrisrOf1JpspiLPOD52j-ieeSTPub0sEX4EQs,27822
|
|
12
|
+
jimm/models/clip/params.py,sha256=mZOGjL0at_uv3dljiGbIumkSkXjf-J0zf958vTuwUKo,24504
|
|
13
|
+
jimm/models/clip/sharding.py,sha256=KIfUU_27dJTk75r1QZtfaBzmvd2E8udS-4gzp4w2jxc,3115
|
|
14
|
+
jimm/models/siglip/__init__.py,sha256=2-hoGa2kAYhsNC-n52m_5PTE35XYNGXuFAQSnMtQ2QE,131
|
|
15
|
+
jimm/models/siglip/params.py,sha256=sc4_x8332boIXNg_FpnfIQKkFnBUpmPMuIcycZudOJs,29485
|
|
16
|
+
jimm/models/siglip/sharding.py,sha256=Ck0HRAYz4gYN5XJws05lKxad04GfpylszaSM2KMN0zc,3029
|
|
17
|
+
jimm/models/siglip/siglip_model.py,sha256=k-yf92I18o28rkBBO3EhGdW-zNiGY39B7hVCQbhdQbQ,26495
|
|
18
|
+
jimm/models/vit/__init__.py,sha256=Hw1rwR3mN13X-0TYmq-H2cZsXUmK95326sYADF3VjWM,74
|
|
19
|
+
jimm/models/vit/params.py,sha256=IaP0Z4nzdXIXWm28P-x8ERJ6lNNrd8HbL7CXFxigYd0,12890
|
|
20
|
+
jimm/models/vit/sharding.py,sha256=8LLOvAOfY0770jF3GZth4JiWzuqMZfERfMqdJ7a-vsg,2963
|
|
21
|
+
jimm/models/vit/vit_model.py,sha256=WenSypHga0G_4kCyBUR5Gru5wDFsVNcEqst3rl6B1qo,9057
|
|
22
|
+
jax_image_models-0.3.3.dist-info/METADATA,sha256=PfaDFPWkrf1BZtIEJGDWd0Gtx2TnSNnGH7Vk5w_teoA,1762
|
|
23
|
+
jax_image_models-0.3.3.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
24
|
+
jax_image_models-0.3.3.dist-info/licenses/LICENSE,sha256=wYMBl9KN_WLSFwJ2lMHMchvv7IhGvIPl4alCEzjFw6I,1070
|
|
25
|
+
jax_image_models-0.3.3.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Pinak Paliwal
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
jimm/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from .common.autotuning import AutotuningResult
|
|
2
|
+
from .common.autotuning import autotune
|
|
3
|
+
from .common.autotuning import autotuned_fn
|
|
4
|
+
from .common.autotuning import cached_autotune
|
|
5
|
+
from .common.autotuning import load_autotune_result
|
|
6
|
+
from .common.tokamax_attention import make_tokamax_attention
|
|
7
|
+
from .common.tokamax_attention import tokamax_attention_fn as tokamax_attention
|
|
8
|
+
from .models import (
|
|
9
|
+
CLIP,
|
|
10
|
+
CLIPTextModel,
|
|
11
|
+
CLIPVisionModel,
|
|
12
|
+
SigLIP,
|
|
13
|
+
SigLIPTextModel,
|
|
14
|
+
SigLIPVisionModel,
|
|
15
|
+
VisionTransformer,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"tokamax_attention",
|
|
20
|
+
"make_tokamax_attention",
|
|
21
|
+
"autotune",
|
|
22
|
+
"cached_autotune",
|
|
23
|
+
"autotuned_fn",
|
|
24
|
+
"load_autotune_result",
|
|
25
|
+
"AutotuningResult",
|
|
26
|
+
"VisionTransformer",
|
|
27
|
+
"CLIP",
|
|
28
|
+
"CLIPTextModel",
|
|
29
|
+
"CLIPVisionModel",
|
|
30
|
+
"SigLIP",
|
|
31
|
+
"SigLIPTextModel",
|
|
32
|
+
"SigLIPVisionModel",
|
|
33
|
+
]
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import functools
|
|
5
|
+
import hashlib
|
|
6
|
+
import os
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from jax.extend import backend as _jax_backend
|
|
11
|
+
from tokamax._src.autotuning import api as _autotune_api
|
|
12
|
+
|
|
13
|
+
AutotuningResult = _autotune_api.AutotuningResult
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _lower(jitted_fn: Any, *args: Any) -> Any:
|
|
17
|
+
"""Unwrap a Flax ``nnx.jit`` ``Lowered`` to the raw ``jax.stages.Lowered``.
|
|
18
|
+
|
|
19
|
+
``nnx.jit`` wraps the JAX lowered object; tokamax's HLO utils require the
|
|
20
|
+
raw type, accessible via the ``.lowered`` attribute.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
jitted_fn (Any): An ``nnx.jit`` or ``jax.jit`` callable.
|
|
24
|
+
*args (Any): Sample arguments for lowering.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Any: Raw ``jax.stages.Lowered`` computation.
|
|
28
|
+
"""
|
|
29
|
+
lowered = jitted_fn.lower(*args)
|
|
30
|
+
return getattr(lowered, "lowered", lowered)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def autotune(
|
|
34
|
+
jitted_fn: Any,
|
|
35
|
+
*sample_args: Any,
|
|
36
|
+
save_path: str | os.PathLike | None = None,
|
|
37
|
+
**kwargs: Any,
|
|
38
|
+
) -> AutotuningResult:
|
|
39
|
+
"""Autotune all tokamax ops in a jitted function and optionally save results.
|
|
40
|
+
|
|
41
|
+
Lowers ``jitted_fn`` with ``sample_args`` to extract op shapes, then
|
|
42
|
+
microbenchmarks every kernel config. Use the result as a context manager::
|
|
43
|
+
|
|
44
|
+
result = jimm.autotune(forward, model, image, text)
|
|
45
|
+
with result:
|
|
46
|
+
out = forward(model, image, text)
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
jitted_fn (Any): Jitted callable containing tokamax ops.
|
|
50
|
+
*sample_args (Any): Representative inputs; shapes/dtypes must match
|
|
51
|
+
production inputs.
|
|
52
|
+
save_path (str | os.PathLike | None): Serialize result as JSON here.
|
|
53
|
+
**kwargs (Any): Forwarded to ``tokamax.autotune``
|
|
54
|
+
(e.g. ``all_implementations``, ``progress_bar``).
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
AutotuningResult: Best config for each op found in ``jitted_fn``.
|
|
58
|
+
"""
|
|
59
|
+
result = _autotune_api.autotune(_lower(jitted_fn, *sample_args), **kwargs)
|
|
60
|
+
if save_path is not None:
|
|
61
|
+
with open(save_path, "w") as f:
|
|
62
|
+
result.dump(f)
|
|
63
|
+
return result
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def load_autotune_result(path: str | os.PathLike) -> AutotuningResult:
|
|
67
|
+
"""Load an :class:`AutotuningResult` from a JSON file.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
path (str | os.PathLike): Path written by :func:`autotune` or
|
|
71
|
+
:func:`cached_autotune`.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
AutotuningResult: Deserialized result.
|
|
75
|
+
"""
|
|
76
|
+
with open(path) as f:
|
|
77
|
+
return AutotuningResult.load(f)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def cached_autotune(
|
|
81
|
+
jitted_fn: Any,
|
|
82
|
+
*sample_args: Any,
|
|
83
|
+
cache_dir: str | os.PathLike,
|
|
84
|
+
**kwargs: Any,
|
|
85
|
+
) -> AutotuningResult | None:
|
|
86
|
+
"""Load autotuning results from disk, or run autotune and cache them.
|
|
87
|
+
|
|
88
|
+
The cache key is an MD5 hash of op autotuning keys + device kind, so the
|
|
89
|
+
same entry is reused whenever model architecture and hardware are unchanged.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
jitted_fn (Any): Jitted callable containing tokamax ops.
|
|
93
|
+
*sample_args (Any): Representative inputs used to extract op shapes.
|
|
94
|
+
cache_dir (str | os.PathLike): Directory for JSON cache files.
|
|
95
|
+
**kwargs (Any): Forwarded to ``tokamax.autotune``.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
AutotuningResult | None: Loaded or freshly computed result, or ``None``
|
|
99
|
+
if no tunable ops exist.
|
|
100
|
+
"""
|
|
101
|
+
bound_args = _autotune_api.get_bound_args(_lower(jitted_fn, *sample_args))
|
|
102
|
+
if not bound_args:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
device_kind = _jax_backend.get_default_device().device_kind
|
|
106
|
+
key_parts = sorted(str(ba.autotuning_cache_key) for ba in bound_args)
|
|
107
|
+
cache_key = hashlib.md5((device_kind + ":" + "|".join(key_parts)).encode()).hexdigest()[:16]
|
|
108
|
+
|
|
109
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
110
|
+
cache_path = os.path.join(cache_dir, f"{cache_key}.json")
|
|
111
|
+
|
|
112
|
+
if os.path.exists(cache_path):
|
|
113
|
+
return load_autotune_result(cache_path)
|
|
114
|
+
|
|
115
|
+
result = _autotune_api.autotune(list(bound_args), **kwargs)
|
|
116
|
+
with open(cache_path, "w") as f:
|
|
117
|
+
result.dump(f)
|
|
118
|
+
return result
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def autotuned_fn(
|
|
122
|
+
jitted_fn: Any,
|
|
123
|
+
*sample_args: Any,
|
|
124
|
+
cache_dir: str | os.PathLike,
|
|
125
|
+
**kwargs: Any,
|
|
126
|
+
) -> Callable[..., Any]:
|
|
127
|
+
"""Wrap a jitted function to always run with tuned tokamax configs.
|
|
128
|
+
|
|
129
|
+
Eagerly calls :func:`cached_autotune`, then returns a wrapper that injects
|
|
130
|
+
the tuned configs via :class:`AutotuningResult` on every call.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
jitted_fn (Any): Jitted callable containing tokamax ops.
|
|
134
|
+
*sample_args (Any): Representative inputs used for autotuning.
|
|
135
|
+
cache_dir (str | os.PathLike): Directory for the autotune cache.
|
|
136
|
+
**kwargs (Any): Forwarded to ``tokamax.autotune``.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Callable[..., Any]: Wrapped callable that applies tuned configs on
|
|
140
|
+
every invocation.
|
|
141
|
+
"""
|
|
142
|
+
result = cached_autotune(jitted_fn, *sample_args, cache_dir=cache_dir, **kwargs)
|
|
143
|
+
ctx = result if result is not None else contextlib.nullcontext()
|
|
144
|
+
|
|
145
|
+
@functools.wraps(jitted_fn)
|
|
146
|
+
def wrapper(*args: Any, **kw: Any) -> Any:
|
|
147
|
+
with ctx:
|
|
148
|
+
return jitted_fn(*args, **kw)
|
|
149
|
+
|
|
150
|
+
return wrapper
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from math import prod
|
|
5
|
+
from typing import Any, Dict, Tuple, TypeVar
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
from flax import nnx
|
|
10
|
+
from huggingface_hub import hf_hub_download
|
|
11
|
+
from jaxtyping import Array, DTypeLike
|
|
12
|
+
from safetensors.flax import load_file as load_safetensors_flax_file
|
|
13
|
+
|
|
14
|
+
_M = TypeVar("_M", bound=nnx.Module)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_params_and_config(
|
|
18
|
+
model_name_or_path: str,
|
|
19
|
+
use_pytorch: bool = False,
|
|
20
|
+
default_config_filename: str = "config.json",
|
|
21
|
+
default_pytorch_filename: str = "pytorch_model.bin",
|
|
22
|
+
default_safetensors_filename: str = "model.safetensors",
|
|
23
|
+
) -> Tuple[Dict[str, Array], Dict[str, Any]]:
|
|
24
|
+
"""Load model parameters and configuration from local directory or HuggingFace Hub.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model_name_or_path (str): Local directory path or HuggingFace model ID.
|
|
28
|
+
use_pytorch (bool): Whether to load from PyTorch weights. Defaults to False.
|
|
29
|
+
default_config_filename (str): Config filename. Defaults to "config.json".
|
|
30
|
+
default_pytorch_filename (str): PyTorch weights filename. Defaults to "pytorch_model.bin".
|
|
31
|
+
default_safetensors_filename (str): Safetensors filename. Defaults to "model.safetensors".
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Tuple[Dict[str, Array], Dict[str, Any]]: Loaded parameters and configuration.
|
|
35
|
+
"""
|
|
36
|
+
if os.path.isdir(model_name_or_path):
|
|
37
|
+
config_file_path = os.path.join(model_name_or_path, default_config_filename)
|
|
38
|
+
weights_filename = default_pytorch_filename if use_pytorch else default_safetensors_filename
|
|
39
|
+
weights_file_path = os.path.join(model_name_or_path, weights_filename)
|
|
40
|
+
else:
|
|
41
|
+
config_file_path = hf_hub_download(repo_id=model_name_or_path, filename=default_config_filename)
|
|
42
|
+
weights_filename = default_pytorch_filename if use_pytorch else default_safetensors_filename
|
|
43
|
+
weights_file_path = hf_hub_download(repo_id=model_name_or_path, filename=weights_filename)
|
|
44
|
+
|
|
45
|
+
with open(config_file_path, "r") as f:
|
|
46
|
+
config = json.load(f)
|
|
47
|
+
|
|
48
|
+
if use_pytorch:
|
|
49
|
+
import torch
|
|
50
|
+
|
|
51
|
+
state_dict = torch.load(weights_file_path, map_location="cpu")
|
|
52
|
+
params_fstate = {k: jnp.array(v.numpy()) for k, v in state_dict.items()}
|
|
53
|
+
else:
|
|
54
|
+
params_fstate = load_safetensors_flax_file(weights_file_path)
|
|
55
|
+
|
|
56
|
+
return params_fstate, config
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def stoi(k: str) -> int | str:
|
|
60
|
+
"""Convert a string to int if numeric, else return as-is.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
k (str): Key to convert.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
int | str: Integer if numeric string, else original string.
|
|
67
|
+
"""
|
|
68
|
+
try:
|
|
69
|
+
return int(k)
|
|
70
|
+
except ValueError:
|
|
71
|
+
return k
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def map_to_bonsai_key(
|
|
75
|
+
mapping: dict[str, tuple[str, Any]],
|
|
76
|
+
key: str,
|
|
77
|
+
) -> tuple[str | None, Any]:
|
|
78
|
+
"""Match a flat HF key against regex patterns and return the flax target key.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
mapping (dict[str, tuple[str, Any]]): Dict of {regex_pattern: (flax_key_template, transform)}.
|
|
82
|
+
key (str): HuggingFace parameter key to match.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
tuple[str | None, Any]: (flax_key, transform) if matched, else (None, None).
|
|
86
|
+
"""
|
|
87
|
+
for pattern, (target, transform) in mapping.items():
|
|
88
|
+
if re.fullmatch(pattern, key):
|
|
89
|
+
return re.sub(pattern, target, key), transform
|
|
90
|
+
return None, None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def to_scan_batched_keys(keys: tuple) -> tuple[tuple | None, int | None]:
|
|
94
|
+
"""Convert a per-layer flat state key to its scan-batched equivalent.
|
|
95
|
+
|
|
96
|
+
With nnx.scan/vmap, transformer layers are stored as a single batched module
|
|
97
|
+
under a "layers" key rather than separate "layers_N" keys. This converts
|
|
98
|
+
keys like ("encoder", "layers_3", "attn", "query", "kernel") to
|
|
99
|
+
("encoder", "layers", "attn", "query", "kernel") and returns the layer index.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
keys (tuple): Flat state key tuple potentially containing a "layers_N" component.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
tuple[tuple | None, int | None]: (batched_keys, layer_idx) if a layers_N component
|
|
106
|
+
is found, else (None, None).
|
|
107
|
+
"""
|
|
108
|
+
new_keys = list(keys)
|
|
109
|
+
layer_idx = None
|
|
110
|
+
for i, k in enumerate(new_keys):
|
|
111
|
+
if isinstance(k, str) and k.startswith("layers_"):
|
|
112
|
+
try:
|
|
113
|
+
layer_idx = int(k[7:])
|
|
114
|
+
new_keys[i] = "layers"
|
|
115
|
+
break
|
|
116
|
+
except ValueError:
|
|
117
|
+
pass
|
|
118
|
+
if layer_idx is None:
|
|
119
|
+
return None, None
|
|
120
|
+
return tuple(new_keys), layer_idx
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _reshape_if_compatible(tensor: Array, target_shape: tuple[int, ...], hf_key: str, bonsai_keys: tuple[Any, ...]) -> Array:
|
|
124
|
+
"""Reshape tensor only when element counts match, else raise a clear error."""
|
|
125
|
+
if tensor.shape == target_shape:
|
|
126
|
+
return tensor
|
|
127
|
+
|
|
128
|
+
if tensor.size != prod(target_shape):
|
|
129
|
+
bonsai_key = ".".join(str(key) for key in bonsai_keys)
|
|
130
|
+
raise ValueError(f"Shape mismatch for {hf_key} -> {bonsai_key}: got {tensor.shape}, expected {target_shape}")
|
|
131
|
+
|
|
132
|
+
return tensor.reshape(target_shape)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def apply_mapping(
|
|
136
|
+
model: _M,
|
|
137
|
+
params_fstate: dict[str, Any],
|
|
138
|
+
mapping: dict[str, tuple[str, Any]],
|
|
139
|
+
param_dtype: DTypeLike,
|
|
140
|
+
) -> _M:
|
|
141
|
+
"""Apply regex-based HF parameter mappings to a model in-place."""
|
|
142
|
+
flat_state = dict(nnx.to_flat_state(nnx.state(model, nnx.Param)))
|
|
143
|
+
layer_accum: dict[tuple, dict[int, Any]] = {}
|
|
144
|
+
|
|
145
|
+
for hf_key, tensor in params_fstate.items():
|
|
146
|
+
bonsai_key, transform = map_to_bonsai_key(mapping, hf_key)
|
|
147
|
+
if bonsai_key is None:
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
keys = tuple(stoi(k) for k in bonsai_key.split("."))
|
|
151
|
+
permute_rule, _, _ = transform.value
|
|
152
|
+
transformed = tensor.astype(param_dtype)
|
|
153
|
+
if permute_rule is not None:
|
|
154
|
+
transformed = jnp.transpose(transformed, permute_rule)
|
|
155
|
+
|
|
156
|
+
if keys in flat_state:
|
|
157
|
+
var = flat_state[keys]
|
|
158
|
+
var[...] = _reshape_if_compatible(transformed, var[...].shape, hf_key, keys)
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
batched_keys, layer_idx = to_scan_batched_keys(keys)
|
|
162
|
+
if batched_keys is not None and layer_idx is not None and batched_keys in flat_state:
|
|
163
|
+
layer_accum.setdefault(batched_keys, {})[layer_idx] = transformed
|
|
164
|
+
|
|
165
|
+
for batched_keys, layers_dict in layer_accum.items():
|
|
166
|
+
var = flat_state[batched_keys]
|
|
167
|
+
num_layers = var[...].shape[0]
|
|
168
|
+
missing = sorted(set(range(num_layers)) - set(layers_dict))
|
|
169
|
+
if missing:
|
|
170
|
+
bonsai_key = ".".join(str(key) for key in batched_keys)
|
|
171
|
+
raise ValueError(f"Missing scanned layers for {bonsai_key}: {missing}")
|
|
172
|
+
|
|
173
|
+
stacked = jnp.stack([layers_dict[i] for i in range(num_layers)], axis=0)
|
|
174
|
+
var[...] = _reshape_if_compatible(stacked, var[...].shape, ".".join(str(key) for key in batched_keys), batched_keys)
|
|
175
|
+
|
|
176
|
+
nnx.update(model, nnx.from_flat_state(flat_state))
|
|
177
|
+
return model
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _slice_layer(d: dict, idx: int) -> dict:
|
|
181
|
+
"""Recursively extract index idx from batched arrays in a nested dict."""
|
|
182
|
+
result = {}
|
|
183
|
+
for key, value in d.items():
|
|
184
|
+
if isinstance(value, dict):
|
|
185
|
+
result[key] = _slice_layer(value, idx)
|
|
186
|
+
elif isinstance(value, jax.Array):
|
|
187
|
+
result[key] = value[idx]
|
|
188
|
+
else:
|
|
189
|
+
result[key] = value
|
|
190
|
+
return result
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _infer_num_layers(d: dict) -> int | None:
|
|
194
|
+
"""Return the leading dimension of the first JAX array found in d."""
|
|
195
|
+
for value in d.values():
|
|
196
|
+
if isinstance(value, jax.Array):
|
|
197
|
+
return int(value.shape[0])
|
|
198
|
+
if isinstance(value, dict):
|
|
199
|
+
n = _infer_num_layers(value)
|
|
200
|
+
if n is not None:
|
|
201
|
+
return n
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _is_numeric_key(key: Any) -> bool:
|
|
206
|
+
"""Return True when a nested state-dict key is an integer layer index."""
|
|
207
|
+
return isinstance(key, int) or (isinstance(key, str) and key.isdigit())
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _should_expand_layers_dict(d: dict) -> bool:
|
|
211
|
+
"""Return True only for scan-batched layer dicts.
|
|
212
|
+
|
|
213
|
+
`nnx.scan` stores stacked transformer blocks under a single ``layers`` key
|
|
214
|
+
whose children are named module fields like ``attn`` and ``mlp``. In
|
|
215
|
+
contrast, ``nnx.Sequential`` also uses a ``layers`` key, but its children
|
|
216
|
+
are numeric submodule indices. Those sequential containers must not be
|
|
217
|
+
expanded.
|
|
218
|
+
"""
|
|
219
|
+
if not d or all(_is_numeric_key(key) for key in d):
|
|
220
|
+
return False
|
|
221
|
+
return _infer_num_layers(d) is not None
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def expand_scanned_layers(state_dict: dict) -> dict:
|
|
225
|
+
"""Expand scan-batched layer parameters into per-layer entries for HF saving.
|
|
226
|
+
|
|
227
|
+
With nnx.scan/vmap, all transformer layer parameters are stored as a single
|
|
228
|
+
batched module under a "layers" key with a leading layer dimension. This
|
|
229
|
+
expands them into separate "layers_N" sub-dicts compatible with the HF
|
|
230
|
+
format conversion utilities.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
state_dict (dict): Model state dict potentially containing a "layers" key
|
|
234
|
+
with batched parameters.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
dict: State dict with "layers" expanded into "layers_0", ..., "layers_(N-1)".
|
|
238
|
+
"""
|
|
239
|
+
result = {}
|
|
240
|
+
for key, value in state_dict.items():
|
|
241
|
+
if key == "layers" and isinstance(value, dict) and _should_expand_layers_dict(value):
|
|
242
|
+
num_layers = _infer_num_layers(value)
|
|
243
|
+
if num_layers is not None:
|
|
244
|
+
for i in range(num_layers):
|
|
245
|
+
result[f"layers_{i}"] = _slice_layer(value, i)
|
|
246
|
+
continue
|
|
247
|
+
result[key] = expand_scanned_layers(value) if isinstance(value, dict) else value
|
|
248
|
+
return result
|
jimm/common/sharding.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import Any, Protocol
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
from jax.sharding import NamedSharding
|
|
6
|
+
from jax.sharding import PartitionSpec as P
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ShardingSpec(Protocol):
|
|
10
|
+
"""Protocol defining the sharding specification for model parameters.
|
|
11
|
+
|
|
12
|
+
Specs represent per-layer (non-stacked) shapes. The Transformer stacks
|
|
13
|
+
layers via nnx.vmap and patches the stacked Variable metadata to prepend
|
|
14
|
+
None for the scan axis, so the optimizer sees the correct 4-D spec.
|
|
15
|
+
|
|
16
|
+
attn_qkv_kernel (in_features, num_heads, head_dim)
|
|
17
|
+
attn_qkv_bias (num_heads, head_dim)
|
|
18
|
+
attn_out_kernel (num_heads, head_dim, out_features)
|
|
19
|
+
attn_out_bias (out_features,)
|
|
20
|
+
mlp_up_kernel (in_features, intermediate_size)
|
|
21
|
+
mlp_up_bias (intermediate_size,)
|
|
22
|
+
mlp_down_kernel (intermediate_size, out_features)
|
|
23
|
+
mlp_down_bias (out_features,)
|
|
24
|
+
layernorm (hidden_size,)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
attn_qkv_kernel: tuple[str | None, str | None, str | None]
|
|
28
|
+
attn_qkv_bias: tuple[str | None, str | None]
|
|
29
|
+
attn_out_kernel: tuple[str | None, str | None, str | None]
|
|
30
|
+
attn_out_bias: tuple[str | None]
|
|
31
|
+
mlp_up_kernel: tuple[str | None, str | None]
|
|
32
|
+
mlp_up_bias: tuple[str | None]
|
|
33
|
+
mlp_down_kernel: tuple[str | None, str | None]
|
|
34
|
+
mlp_down_bias: tuple[str | None]
|
|
35
|
+
layernorm: tuple[str | None]
|
|
36
|
+
patch_conv_kernel: tuple[str | None, str | None, str | None, str | None]
|
|
37
|
+
patch_conv_bias: tuple[str | None]
|
|
38
|
+
embed: tuple[str | None, str | None]
|
|
39
|
+
pos_embed_3d: tuple[str | None, str | None, str | None]
|
|
40
|
+
pos_embed_2d: tuple[str | None, str | None]
|
|
41
|
+
vision_pos_id: tuple[str | None, str | None]
|
|
42
|
+
text_pos_embed: tuple[str | None, str | None]
|
|
43
|
+
cls_token: tuple[str | None, str | None, str | None]
|
|
44
|
+
probe_token: tuple[str | None, str | None, str | None]
|
|
45
|
+
proj_kernel: tuple[str | None, str | None]
|
|
46
|
+
proj_bias: tuple[str | None]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclasses.dataclass(frozen=True)
|
|
50
|
+
class NoSharding:
|
|
51
|
+
"""No sharding - all parameters replicated."""
|
|
52
|
+
|
|
53
|
+
attn_qkv_kernel: tuple[str | None, str | None, str | None] = (None, None, None)
|
|
54
|
+
attn_qkv_bias: tuple[str | None, str | None] = (None, None)
|
|
55
|
+
attn_out_kernel: tuple[str | None, str | None, str | None] = (None, None, None)
|
|
56
|
+
attn_out_bias: tuple[str | None] = (None,)
|
|
57
|
+
mlp_up_kernel: tuple[str | None, str | None] = (None, None)
|
|
58
|
+
mlp_up_bias: tuple[str | None] = (None,)
|
|
59
|
+
mlp_down_kernel: tuple[str | None, str | None] = (None, None)
|
|
60
|
+
mlp_down_bias: tuple[str | None] = (None,)
|
|
61
|
+
layernorm: tuple[str | None] = (None,)
|
|
62
|
+
patch_conv_kernel: tuple[str | None, str | None, str | None, str | None] = (None, None, None, None)
|
|
63
|
+
patch_conv_bias: tuple[str | None] = (None,)
|
|
64
|
+
embed: tuple[str | None, str | None] = (None, None)
|
|
65
|
+
pos_embed_3d: tuple[str | None, str | None, str | None] = (None, None, None)
|
|
66
|
+
pos_embed_2d: tuple[str | None, str | None] = (None, None)
|
|
67
|
+
vision_pos_id: tuple[str | None, str | None] = (None, None)
|
|
68
|
+
text_pos_embed: tuple[str | None, str | None] = (None, None)
|
|
69
|
+
cls_token: tuple[str | None, str | None, str | None] = (None, None, None)
|
|
70
|
+
probe_token: tuple[str | None, str | None, str | None] = (None, None, None)
|
|
71
|
+
proj_kernel: tuple[str | None, str | None] = (None, None)
|
|
72
|
+
proj_bias: tuple[str | None] = (None,)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def sharding_of(value: Any) -> NamedSharding:
|
|
76
|
+
"""Returns the traced NamedSharding for a value in explicit mode."""
|
|
77
|
+
|
|
78
|
+
return jax.typeof(value).sharding
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def named_sharding_like(reference: Any, spec: P | tuple[str | None, ...]) -> NamedSharding:
|
|
82
|
+
"""Builds a NamedSharding on the same mesh as the reference value."""
|
|
83
|
+
|
|
84
|
+
if not isinstance(spec, P):
|
|
85
|
+
spec = P(*spec)
|
|
86
|
+
return NamedSharding(sharding_of(reference).mesh, spec)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def replicated_sharding(reference: Any) -> NamedSharding:
|
|
90
|
+
"""Builds a replicated sharding matching the reference mesh and rank."""
|
|
91
|
+
|
|
92
|
+
return named_sharding_like(reference, P(*([None] * reference.ndim)))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def reshard_like(value: Any, reference: Any) -> Any:
|
|
96
|
+
"""Reshards a value to match the sharding of a reference value."""
|
|
97
|
+
|
|
98
|
+
return jax.sharding.reshard(value, sharding_of(reference))
|