waxjax 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ include README.md
2
+ global-exclude __pycache__
waxjax-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,55 @@
1
+ Metadata-Version: 2.4
2
+ Name: waxjax
3
+ Version: 0.1.0
4
+ Summary: Load HuggingFace SafeTensors models to JAX/Flax
5
+ Home-page: https://github.com/ik12-b/waxjax
6
+ Author: Your Name
7
+ License: Apache-2.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.9
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: safetensors>=0.4.0
14
+ Requires-Dist: numpy>=1.24
15
+ Requires-Dist: jax>=0.4.0
16
+ Provides-Extra: torch
17
+ Requires-Dist: torch>=2.0; extra == "torch"
18
+ Provides-Extra: flax
19
+ Requires-Dist: flax>=0.7; extra == "flax"
20
+ Provides-Extra: all
21
+ Requires-Dist: torch>=2.0; extra == "all"
22
+ Requires-Dist: flax>=0.7; extra == "all"
23
+ Requires-Dist: ml_dtypes; extra == "all"
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest; extra == "dev"
26
+ Requires-Dist: black; extra == "dev"
27
+ Requires-Dist: ruff; extra == "dev"
28
+ Dynamic: author
29
+ Dynamic: classifier
30
+ Dynamic: description
31
+ Dynamic: description-content-type
32
+ Dynamic: home-page
33
+ Dynamic: license
34
+ Dynamic: provides-extra
35
+ Dynamic: requires-dist
36
+ Dynamic: requires-python
37
+ Dynamic: summary
38
+
39
+ # waxjax
40
+
41
+ Lightweight helper to load HuggingFace SafeTensors checkpoints into JAX/Flax.
42
+
43
+ Install (latest from source):
44
+
45
+ ```bash
46
+ pip install git+https://github.com/ik12-b/waxjax.git
47
+ ```
48
+
49
+ Build and publish (recommended to test to TestPyPI first): see `RELEASE.md` or run locally:
50
+
51
+ ```bash
52
+ python -m pip install --upgrade build twine
53
+ python -m build
54
+ twine upload --repository testpypi dist/*
55
+ ```
waxjax-0.1.0/README.md ADDED
@@ -0,0 +1,17 @@
1
+ # waxjax
2
+
3
+ Lightweight helper to load HuggingFace SafeTensors checkpoints into JAX/Flax.
4
+
5
+ Install (latest from source):
6
+
7
+ ```bash
8
+ pip install git+https://github.com/ik12-b/waxjax.git
9
+ ```
10
+
11
+ Build and publish (recommended to test to TestPyPI first): see `RELEASE.md` or run locally:
12
+
13
+ ```bash
14
+ python -m pip install --upgrade build twine
15
+ python -m build
16
+ twine upload --repository testpypi dist/*
17
+ ```
@@ -0,0 +1,3 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,9 @@
1
+ from .core.api import load, save, inspect, verify
2
+ from .core.rules import MappingRule
3
+ from .architectures._registry import register_architecture, ArchitectureConfig
4
+
5
+ __version__ = "0.1.0"
6
+ __all__ = [
7
+ "load", "save", "inspect", "verify",
8
+ "MappingRule", "register_architecture", "ArchitectureConfig",
9
+ ]
waxjax-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
waxjax-0.1.0/setup.py ADDED
@@ -0,0 +1,34 @@
1
+ from pathlib import Path
2
+ from setuptools import setup, find_packages
3
+
4
+ here = Path(__file__).parent
5
+ long_description = (here / "README.md").read_text(encoding="utf-8") if (here / "README.md").exists() else ""
6
+
7
+ setup(
8
+ name="waxjax",
9
+ version="0.1.0",
10
+ description="Load HuggingFace SafeTensors models to JAX/Flax",
11
+ long_description=long_description,
12
+ long_description_content_type="text/markdown",
13
+ author="Your Name",
14
+ license="Apache-2.0",
15
+ url="https://github.com/ik12-b/waxjax",
16
+ packages=find_packages(),
17
+ python_requires=">=3.9",
18
+ install_requires=[
19
+ "safetensors>=0.4.0",
20
+ "numpy>=1.24",
21
+ "jax>=0.4.0",
22
+ ],
23
+ extras_require={
24
+ "torch": ["torch>=2.0"],
25
+ "flax": ["flax>=0.7"],
26
+ "all": ["torch>=2.0", "flax>=0.7", "ml_dtypes"],
27
+ "dev": ["pytest", "black", "ruff"],
28
+ },
29
+ classifiers=[
30
+ "Programming Language :: Python :: 3",
31
+ "License :: OSI Approved :: Apache Software License",
32
+ "Operating System :: OS Independent",
33
+ ],
34
+ )
@@ -0,0 +1,72 @@
1
+ # tests/test_api.py
2
+ """Test public API: inspect, verify, unsupported model."""
3
+
4
+ import pytest
5
+ import json
6
+ import waxjax
7
+
8
+
9
+ class TestInspect:
10
+
11
+ def test_inspect_runs_without_error(self, qwen2_dir, capsys):
12
+ waxjax.inspect(qwen2_dir)
13
+ out = capsys.readouterr().out
14
+ assert "qwen2" in out
15
+ assert "✓" in out
16
+
17
+ def test_inspect_shows_unsupported(self, tmp_path, capsys):
18
+ (tmp_path / "config.json").write_text(
19
+ json.dumps({"model_type": "unknown_model_xyz"})
20
+ )
21
+ # Buat dummy safetensors
22
+ import torch
23
+ from safetensors.torch import save_file
24
+ save_file({"w": torch.randn(2,2)}, str(tmp_path / "model.safetensors"))
25
+
26
+ waxjax.inspect(str(tmp_path))
27
+ out = capsys.readouterr().out
28
+ assert "✗" in out or "belum didukung" in out
29
+
30
+
31
+ class TestUnsupportedModel:
32
+
33
+ def test_load_unknown_raises_valueerror(self, tmp_path):
34
+ (tmp_path / "config.json").write_text(
35
+ json.dumps({"model_type": "unknown_xyz"})
36
+ )
37
+ import torch
38
+ from safetensors.torch import save_file
39
+ save_file({"w": torch.randn(2,2)}, str(tmp_path / "model.safetensors"))
40
+
41
+ with pytest.raises(ValueError, match="belum didukung"):
42
+ waxjax.load(str(tmp_path))
43
+
44
+ def test_load_with_custom_rules_bypasses_registry(self, tmp_path):
45
+ """Custom rules bisa dipakai untuk model yang belum terdaftar."""
46
+ from waxjax.core.rules import MappingRule
47
+ import torch
48
+ from safetensors.torch import save_file
49
+
50
+ (tmp_path / "config.json").write_text(
51
+ json.dumps({"model_type": "custom_unknown"})
52
+ )
53
+ save_file({"fc.weight": torch.randn(4, 8)},
54
+ str(tmp_path / "model.safetensors"))
55
+
56
+ rules = [
57
+ MappingRule("fc.weight",
58
+ rename=lambda k: k.replace(".weight", ".kernel"),
59
+ transform="transpose_2d")
60
+ ]
61
+ params = waxjax.load(str(tmp_path), rules=rules)
62
+ assert "fc" in params
63
+ assert "kernel" in params["fc"]
64
+
65
+
66
+ class TestVerify:
67
+
68
+ def test_verify_passes_for_correct_conversion(self, qwen2_dir):
69
+ params = waxjax.load(qwen2_dir)
70
+ report = waxjax.verify(params, qwen2_dir, sample_n=5)
71
+ assert report["passed"] == report["total"]
72
+ assert report["max_diff"] < 1e-5
@@ -0,0 +1,65 @@
1
+ # tests/test_loader.py
2
+ """Test loader: baca safetensors, handle bfloat16, sharding."""
3
+
4
+ import pytest
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from waxjax.core.loader import load_safetensors
8
+
9
+
10
+ class TestLoader:
11
+
12
+ def test_load_basic(self, qwen2_dir):
13
+ weights = load_safetensors(qwen2_dir)
14
+ assert isinstance(weights, dict)
15
+ assert len(weights) > 0
16
+
17
+ def test_all_float32(self, qwen2_dir):
18
+ """Semua tensor harus float32 setelah load."""
19
+ weights = load_safetensors(qwen2_dir)
20
+ for key, arr in weights.items():
21
+ assert arr.dtype == np.float32, f"{key} bukan float32: {arr.dtype}"
22
+
23
+ def test_expected_keys_present_qwen2(self, qwen2_dir):
24
+ weights = load_safetensors(qwen2_dir)
25
+ assert "model.embed_tokens.weight" in weights
26
+ assert "model.layers.0.self_attn.q_proj.weight" in weights
27
+ assert "model.layers.0.self_attn.q_proj.bias" in weights
28
+ assert "model.norm.weight" in weights
29
+
30
+ def test_expected_keys_present_nllb(self, nllb_dir):
31
+ weights = load_safetensors(nllb_dir)
32
+ assert "model.shared.weight" in weights
33
+ assert "model.encoder.layers.0.self_attn.q_proj.weight" in weights
34
+ assert "model.decoder.layers.0.encoder_attn.q_proj.weight" in weights
35
+ assert "model.encoder.embed_positions.weight" in weights
36
+
37
+ def test_missing_dir_raises(self, tmp_path):
38
+ with pytest.raises(FileNotFoundError):
39
+ load_safetensors(str(tmp_path / "nonexistent"))
40
+
41
+ def test_bfloat16_handling(self, tmp_path):
42
+ """Tensor bfloat16 harus di-cast ke float32 tanpa error."""
43
+ import torch
44
+ from safetensors.torch import save_file
45
+
46
+ bf16_tensor = torch.randn(4, 8).to(torch.bfloat16)
47
+ save_file({"layer.weight": bf16_tensor}, str(tmp_path / "model.safetensors"))
48
+
49
+ weights = load_safetensors(str(tmp_path))
50
+ assert weights["layer.weight"].dtype == np.float32
51
+
52
+ def test_sharded_loading(self, tmp_path):
53
+ """Dua file safetensors di-merge dengan benar."""
54
+ import torch
55
+ from safetensors.torch import save_file
56
+
57
+ save_file({"shard1.weight": torch.randn(4, 4)},
58
+ str(tmp_path / "model-00001-of-00002.safetensors"))
59
+ save_file({"shard2.weight": torch.randn(4, 4)},
60
+ str(tmp_path / "model-00002-of-00002.safetensors"))
61
+
62
+ weights = load_safetensors(str(tmp_path))
63
+ assert "shard1.weight" in weights
64
+ assert "shard2.weight" in weights
65
+ assert len(weights) == 2
@@ -0,0 +1,88 @@
1
+ # tests/test_mapper.py
2
+ """Test mapper: apply rules, skip, tied weights."""
3
+
4
+ import pytest
5
+ import numpy as np
6
+ from waxjax.core.mapper import apply_rules
7
+ from waxjax.core.rules import MappingRule
8
+
9
+
10
+ def _linear_rules():
11
+ return [
12
+ MappingRule("**.weight", rename=lambda k: k.replace(".weight", ".kernel"),
13
+ transform="transpose_2d"),
14
+ MappingRule("**.bias", rename=lambda k: k, transform="no_op"),
15
+ ]
16
+
17
+
18
+ class TestMapper:
19
+
20
+ def test_basic_rename_and_transpose(self):
21
+ weights = {"fc.weight": np.ones((4, 8))}
22
+ result = apply_rules(weights, _linear_rules())
23
+ assert ("fc", "kernel") in result
24
+ assert result[("fc", "kernel")].shape == (8, 4)
25
+
26
+ def test_bias_not_transposed(self):
27
+ weights = {"fc.bias": np.ones(8)}
28
+ result = apply_rules(weights, _linear_rules())
29
+ assert result[("fc", "bias")].shape == (8,)
30
+
31
+ def test_skip_sentinel(self):
32
+ """Key dengan rename → '__skip__' tidak masuk ke output."""
33
+ rules = [
34
+ MappingRule(
35
+ "tied.weight",
36
+ rename=lambda k: "__skip__",
37
+ transform="no_op",
38
+ )
39
+ ]
40
+ weights = {"tied.weight": np.ones((4, 4))}
41
+ result = apply_rules(weights, rules)
42
+ assert ("tied", "weight") not in result
43
+ assert len(result) == 0
44
+
45
+ def test_priority_order(self):
46
+ """Rule lebih spesifik menang atas rule lebih umum."""
47
+ specific = MappingRule(
48
+ "model.norm.weight",
49
+ rename=lambda k: k.replace(".weight", ".scale"),
50
+ transform="no_op",
51
+ )
52
+ generic = MappingRule(
53
+ "**.weight",
54
+ rename=lambda k: k.replace(".weight", ".kernel"),
55
+ transform="transpose_2d",
56
+ )
57
+ weights = {"model.norm.weight": np.ones(8)}
58
+ result = apply_rules(weights, [generic, specific])
59
+ # Harus pakai rule spesifik → .scale, bukan .kernel
60
+ assert ("model", "norm", "scale") in result
61
+ assert ("model", "norm", "kernel") not in result
62
+
63
+ def test_tied_weights_auto_created(self):
64
+ """Tied weight dibuat otomatis jika source ada tapi target tidak."""
65
+ rules = [
66
+ MappingRule(
67
+ "model.shared.weight",
68
+ rename=lambda k: k.replace(".weight", ".embedding"),
69
+ transform="no_op",
70
+ )
71
+ ]
72
+ weights = {"model.shared.weight": np.ones((512, 64))}
73
+ result = apply_rules(
74
+ weights, rules,
75
+ tied_weights={"model.shared.weight": "lm_head.kernel"}
76
+ )
77
+ # embedding harus ada
78
+ assert ("model", "shared", "embedding") in result
79
+ # lm_head.kernel harus dibuat otomatis dengan transpose
80
+ assert ("lm_head", "kernel") in result
81
+ assert result[("lm_head", "kernel")].shape == (64, 512)
82
+
83
+ def test_unmatched_keys_warn(self):
84
+ """Key tanpa rule matching menghasilkan warning, tidak error."""
85
+ weights = {"unknown.mystery": np.ones((4, 4))}
86
+ with pytest.warns(UserWarning, match="tidak match"):
87
+ result = apply_rules(weights, _linear_rules())
88
+ assert len(result) == 0
@@ -0,0 +1,93 @@
1
+ # tests/test_nllb.py
2
+ """Integration test end-to-end untuk arsitektur NLLB/M2M-100."""
3
+
4
+ import pytest
5
+ import numpy as np
6
+ from flax.traverse_util import flatten_dict
7
+ import waxjax
8
+
9
+
10
+ class TestNLLBLoad:
11
+
12
+ def test_load_success(self, nllb_dir):
13
+ params = waxjax.load(nllb_dir)
14
+ assert isinstance(params, dict)
15
+
16
+ def test_shared_embedding_shape(self, nllb_dir):
17
+ """model.shared.embedding shape (V, D) tidak di-transpose."""
18
+ params = waxjax.load(nllb_dir)
19
+ embed = np.array(params["model"]["shared"]["embedding"])
20
+ assert embed.shape == (512, 64)
21
+
22
+ def test_tied_embed_tokens_skipped(self, nllb_dir):
23
+ """encoder/decoder embed_tokens tidak masuk params (tied ke shared)."""
24
+ params = waxjax.load(nllb_dir)
25
+ flat = flatten_dict(params)
26
+ assert ("model","encoder","embed_tokens","embedding") not in flat
27
+ assert ("model","decoder","embed_tokens","embedding") not in flat
28
+
29
+ def test_positional_embedding_present(self, nllb_dir):
30
+ params = waxjax.load(nllb_dir)
31
+ flat = flatten_dict(params)
32
+ assert ("model","encoder","embed_positions","embedding") in flat
33
+ assert ("model","decoder","embed_positions","embedding") in flat
34
+ # Shape (MAXPOS+2, D) = (132, 64)
35
+ pos = np.array(flat[("model","encoder","embed_positions","embedding")])
36
+ assert pos.shape == (132, 64)
37
+
38
+ def test_layernorm_has_scale_and_bias(self, nllb_dir):
39
+ """NLLB pakai LayerNorm → punya scale DAN bias (beda dari Qwen2 RMSNorm)."""
40
+ params = waxjax.load(nllb_dir)
41
+ flat = flatten_dict(params)
42
+ assert ("model","encoder","layers","0","self_attn_layer_norm","scale") in flat
43
+ assert ("model","encoder","layers","0","self_attn_layer_norm","bias") in flat
44
+
45
+ def test_fc1_fc2_kernel_transposed(self, nllb_dir):
46
+ """fc1/fc2 harus di-transpose: PT (FFN,D) → Flax (D,FFN)."""
47
+ params = waxjax.load(nllb_dir)
48
+ flat = flatten_dict(params)
49
+ fc1 = np.array(flat[("model","encoder","layers","0","fc1","kernel")])
50
+ fc2 = np.array(flat[("model","encoder","layers","0","fc2","kernel")])
51
+ # D=64, FFN=128
52
+ assert fc1.shape == (64, 128) # transposed dari (128,64)
53
+ assert fc2.shape == (128, 64) # transposed dari (64,128)
54
+
55
+ def test_out_proj_not_o_proj(self, nllb_dir):
56
+ """NLLB pakai out_proj bukan o_proj seperti Qwen2."""
57
+ params = waxjax.load(nllb_dir)
58
+ flat = flatten_dict(params)
59
+ assert ("model","encoder","layers","0","self_attn","out_proj","kernel") in flat
60
+ assert ("model","encoder","layers","0","self_attn","o_proj","kernel") not in flat
61
+
62
+ def test_cross_attention_present(self, nllb_dir):
63
+ """Decoder harus punya encoder_attn (cross-attention)."""
64
+ params = waxjax.load(nllb_dir)
65
+ flat = flatten_dict(params)
66
+ assert ("model","decoder","layers","0","encoder_attn","q_proj","kernel") in flat
67
+ assert ("model","decoder","layers","0","encoder_attn_layer_norm","scale") in flat
68
+
69
+ def test_encoder_has_no_cross_attention(self, nllb_dir):
70
+ """Encoder tidak punya encoder_attn."""
71
+ params = waxjax.load(nllb_dir)
72
+ flat = flatten_dict(params)
73
+ assert ("model","encoder","layers","0","encoder_attn","q_proj","kernel") not in flat
74
+
75
+ def test_lm_head_tied_to_shared(self, nllb_dir):
76
+ """lm_head.kernel == model.shared.embedding.T"""
77
+ params = waxjax.load(nllb_dir)
78
+ flat = flatten_dict(params)
79
+ shared = np.array(flat[("model","shared","embedding")]) # (V,D)
80
+ lm = np.array(flat[("lm_head","kernel")]) # (D,V)
81
+ assert lm.shape == (64, 512)
82
+ assert np.allclose(shared.T, lm, atol=1e-6)
83
+
84
+ def test_both_encoder_decoder_layers(self, nllb_dir):
85
+ params = waxjax.load(nllb_dir)
86
+ flat = flatten_dict(params)
87
+ for i in range(2):
88
+ # Encoder
89
+ assert ("model","encoder","layers",str(i),"self_attn","q_proj","kernel") in flat
90
+ assert ("model","encoder","layers",str(i),"fc1","kernel") in flat
91
+ # Decoder
92
+ assert ("model","decoder","layers",str(i),"self_attn","q_proj","kernel") in flat
93
+ assert ("model","decoder","layers",str(i),"encoder_attn","q_proj","kernel") in flat
@@ -0,0 +1,83 @@
1
+ # tests/test_qwen2.py
2
+ """Integration test end-to-end untuk arsitektur Qwen2."""
3
+
4
+ import pytest
5
+ import numpy as np
6
+ import jax.numpy as jnp
7
+ from flax.traverse_util import flatten_dict
8
+ import waxjax
9
+
10
+
11
+ class TestQwen2Load:
12
+
13
+ def test_load_returns_nested_dict(self, qwen2_dir):
14
+ params = waxjax.load(qwen2_dir)
15
+ assert isinstance(params, dict)
16
+ assert "model" in params
17
+
18
+ def test_load_frozen_dict(self, qwen2_dir):
19
+ from flax.core import FrozenDict
20
+ params = waxjax.load(qwen2_dir, format="frozen")
21
+ assert isinstance(params, FrozenDict)
22
+
23
+ def test_embedding_shape(self, qwen2_dir):
24
+ params = waxjax.load(qwen2_dir)
25
+ embed = np.array(params["model"]["embed_tokens"]["embedding"])
26
+ # (V=512, H=64) — tidak di-transpose
27
+ assert embed.shape == (512, 64)
28
+
29
+ def test_linear_kernel_transposed(self, qwen2_dir):
30
+ params = waxjax.load(qwen2_dir)
31
+ flat = flatten_dict(params)
32
+ q_kernel = np.array(flat[("model","layers","0","self_attn","q_proj","kernel")])
33
+ # PT shape (QH*HD=64, H=64) → Flax shape (H=64, QH*HD=64)
34
+ assert q_kernel.shape == (64, 64)
35
+
36
+ def test_rmsnorm_has_scale_not_kernel(self, qwen2_dir):
37
+ params = waxjax.load(qwen2_dir)
38
+ flat = flatten_dict(params)
39
+ assert ("model","layers","0","input_layernorm","scale") in flat
40
+ assert ("model","layers","0","input_layernorm","kernel") not in flat
41
+
42
+ def test_rmsnorm_no_bias(self, qwen2_dir):
43
+ """Qwen2 pakai RMSNorm, tidak ada bias."""
44
+ params = waxjax.load(qwen2_dir)
45
+ flat = flatten_dict(params)
46
+ assert ("model","layers","0","input_layernorm","bias") not in flat
47
+
48
+ def test_tied_lm_head_created(self, qwen2_dir):
49
+ """lm_head.kernel harus ada meski tidak disimpan di file."""
50
+ params = waxjax.load(qwen2_dir)
51
+ flat = flatten_dict(params)
52
+ assert ("lm_head","kernel") in flat
53
+ # Shape: (H=64, V=512)
54
+ assert flat[("lm_head","kernel")].shape == (64, 512)
55
+
56
+ def test_tied_lm_head_values_correct(self, qwen2_dir):
57
+ """lm_head.kernel == embed.T secara numerik."""
58
+ params = waxjax.load(qwen2_dir)
59
+ flat = flatten_dict(params)
60
+ embed = np.array(flat[("model","embed_tokens","embedding")])
61
+ lm = np.array(flat[("lm_head","kernel")])
62
+ assert np.allclose(embed.T, lm, atol=1e-6)
63
+
64
+ def test_all_layers_present(self, qwen2_dir):
65
+ params = waxjax.load(qwen2_dir)
66
+ flat = flatten_dict(params)
67
+ for i in range(2): # NL=2 di fixture
68
+ assert ("model","layers",str(i),"self_attn","q_proj","kernel") in flat
69
+ assert ("model","layers",str(i),"mlp","gate_proj","kernel") in flat
70
+
71
+ def test_all_values_jnp_array(self, qwen2_dir):
72
+ params = waxjax.load(qwen2_dir)
73
+ flat = flatten_dict(params)
74
+ for key, val in flat.items():
75
+ assert isinstance(val, jnp.ndarray), f"{key} bukan jnp.ndarray"
76
+
77
+ def test_no_nan_or_inf(self, qwen2_dir):
78
+ params = waxjax.load(qwen2_dir)
79
+ flat = flatten_dict(params)
80
+ for key, val in flat.items():
81
+ arr = np.array(val)
82
+ assert not np.any(np.isnan(arr)), f"NaN di {key}"
83
+ assert not np.any(np.isinf(arr)), f"Inf di {key}"
@@ -0,0 +1,77 @@
1
+ # tests/test_rules.py
2
+ """Unit test untuk MappingRule — matching dan transform."""
3
+
4
+ import pytest
5
+ import numpy as np
6
+ from waxjax.core.rules import MappingRule, transpose_2d, no_op
7
+
8
+
9
+ class TestMappingRuleMatch:
10
+
11
+ def test_exact_match(self):
12
+ rule = MappingRule("lm_head.weight", rename=lambda k: k, transform="no_op")
13
+ assert rule.matches("lm_head.weight")
14
+ assert not rule.matches("lm_head.bias")
15
+
16
+ def test_single_wildcard(self):
17
+ rule = MappingRule(
18
+ "model.layers.*.self_attn.q_proj.weight",
19
+ rename=lambda k: k, transform="no_op"
20
+ )
21
+ assert rule.matches("model.layers.0.self_attn.q_proj.weight")
22
+ assert rule.matches("model.layers.11.self_attn.q_proj.weight")
23
+ # * tidak boleh match dot
24
+ assert not rule.matches("model.layers.0.1.self_attn.q_proj.weight")
25
+
26
+ def test_double_wildcard(self):
27
+ rule = MappingRule(
28
+ "model.**.weight",
29
+ rename=lambda k: k, transform="no_op"
30
+ )
31
+ assert rule.matches("model.layers.0.self_attn.q_proj.weight")
32
+ assert rule.matches("model.norm.weight")
33
+
34
+ def test_priority_auto(self):
35
+ specific = MappingRule(
36
+ "model.layers.*.self_attn_layer_norm.weight",
37
+ rename=lambda k: k, transform="no_op"
38
+ )
39
+ generic = MappingRule(
40
+ "model.layers.*.*.weight",
41
+ rename=lambda k: k, transform="no_op"
42
+ )
43
+ assert specific.priority > generic.priority
44
+
45
+ def test_invalid_transform_raises(self):
46
+ with pytest.raises(ValueError, match="tidak dikenal"):
47
+ MappingRule("*.weight", rename=lambda k: k, transform="unknown_transform")
48
+
49
+
50
+ class TestMappingRuleTransform:
51
+
52
+ def test_transpose_2d(self):
53
+ arr = np.ones((4, 8))
54
+ rule = MappingRule("*.weight", rename=lambda k: k, transform="transpose_2d")
55
+ result = rule.apply_transform(arr)
56
+ assert result.shape == (8, 4)
57
+
58
+ def test_no_transpose_1d(self):
59
+ """1D array (bias) tidak boleh di-transpose meski pakai transpose_2d."""
60
+ arr = np.ones(8)
61
+ result = transpose_2d(arr)
62
+ assert result.shape == (8,)
63
+
64
+ def test_custom_callable_transform(self):
65
+ def double(arr):
66
+ return arr * 2
67
+ rule = MappingRule("*.weight", rename=lambda k: k, transform=double)
68
+ arr = np.array([1.0, 2.0])
69
+ assert np.allclose(rule.apply_transform(arr), [2.0, 4.0])
70
+
71
+ def test_rename_applied(self):
72
+ rule = MappingRule(
73
+ "*.weight",
74
+ rename=lambda k: k.replace(".weight", ".kernel"),
75
+ transform="no_op"
76
+ )
77
+ assert rule.apply_rename("model.fc.weight") == "model.fc.kernel"
@@ -0,0 +1,16 @@
1
+ """Compatibility shim: expose the public API under the expected
2
+ package name `waxjax` while the real implementation lives in
3
+ `safejax` (to preserve the original repository layout).
4
+ """
5
+ # Re-export the public API without importing `safejax` package object
6
+ # to avoid circular-import issues during package initialization.
7
+ from .core.api import load, save, inspect, verify
8
+ from .core.rules import MappingRule
9
+ from .architectures._registry import register_architecture, ArchitectureConfig
10
+
11
+ __version__ = "0.1.0"
12
+
13
+ __all__ = [
14
+ "load", "save", "inspect", "verify",
15
+ "MappingRule", "register_architecture", "ArchitectureConfig",
16
+ ]
@@ -0,0 +1,2 @@
1
+ from . import _registry, qwen2, nllb
2
+ __all__ = ["_registry", "qwen2", "nllb"]
@@ -0,0 +1 @@
1
+ from safejax.architectures._registry import *
@@ -0,0 +1 @@
1
+ from safejax.architectures.nllb import *
@@ -0,0 +1 @@
1
+ from safejax.architectures.qwen2 import *
@@ -0,0 +1,3 @@
1
+ # waxjax.core package — wrappers re-exporting safejax.core
2
+ from . import loader, mapper, nester, rules, api
3
+ __all__ = ["loader", "mapper", "nester", "rules", "api"]
@@ -0,0 +1 @@
1
+ from safejax.core.api import *
@@ -0,0 +1 @@
1
+ from safejax.core.loader import *
@@ -0,0 +1 @@
1
+ from safejax.core.mapper import *
@@ -0,0 +1 @@
1
+ from safejax.core.nester import *
@@ -0,0 +1 @@
1
+ from safejax.core.rules import *
@@ -0,0 +1,55 @@
1
+ Metadata-Version: 2.4
2
+ Name: waxjax
3
+ Version: 0.1.0
4
+ Summary: Load HuggingFace SafeTensors models to JAX/Flax
5
+ Home-page: https://github.com/ik12-b/waxjax
6
+ Author: Your Name
7
+ License: Apache-2.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.9
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: safetensors>=0.4.0
14
+ Requires-Dist: numpy>=1.24
15
+ Requires-Dist: jax>=0.4.0
16
+ Provides-Extra: torch
17
+ Requires-Dist: torch>=2.0; extra == "torch"
18
+ Provides-Extra: flax
19
+ Requires-Dist: flax>=0.7; extra == "flax"
20
+ Provides-Extra: all
21
+ Requires-Dist: torch>=2.0; extra == "all"
22
+ Requires-Dist: flax>=0.7; extra == "all"
23
+ Requires-Dist: ml_dtypes; extra == "all"
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest; extra == "dev"
26
+ Requires-Dist: black; extra == "dev"
27
+ Requires-Dist: ruff; extra == "dev"
28
+ Dynamic: author
29
+ Dynamic: classifier
30
+ Dynamic: description
31
+ Dynamic: description-content-type
32
+ Dynamic: home-page
33
+ Dynamic: license
34
+ Dynamic: provides-extra
35
+ Dynamic: requires-dist
36
+ Dynamic: requires-python
37
+ Dynamic: summary
38
+
39
+ # waxjax
40
+
41
+ Lightweight helper to load HuggingFace SafeTensors checkpoints into JAX/Flax.
42
+
43
+ Install (latest from source):
44
+
45
+ ```bash
46
+ pip install git+https://github.com/ik12-b/waxjax.git
47
+ ```
48
+
49
+ Build and publish (recommended to test to TestPyPI first): see `RELEASE.md` or run locally:
50
+
51
+ ```bash
52
+ python -m pip install --upgrade build twine
53
+ python -m build
54
+ twine upload --repository testpypi dist/*
55
+ ```
@@ -0,0 +1,27 @@
1
+ MANIFEST.in
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ safejax/__init__.py
6
+ tests/test_api.py
7
+ tests/test_loader.py
8
+ tests/test_mapper.py
9
+ tests/test_nllb.py
10
+ tests/test_qwen2.py
11
+ tests/test_rules.py
12
+ waxjax/__init__.py
13
+ waxjax.egg-info/PKG-INFO
14
+ waxjax.egg-info/SOURCES.txt
15
+ waxjax.egg-info/dependency_links.txt
16
+ waxjax.egg-info/requires.txt
17
+ waxjax.egg-info/top_level.txt
18
+ waxjax/architectures/__init__.py
19
+ waxjax/architectures/_registry.py
20
+ waxjax/architectures/nllb.py
21
+ waxjax/architectures/qwen2.py
22
+ waxjax/core/__init__.py
23
+ waxjax/core/api.py
24
+ waxjax/core/loader.py
25
+ waxjax/core/mapper.py
26
+ waxjax/core/nester.py
27
+ waxjax/core/rules.py
@@ -0,0 +1,19 @@
1
+ safetensors>=0.4.0
2
+ numpy>=1.24
3
+ jax>=0.4.0
4
+
5
+ [all]
6
+ torch>=2.0
7
+ flax>=0.7
8
+ ml_dtypes
9
+
10
+ [dev]
11
+ pytest
12
+ black
13
+ ruff
14
+
15
+ [flax]
16
+ flax>=0.7
17
+
18
+ [torch]
19
+ torch>=2.0
@@ -0,0 +1,2 @@
1
+ safejax
2
+ waxjax