lalamo 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {lalamo-0.2.4 → lalamo-0.2.6}/PKG-INFO +2 -1
  2. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/__init__.py +1 -1
  3. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/common.py +1 -1
  4. lalamo-0.2.6/lalamo/modules/torch_interop.py +29 -0
  5. lalamo-0.2.6/lalamo/utils.py +27 -0
  6. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo.egg-info/PKG-INFO +2 -1
  7. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo.egg-info/SOURCES.txt +1 -0
  8. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo.egg-info/requires.txt +1 -0
  9. {lalamo-0.2.4 → lalamo-0.2.6}/pyproject.toml +1 -0
  10. lalamo-0.2.4/lalamo/utils.py +0 -55
  11. {lalamo-0.2.4 → lalamo-0.2.6}/LICENSE +0 -0
  12. {lalamo-0.2.4 → lalamo-0.2.6}/README.md +0 -0
  13. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/common.py +0 -0
  14. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/language_model.py +0 -0
  15. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/main.py +0 -0
  16. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/__init__.py +0 -0
  17. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/common.py +0 -0
  18. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/__init__.py +0 -0
  19. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/common.py +0 -0
  20. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/executorch.py +0 -0
  21. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/huggingface/__init__.py +0 -0
  22. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/huggingface/common.py +0 -0
  23. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/huggingface/gemma2.py +0 -0
  24. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/huggingface/gemma3.py +0 -0
  25. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/huggingface/llama.py +0 -0
  26. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/huggingface/mistral.py +0 -0
  27. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/huggingface/qwen2.py +0 -0
  28. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/configs/huggingface/qwen3.py +0 -0
  29. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/loaders/__init__.py +0 -0
  30. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/loaders/common.py +0 -0
  31. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/loaders/executorch.py +0 -0
  32. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/loaders/huggingface.py +0 -0
  33. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/__init__.py +0 -0
  34. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/deepseek.py +0 -0
  35. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/gemma.py +0 -0
  36. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/huggingface.py +0 -0
  37. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/llama.py +0 -0
  38. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/mistral.py +0 -0
  39. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/pleias.py +0 -0
  40. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/polaris.py +0 -0
  41. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/qwen.py +0 -0
  42. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/model_import/model_specs/reka.py +0 -0
  43. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/__init__.py +0 -0
  44. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/activations.py +0 -0
  45. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/attention.py +0 -0
  46. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/common.py +0 -0
  47. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/decoder.py +0 -0
  48. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/decoder_layer.py +0 -0
  49. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/embedding.py +0 -0
  50. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/kv_cache.py +0 -0
  51. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/linear.py +0 -0
  52. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/mlp.py +0 -0
  53. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/normalization.py +0 -0
  54. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/rope.py +0 -0
  55. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/modules/utils.py +0 -0
  56. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo/quantization.py +0 -0
  57. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo.egg-info/dependency_links.txt +0 -0
  58. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo.egg-info/entry_points.txt +0 -0
  59. {lalamo-0.2.4 → lalamo-0.2.6}/lalamo.egg-info/top_level.txt +0 -0
  60. {lalamo-0.2.4 → lalamo-0.2.6}/setup.cfg +0 -0
  61. {lalamo-0.2.4 → lalamo-0.2.6}/tests/test_generation.py +0 -0
  62. {lalamo-0.2.4 → lalamo-0.2.6}/tests/test_huggingface_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -18,6 +18,7 @@ Requires-Dist: optax>=0.2.4
18
18
  Requires-Dist: rich>=14.0.0
19
19
  Requires-Dist: thefuzz>=0.22.1
20
20
  Requires-Dist: typer>=0.15.1
21
+ Requires-Dist: safetensors>=0.6.2
21
22
  Dynamic: license-file
22
23
 
23
24
  <p align="center">
@@ -1,7 +1,7 @@
1
1
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
2
  from lalamo.modules import Decoder
3
3
 
4
- __version__ = "0.2.4"
4
+ __version__ = "0.2.6"
5
5
 
