lalamo 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/model_specs/common.py +1 -1
- lalamo/modules/torch_interop.py +29 -0
- lalamo/utils.py +1 -29
- {lalamo-0.2.5.dist-info → lalamo-0.2.6.dist-info}/METADATA +1 -1
- {lalamo-0.2.5.dist-info → lalamo-0.2.6.dist-info}/RECORD +10 -9
- {lalamo-0.2.5.dist-info → lalamo-0.2.6.dist-info}/WHEEL +0 -0
- {lalamo-0.2.5.dist-info → lalamo-0.2.6.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.5.dist-info → lalamo-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.5.dist-info → lalamo-0.2.6.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -7,8 +7,8 @@ from jaxtyping import Array, DTypeLike
|
|
|
7
7
|
from safetensors.flax import load_file as load_safetensors
|
|
8
8
|
|
|
9
9
|
from lalamo.model_import.configs import ForeignConfig
|
|
10
|
+
from lalamo.modules.torch_interop import torch_to_jax
|
|
10
11
|
from lalamo.quantization import QuantizationMode
|
|
11
|
-
from lalamo.utils import torch_to_jax
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
14
14
|
"HUGGINFACE_GENERATION_CONFIG_FILE",
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import torch
|
|
3
|
+
from jaxtyping import Array
|
|
4
|
+
|
|
5
|
+
__all__ = ["jax_to_torch", "torch_to_jax"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@torch.no_grad()
|
|
9
|
+
def _torch_to_jax_bfloat16(tensor: torch.Tensor) -> Array:
|
|
10
|
+
if tensor.dtype != torch.bfloat16:
|
|
11
|
+
raise ValueError("Trying to convert non-bfloat16 tensor to bfloat16")
|
|
12
|
+
intermediate_tensor = tensor.view(torch.uint16)
|
|
13
|
+
return jnp.array(intermediate_tensor).view("bfloat16")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def torch_to_jax(array: torch.Tensor) -> Array:
|
|
17
|
+
array = array.detach().cpu()
|
|
18
|
+
if array.dtype == torch.bfloat16:
|
|
19
|
+
return _torch_to_jax_bfloat16(array)
|
|
20
|
+
return jnp.array(array.numpy())
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def jax_to_torch(array: Array) -> torch.Tensor:
|
|
24
|
+
from torch.utils import dlpack as _dlpack
|
|
25
|
+
|
|
26
|
+
if array.dtype == jnp.bfloat16:
|
|
27
|
+
intermediate_array = array.view(jnp.uint16)
|
|
28
|
+
return _dlpack.from_dlpack(intermediate_array).view(torch.bfloat16)
|
|
29
|
+
return _dlpack.from_dlpack(array)
|
lalamo/utils.py
CHANGED
|
@@ -1,36 +1,8 @@
|
|
|
1
1
|
import einops
|
|
2
2
|
import jax.numpy as jnp
|
|
3
|
-
import torch.utils.dlpack
|
|
4
3
|
from jaxtyping import Array
|
|
5
4
|
|
|
6
|
-
__all__ = [
|
|
7
|
-
"jax_to_torch",
|
|
8
|
-
"jax_uint4_to_packed_uint8",
|
|
9
|
-
"torch_to_jax",
|
|
10
|
-
]
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@torch.no_grad()
|
|
14
|
-
def _torch_to_jax_bfloat16(tensor: torch.Tensor) -> Array:
|
|
15
|
-
# Credit: https://github.com/jax-ml/ml_dtypes/issues/81#issuecomment-2399636232
|
|
16
|
-
if tensor.dtype != torch.bfloat16:
|
|
17
|
-
raise ValueError("Trying to convert non-bfloat16 tensor to bfloat16")
|
|
18
|
-
intermediate_tensor = tensor.view(torch.uint16)
|
|
19
|
-
return jnp.array(intermediate_tensor).view("bfloat16")
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def torch_to_jax(array: torch.Tensor) -> Array:
|
|
23
|
-
array = array.detach().cpu()
|
|
24
|
-
if array.dtype == torch.bfloat16:
|
|
25
|
-
return _torch_to_jax_bfloat16(array)
|
|
26
|
-
return jnp.array(array.numpy())
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def jax_to_torch(array: Array) -> torch.Tensor:
|
|
30
|
-
if array.dtype == jnp.bfloat16:
|
|
31
|
-
intermediate_array = array.view(jnp.uint16)
|
|
32
|
-
return torch.utils.dlpack.from_dlpack(intermediate_array).view(torch.bfloat16)
|
|
33
|
-
return torch.utils.dlpack.from_dlpack(array)
|
|
5
|
+
__all__ = ["jax_uint4_to_packed_uint8"]
|
|
34
6
|
|
|
35
7
|
|
|
36
8
|
def jax_uint4_to_packed_uint8(array: Array) -> Array:
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=VhtKkFHsGy_199K6ToPyIuepm2yVsi7knqRNCG5PiXk,217
|
|
2
2
|
lalamo/common.py,sha256=uYLw68V4AF3zlENG3KAIKRpOFXVHv8xX_n0cc3qJnj4,1877
|
|
3
3
|
lalamo/language_model.py,sha256=GiA_BDQuYCgVBFHljb_ltW_M7g3I1Siwm111M3Jc8MM,9286
|
|
4
4
|
lalamo/main.py,sha256=K2RLyTcxvBCP0teSsminssj_oUkuQAQ5y9ixa1uOqas,9546
|
|
5
5
|
lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
|
|
6
|
-
lalamo/utils.py,sha256=
|
|
6
|
+
lalamo/utils.py,sha256=ihV9ojDMlAf2_Ja5kNZMIYLMQxpQXBlNOd9TIdMq0yM,815
|
|
7
7
|
lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
|
|
8
8
|
lalamo/model_import/common.py,sha256=sHXEGQUtVb6TRT5FOGtJG9pz1Ohy5v_LtunubVxZKqQ,3303
|
|
9
9
|
lalamo/model_import/configs/__init__.py,sha256=JYXeco_kfzKZuWqEmG24qxeYWs-FuE1W1kNgoFNrBEw,461
|
|
@@ -22,7 +22,7 @@ lalamo/model_import/loaders/common.py,sha256=2FigeDMUwlMPUebX8DAK2Yh9aLgVtsfTj0S
|
|
|
22
22
|
lalamo/model_import/loaders/executorch.py,sha256=nSvpylK8QL3nBk78P3FabLoyA87E3kv5CCpMfvuZe6Q,8886
|
|
23
23
|
lalamo/model_import/loaders/huggingface.py,sha256=Ze_qB0fSxY8lH4ovH0t8jd5jiteasUWkS9HdgMZXCrs,10523
|
|
24
24
|
lalamo/model_import/model_specs/__init__.py,sha256=_sJthAH1xXl5B9JPhRqMVP2t5KkhzqmKFHSRlOiFg8s,915
|
|
25
|
-
lalamo/model_import/model_specs/common.py,sha256=
|
|
25
|
+
lalamo/model_import/model_specs/common.py,sha256=H56uwiyDsm6GD-1uxCm-zGBLuxE_EN84ivHi80niL0g,3822
|
|
26
26
|
lalamo/model_import/model_specs/deepseek.py,sha256=9l3pVyC-ZoIaFG4xWhPDCbKkD2TsND286o0KzO0uxKo,788
|
|
27
27
|
lalamo/model_import/model_specs/gemma.py,sha256=y4aDeaGGl4JPIanAgPMOlyfD_cx3Q7rpTKgDgx5AsX0,2299
|
|
28
28
|
lalamo/model_import/model_specs/huggingface.py,sha256=ktDJ_qZxSGmHREydrYQaWi71bXJZiHqzHDoZeORENno,784
|
|
@@ -44,10 +44,11 @@ lalamo/modules/linear.py,sha256=loUGFu3wx-iGqDqGMphQorhqBm7b9lAqT4B0jAmoamk,2408
|
|
|
44
44
|
lalamo/modules/mlp.py,sha256=bV8qJTjsQFGv-CA7d32UQFn6BX5zmCKWC5pgm29-W3U,2631
|
|
45
45
|
lalamo/modules/normalization.py,sha256=BWCHv6ycFJ_qMGfxkusGfay9dWzUlbpuwmjbLy2rI68,2380
|
|
46
46
|
lalamo/modules/rope.py,sha256=Vdt2J_W0MPDK52nHsroLVCfWMHyHW3AfrKZCZAE4VYs,9369
|
|
47
|
+
lalamo/modules/torch_interop.py,sha256=-mujd1zI4ec2w92Hd50RtDa0K3jl6ZSnPxc5r3Fp9nU,916
|
|
47
48
|
lalamo/modules/utils.py,sha256=5QTdi34kEI5jix7TfTdB0mOYZbzZUul_T1y8eWCA6lQ,262
|
|
48
|
-
lalamo-0.2.
|
|
49
|
-
lalamo-0.2.
|
|
50
|
-
lalamo-0.2.
|
|
51
|
-
lalamo-0.2.
|
|
52
|
-
lalamo-0.2.
|
|
53
|
-
lalamo-0.2.
|
|
49
|
+
lalamo-0.2.6.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
50
|
+
lalamo-0.2.6.dist-info/METADATA,sha256=3IJMF8cNEPM2dmIZfDcNlpOT-SQkvJBkZnWPxjA3CWY,2645
|
|
51
|
+
lalamo-0.2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
52
|
+
lalamo-0.2.6.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
53
|
+
lalamo-0.2.6.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
54
|
+
lalamo-0.2.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|