lalamo 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {lalamo-0.2.2 → lalamo-0.2.4}/PKG-INFO +1 -1
  2. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/__init__.py +1 -1
  3. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/__init__.py +3 -2
  4. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/executorch.py +1 -1
  5. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/loaders/__init__.py +2 -2
  6. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/common.py +10 -4
  7. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/llama.py +29 -30
  8. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo.egg-info/PKG-INFO +1 -1
  9. {lalamo-0.2.2 → lalamo-0.2.4}/LICENSE +0 -0
  10. {lalamo-0.2.2 → lalamo-0.2.4}/README.md +0 -0
  11. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/common.py +0 -0
  12. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/language_model.py +0 -0
  13. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/main.py +0 -0
  14. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/__init__.py +0 -0
  15. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/common.py +0 -0
  16. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/common.py +0 -0
  17. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/huggingface/__init__.py +0 -0
  18. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/huggingface/common.py +0 -0
  19. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/huggingface/gemma2.py +0 -0
  20. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/huggingface/gemma3.py +0 -0
  21. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/huggingface/llama.py +0 -0
  22. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/huggingface/mistral.py +0 -0
  23. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/huggingface/qwen2.py +0 -0
  24. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/configs/huggingface/qwen3.py +0 -0
  25. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/loaders/common.py +0 -0
  26. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/loaders/executorch.py +0 -0
  27. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/loaders/huggingface.py +0 -0
  28. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/__init__.py +0 -0
  29. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/deepseek.py +0 -0
  30. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/gemma.py +0 -0
  31. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/huggingface.py +0 -0
  32. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/mistral.py +0 -0
  33. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/pleias.py +0 -0
  34. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/polaris.py +0 -0
  35. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/qwen.py +0 -0
  36. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/model_import/model_specs/reka.py +0 -0
  37. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/__init__.py +0 -0
  38. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/activations.py +0 -0
  39. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/attention.py +0 -0
  40. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/common.py +0 -0
  41. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/decoder.py +0 -0
  42. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/decoder_layer.py +0 -0
  43. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/embedding.py +0 -0
  44. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/kv_cache.py +0 -0
  45. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/linear.py +0 -0
  46. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/mlp.py +0 -0
  47. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/normalization.py +0 -0
  48. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/rope.py +0 -0
  49. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/modules/utils.py +0 -0
  50. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/quantization.py +0 -0
  51. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo/utils.py +0 -0
  52. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo.egg-info/SOURCES.txt +0 -0
  53. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo.egg-info/dependency_links.txt +0 -0
  54. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo.egg-info/entry_points.txt +0 -0
  55. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo.egg-info/requires.txt +0 -0
  56. {lalamo-0.2.2 → lalamo-0.2.4}/lalamo.egg-info/top_level.txt +0 -0
  57. {lalamo-0.2.2 → lalamo-0.2.4}/pyproject.toml +0 -0
  58. {lalamo-0.2.2 → lalamo-0.2.4}/setup.cfg +0 -0
  59. {lalamo-0.2.2 → lalamo-0.2.4}/tests/test_generation.py +0 -0
  60. {lalamo-0.2.2 → lalamo-0.2.4}/tests/test_huggingface_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,7 +1,7 @@
1
1
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
2
  from lalamo.modules import Decoder
3
3
 
4
- __version__ = "0.2.2"
4
+ __version__ = "0.2.4"
5
5
 
