waxjax 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- waxjax-0.1.0/MANIFEST.in +2 -0
- waxjax-0.1.0/PKG-INFO +55 -0
- waxjax-0.1.0/README.md +17 -0
- waxjax-0.1.0/pyproject.toml +3 -0
- waxjax-0.1.0/safejax/__init__.py +9 -0
- waxjax-0.1.0/setup.cfg +4 -0
- waxjax-0.1.0/setup.py +34 -0
- waxjax-0.1.0/tests/test_api.py +72 -0
- waxjax-0.1.0/tests/test_loader.py +65 -0
- waxjax-0.1.0/tests/test_mapper.py +88 -0
- waxjax-0.1.0/tests/test_nllb.py +93 -0
- waxjax-0.1.0/tests/test_qwen2.py +83 -0
- waxjax-0.1.0/tests/test_rules.py +77 -0
- waxjax-0.1.0/waxjax/__init__.py +16 -0
- waxjax-0.1.0/waxjax/architectures/__init__.py +2 -0
- waxjax-0.1.0/waxjax/architectures/_registry.py +1 -0
- waxjax-0.1.0/waxjax/architectures/nllb.py +1 -0
- waxjax-0.1.0/waxjax/architectures/qwen2.py +1 -0
- waxjax-0.1.0/waxjax/core/__init__.py +3 -0
- waxjax-0.1.0/waxjax/core/api.py +1 -0
- waxjax-0.1.0/waxjax/core/loader.py +1 -0
- waxjax-0.1.0/waxjax/core/mapper.py +1 -0
- waxjax-0.1.0/waxjax/core/nester.py +1 -0
- waxjax-0.1.0/waxjax/core/rules.py +1 -0
- waxjax-0.1.0/waxjax.egg-info/PKG-INFO +55 -0
- waxjax-0.1.0/waxjax.egg-info/SOURCES.txt +27 -0
- waxjax-0.1.0/waxjax.egg-info/dependency_links.txt +1 -0
- waxjax-0.1.0/waxjax.egg-info/requires.txt +19 -0
- waxjax-0.1.0/waxjax.egg-info/top_level.txt +2 -0
waxjax-0.1.0/MANIFEST.in
ADDED
waxjax-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: waxjax
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Load HuggingFace SafeTensors models to JAX/Flax
|
|
5
|
+
Home-page: https://github.com/ik12-b/waxjax
|
|
6
|
+
Author: Your Name
|
|
7
|
+
License: Apache-2.0
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Python: >=3.9
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
Requires-Dist: safetensors>=0.4.0
|
|
14
|
+
Requires-Dist: numpy>=1.24
|
|
15
|
+
Requires-Dist: jax>=0.4.0
|
|
16
|
+
Provides-Extra: torch
|
|
17
|
+
Requires-Dist: torch>=2.0; extra == "torch"
|
|
18
|
+
Provides-Extra: flax
|
|
19
|
+
Requires-Dist: flax>=0.7; extra == "flax"
|
|
20
|
+
Provides-Extra: all
|
|
21
|
+
Requires-Dist: torch>=2.0; extra == "all"
|
|
22
|
+
Requires-Dist: flax>=0.7; extra == "all"
|
|
23
|
+
Requires-Dist: ml_dtypes; extra == "all"
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest; extra == "dev"
|
|
26
|
+
Requires-Dist: black; extra == "dev"
|
|
27
|
+
Requires-Dist: ruff; extra == "dev"
|
|
28
|
+
Dynamic: author
|
|
29
|
+
Dynamic: classifier
|
|
30
|
+
Dynamic: description
|
|
31
|
+
Dynamic: description-content-type
|
|
32
|
+
Dynamic: home-page
|
|
33
|
+
Dynamic: license
|
|
34
|
+
Dynamic: provides-extra
|
|
35
|
+
Dynamic: requires-dist
|
|
36
|
+
Dynamic: requires-python
|
|
37
|
+
Dynamic: summary
|
|
38
|
+
|
|
39
|
+
# waxjax
|
|
40
|
+
|
|
41
|
+
Lightweight helper to load HuggingFace SafeTensors checkpoints into JAX/Flax.
|
|
42
|
+
|
|
43
|
+
Install (latest from source):
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
pip install git+https://github.com/ik12-b/waxjax.git
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Build and publish (recommended to test to TestPyPI first): see `RELEASE.md` or run locally:
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
python -m pip install --upgrade build twine
|
|
53
|
+
python -m build
|
|
54
|
+
twine upload --repository testpypi dist/*
|
|
55
|
+
```
|
waxjax-0.1.0/README.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# waxjax
|
|
2
|
+
|
|
3
|
+
Lightweight helper to load HuggingFace SafeTensors checkpoints into JAX/Flax.
|
|
4
|
+
|
|
5
|
+
Install (latest from source):
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install git+https://github.com/ik12-b/waxjax.git
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Build and publish (recommended to test to TestPyPI first): see `RELEASE.md` or run locally:
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
python -m pip install --upgrade build twine
|
|
15
|
+
python -m build
|
|
16
|
+
twine upload --repository testpypi dist/*
|
|
17
|
+
```
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .core.api import load, save, inspect, verify
|
|
2
|
+
from .core.rules import MappingRule
|
|
3
|
+
from .architectures._registry import register_architecture, ArchitectureConfig
|
|
4
|
+
|
|
5
|
+
__version__ = "0.1.0"
|
|
6
|
+
__all__ = [
|
|
7
|
+
"load", "save", "inspect", "verify",
|
|
8
|
+
"MappingRule", "register_architecture", "ArchitectureConfig",
|
|
9
|
+
]
|
waxjax-0.1.0/setup.cfg
ADDED
waxjax-0.1.0/setup.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from setuptools import setup, find_packages
|
|
3
|
+
|
|
4
|
+
here = Path(__file__).parent
|
|
5
|
+
long_description = (here / "README.md").read_text(encoding="utf-8") if (here / "README.md").exists() else ""
|
|
6
|
+
|
|
7
|
+
setup(
|
|
8
|
+
name="waxjax",
|
|
9
|
+
version="0.1.0",
|
|
10
|
+
description="Load HuggingFace SafeTensors models to JAX/Flax",
|
|
11
|
+
long_description=long_description,
|
|
12
|
+
long_description_content_type="text/markdown",
|
|
13
|
+
author="Your Name",
|
|
14
|
+
license="Apache-2.0",
|
|
15
|
+
url="https://github.com/ik12-b/waxjax",
|
|
16
|
+
packages=find_packages(),
|
|
17
|
+
python_requires=">=3.9",
|
|
18
|
+
install_requires=[
|
|
19
|
+
"safetensors>=0.4.0",
|
|
20
|
+
"numpy>=1.24",
|
|
21
|
+
"jax>=0.4.0",
|
|
22
|
+
],
|
|
23
|
+
extras_require={
|
|
24
|
+
"torch": ["torch>=2.0"],
|
|
25
|
+
"flax": ["flax>=0.7"],
|
|
26
|
+
"all": ["torch>=2.0", "flax>=0.7", "ml_dtypes"],
|
|
27
|
+
"dev": ["pytest", "black", "ruff"],
|
|
28
|
+
},
|
|
29
|
+
classifiers=[
|
|
30
|
+
"Programming Language :: Python :: 3",
|
|
31
|
+
"License :: OSI Approved :: Apache Software License",
|
|
32
|
+
"Operating System :: OS Independent",
|
|
33
|
+
],
|
|
34
|
+
)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# tests/test_api.py
|
|
2
|
+
"""Test public API: inspect, verify, unsupported model."""
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import json
|
|
6
|
+
import waxjax
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestInspect:
|
|
10
|
+
|
|
11
|
+
def test_inspect_runs_without_error(self, qwen2_dir, capsys):
|
|
12
|
+
waxjax.inspect(qwen2_dir)
|
|
13
|
+
out = capsys.readouterr().out
|
|
14
|
+
assert "qwen2" in out
|
|
15
|
+
assert "✓" in out
|
|
16
|
+
|
|
17
|
+
def test_inspect_shows_unsupported(self, tmp_path, capsys):
|
|
18
|
+
(tmp_path / "config.json").write_text(
|
|
19
|
+
json.dumps({"model_type": "unknown_model_xyz"})
|
|
20
|
+
)
|
|
21
|
+
# Buat dummy safetensors
|
|
22
|
+
import torch
|
|
23
|
+
from safetensors.torch import save_file
|
|
24
|
+
save_file({"w": torch.randn(2,2)}, str(tmp_path / "model.safetensors"))
|
|
25
|
+
|
|
26
|
+
waxjax.inspect(str(tmp_path))
|
|
27
|
+
out = capsys.readouterr().out
|
|
28
|
+
assert "✗" in out or "belum didukung" in out
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TestUnsupportedModel:
|
|
32
|
+
|
|
33
|
+
def test_load_unknown_raises_valueerror(self, tmp_path):
|
|
34
|
+
(tmp_path / "config.json").write_text(
|
|
35
|
+
json.dumps({"model_type": "unknown_xyz"})
|
|
36
|
+
)
|
|
37
|
+
import torch
|
|
38
|
+
from safetensors.torch import save_file
|
|
39
|
+
save_file({"w": torch.randn(2,2)}, str(tmp_path / "model.safetensors"))
|
|
40
|
+
|
|
41
|
+
with pytest.raises(ValueError, match="belum didukung"):
|
|
42
|
+
waxjax.load(str(tmp_path))
|
|
43
|
+
|
|
44
|
+
def test_load_with_custom_rules_bypasses_registry(self, tmp_path):
|
|
45
|
+
"""Custom rules bisa dipakai untuk model yang belum terdaftar."""
|
|
46
|
+
from waxjax.core.rules import MappingRule
|
|
47
|
+
import torch
|
|
48
|
+
from safetensors.torch import save_file
|
|
49
|
+
|
|
50
|
+
(tmp_path / "config.json").write_text(
|
|
51
|
+
json.dumps({"model_type": "custom_unknown"})
|
|
52
|
+
)
|
|
53
|
+
save_file({"fc.weight": torch.randn(4, 8)},
|
|
54
|
+
str(tmp_path / "model.safetensors"))
|
|
55
|
+
|
|
56
|
+
rules = [
|
|
57
|
+
MappingRule("fc.weight",
|
|
58
|
+
rename=lambda k: k.replace(".weight", ".kernel"),
|
|
59
|
+
transform="transpose_2d")
|
|
60
|
+
]
|
|
61
|
+
params = waxjax.load(str(tmp_path), rules=rules)
|
|
62
|
+
assert "fc" in params
|
|
63
|
+
assert "kernel" in params["fc"]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TestVerify:
|
|
67
|
+
|
|
68
|
+
def test_verify_passes_for_correct_conversion(self, qwen2_dir):
|
|
69
|
+
params = waxjax.load(qwen2_dir)
|
|
70
|
+
report = waxjax.verify(params, qwen2_dir, sample_n=5)
|
|
71
|
+
assert report["passed"] == report["total"]
|
|
72
|
+
assert report["max_diff"] < 1e-5
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# tests/test_loader.py
|
|
2
|
+
"""Test loader: baca safetensors, handle bfloat16, sharding."""
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import numpy as np
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from waxjax.core.loader import load_safetensors
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestLoader:
|
|
11
|
+
|
|
12
|
+
def test_load_basic(self, qwen2_dir):
|
|
13
|
+
weights = load_safetensors(qwen2_dir)
|
|
14
|
+
assert isinstance(weights, dict)
|
|
15
|
+
assert len(weights) > 0
|
|
16
|
+
|
|
17
|
+
def test_all_float32(self, qwen2_dir):
|
|
18
|
+
"""Semua tensor harus float32 setelah load."""
|
|
19
|
+
weights = load_safetensors(qwen2_dir)
|
|
20
|
+
for key, arr in weights.items():
|
|
21
|
+
assert arr.dtype == np.float32, f"{key} bukan float32: {arr.dtype}"
|
|
22
|
+
|
|
23
|
+
def test_expected_keys_present_qwen2(self, qwen2_dir):
|
|
24
|
+
weights = load_safetensors(qwen2_dir)
|
|
25
|
+
assert "model.embed_tokens.weight" in weights
|
|
26
|
+
assert "model.layers.0.self_attn.q_proj.weight" in weights
|
|
27
|
+
assert "model.layers.0.self_attn.q_proj.bias" in weights
|
|
28
|
+
assert "model.norm.weight" in weights
|
|
29
|
+
|
|
30
|
+
def test_expected_keys_present_nllb(self, nllb_dir):
|
|
31
|
+
weights = load_safetensors(nllb_dir)
|
|
32
|
+
assert "model.shared.weight" in weights
|
|
33
|
+
assert "model.encoder.layers.0.self_attn.q_proj.weight" in weights
|
|
34
|
+
assert "model.decoder.layers.0.encoder_attn.q_proj.weight" in weights
|
|
35
|
+
assert "model.encoder.embed_positions.weight" in weights
|
|
36
|
+
|
|
37
|
+
def test_missing_dir_raises(self, tmp_path):
|
|
38
|
+
with pytest.raises(FileNotFoundError):
|
|
39
|
+
load_safetensors(str(tmp_path / "nonexistent"))
|
|
40
|
+
|
|
41
|
+
def test_bfloat16_handling(self, tmp_path):
|
|
42
|
+
"""Tensor bfloat16 harus di-cast ke float32 tanpa error."""
|
|
43
|
+
import torch
|
|
44
|
+
from safetensors.torch import save_file
|
|
45
|
+
|
|
46
|
+
bf16_tensor = torch.randn(4, 8).to(torch.bfloat16)
|
|
47
|
+
save_file({"layer.weight": bf16_tensor}, str(tmp_path / "model.safetensors"))
|
|
48
|
+
|
|
49
|
+
weights = load_safetensors(str(tmp_path))
|
|
50
|
+
assert weights["layer.weight"].dtype == np.float32
|
|
51
|
+
|
|
52
|
+
def test_sharded_loading(self, tmp_path):
|
|
53
|
+
"""Dua file safetensors di-merge dengan benar."""
|
|
54
|
+
import torch
|
|
55
|
+
from safetensors.torch import save_file
|
|
56
|
+
|
|
57
|
+
save_file({"shard1.weight": torch.randn(4, 4)},
|
|
58
|
+
str(tmp_path / "model-00001-of-00002.safetensors"))
|
|
59
|
+
save_file({"shard2.weight": torch.randn(4, 4)},
|
|
60
|
+
str(tmp_path / "model-00002-of-00002.safetensors"))
|
|
61
|
+
|
|
62
|
+
weights = load_safetensors(str(tmp_path))
|
|
63
|
+
assert "shard1.weight" in weights
|
|
64
|
+
assert "shard2.weight" in weights
|
|
65
|
+
assert len(weights) == 2
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# tests/test_mapper.py
|
|
2
|
+
"""Test mapper: apply rules, skip, tied weights."""
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import numpy as np
|
|
6
|
+
from waxjax.core.mapper import apply_rules
|
|
7
|
+
from waxjax.core.rules import MappingRule
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _linear_rules():
|
|
11
|
+
return [
|
|
12
|
+
MappingRule("**.weight", rename=lambda k: k.replace(".weight", ".kernel"),
|
|
13
|
+
transform="transpose_2d"),
|
|
14
|
+
MappingRule("**.bias", rename=lambda k: k, transform="no_op"),
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestMapper:
|
|
19
|
+
|
|
20
|
+
def test_basic_rename_and_transpose(self):
|
|
21
|
+
weights = {"fc.weight": np.ones((4, 8))}
|
|
22
|
+
result = apply_rules(weights, _linear_rules())
|
|
23
|
+
assert ("fc", "kernel") in result
|
|
24
|
+
assert result[("fc", "kernel")].shape == (8, 4)
|
|
25
|
+
|
|
26
|
+
def test_bias_not_transposed(self):
|
|
27
|
+
weights = {"fc.bias": np.ones(8)}
|
|
28
|
+
result = apply_rules(weights, _linear_rules())
|
|
29
|
+
assert result[("fc", "bias")].shape == (8,)
|
|
30
|
+
|
|
31
|
+
def test_skip_sentinel(self):
|
|
32
|
+
"""Key dengan rename → '__skip__' tidak masuk ke output."""
|
|
33
|
+
rules = [
|
|
34
|
+
MappingRule(
|
|
35
|
+
"tied.weight",
|
|
36
|
+
rename=lambda k: "__skip__",
|
|
37
|
+
transform="no_op",
|
|
38
|
+
)
|
|
39
|
+
]
|
|
40
|
+
weights = {"tied.weight": np.ones((4, 4))}
|
|
41
|
+
result = apply_rules(weights, rules)
|
|
42
|
+
assert ("tied", "weight") not in result
|
|
43
|
+
assert len(result) == 0
|
|
44
|
+
|
|
45
|
+
def test_priority_order(self):
|
|
46
|
+
"""Rule lebih spesifik menang atas rule lebih umum."""
|
|
47
|
+
specific = MappingRule(
|
|
48
|
+
"model.norm.weight",
|
|
49
|
+
rename=lambda k: k.replace(".weight", ".scale"),
|
|
50
|
+
transform="no_op",
|
|
51
|
+
)
|
|
52
|
+
generic = MappingRule(
|
|
53
|
+
"**.weight",
|
|
54
|
+
rename=lambda k: k.replace(".weight", ".kernel"),
|
|
55
|
+
transform="transpose_2d",
|
|
56
|
+
)
|
|
57
|
+
weights = {"model.norm.weight": np.ones(8)}
|
|
58
|
+
result = apply_rules(weights, [generic, specific])
|
|
59
|
+
# Harus pakai rule spesifik → .scale, bukan .kernel
|
|
60
|
+
assert ("model", "norm", "scale") in result
|
|
61
|
+
assert ("model", "norm", "kernel") not in result
|
|
62
|
+
|
|
63
|
+
def test_tied_weights_auto_created(self):
|
|
64
|
+
"""Tied weight dibuat otomatis jika source ada tapi target tidak."""
|
|
65
|
+
rules = [
|
|
66
|
+
MappingRule(
|
|
67
|
+
"model.shared.weight",
|
|
68
|
+
rename=lambda k: k.replace(".weight", ".embedding"),
|
|
69
|
+
transform="no_op",
|
|
70
|
+
)
|
|
71
|
+
]
|
|
72
|
+
weights = {"model.shared.weight": np.ones((512, 64))}
|
|
73
|
+
result = apply_rules(
|
|
74
|
+
weights, rules,
|
|
75
|
+
tied_weights={"model.shared.weight": "lm_head.kernel"}
|
|
76
|
+
)
|
|
77
|
+
# embedding harus ada
|
|
78
|
+
assert ("model", "shared", "embedding") in result
|
|
79
|
+
# lm_head.kernel harus dibuat otomatis dengan transpose
|
|
80
|
+
assert ("lm_head", "kernel") in result
|
|
81
|
+
assert result[("lm_head", "kernel")].shape == (64, 512)
|
|
82
|
+
|
|
83
|
+
def test_unmatched_keys_warn(self):
|
|
84
|
+
"""Key tanpa rule matching menghasilkan warning, tidak error."""
|
|
85
|
+
weights = {"unknown.mystery": np.ones((4, 4))}
|
|
86
|
+
with pytest.warns(UserWarning, match="tidak match"):
|
|
87
|
+
result = apply_rules(weights, _linear_rules())
|
|
88
|
+
assert len(result) == 0
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# tests/test_nllb.py
|
|
2
|
+
"""Integration test end-to-end untuk arsitektur NLLB/M2M-100."""
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import numpy as np
|
|
6
|
+
from flax.traverse_util import flatten_dict
|
|
7
|
+
import waxjax
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestNLLBLoad:
|
|
11
|
+
|
|
12
|
+
def test_load_success(self, nllb_dir):
|
|
13
|
+
params = waxjax.load(nllb_dir)
|
|
14
|
+
assert isinstance(params, dict)
|
|
15
|
+
|
|
16
|
+
def test_shared_embedding_shape(self, nllb_dir):
|
|
17
|
+
"""model.shared.embedding shape (V, D) tidak di-transpose."""
|
|
18
|
+
params = waxjax.load(nllb_dir)
|
|
19
|
+
embed = np.array(params["model"]["shared"]["embedding"])
|
|
20
|
+
assert embed.shape == (512, 64)
|
|
21
|
+
|
|
22
|
+
def test_tied_embed_tokens_skipped(self, nllb_dir):
|
|
23
|
+
"""encoder/decoder embed_tokens tidak masuk params (tied ke shared)."""
|
|
24
|
+
params = waxjax.load(nllb_dir)
|
|
25
|
+
flat = flatten_dict(params)
|
|
26
|
+
assert ("model","encoder","embed_tokens","embedding") not in flat
|
|
27
|
+
assert ("model","decoder","embed_tokens","embedding") not in flat
|
|
28
|
+
|
|
29
|
+
def test_positional_embedding_present(self, nllb_dir):
|
|
30
|
+
params = waxjax.load(nllb_dir)
|
|
31
|
+
flat = flatten_dict(params)
|
|
32
|
+
assert ("model","encoder","embed_positions","embedding") in flat
|
|
33
|
+
assert ("model","decoder","embed_positions","embedding") in flat
|
|
34
|
+
# Shape (MAXPOS+2, D) = (132, 64)
|
|
35
|
+
pos = np.array(flat[("model","encoder","embed_positions","embedding")])
|
|
36
|
+
assert pos.shape == (132, 64)
|
|
37
|
+
|
|
38
|
+
def test_layernorm_has_scale_and_bias(self, nllb_dir):
|
|
39
|
+
"""NLLB pakai LayerNorm → punya scale DAN bias (beda dari Qwen2 RMSNorm)."""
|
|
40
|
+
params = waxjax.load(nllb_dir)
|
|
41
|
+
flat = flatten_dict(params)
|
|
42
|
+
assert ("model","encoder","layers","0","self_attn_layer_norm","scale") in flat
|
|
43
|
+
assert ("model","encoder","layers","0","self_attn_layer_norm","bias") in flat
|
|
44
|
+
|
|
45
|
+
def test_fc1_fc2_kernel_transposed(self, nllb_dir):
|
|
46
|
+
"""fc1/fc2 harus di-transpose: PT (FFN,D) → Flax (D,FFN)."""
|
|
47
|
+
params = waxjax.load(nllb_dir)
|
|
48
|
+
flat = flatten_dict(params)
|
|
49
|
+
fc1 = np.array(flat[("model","encoder","layers","0","fc1","kernel")])
|
|
50
|
+
fc2 = np.array(flat[("model","encoder","layers","0","fc2","kernel")])
|
|
51
|
+
# D=64, FFN=128
|
|
52
|
+
assert fc1.shape == (64, 128) # transposed dari (128,64)
|
|
53
|
+
assert fc2.shape == (128, 64) # transposed dari (64,128)
|
|
54
|
+
|
|
55
|
+
def test_out_proj_not_o_proj(self, nllb_dir):
|
|
56
|
+
"""NLLB pakai out_proj bukan o_proj seperti Qwen2."""
|
|
57
|
+
params = waxjax.load(nllb_dir)
|
|
58
|
+
flat = flatten_dict(params)
|
|
59
|
+
assert ("model","encoder","layers","0","self_attn","out_proj","kernel") in flat
|
|
60
|
+
assert ("model","encoder","layers","0","self_attn","o_proj","kernel") not in flat
|
|
61
|
+
|
|
62
|
+
def test_cross_attention_present(self, nllb_dir):
|
|
63
|
+
"""Decoder harus punya encoder_attn (cross-attention)."""
|
|
64
|
+
params = waxjax.load(nllb_dir)
|
|
65
|
+
flat = flatten_dict(params)
|
|
66
|
+
assert ("model","decoder","layers","0","encoder_attn","q_proj","kernel") in flat
|
|
67
|
+
assert ("model","decoder","layers","0","encoder_attn_layer_norm","scale") in flat
|
|
68
|
+
|
|
69
|
+
def test_encoder_has_no_cross_attention(self, nllb_dir):
|
|
70
|
+
"""Encoder tidak punya encoder_attn."""
|
|
71
|
+
params = waxjax.load(nllb_dir)
|
|
72
|
+
flat = flatten_dict(params)
|
|
73
|
+
assert ("model","encoder","layers","0","encoder_attn","q_proj","kernel") not in flat
|
|
74
|
+
|
|
75
|
+
def test_lm_head_tied_to_shared(self, nllb_dir):
|
|
76
|
+
"""lm_head.kernel == model.shared.embedding.T"""
|
|
77
|
+
params = waxjax.load(nllb_dir)
|
|
78
|
+
flat = flatten_dict(params)
|
|
79
|
+
shared = np.array(flat[("model","shared","embedding")]) # (V,D)
|
|
80
|
+
lm = np.array(flat[("lm_head","kernel")]) # (D,V)
|
|
81
|
+
assert lm.shape == (64, 512)
|
|
82
|
+
assert np.allclose(shared.T, lm, atol=1e-6)
|
|
83
|
+
|
|
84
|
+
def test_both_encoder_decoder_layers(self, nllb_dir):
|
|
85
|
+
params = waxjax.load(nllb_dir)
|
|
86
|
+
flat = flatten_dict(params)
|
|
87
|
+
for i in range(2):
|
|
88
|
+
# Encoder
|
|
89
|
+
assert ("model","encoder","layers",str(i),"self_attn","q_proj","kernel") in flat
|
|
90
|
+
assert ("model","encoder","layers",str(i),"fc1","kernel") in flat
|
|
91
|
+
# Decoder
|
|
92
|
+
assert ("model","decoder","layers",str(i),"self_attn","q_proj","kernel") in flat
|
|
93
|
+
assert ("model","decoder","layers",str(i),"encoder_attn","q_proj","kernel") in flat
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# tests/test_qwen2.py
|
|
2
|
+
"""Integration test end-to-end untuk arsitektur Qwen2."""
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import numpy as np
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from flax.traverse_util import flatten_dict
|
|
8
|
+
import waxjax
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestQwen2Load:
|
|
12
|
+
|
|
13
|
+
def test_load_returns_nested_dict(self, qwen2_dir):
|
|
14
|
+
params = waxjax.load(qwen2_dir)
|
|
15
|
+
assert isinstance(params, dict)
|
|
16
|
+
assert "model" in params
|
|
17
|
+
|
|
18
|
+
def test_load_frozen_dict(self, qwen2_dir):
|
|
19
|
+
from flax.core import FrozenDict
|
|
20
|
+
params = waxjax.load(qwen2_dir, format="frozen")
|
|
21
|
+
assert isinstance(params, FrozenDict)
|
|
22
|
+
|
|
23
|
+
def test_embedding_shape(self, qwen2_dir):
|
|
24
|
+
params = waxjax.load(qwen2_dir)
|
|
25
|
+
embed = np.array(params["model"]["embed_tokens"]["embedding"])
|
|
26
|
+
# (V=512, H=64) — tidak di-transpose
|
|
27
|
+
assert embed.shape == (512, 64)
|
|
28
|
+
|
|
29
|
+
def test_linear_kernel_transposed(self, qwen2_dir):
|
|
30
|
+
params = waxjax.load(qwen2_dir)
|
|
31
|
+
flat = flatten_dict(params)
|
|
32
|
+
q_kernel = np.array(flat[("model","layers","0","self_attn","q_proj","kernel")])
|
|
33
|
+
# PT shape (QH*HD=64, H=64) → Flax shape (H=64, QH*HD=64)
|
|
34
|
+
assert q_kernel.shape == (64, 64)
|
|
35
|
+
|
|
36
|
+
def test_rmsnorm_has_scale_not_kernel(self, qwen2_dir):
|
|
37
|
+
params = waxjax.load(qwen2_dir)
|
|
38
|
+
flat = flatten_dict(params)
|
|
39
|
+
assert ("model","layers","0","input_layernorm","scale") in flat
|
|
40
|
+
assert ("model","layers","0","input_layernorm","kernel") not in flat
|
|
41
|
+
|
|
42
|
+
def test_rmsnorm_no_bias(self, qwen2_dir):
|
|
43
|
+
"""Qwen2 pakai RMSNorm, tidak ada bias."""
|
|
44
|
+
params = waxjax.load(qwen2_dir)
|
|
45
|
+
flat = flatten_dict(params)
|
|
46
|
+
assert ("model","layers","0","input_layernorm","bias") not in flat
|
|
47
|
+
|
|
48
|
+
def test_tied_lm_head_created(self, qwen2_dir):
|
|
49
|
+
"""lm_head.kernel harus ada meski tidak disimpan di file."""
|
|
50
|
+
params = waxjax.load(qwen2_dir)
|
|
51
|
+
flat = flatten_dict(params)
|
|
52
|
+
assert ("lm_head","kernel") in flat
|
|
53
|
+
# Shape: (H=64, V=512)
|
|
54
|
+
assert flat[("lm_head","kernel")].shape == (64, 512)
|
|
55
|
+
|
|
56
|
+
def test_tied_lm_head_values_correct(self, qwen2_dir):
|
|
57
|
+
"""lm_head.kernel == embed.T secara numerik."""
|
|
58
|
+
params = waxjax.load(qwen2_dir)
|
|
59
|
+
flat = flatten_dict(params)
|
|
60
|
+
embed = np.array(flat[("model","embed_tokens","embedding")])
|
|
61
|
+
lm = np.array(flat[("lm_head","kernel")])
|
|
62
|
+
assert np.allclose(embed.T, lm, atol=1e-6)
|
|
63
|
+
|
|
64
|
+
def test_all_layers_present(self, qwen2_dir):
|
|
65
|
+
params = waxjax.load(qwen2_dir)
|
|
66
|
+
flat = flatten_dict(params)
|
|
67
|
+
for i in range(2): # NL=2 di fixture
|
|
68
|
+
assert ("model","layers",str(i),"self_attn","q_proj","kernel") in flat
|
|
69
|
+
assert ("model","layers",str(i),"mlp","gate_proj","kernel") in flat
|
|
70
|
+
|
|
71
|
+
def test_all_values_jnp_array(self, qwen2_dir):
|
|
72
|
+
params = waxjax.load(qwen2_dir)
|
|
73
|
+
flat = flatten_dict(params)
|
|
74
|
+
for key, val in flat.items():
|
|
75
|
+
assert isinstance(val, jnp.ndarray), f"{key} bukan jnp.ndarray"
|
|
76
|
+
|
|
77
|
+
def test_no_nan_or_inf(self, qwen2_dir):
|
|
78
|
+
params = waxjax.load(qwen2_dir)
|
|
79
|
+
flat = flatten_dict(params)
|
|
80
|
+
for key, val in flat.items():
|
|
81
|
+
arr = np.array(val)
|
|
82
|
+
assert not np.any(np.isnan(arr)), f"NaN di {key}"
|
|
83
|
+
assert not np.any(np.isinf(arr)), f"Inf di {key}"
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# tests/test_rules.py
|
|
2
|
+
"""Unit test untuk MappingRule — matching dan transform."""
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import numpy as np
|
|
6
|
+
from waxjax.core.rules import MappingRule, transpose_2d, no_op
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestMappingRuleMatch:
|
|
10
|
+
|
|
11
|
+
def test_exact_match(self):
|
|
12
|
+
rule = MappingRule("lm_head.weight", rename=lambda k: k, transform="no_op")
|
|
13
|
+
assert rule.matches("lm_head.weight")
|
|
14
|
+
assert not rule.matches("lm_head.bias")
|
|
15
|
+
|
|
16
|
+
def test_single_wildcard(self):
|
|
17
|
+
rule = MappingRule(
|
|
18
|
+
"model.layers.*.self_attn.q_proj.weight",
|
|
19
|
+
rename=lambda k: k, transform="no_op"
|
|
20
|
+
)
|
|
21
|
+
assert rule.matches("model.layers.0.self_attn.q_proj.weight")
|
|
22
|
+
assert rule.matches("model.layers.11.self_attn.q_proj.weight")
|
|
23
|
+
# * tidak boleh match dot
|
|
24
|
+
assert not rule.matches("model.layers.0.1.self_attn.q_proj.weight")
|
|
25
|
+
|
|
26
|
+
def test_double_wildcard(self):
|
|
27
|
+
rule = MappingRule(
|
|
28
|
+
"model.**.weight",
|
|
29
|
+
rename=lambda k: k, transform="no_op"
|
|
30
|
+
)
|
|
31
|
+
assert rule.matches("model.layers.0.self_attn.q_proj.weight")
|
|
32
|
+
assert rule.matches("model.norm.weight")
|
|
33
|
+
|
|
34
|
+
def test_priority_auto(self):
|
|
35
|
+
specific = MappingRule(
|
|
36
|
+
"model.layers.*.self_attn_layer_norm.weight",
|
|
37
|
+
rename=lambda k: k, transform="no_op"
|
|
38
|
+
)
|
|
39
|
+
generic = MappingRule(
|
|
40
|
+
"model.layers.*.*.weight",
|
|
41
|
+
rename=lambda k: k, transform="no_op"
|
|
42
|
+
)
|
|
43
|
+
assert specific.priority > generic.priority
|
|
44
|
+
|
|
45
|
+
def test_invalid_transform_raises(self):
|
|
46
|
+
with pytest.raises(ValueError, match="tidak dikenal"):
|
|
47
|
+
MappingRule("*.weight", rename=lambda k: k, transform="unknown_transform")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TestMappingRuleTransform:
|
|
51
|
+
|
|
52
|
+
def test_transpose_2d(self):
|
|
53
|
+
arr = np.ones((4, 8))
|
|
54
|
+
rule = MappingRule("*.weight", rename=lambda k: k, transform="transpose_2d")
|
|
55
|
+
result = rule.apply_transform(arr)
|
|
56
|
+
assert result.shape == (8, 4)
|
|
57
|
+
|
|
58
|
+
def test_no_transpose_1d(self):
|
|
59
|
+
"""1D array (bias) tidak boleh di-transpose meski pakai transpose_2d."""
|
|
60
|
+
arr = np.ones(8)
|
|
61
|
+
result = transpose_2d(arr)
|
|
62
|
+
assert result.shape == (8,)
|
|
63
|
+
|
|
64
|
+
def test_custom_callable_transform(self):
|
|
65
|
+
def double(arr):
|
|
66
|
+
return arr * 2
|
|
67
|
+
rule = MappingRule("*.weight", rename=lambda k: k, transform=double)
|
|
68
|
+
arr = np.array([1.0, 2.0])
|
|
69
|
+
assert np.allclose(rule.apply_transform(arr), [2.0, 4.0])
|
|
70
|
+
|
|
71
|
+
def test_rename_applied(self):
|
|
72
|
+
rule = MappingRule(
|
|
73
|
+
"*.weight",
|
|
74
|
+
rename=lambda k: k.replace(".weight", ".kernel"),
|
|
75
|
+
transform="no_op"
|
|
76
|
+
)
|
|
77
|
+
assert rule.apply_rename("model.fc.weight") == "model.fc.kernel"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Compatibility shim: expose the public API under the expected
|
|
2
|
+
package name `waxjax` while the real implementation lives in
|
|
3
|
+
`safejax` (to preserve the original repository layout).
|
|
4
|
+
"""
|
|
5
|
+
# Re-export the public API without importing `safejax` package object
|
|
6
|
+
# to avoid circular-import issues during package initialization.
|
|
7
|
+
from .core.api import load, save, inspect, verify
|
|
8
|
+
from .core.rules import MappingRule
|
|
9
|
+
from .architectures._registry import register_architecture, ArchitectureConfig
|
|
10
|
+
|
|
11
|
+
__version__ = "0.1.0"
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"load", "save", "inspect", "verify",
|
|
15
|
+
"MappingRule", "register_architecture", "ArchitectureConfig",
|
|
16
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from safejax.architectures._registry import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from safejax.architectures.nllb import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from safejax.architectures.qwen2 import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from safejax.core.api import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from safejax.core.loader import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from safejax.core.mapper import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from safejax.core.nester import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from safejax.core.rules import *
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: waxjax
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Load HuggingFace SafeTensors models to JAX/Flax
|
|
5
|
+
Home-page: https://github.com/ik12-b/waxjax
|
|
6
|
+
Author: Your Name
|
|
7
|
+
License: Apache-2.0
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Python: >=3.9
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
Requires-Dist: safetensors>=0.4.0
|
|
14
|
+
Requires-Dist: numpy>=1.24
|
|
15
|
+
Requires-Dist: jax>=0.4.0
|
|
16
|
+
Provides-Extra: torch
|
|
17
|
+
Requires-Dist: torch>=2.0; extra == "torch"
|
|
18
|
+
Provides-Extra: flax
|
|
19
|
+
Requires-Dist: flax>=0.7; extra == "flax"
|
|
20
|
+
Provides-Extra: all
|
|
21
|
+
Requires-Dist: torch>=2.0; extra == "all"
|
|
22
|
+
Requires-Dist: flax>=0.7; extra == "all"
|
|
23
|
+
Requires-Dist: ml_dtypes; extra == "all"
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest; extra == "dev"
|
|
26
|
+
Requires-Dist: black; extra == "dev"
|
|
27
|
+
Requires-Dist: ruff; extra == "dev"
|
|
28
|
+
Dynamic: author
|
|
29
|
+
Dynamic: classifier
|
|
30
|
+
Dynamic: description
|
|
31
|
+
Dynamic: description-content-type
|
|
32
|
+
Dynamic: home-page
|
|
33
|
+
Dynamic: license
|
|
34
|
+
Dynamic: provides-extra
|
|
35
|
+
Dynamic: requires-dist
|
|
36
|
+
Dynamic: requires-python
|
|
37
|
+
Dynamic: summary
|
|
38
|
+
|
|
39
|
+
# waxjax
|
|
40
|
+
|
|
41
|
+
Lightweight helper to load HuggingFace SafeTensors checkpoints into JAX/Flax.
|
|
42
|
+
|
|
43
|
+
Install (latest from source):
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
pip install git+https://github.com/ik12-b/waxjax.git
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Build and publish (recommended to test to TestPyPI first): see `RELEASE.md` or run locally:
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
python -m pip install --upgrade build twine
|
|
53
|
+
python -m build
|
|
54
|
+
twine upload --repository testpypi dist/*
|
|
55
|
+
```
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
MANIFEST.in
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
setup.py
|
|
5
|
+
safejax/__init__.py
|
|
6
|
+
tests/test_api.py
|
|
7
|
+
tests/test_loader.py
|
|
8
|
+
tests/test_mapper.py
|
|
9
|
+
tests/test_nllb.py
|
|
10
|
+
tests/test_qwen2.py
|
|
11
|
+
tests/test_rules.py
|
|
12
|
+
waxjax/__init__.py
|
|
13
|
+
waxjax.egg-info/PKG-INFO
|
|
14
|
+
waxjax.egg-info/SOURCES.txt
|
|
15
|
+
waxjax.egg-info/dependency_links.txt
|
|
16
|
+
waxjax.egg-info/requires.txt
|
|
17
|
+
waxjax.egg-info/top_level.txt
|
|
18
|
+
waxjax/architectures/__init__.py
|
|
19
|
+
waxjax/architectures/_registry.py
|
|
20
|
+
waxjax/architectures/nllb.py
|
|
21
|
+
waxjax/architectures/qwen2.py
|
|
22
|
+
waxjax/core/__init__.py
|
|
23
|
+
waxjax/core/api.py
|
|
24
|
+
waxjax/core/loader.py
|
|
25
|
+
waxjax/core/mapper.py
|
|
26
|
+
waxjax/core/nester.py
|
|
27
|
+
waxjax/core/rules.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|