lalamo 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
2
  from lalamo.modules import Decoder
3
3
 
4
- __version__ = "0.2.2"
4
+ __version__ = "0.2.4"
5
5
 
6
6
  __all__ = [
7
7
  "REPO_TO_MODEL",
@@ -1,5 +1,6 @@
1
1
  from .common import ForeignConfig
2
- from .executorch import ETLlamaConfig
2
+
3
+ # from .executorch import ETLlamaConfig
3
4
  from .huggingface import (
4
5
  HFGemma2Config,
5
6
  HFGemma3Config,
@@ -11,7 +12,7 @@ from .huggingface import (
11
12
  )
12
13
 
13
14
  __all__ = [
14
- "ETLlamaConfig",
15
+ # "ETLlamaConfig",
15
16
  "ForeignConfig",
16
17
  "HFGemma2Config",
17
18
  "HFGemma3Config",
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
3
  import jax.numpy as jnp
4
4
  from jaxtyping import Array, DTypeLike
5
5
 
6
- from lalamo.model_import.loaders import load_executorch
6
+ from lalamo.model_import.loaders.executorch import load_executorch
7
7
  from lalamo.modules import (
8
8
  Activation,
9
9
  AttentionConfig,
@@ -1,7 +1,7 @@
1
- from .executorch import load_executorch
1
+ # from .executorch import load_executorch
2
2
  from .huggingface import load_huggingface
3
3
 
4
4
  __all__ = [
5
- "load_executorch",
5
+ # "load_executorch",
6
6
  "load_huggingface",
7
7
  ]
@@ -3,7 +3,6 @@ from enum import Enum
3
3
  from pathlib import Path
4
4
 
5
5
  import jax.numpy as jnp
6
- import torch
7
6
  from jaxtyping import Array, DTypeLike
8
7
  from safetensors.flax import load_file as load_safetensors
9
8
 
@@ -17,9 +16,9 @@ __all__ = [
17
16
  "ModelSpec",
18
17
  "TokenizerFileSpec",
19
18
  "UseCase",
20
- "huggingface_weight_files",
21
19
  "awq_model_spec",
22
20
  "build_quantized_models",
21
+ "huggingface_weight_files",
23
22
  ]
24
23
 
25
24
 
@@ -36,6 +35,9 @@ class WeightsType(Enum):
36
35
  def load(self, filename: Path | str, float_dtype: DTypeLike) -> dict[str, jnp.ndarray]:
37
36
  if self == WeightsType.SAFETENSORS:
38
37
  return {k: cast_if_float(v, float_dtype) for k, v in load_safetensors(filename).items()}
38
+
39
+ import torch
40
+
39
41
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
40
42
  return {k: cast_if_float(torch_to_jax(v), float_dtype) for k, v in torch_weights.items()}
41
43
 
@@ -72,11 +74,15 @@ def huggingface_weight_files(num_shards: int) -> tuple[str, ...]:
72
74
  return tuple(f"model-{i:05d}-of-{num_shards:05d}.safetensors" for i in range(1, num_shards + 1))
73
75
 
74
76
 
75
- def awq_model_spec(model_spec: ModelSpec, repo: str, quantization: QuantizationMode = QuantizationMode.UINT4) -> ModelSpec:
77
+ def awq_model_spec(
78
+ model_spec: ModelSpec,
79
+ repo: str,
80
+ quantization: QuantizationMode = QuantizationMode.UINT4,
81
+ ) -> ModelSpec:
76
82
  return ModelSpec(
77
83
  vendor=model_spec.vendor,
78
84
  family=model_spec.family,
79
- name="{}-AWQ".format(model_spec.name),
85
+ name=f"{model_spec.name}-AWQ",
80
86
  size=model_spec.size,
81
87
  quantization=quantization,
82
88
  repo=repo,
@@ -1,7 +1,6 @@
1
1
  from dataclasses import replace
2
2
 
3
- from lalamo.model_import.configs import ETLlamaConfig, HFLlamaConfig
4
- from lalamo.quantization import QuantizationMode
3
+ from lalamo.model_import.configs import HFLlamaConfig
5
4
 
6
5
  from .common import (
7
6
  HUGGINFACE_GENERATION_CONFIG_FILE,
@@ -54,20 +53,20 @@ LLAMA32 = [
54
53
  tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
55
54
  use_cases=tuple(),
56
55
  ),
57
- ModelSpec(
58
- vendor="Meta",
59
- family="Llama-3.2",
60
- name="Llama-3.2-1B-Instruct-QLoRA",
61
- size="1B",
62
- quantization=QuantizationMode.UINT4,
63
- repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
64
- config_type=ETLlamaConfig,
65
- config_file_name="params.json",
66
- weights_file_names=("consolidated.00.pth",),
67
- weights_type=WeightsType.TORCH,
68
- tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-1B-Instruct"),
69
- use_cases=tuple(),
70
- ),
56
+ # ModelSpec(
57
+ # vendor="Meta",
58
+ # family="Llama-3.2",
59
+ # name="Llama-3.2-1B-Instruct-QLoRA",
60
+ # size="1B",
61
+ # quantization=QuantizationMode.UINT4,
62
+ # repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
63
+ # config_type=ETLlamaConfig,
64
+ # config_file_name="params.json",
65
+ # weights_file_names=("consolidated.00.pth",),
66
+ # weights_type=WeightsType.TORCH,
67
+ # tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-1B-Instruct"),
68
+ # use_cases=tuple(),
69
+ # ),
71
70
  ModelSpec(
72
71
  vendor="Meta",
73
72
  family="Llama-3.2",
@@ -82,20 +81,20 @@ LLAMA32 = [
82
81
  tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
83
82
  use_cases=tuple(),
84
83
  ),
85
- ModelSpec(
86
- vendor="Meta",
87
- family="Llama-3.2",
88
- name="Llama-3.2-3B-Instruct-QLoRA",
89
- size="3B",
90
- quantization=QuantizationMode.UINT4,
91
- repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
92
- config_type=ETLlamaConfig,
93
- config_file_name="params.json",
94
- weights_file_names=("consolidated.00.pth",),
95
- tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-3B-Instruct"),
96
- weights_type=WeightsType.TORCH,
97
- use_cases=tuple(),
98
- ),
84
+ # ModelSpec(
85
+ # vendor="Meta",
86
+ # family="Llama-3.2",
87
+ # name="Llama-3.2-3B-Instruct-QLoRA",
88
+ # size="3B",
89
+ # quantization=QuantizationMode.UINT4,
90
+ # repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
91
+ # config_type=ETLlamaConfig,
92
+ # config_file_name="params.json",
93
+ # weights_file_names=("consolidated.00.pth",),
94
+ # tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-3B-Instruct"),
95
+ # weights_type=WeightsType.TORCH,
96
+ # use_cases=tuple(),
97
+ # ),
99
98
  ]
100
99
 
101
100
  LLAMA_MODELS = LLAMA31 + LLAMA32
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,4 +1,4 @@
1
- lalamo/__init__.py,sha256=1n--wwvN86epEr8MSH_-qSZPiHPDNRg45XRCMg6aV0o,217
1
+ lalamo/__init__.py,sha256=mEbbq3bHm0JhMKHAv7egBwAjSSpWpmxYeTeo9df4X8o,217
2
2
  lalamo/common.py,sha256=uYLw68V4AF3zlENG3KAIKRpOFXVHv8xX_n0cc3qJnj4,1877
3
3
  lalamo/language_model.py,sha256=GiA_BDQuYCgVBFHljb_ltW_M7g3I1Siwm111M3Jc8MM,9286
4
4
  lalamo/main.py,sha256=K2RLyTcxvBCP0teSsminssj_oUkuQAQ5y9ixa1uOqas,9546
@@ -6,9 +6,9 @@ lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
6
6
  lalamo/utils.py,sha256=QzkT0_82nd9pS5p0e7yOOdL_ZeKQr_Ftj4kFrWF35R8,1754
7
7
  lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
8
8
  lalamo/model_import/common.py,sha256=sHXEGQUtVb6TRT5FOGtJG9pz1Ohy5v_LtunubVxZKqQ,3303
9
- lalamo/model_import/configs/__init__.py,sha256=AbVkVT4tiofvQKym2maTv0dazEbalVrqFZjxqyVzk5o,456
9
+ lalamo/model_import/configs/__init__.py,sha256=JYXeco_kfzKZuWqEmG24qxeYWs-FuE1W1kNgoFNrBEw,461
10
10
  lalamo/model_import/configs/common.py,sha256=MKAinEL7WXkijS3IrfiTRgx2l6otpnIaJG_CajosMCU,1803
11
- lalamo/model_import/configs/executorch.py,sha256=yOa3mdM2FK3xahRclqhDbpste1DAW2kps2f-GgKnrfc,5255
11
+ lalamo/model_import/configs/executorch.py,sha256=Kx_T-B5jumfWf9vj20We4FF0GkSkTmIYeWOss88-qYA,5266
12
12
  lalamo/model_import/configs/huggingface/__init__.py,sha256=kWHUnZDwGQCbA3Ucm-FEDr8zZ2yZ3yviPVftlNgMk30,460
13
13
  lalamo/model_import/configs/huggingface/common.py,sha256=p6oEKIT2Ezh_d8eDXYzHaJaqjPriQrAzz2bkEq_HkgY,1698
14
14
  lalamo/model_import/configs/huggingface/gemma2.py,sha256=oIefI_ad-7DtzXmisFczkKPuOQ-KkzMkKWTk9likaMs,4101
@@ -17,16 +17,16 @@ lalamo/model_import/configs/huggingface/llama.py,sha256=_vOalgc24uhMcPyCqyxWOZk8
17
17
  lalamo/model_import/configs/huggingface/mistral.py,sha256=39qsX_Twml8C0xz0CayVZse2uaHJtKS9-54B8nQw_5k,4148
18
18
  lalamo/model_import/configs/huggingface/qwen2.py,sha256=GnO1_DKDewiB4AW8lJu_x30lL-GgB9GYc64rl6XqfYI,4963
19
19
  lalamo/model_import/configs/huggingface/qwen3.py,sha256=UJ-EP0geHmGXnT_Ioy7Z7V4vns_dKz2YpPe-GLPQg20,5029
20
- lalamo/model_import/loaders/__init__.py,sha256=tocl2MQcMH5mLBkkGwdufDjscDvzbQ24Qz37_vKz1_o,144
20
+ lalamo/model_import/loaders/__init__.py,sha256=Olg7a79phusilNgEa7PTgx1JgQQJLgAVg18T8isp0mw,148
21
21
  lalamo/model_import/loaders/common.py,sha256=2FigeDMUwlMPUebX8DAK2Yh9aLgVtsfTj0S431p7A0o,1782
22
22
  lalamo/model_import/loaders/executorch.py,sha256=nSvpylK8QL3nBk78P3FabLoyA87E3kv5CCpMfvuZe6Q,8886
23
23
  lalamo/model_import/loaders/huggingface.py,sha256=Ze_qB0fSxY8lH4ovH0t8jd5jiteasUWkS9HdgMZXCrs,10523
24
24
  lalamo/model_import/model_specs/__init__.py,sha256=_sJthAH1xXl5B9JPhRqMVP2t5KkhzqmKFHSRlOiFg8s,915
25
- lalamo/model_import/model_specs/common.py,sha256=ygfNjwVZBrjNkCVuv66R1vy5hXjgbAJyDc0QJfRfgik,3789
25
+ lalamo/model_import/model_specs/common.py,sha256=Ob3yTMDczKUHMWBH0PaClbSvHJhKfZ-zbv2Z04YqMVg,3806
26
26
  lalamo/model_import/model_specs/deepseek.py,sha256=9l3pVyC-ZoIaFG4xWhPDCbKkD2TsND286o0KzO0uxKo,788
27
27
  lalamo/model_import/model_specs/gemma.py,sha256=y4aDeaGGl4JPIanAgPMOlyfD_cx3Q7rpTKgDgx5AsX0,2299
28
28
  lalamo/model_import/model_specs/huggingface.py,sha256=ktDJ_qZxSGmHREydrYQaWi71bXJZiHqzHDoZeORENno,784
29
- lalamo/model_import/model_specs/llama.py,sha256=oPnHw8qV2l_cfQcW6OPTfehatP-ovLMPppIZVJ8yOWI,3234
29
+ lalamo/model_import/model_specs/llama.py,sha256=7eXfMwj_VZpeHAuXmPk1jcA_X7iXsJ8AWf6pk_Qy7rg,3226
30
30
  lalamo/model_import/model_specs/mistral.py,sha256=xDX2SyTruGR7A8LI_Ypa6qAP5nVyYhxLffoxS2F6bmI,1649
31
31
  lalamo/model_import/model_specs/pleias.py,sha256=zLRjmT6PXFtykqSYpaRtVObP306urMjF2J6dTKdAbQM,747
32
32
  lalamo/model_import/model_specs/polaris.py,sha256=TiGlXI3j7HP9bs01jdcysBNFxvNKnxTF30wuv5Jg2mQ,768
@@ -45,9 +45,9 @@ lalamo/modules/mlp.py,sha256=bV8qJTjsQFGv-CA7d32UQFn6BX5zmCKWC5pgm29-W3U,2631
45
45
  lalamo/modules/normalization.py,sha256=BWCHv6ycFJ_qMGfxkusGfay9dWzUlbpuwmjbLy2rI68,2380
46
46
  lalamo/modules/rope.py,sha256=Vdt2J_W0MPDK52nHsroLVCfWMHyHW3AfrKZCZAE4VYs,9369
47
47
  lalamo/modules/utils.py,sha256=5QTdi34kEI5jix7TfTdB0mOYZbzZUul_T1y8eWCA6lQ,262
48
- lalamo-0.2.2.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
49
- lalamo-0.2.2.dist-info/METADATA,sha256=6vrCJSOr_hGJeCWIA46DbL6OnoEL3rdK3xYrxeqeVRo,2611
50
- lalamo-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
51
- lalamo-0.2.2.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
52
- lalamo-0.2.2.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
53
- lalamo-0.2.2.dist-info/RECORD,,
48
+ lalamo-0.2.4.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
49
+ lalamo-0.2.4.dist-info/METADATA,sha256=mTCoEZB9eNgl86j-CSoT8YmFboXd5SUm4IC0YLgxBuk,2611
50
+ lalamo-0.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
51
+ lalamo-0.2.4.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
52
+ lalamo-0.2.4.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
53
+ lalamo-0.2.4.dist-info/RECORD,,
File without changes