6
6
  __all__ = [
7
7
  "REPO_TO_MODEL",
@@ -7,8 +7,8 @@ from jaxtyping import Array, DTypeLike
7
7
  from safetensors.flax import load_file as load_safetensors
8
8
 
9
9
  from lalamo.model_import.configs import ForeignConfig
10
+ from lalamo.modules.torch_interop import torch_to_jax
10
11
  from lalamo.quantization import QuantizationMode
11
- from lalamo.utils import torch_to_jax
12
12
 
13
13
  __all__ = [
14
14
  "HUGGINFACE_GENERATION_CONFIG_FILE",
@@ -0,0 +1,29 @@
1
+ import jax.numpy as jnp
2
+ import torch
3
+ from jaxtyping import Array
4
+
5
+ __all__ = ["jax_to_torch", "torch_to_jax"]
6
+
7
+
8
+ @torch.no_grad()
9
+ def _torch_to_jax_bfloat16(tensor: torch.Tensor) -> Array:
10
+ if tensor.dtype != torch.bfloat16:
11
+ raise ValueError("Trying to convert non-bfloat16 tensor to bfloat16")
12
+ intermediate_tensor = tensor.view(torch.uint16)
13
+ return jnp.array(intermediate_tensor).view("bfloat16")
14
+
15
+
16
+ def torch_to_jax(array: torch.Tensor) -> Array:
17
+ array = array.detach().cpu()
18
+ if array.dtype == torch.bfloat16:
19
+ return _torch_to_jax_bfloat16(array)
20
+ return jnp.array(array.numpy())
21
+
22
+
23
+ def jax_to_torch(array: Array) -> torch.Tensor:
24
+ from torch.utils import dlpack as _dlpack
25
+
26
+ if array.dtype == jnp.bfloat16:
27
+ intermediate_array = array.view(jnp.uint16)
28
+ return _dlpack.from_dlpack(intermediate_array).view(torch.bfloat16)
29
+ return _dlpack.from_dlpack(array)
@@ -0,0 +1,27 @@
1
+ import einops
2
+ import jax.numpy as jnp
3
+ from jaxtyping import Array
4
+
5
+ __all__ = ["jax_uint4_to_packed_uint8"]
6
+
7
+
8
+ def jax_uint4_to_packed_uint8(array: Array) -> Array:
9
+ if array.dtype != jnp.uint4:
10
+ raise ValueError(f"Input array must have dtype jnp.uint4, but got {array.dtype}")
11
+
12
+ if not array.shape:
13
+ raise ValueError("Input array cannot be a scalar and must have at least one dimension.")
14
+
15
+ *_, last_dim = array.shape
16
+ if last_dim % 2 != 0:
17
+ raise ValueError(f"The last dimension of the input array must be even, but got shape {array.shape}")
18
+
19
+ low_nibbles, high_nibbles = einops.rearrange(
20
+ array.astype(jnp.uint8),
21
+ "... (dim_half two) -> two ... dim_half",
22
+ two=2,
23
+ )
24
+
25
+ packed = (high_nibbles << 4) | low_nibbles
26
+
27
+ return packed.astype(jnp.uint8)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -18,6 +18,7 @@ Requires-Dist: optax>=0.2.4
18
18
  Requires-Dist: rich>=14.0.0
19
19
  Requires-Dist: thefuzz>=0.22.1
20
20
  Requires-Dist: typer>=0.15.1
21
+ Requires-Dist: safetensors>=0.6.2
21
22
  Dynamic: license-file
22
23
 
23
24
  <p align="center">
@@ -53,6 +53,7 @@ lalamo/modules/linear.py
53
53
  lalamo/modules/mlp.py
54
54
  lalamo/modules/normalization.py
55
55
  lalamo/modules/rope.py
56
+ lalamo/modules/torch_interop.py
56
57
  lalamo/modules/utils.py
57
58
  tests/test_generation.py
58
59
  tests/test_huggingface_models.py
@@ -9,6 +9,7 @@ optax>=0.2.4
9
9
  rich>=14.0.0
10
10
  thefuzz>=0.22.1
11
11
  typer>=0.15.1
12
+ safetensors>=0.6.2
12
13
 
13
14
  [:sys_platform == "darwin"]
14
15
  jax>=0.4.38
@@ -16,6 +16,7 @@ dependencies = [
16
16
  "rich>=14.0.0",
17
17
  "thefuzz>=0.22.1",
18
18
  "typer>=0.15.1",
19
+ "safetensors>=0.6.2",
19
20
  # "jax-metal>=0.1.1 ; sys_platform == 'darwin'",
20
21
  ]
21
22
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,55 +0,0 @@
1
- import einops
2
- import jax.numpy as jnp
3
- import torch.utils.dlpack
4
- from jaxtyping import Array
5
-
6
- __all__ = [
7
- "jax_to_torch",
8
- "jax_uint4_to_packed_uint8",
9
- "torch_to_jax",
10
- ]
11
-
12
-
13
- @torch.no_grad()
14
- def _torch_to_jax_bfloat16(tensor: torch.Tensor) -> Array:
15
- # Credit: https://github.com/jax-ml/ml_dtypes/issues/81#issuecomment-2399636232
16
- if tensor.dtype != torch.bfloat16:
17
- raise ValueError("Trying to convert non-bfloat16 tensor to bfloat16")
18
- intermediate_tensor = tensor.view(torch.uint16)
19
- return jnp.array(intermediate_tensor).view("bfloat16")
20
-
21
-
22
- def torch_to_jax(array: torch.Tensor) -> Array:
23
- array = array.detach().cpu()
24
- if array.dtype == torch.bfloat16:
25
- return _torch_to_jax_bfloat16(array)
26
- return jnp.array(array.numpy())
27
-
28
-
29
- def jax_to_torch(array: Array) -> torch.Tensor:
30
- if array.dtype == jnp.bfloat16:
31
- intermediate_array = array.view(jnp.uint16)
32
- return torch.utils.dlpack.from_dlpack(intermediate_array).view(torch.bfloat16)
33
- return torch.utils.dlpack.from_dlpack(array)
34
-
35
-
36
- def jax_uint4_to_packed_uint8(array: Array) -> Array:
37
- if array.dtype != jnp.uint4:
38
- raise ValueError(f"Input array must have dtype jnp.uint4, but got {array.dtype}")
39
-
40
- if not array.shape:
41
- raise ValueError("Input array cannot be a scalar and must have at least one dimension.")
42
-
43
- *_, last_dim = array.shape
44
- if last_dim % 2 != 0:
45
- raise ValueError(f"The last dimension of the input array must be even, but got shape {array.shape}")
46
-
47
- low_nibbles, high_nibbles = einops.rearrange(
48
- array.astype(jnp.uint8),
49
- "... (dim_half two) -> two ... dim_half",
50
- two=2,
51
- )
52
-
53
- packed = (high_nibbles << 4) | low_nibbles
54
-
55
- return packed.astype(jnp.uint8)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes