lalamo 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
2
  from lalamo.modules import Decoder
3
3
 
4
- __version__ = "0.2.5"
4
+ __version__ = "0.2.7"
5
5
 
6
6
  __all__ = [
7
7
  "REPO_TO_MODEL",
@@ -8,7 +8,6 @@ from safetensors.flax import load_file as load_safetensors
8
8
 
9
9
  from lalamo.model_import.configs import ForeignConfig
10
10
  from lalamo.quantization import QuantizationMode
11
- from lalamo.utils import torch_to_jax
12
11
 
13
12
  __all__ = [
14
13
  "HUGGINFACE_GENERATION_CONFIG_FILE",
@@ -38,6 +37,8 @@ class WeightsType(Enum):
38
37
 
39
38
  import torch
40
39
 
40
+ from lalamo.modules.torch_interop import torch_to_jax
41
+
41
42
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
42
43
  return {k: cast_if_float(torch_to_jax(v), float_dtype) for k, v in torch_weights.items()}
43
44
 
@@ -95,7 +96,7 @@ def awq_model_spec(
95
96
  )
96
97
 
97
98
 
98
- def build_quantized_models(model_specs: list[ModelSpec]):
99
+ def build_quantized_models(model_specs: list[ModelSpec]) -> list[ModelSpec]:
99
100
  quantization_compatible_repos: list[str] = [
100
101
  "Qwen/Qwen2.5-3B-Instruct",
101
102
  "Qwen/Qwen2.5-7B-Instruct",
@@ -0,0 +1,29 @@
1
+ import jax.numpy as jnp
2
+ import torch
3
+ from jaxtyping import Array
4
+
5
+ __all__ = ["jax_to_torch", "torch_to_jax"]
6
+
7
+
8
+ @torch.no_grad()
9
+ def _torch_to_jax_bfloat16(tensor: torch.Tensor) -> Array:
10
+ if tensor.dtype != torch.bfloat16:
11
+ raise ValueError("Trying to convert non-bfloat16 tensor to bfloat16")
12
+ intermediate_tensor = tensor.view(torch.uint16)
13
+ return jnp.array(intermediate_tensor).view("bfloat16")
14
+
15
+
16
+ def torch_to_jax(array: torch.Tensor) -> Array:
17
+ array = array.detach().cpu()
18
+ if array.dtype == torch.bfloat16:
19
+ return _torch_to_jax_bfloat16(array)
20
+ return jnp.array(array.numpy())
21
+
22
+
23
+ def jax_to_torch(array: Array) -> torch.Tensor:
24
+ from torch.utils import dlpack as _dlpack
25
+
26
+ if array.dtype == jnp.bfloat16:
27
+ intermediate_array = array.view(jnp.uint16)
28
+ return _dlpack.from_dlpack(intermediate_array).view(torch.bfloat16)
29
+ return _dlpack.from_dlpack(array)
lalamo/utils.py CHANGED
@@ -1,36 +1,8 @@
1
1
  import einops
2
2
  import jax.numpy as jnp
3
- import torch.utils.dlpack
4
3
  from jaxtyping import Array
5
4
 
6
- __all__ = [
7
- "jax_to_torch",
8
- "jax_uint4_to_packed_uint8",
9
- "torch_to_jax",
10
- ]
11
-
12
-
13
- @torch.no_grad()
14
- def _torch_to_jax_bfloat16(tensor: torch.Tensor) -> Array:
15
- # Credit: https://github.com/jax-ml/ml_dtypes/issues/81#issuecomment-2399636232
16
- if tensor.dtype != torch.bfloat16:
17
- raise ValueError("Trying to convert non-bfloat16 tensor to bfloat16")
18
- intermediate_tensor = tensor.view(torch.uint16)
19
- return jnp.array(intermediate_tensor).view("bfloat16")
20
-
21
-
22
- def torch_to_jax(array: torch.Tensor) -> Array:
23
- array = array.detach().cpu()
24
- if array.dtype == torch.bfloat16:
25
- return _torch_to_jax_bfloat16(array)
26
- return jnp.array(array.numpy())
27
-
28
-
29
- def jax_to_torch(array: Array) -> torch.Tensor:
30
- if array.dtype == jnp.bfloat16:
31
- intermediate_array = array.view(jnp.uint16)
32
- return torch.utils.dlpack.from_dlpack(intermediate_array).view(torch.bfloat16)
33
- return torch.utils.dlpack.from_dlpack(array)
5
+ __all__ = ["jax_uint4_to_packed_uint8"]
34
6
 
35
7
 
36
8
  def jax_uint4_to_packed_uint8(array: Array) -> Array:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,9 +1,9 @@
1
- lalamo/__init__.py,sha256=HNdYYIO9ypGyZ53AUbiDIp-i3W61ZupGATapp3rEFYQ,217
1
+ lalamo/__init__.py,sha256=pJx61SiYtLEREzNFL6L0V3TEa7F17hyj0jHuQMAZ7uw,217
2
2
  lalamo/common.py,sha256=uYLw68V4AF3zlENG3KAIKRpOFXVHv8xX_n0cc3qJnj4,1877
3
3
  lalamo/language_model.py,sha256=GiA_BDQuYCgVBFHljb_ltW_M7g3I1Siwm111M3Jc8MM,9286
4
4
  lalamo/main.py,sha256=K2RLyTcxvBCP0teSsminssj_oUkuQAQ5y9ixa1uOqas,9546
5
5
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
6
- lalamo/utils.py,sha256=QzkT0_82nd9pS5p0e7yOOdL_ZeKQr_Ftj4kFrWF35R8,1754
6
+ lalamo/utils.py,sha256=ihV9ojDMlAf2_Ja5kNZMIYLMQxpQXBlNOd9TIdMq0yM,815
7
7
  lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
8
8
  lalamo/model_import/common.py,sha256=sHXEGQUtVb6TRT5FOGtJG9pz1Ohy5v_LtunubVxZKqQ,3303
9
9
  lalamo/model_import/configs/__init__.py,sha256=JYXeco_kfzKZuWqEmG24qxeYWs-FuE1W1kNgoFNrBEw,461
@@ -22,7 +22,7 @@ lalamo/model_import/loaders/common.py,sha256=2FigeDMUwlMPUebX8DAK2Yh9aLgVtsfTj0S
22
22
  lalamo/model_import/loaders/executorch.py,sha256=nSvpylK8QL3nBk78P3FabLoyA87E3kv5CCpMfvuZe6Q,8886
23
23
  lalamo/model_import/loaders/huggingface.py,sha256=Ze_qB0fSxY8lH4ovH0t8jd5jiteasUWkS9HdgMZXCrs,10523
24
24
  lalamo/model_import/model_specs/__init__.py,sha256=_sJthAH1xXl5B9JPhRqMVP2t5KkhzqmKFHSRlOiFg8s,915
25
- lalamo/model_import/model_specs/common.py,sha256=Ob3yTMDczKUHMWBH0PaClbSvHJhKfZ-zbv2Z04YqMVg,3806
25
+ lalamo/model_import/model_specs/common.py,sha256=oPKd6kKmmUBPQD5UV_yzSsNwq6R-l3ecqBKDXiDYx8c,3850
26
26
  lalamo/model_import/model_specs/deepseek.py,sha256=9l3pVyC-ZoIaFG4xWhPDCbKkD2TsND286o0KzO0uxKo,788
27
27
  lalamo/model_import/model_specs/gemma.py,sha256=y4aDeaGGl4JPIanAgPMOlyfD_cx3Q7rpTKgDgx5AsX0,2299
28
28
  lalamo/model_import/model_specs/huggingface.py,sha256=ktDJ_qZxSGmHREydrYQaWi71bXJZiHqzHDoZeORENno,784
@@ -44,10 +44,11 @@ lalamo/modules/linear.py,sha256=loUGFu3wx-iGqDqGMphQorhqBm7b9lAqT4B0jAmoamk,2408
44
44
  lalamo/modules/mlp.py,sha256=bV8qJTjsQFGv-CA7d32UQFn6BX5zmCKWC5pgm29-W3U,2631
45
45
  lalamo/modules/normalization.py,sha256=BWCHv6ycFJ_qMGfxkusGfay9dWzUlbpuwmjbLy2rI68,2380
46
46
  lalamo/modules/rope.py,sha256=Vdt2J_W0MPDK52nHsroLVCfWMHyHW3AfrKZCZAE4VYs,9369
47
+ lalamo/modules/torch_interop.py,sha256=-mujd1zI4ec2w92Hd50RtDa0K3jl6ZSnPxc5r3Fp9nU,916
47
48
  lalamo/modules/utils.py,sha256=5QTdi34kEI5jix7TfTdB0mOYZbzZUul_T1y8eWCA6lQ,262
48
- lalamo-0.2.5.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
49
- lalamo-0.2.5.dist-info/METADATA,sha256=Uh-z7iYbNur26j9mBoH5OInfmBqXQJ2pRdSKRiru1xg,2645
50
- lalamo-0.2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
51
- lalamo-0.2.5.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
52
- lalamo-0.2.5.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
53
- lalamo-0.2.5.dist-info/RECORD,,
49
+ lalamo-0.2.7.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
50
+ lalamo-0.2.7.dist-info/METADATA,sha256=xO4NaahkCxodVGM71maSqxethUJsMTXLG-TVImtmEO4,2645
51
+ lalamo-0.2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
52
+ lalamo-0.2.7.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
53
+ lalamo-0.2.7.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
54
+ lalamo-0.2.7.dist-info/RECORD,,
File without changes