6
6
  __all__ = [
7
7
  "REPO_TO_MODEL",
@@ -1,5 +1,6 @@
1
1
  from .common import ForeignConfig
2
- from .executorch import ETLlamaConfig
2
+
3
+ # from .executorch import ETLlamaConfig
3
4
  from .huggingface import (
4
5
  HFGemma2Config,
5
6
  HFGemma3Config,
@@ -11,7 +12,7 @@ from .huggingface import (
11
12
  )
12
13
 
13
14
  __all__ = [
14
- "ETLlamaConfig",
15
+ # "ETLlamaConfig",
15
16
  "ForeignConfig",
16
17
  "HFGemma2Config",
17
18
  "HFGemma3Config",
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
3
  import jax.numpy as jnp
4
4
  from jaxtyping import Array, DTypeLike
5
5
 
6
- from lalamo.model_import.loaders import load_executorch
6
+ from lalamo.model_import.loaders.executorch import load_executorch
7
7
  from lalamo.modules import (
8
8
  Activation,
9
9
  AttentionConfig,
@@ -1,7 +1,7 @@
1
- from .executorch import load_executorch
1
+ # from .executorch import load_executorch
2
2
  from .huggingface import load_huggingface
3
3
 
4
4
  __all__ = [
5
- "load_executorch",
5
+ # "load_executorch",
6
6
  "load_huggingface",
7
7
  ]
@@ -3,7 +3,6 @@ from enum import Enum
3
3
  from pathlib import Path
4
4
 
5
5
  import jax.numpy as jnp
6
- import torch
7
6
  from jaxtyping import Array, DTypeLike
8
7
  from safetensors.flax import load_file as load_safetensors
9
8
 
@@ -17,9 +16,9 @@ __all__ = [
17
16
  "ModelSpec",
18
17
  "TokenizerFileSpec",
19
18
  "UseCase",
20
- "huggingface_weight_files",
21
19
  "awq_model_spec",
22
20
  "build_quantized_models",
21
+ "huggingface_weight_files",
23
22
  ]
24
23
 
25
24
 
@@ -36,6 +35,9 @@ class WeightsType(Enum):
36
35
  def load(self, filename: Path | str, float_dtype: DTypeLike) -> dict[str, jnp.ndarray]:
37
36
  if self == WeightsType.SAFETENSORS:
38
37
  return {k: cast_if_float(v, float_dtype) for k, v in load_safetensors(filename).items()}
38
+
39
+ import torch
40
+
39
41
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
40
42
  return {k: cast_if_float(torch_to_jax(v), float_dtype) for k, v in torch_weights.items()}
41
43
 
@@ -72,11 +74,15 @@ def huggingface_weight_files(num_shards: int) -> tuple[str, ...]:
72
74
  return tuple(f"model-{i:05d}-of-{num_shards:05d}.safetensors" for i in range(1, num_shards + 1))
73
75
 
74
76
 
75
- def awq_model_spec(model_spec: ModelSpec, repo: str, quantization: QuantizationMode = QuantizationMode.UINT4) -> ModelSpec:
77
+ def awq_model_spec(
78
+ model_spec: ModelSpec,
79
+ repo: str,
80
+ quantization: QuantizationMode = QuantizationMode.UINT4,
81
+ ) -> ModelSpec:
76
82
  return ModelSpec(
77
83
  vendor=model_spec.vendor,
78
84
  family=model_spec.family,
79
- name="{}-AWQ".format(model_spec.name),
85
+ name=f"{model_spec.name}-AWQ",
80
86
  size=model_spec.size,
81
87
  quantization=quantization,
82
88
  repo=repo,
@@ -1,7 +1,6 @@
1
1
  from dataclasses import replace
2
2
 
3
- from lalamo.model_import.configs import ETLlamaConfig, HFLlamaConfig
4
- from lalamo.quantization import QuantizationMode
3
+ from lalamo.model_import.configs import HFLlamaConfig
5
4
 
6
5
  from .common import (
7
6
  HUGGINFACE_GENERATION_CONFIG_FILE,
@@ -54,20 +53,20 @@ LLAMA32 = [
54
53
  tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
55
54
  use_cases=tuple(),
56
55
  ),
57
- ModelSpec(
58
- vendor="Meta",
59
- family="Llama-3.2",
60
- name="Llama-3.2-1B-Instruct-QLoRA",
61
- size="1B",
62
- quantization=QuantizationMode.UINT4,
63
- repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
64
- config_type=ETLlamaConfig,
65
- config_file_name="params.json",
66
- weights_file_names=("consolidated.00.pth",),
67
- weights_type=WeightsType.TORCH,
68
- tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-1B-Instruct"),
69
- use_cases=tuple(),
70
- ),
56
+ # ModelSpec(
57
+ # vendor="Meta",
58
+ # family="Llama-3.2",
59
+ # name="Llama-3.2-1B-Instruct-QLoRA",
60
+ # size="1B",
61
+ # quantization=QuantizationMode.UINT4,
62
+ # repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
63
+ # config_type=ETLlamaConfig,
64
+ # config_file_name="params.json",
65
+ # weights_file_names=("consolidated.00.pth",),
66
+ # weights_type=WeightsType.TORCH,
67
+ # tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-1B-Instruct"),
68
+ # use_cases=tuple(),
69
+ # ),
71
70
  ModelSpec(
72
71
  vendor="Meta",
73
72
  family="Llama-3.2",
@@ -82,20 +81,20 @@ LLAMA32 = [
82
81
  tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
83
82
  use_cases=tuple(),
84
83
  ),
85
- ModelSpec(
86
- vendor="Meta",
87
- family="Llama-3.2",
88
- name="Llama-3.2-3B-Instruct-QLoRA",
89
- size="3B",
90
- quantization=QuantizationMode.UINT4,
91
- repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
92
- config_type=ETLlamaConfig,
93
- config_file_name="params.json",
94
- weights_file_names=("consolidated.00.pth",),
95
- tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-3B-Instruct"),
96
- weights_type=WeightsType.TORCH,
97
- use_cases=tuple(),
98
- ),
84
+ # ModelSpec(
85
+ # vendor="Meta",
86
+ # family="Llama-3.2",
87
+ # name="Llama-3.2-3B-Instruct-QLoRA",
88
+ # size="3B",
89
+ # quantization=QuantizationMode.UINT4,
90
+ # repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
91
+ # config_type=ETLlamaConfig,
92
+ # config_file_name="params.json",
93
+ # weights_file_names=("consolidated.00.pth",),
94
+ # tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-3B-Instruct"),
95
+ # weights_type=WeightsType.TORCH,
96
+ # use_cases=tuple(),
97
+ # ),
99
98
  ]
100
99
 
101
100
  LLAMA_MODELS = LLAMA31 + LLAMA32
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes