lalamo 0.2.6__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.2.6 → lalamo-0.3.0}/PKG-INFO +6 -6
- {lalamo-0.2.6 → lalamo-0.3.0}/README.md +4 -6
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/__init__.py +1 -1
- lalamo-0.3.0/lalamo/common.py +110 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/language_model.py +106 -83
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/main.py +91 -18
- lalamo-0.3.0/lalamo/message_processor.py +170 -0
- lalamo-0.3.0/lalamo/model_import/common.py +227 -0
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/__init__.py +0 -1
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/common.py +11 -10
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/huggingface/common.py +9 -4
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/huggingface/gemma3.py +2 -2
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/huggingface/llama.py +2 -2
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/huggingface/mistral.py +1 -1
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/huggingface/qwen2.py +1 -1
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/huggingface/qwen3.py +1 -1
- lalamo-0.3.0/lalamo/model_import/huggingface_generation_config.py +44 -0
- lalamo-0.3.0/lalamo/model_import/huggingface_tokenizer_config.py +85 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/model_import/loaders/common.py +2 -1
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/model_import/loaders/huggingface.py +12 -10
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/model_import/model_specs/__init__.py +3 -2
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/model_import/model_specs/common.py +32 -33
- lalamo-0.3.0/lalamo/model_import/model_specs/deepseek.py +19 -0
- lalamo-0.3.0/lalamo/model_import/model_specs/gemma.py +53 -0
- lalamo-0.3.0/lalamo/model_import/model_specs/huggingface.py +18 -0
- lalamo-0.3.0/lalamo/model_import/model_specs/llama.py +44 -0
- lalamo-0.3.0/lalamo/model_import/model_specs/mistral.py +49 -0
- lalamo-0.3.0/lalamo/model_import/model_specs/pleias.py +18 -0
- lalamo-0.3.0/lalamo/model_import/model_specs/polaris.py +20 -0
- lalamo-0.3.0/lalamo/model_import/model_specs/qwen.py +237 -0
- lalamo-0.3.0/lalamo/model_import/model_specs/reka.py +19 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/__init__.py +2 -1
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/attention.py +90 -10
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/common.py +51 -4
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/decoder.py +90 -8
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/decoder_layer.py +85 -8
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/embedding.py +95 -29
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/kv_cache.py +3 -3
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/linear.py +170 -130
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/mlp.py +40 -7
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/normalization.py +24 -6
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/rope.py +24 -6
- lalamo-0.3.0/lalamo/sampling.py +99 -0
- lalamo-0.3.0/lalamo/utils.py +112 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo.egg-info/PKG-INFO +6 -6
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo.egg-info/SOURCES.txt +17 -12
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo.egg-info/requires.txt +2 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/pyproject.toml +6 -1
- {lalamo-0.2.6 → lalamo-0.3.0}/tests/test_generation.py +15 -23
- {lalamo-0.2.6 → lalamo-0.3.0}/tests/test_huggingface_models.py +3 -3
- lalamo-0.3.0/tests/test_parameter_tree.py +103 -0
- lalamo-0.2.6/lalamo/common.py +0 -60
- lalamo-0.2.6/lalamo/model_import/common.py +0 -111
- lalamo-0.2.6/lalamo/model_import/model_specs/deepseek.py +0 -28
- lalamo-0.2.6/lalamo/model_import/model_specs/gemma.py +0 -76
- lalamo-0.2.6/lalamo/model_import/model_specs/huggingface.py +0 -28
- lalamo-0.2.6/lalamo/model_import/model_specs/llama.py +0 -100
- lalamo-0.2.6/lalamo/model_import/model_specs/mistral.py +0 -59
- lalamo-0.2.6/lalamo/model_import/model_specs/pleias.py +0 -28
- lalamo-0.2.6/lalamo/model_import/model_specs/polaris.py +0 -22
- lalamo-0.2.6/lalamo/model_import/model_specs/qwen.py +0 -336
- lalamo-0.2.6/lalamo/model_import/model_specs/reka.py +0 -28
- lalamo-0.2.6/lalamo/utils.py +0 -27
- {lalamo-0.2.6 → lalamo-0.3.0}/LICENSE +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/executorch.py +0 -0
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/huggingface/__init__.py +0 -0
- {lalamo-0.2.6/lalamo/model_import/configs → lalamo-0.3.0/lalamo/model_import/decoder_configs}/huggingface/gemma2.py +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/model_import/loaders/executorch.py +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/activations.py +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/modules/utils.py +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo/quantization.py +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.2.6 → lalamo-0.3.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lalamo
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: JAX library for optimization and export of models for use with the UZU inference engine.
|
|
5
5
|
Requires-Python: <4,>=3.12
|
|
6
6
|
Description-Content-Type: text/markdown
|
|
@@ -13,10 +13,12 @@ Requires-Dist: huggingface-hub[hf-transfer]>=0.27.1
|
|
|
13
13
|
Requires-Dist: jax>=0.4.38; sys_platform == "darwin"
|
|
14
14
|
Requires-Dist: jax[cuda]>=0.4.38; sys_platform == "linux"
|
|
15
15
|
Requires-Dist: jaxtyping>=0.2.36
|
|
16
|
+
Requires-Dist: jinja2>=3.1.6
|
|
16
17
|
Requires-Dist: ml-dtypes>=0.5.1
|
|
17
18
|
Requires-Dist: optax>=0.2.4
|
|
18
19
|
Requires-Dist: rich>=14.0.0
|
|
19
20
|
Requires-Dist: thefuzz>=0.22.1
|
|
21
|
+
Requires-Dist: tokenizers>=0.21.2
|
|
20
22
|
Requires-Dist: typer>=0.15.1
|
|
21
23
|
Requires-Dist: safetensors>=0.6.2
|
|
22
24
|
Dynamic: license-file
|
|
@@ -48,9 +50,11 @@ uv run lalamo list-models
|
|
|
48
50
|
To convert a model, run:
|
|
49
51
|
|
|
50
52
|
```bash
|
|
51
|
-
uv run lalamo convert MODEL_REPO
|
|
53
|
+
uv run lalamo convert MODEL_REPO
|
|
52
54
|
```
|
|
53
55
|
|
|
56
|
+
Note: on some CPU platform you may be getting an error saying `The precision 'F16_F16_F32' is not supported by dot_general on CPU`. This is due to a bug in XLA, which causes matmuls inside `jax.jit` not work correctly on CPUs. The workaround is to set the environment variable `JAX_DISABLE_JIT=1` when running the conversion.
|
|
57
|
+
|
|
54
58
|
After that, you can find the converted model in the `models` folder. For more options see `uv run lalamo convert --help`.
|
|
55
59
|
|
|
56
60
|
## Model Support
|
|
@@ -66,10 +70,6 @@ ModelSpec(
|
|
|
66
70
|
quantization=None,
|
|
67
71
|
repo="google/gemma-3-1b-it",
|
|
68
72
|
config_type=HFGemma3TextConfig,
|
|
69
|
-
config_file_name="config.json",
|
|
70
|
-
weights_file_names=huggingface_weight_files(1),
|
|
71
73
|
weights_type=WeightsType.SAFETENSORS,
|
|
72
|
-
tokenizer_files=HUGGINGFACE_TOKENIZER_FILES,
|
|
73
|
-
use_cases=tuple(),
|
|
74
74
|
)
|
|
75
75
|
```
|
|
@@ -25,9 +25,11 @@ uv run lalamo list-models
|
|
|
25
25
|
To convert a model, run:
|
|
26
26
|
|
|
27
27
|
```bash
|
|
28
|
-
uv run lalamo convert MODEL_REPO
|
|
28
|
+
uv run lalamo convert MODEL_REPO
|
|
29
29
|
```
|
|
30
30
|
|
|
31
|
+
Note: on some CPU platform you may be getting an error saying `The precision 'F16_F16_F32' is not supported by dot_general on CPU`. This is due to a bug in XLA, which causes matmuls inside `jax.jit` not work correctly on CPUs. The workaround is to set the environment variable `JAX_DISABLE_JIT=1` when running the conversion.
|
|
32
|
+
|
|
31
33
|
After that, you can find the converted model in the `models` folder. For more options see `uv run lalamo convert --help`.
|
|
32
34
|
|
|
33
35
|
## Model Support
|
|
@@ -43,10 +45,6 @@ ModelSpec(
|
|
|
43
45
|
quantization=None,
|
|
44
46
|
repo="google/gemma-3-1b-it",
|
|
45
47
|
config_type=HFGemma3TextConfig,
|
|
46
|
-
config_file_name="config.json",
|
|
47
|
-
weights_file_names=huggingface_weight_files(1),
|
|
48
48
|
weights_type=WeightsType.SAFETENSORS,
|
|
49
|
-
tokenizer_files=HUGGINGFACE_TOKENIZER_FILES,
|
|
50
|
-
use_cases=tuple(),
|
|
51
49
|
)
|
|
52
|
-
```
|
|
50
|
+
```
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from collections.abc import Mapping, Sequence
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jax._src.api import ShapeDtypeStruct
|
|
7
|
+
from jaxtyping import Array, DTypeLike
|
|
8
|
+
|
|
9
|
+
from lalamo.utils import MapDictValues, MapSequence
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"DEFAULT_PRECISION",
|
|
13
|
+
"ArrayLike",
|
|
14
|
+
"ParameterPath",
|
|
15
|
+
"ParameterTree",
|
|
16
|
+
"dummy_array",
|
|
17
|
+
"flatten_parameters",
|
|
18
|
+
"unflatten_parameters",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
type ArrayLike = Array | ShapeDtypeStruct
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
type ParameterTree[ArrayType: ArrayLike] = (
|
|
28
|
+
Mapping[str, ArrayType | ParameterTree[ArrayType]] | Sequence[ArrayType | ParameterTree[ArrayType]]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def dummy_array(shape: int | tuple[int, ...], dtype: DTypeLike) -> Array:
|
|
33
|
+
if isinstance(shape, int):
|
|
34
|
+
shape = (shape,)
|
|
35
|
+
return cast("Array", ShapeDtypeStruct(shape=shape, dtype=dtype))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def flatten_parameters[ArrayType: ArrayLike](nested_parameters: ParameterTree[ArrayType]) -> dict[str, ArrayType]:
|
|
39
|
+
result: dict[str, ArrayType] = {}
|
|
40
|
+
if not isinstance(nested_parameters, Mapping):
|
|
41
|
+
nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
|
|
42
|
+
for key, value in nested_parameters.items():
|
|
43
|
+
key_path = ParameterPath(key)
|
|
44
|
+
if isinstance(value, (Array, ShapeDtypeStruct)):
|
|
45
|
+
result[key_path] = value
|
|
46
|
+
else:
|
|
47
|
+
update: dict[str, ArrayType] = {
|
|
48
|
+
str(key_path / subkey): subvalue for subkey, subvalue in flatten_parameters(value).items()
|
|
49
|
+
}
|
|
50
|
+
result.update(update)
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
type KeyTree = Mapping[str, str | KeyTree] | Sequence[str | KeyTree]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _unflatten_keys(flat_keys: Mapping[str, str]) -> KeyTree:
|
|
58
|
+
groups: dict[str, dict[str, str] | str] = defaultdict(dict)
|
|
59
|
+
for subkey, full_key in flat_keys.items():
|
|
60
|
+
match subkey.split(".", maxsplit=1):
|
|
61
|
+
case [head]:
|
|
62
|
+
groups[head] = full_key
|
|
63
|
+
case [head, tail]:
|
|
64
|
+
group = groups[head]
|
|
65
|
+
assert isinstance(group, dict)
|
|
66
|
+
group[tail] = full_key
|
|
67
|
+
|
|
68
|
+
unflattened_groups: dict[str, KeyTree] = {}
|
|
69
|
+
for subkey, group in groups.items():
|
|
70
|
+
if isinstance(group, str):
|
|
71
|
+
unflattened_groups[subkey] = group
|
|
72
|
+
else:
|
|
73
|
+
unflattened_groups[subkey] = _unflatten_keys(group)
|
|
74
|
+
|
|
75
|
+
if any(key.isnumeric() for key in unflattened_groups):
|
|
76
|
+
assert set(unflattened_groups.keys()) == set(map(str, range(len(unflattened_groups))))
|
|
77
|
+
return [v for k, v in sorted(unflattened_groups.items(), key=lambda item: int(item[0]))]
|
|
78
|
+
return unflattened_groups
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _recursive_map_dict[ArrayType: ArrayLike](
|
|
82
|
+
key_tree: KeyTree | str,
|
|
83
|
+
root_collection: Mapping[str, ArrayType],
|
|
84
|
+
) -> ParameterTree[ArrayType] | ArrayType:
|
|
85
|
+
if isinstance(key_tree, str):
|
|
86
|
+
return root_collection[key_tree]
|
|
87
|
+
if isinstance(key_tree, Mapping):
|
|
88
|
+
return MapDictValues(lambda subtree: _recursive_map_dict(subtree, root_collection), key_tree)
|
|
89
|
+
if isinstance(key_tree, Sequence):
|
|
90
|
+
return MapSequence(lambda subtree: _recursive_map_dict(subtree, root_collection), key_tree)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def unflatten_parameters[ArrayType: ArrayLike](flat_parameters: Mapping[str, ArrayType]) -> ParameterTree[ArrayType]:
|
|
94
|
+
unflattened_keys = _unflatten_keys({k: k for k in flat_parameters})
|
|
95
|
+
result = _recursive_map_dict(unflattened_keys, flat_parameters)
|
|
96
|
+
assert not isinstance(result, (Array, ShapeDtypeStruct))
|
|
97
|
+
return result
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class ParameterPath(str):
|
|
101
|
+
__slots__ = ()
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def components(self) -> tuple[str, ...]:
|
|
105
|
+
return tuple(self.split("."))
|
|
106
|
+
|
|
107
|
+
def __truediv__(self, other: str | int) -> "ParameterPath":
|
|
108
|
+
if not self:
|
|
109
|
+
return ParameterPath(str(other))
|
|
110
|
+
return ParameterPath(self + "." + str(other))
|
|
@@ -1,89 +1,28 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
2
|
from collections.abc import Iterable
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import NamedTuple, Self
|
|
5
6
|
|
|
6
7
|
import equinox as eqx
|
|
7
8
|
import jax
|
|
8
9
|
import jax.numpy as jnp
|
|
9
10
|
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
|
|
11
|
+
from safetensors.flax import load_file
|
|
12
|
+
from tokenizers import Tokenizer
|
|
10
13
|
|
|
11
|
-
from lalamo.
|
|
14
|
+
from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
|
|
15
|
+
from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
|
|
16
|
+
from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule, WeightLayout, config_converter
|
|
17
|
+
from lalamo.sampling import SamplingPolicy, make_policy
|
|
12
18
|
|
|
13
19
|
__all__ = [
|
|
14
|
-
"
|
|
15
|
-
"CompositePolicy",
|
|
16
|
-
"GreedyPolicy",
|
|
20
|
+
"GenerationConfig",
|
|
17
21
|
"LanguageModel",
|
|
18
|
-
"
|
|
19
|
-
"TemperaturePolicy",
|
|
20
|
-
"TopKPolicy",
|
|
21
|
-
"TopPPolicy",
|
|
22
|
+
"LanguageModelConfig",
|
|
22
23
|
]
|
|
23
24
|
|
|
24
25
|
|
|
25
|
-
class SamplingPolicy(eqx.Module):
|
|
26
|
-
@abstractmethod
|
|
27
|
-
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]: ...
|
|
28
|
-
|
|
29
|
-
def __call__(self, logits: Float[Array, " vocabulary"], *, key: PRNGKeyArray) -> Int[Array, ""]:
|
|
30
|
-
return jax.random.categorical(key, self.process_logits(logits))
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class GreedyPolicy(SamplingPolicy):
|
|
34
|
-
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
35
|
-
max_logit_value = jnp.max(logits)
|
|
36
|
-
return jnp.where(logits == max_logit_value, 1.0, -jnp.inf)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class TemperaturePolicy(SamplingPolicy):
|
|
40
|
-
temperature: float = eqx.field(static=True)
|
|
41
|
-
|
|
42
|
-
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
43
|
-
return logits / self.temperature
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class TopKPolicy(SamplingPolicy):
|
|
47
|
-
k: int = eqx.field(static=True)
|
|
48
|
-
|
|
49
|
-
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
50
|
-
top_k_logits, _ = jax.lax.top_k(logits, self.k)
|
|
51
|
-
min_logit_val = jnp.min(top_k_logits)
|
|
52
|
-
return jnp.where(logits >= min_logit_val, logits, -jnp.inf)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class TopPPolicy(SamplingPolicy):
|
|
56
|
-
p: float = eqx.field(static=True)
|
|
57
|
-
|
|
58
|
-
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
59
|
-
sorted_indices = jnp.argsort(logits, descending=True)
|
|
60
|
-
sorted_logits = logits[sorted_indices]
|
|
61
|
-
cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits))
|
|
62
|
-
|
|
63
|
-
to_remove = cumulative_probs > self.p
|
|
64
|
-
to_remove = jnp.roll(to_remove, 1)
|
|
65
|
-
to_remove = to_remove.at[0].set(False)
|
|
66
|
-
|
|
67
|
-
return jnp.where(to_remove, -jnp.inf, logits)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class BanTokensPolicy(SamplingPolicy):
|
|
71
|
-
banned_tokens: list[int] = eqx.field(static=True)
|
|
72
|
-
|
|
73
|
-
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
74
|
-
banned_tokens_indices = jnp.asarray(self.banned_tokens, dtype=jnp.int32)
|
|
75
|
-
return logits.at[banned_tokens_indices].set(-jnp.inf)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class CompositePolicy(SamplingPolicy):
|
|
79
|
-
policies: list[SamplingPolicy] = eqx.field(static=True)
|
|
80
|
-
|
|
81
|
-
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
82
|
-
for policy in self.policies:
|
|
83
|
-
logits = policy.process_logits(logits)
|
|
84
|
-
return logits
|
|
85
|
-
|
|
86
|
-
|
|
87
26
|
class PrefillResults(NamedTuple):
|
|
88
27
|
last_token_logits: Float[Array, " vocabulary"]
|
|
89
28
|
last_token_position: Int[Array, ""]
|
|
@@ -98,9 +37,66 @@ class DecodingState(NamedTuple):
|
|
|
98
37
|
|
|
99
38
|
|
|
100
39
|
@dataclass(frozen=True)
|
|
101
|
-
class
|
|
40
|
+
class GenerationConfig:
|
|
41
|
+
stop_token_ids: tuple[int, ...]
|
|
42
|
+
temperature: float | None
|
|
43
|
+
top_k: int | None
|
|
44
|
+
top_p: float | None
|
|
45
|
+
banned_tokens: tuple[int, ...] | None
|
|
46
|
+
|
|
47
|
+
def default_policy(self) -> SamplingPolicy:
|
|
48
|
+
return make_policy(self.temperature, self.top_k, self.top_p, self.banned_tokens)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class LanguageModelConfig:
|
|
53
|
+
decoder_config: DecoderConfig
|
|
54
|
+
message_processor_config: MessageProcessorConfig
|
|
55
|
+
generation_config: GenerationConfig
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
102
59
|
decoder: Decoder
|
|
60
|
+
message_processor: MessageProcessor = eqx.field(static=True)
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def load(cls, path: Path | str, weight_layout: WeightLayout = WeightLayout.AUTO) -> Self:
|
|
64
|
+
if isinstance(path, str):
|
|
65
|
+
path = Path(path)
|
|
66
|
+
with open(path / "config.json") as config_file:
|
|
67
|
+
config_json = json.load(config_file)
|
|
68
|
+
config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
|
|
69
|
+
weights = unflatten_parameters(load_file(path / "model.safetensors"))
|
|
70
|
+
decoder = config.decoder_config.empty().import_weights(weights, weight_layout)
|
|
71
|
+
tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
|
|
72
|
+
message_processor = MessageProcessor(config.message_processor_config, tokenizer)
|
|
73
|
+
return cls(config, decoder, message_processor)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def activation_precision(self) -> DTypeLike:
|
|
77
|
+
return self.decoder.activation_precision
|
|
78
|
+
|
|
79
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
80
|
+
return self.decoder.export_weights(weight_layout)
|
|
81
|
+
|
|
82
|
+
def import_weights(
|
|
83
|
+
self,
|
|
84
|
+
weights: ParameterTree[Array],
|
|
85
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
86
|
+
) -> Self:
|
|
87
|
+
return replace(
|
|
88
|
+
self,
|
|
89
|
+
decoder=self.decoder.import_weights(weights, weight_layout),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def stop_token_ids(self) -> tuple[int, ...]:
|
|
94
|
+
return self.config.generation_config.stop_token_ids
|
|
103
95
|
|
|
96
|
+
def default_sampling_policy(self) -> SamplingPolicy:
|
|
97
|
+
return self.config.generation_config.default_policy()
|
|
98
|
+
|
|
99
|
+
@eqx.filter_jit
|
|
104
100
|
def _prefill(
|
|
105
101
|
self,
|
|
106
102
|
token_ids: Int[Array, " tokens"],
|
|
@@ -137,7 +133,8 @@ class LanguageModel:
|
|
|
137
133
|
kv_cache=decoder_outputs.updated_kv_cache,
|
|
138
134
|
)
|
|
139
135
|
|
|
140
|
-
|
|
136
|
+
@eqx.filter_jit
|
|
137
|
+
def generate_tokens(
|
|
141
138
|
self,
|
|
142
139
|
prompt_token_ids: Int[Array, " prompt_tokens"],
|
|
143
140
|
sampling_policy: SamplingPolicy | None = None,
|
|
@@ -148,7 +145,9 @@ class LanguageModel:
|
|
|
148
145
|
key: PRNGKeyArray | None = None,
|
|
149
146
|
) -> Int[Array, " response_tokens"]:
|
|
150
147
|
if sampling_policy is None:
|
|
151
|
-
sampling_policy =
|
|
148
|
+
sampling_policy = self.default_sampling_policy()
|
|
149
|
+
if eos_token_ids is None:
|
|
150
|
+
eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
|
|
152
151
|
|
|
153
152
|
(input_length,) = prompt_token_ids.shape
|
|
154
153
|
prefill_results = self._prefill(
|
|
@@ -177,10 +176,7 @@ class LanguageModel:
|
|
|
177
176
|
next_token_id = jax.random.categorical(key, processed_logits)
|
|
178
177
|
next_token_position = state.last_token_position + 1
|
|
179
178
|
|
|
180
|
-
|
|
181
|
-
stop_flag = state.stop_flag | jnp.any(next_token_id == eos_token_ids)
|
|
182
|
-
else:
|
|
183
|
-
stop_flag = state.stop_flag
|
|
179
|
+
stop_flag = state.stop_flag | jnp.any(next_token_id == eos_token_ids)
|
|
184
180
|
|
|
185
181
|
decoder_outputs = self.decoder(
|
|
186
182
|
next_token_id.reshape(1),
|
|
@@ -207,7 +203,32 @@ class LanguageModel:
|
|
|
207
203
|
|
|
208
204
|
return tokens
|
|
209
205
|
|
|
210
|
-
def
|
|
206
|
+
def reply(
|
|
207
|
+
self,
|
|
208
|
+
messages: Iterable[Message],
|
|
209
|
+
sampling_policy: SamplingPolicy | None = None,
|
|
210
|
+
*,
|
|
211
|
+
key: PRNGKeyArray | None = None,
|
|
212
|
+
) -> AssistantMessage:
|
|
213
|
+
formatted_messages = self.message_processor.render_request(messages)
|
|
214
|
+
token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
|
|
215
|
+
response_ids = self.generate_tokens(token_ids, sampling_policy, key=key)
|
|
216
|
+
response_text = self.message_processor.detokenize(response_ids.tolist())
|
|
217
|
+
return self.message_processor.parse_response(response_text)
|
|
218
|
+
|
|
219
|
+
def stream_reply_text(
|
|
220
|
+
self,
|
|
221
|
+
messages: Iterable[Message],
|
|
222
|
+
sampling_policy: SamplingPolicy | None = None,
|
|
223
|
+
*,
|
|
224
|
+
key: PRNGKeyArray | None = None,
|
|
225
|
+
) -> Iterable[str]:
|
|
226
|
+
formatted_messages = self.message_processor.render_request(messages)
|
|
227
|
+
token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
|
|
228
|
+
for token_id in self.stream_tokens(token_ids, sampling_policy, key=key):
|
|
229
|
+
yield self.message_processor.detokenize([token_id.item()])
|
|
230
|
+
|
|
231
|
+
def stream_tokens(
|
|
211
232
|
self,
|
|
212
233
|
prompt_token_ids: Int[Array, " prompt_tokens"],
|
|
213
234
|
sampling_policy: SamplingPolicy | None = None,
|
|
@@ -218,7 +239,9 @@ class LanguageModel:
|
|
|
218
239
|
key: PRNGKeyArray | None = None,
|
|
219
240
|
) -> Iterable[Int[Array, ""]]:
|
|
220
241
|
if sampling_policy is None:
|
|
221
|
-
sampling_policy =
|
|
242
|
+
sampling_policy = self.default_sampling_policy()
|
|
243
|
+
if eos_token_ids is None:
|
|
244
|
+
eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
|
|
222
245
|
|
|
223
246
|
(input_length,) = prompt_token_ids.shape
|
|
224
247
|
prefill_results = self._prefill(
|
|
@@ -244,7 +267,7 @@ class LanguageModel:
|
|
|
244
267
|
|
|
245
268
|
yield next_token_id
|
|
246
269
|
|
|
247
|
-
if
|
|
270
|
+
if jnp.any(next_token_id == eos_token_ids):
|
|
248
271
|
return
|
|
249
272
|
|
|
250
273
|
next_token_position = state.last_token_position + 1
|
|
@@ -20,7 +20,17 @@ from rich.table import Table
|
|
|
20
20
|
from safetensors.flax import save_file
|
|
21
21
|
from typer import Argument, Exit, Option, Typer
|
|
22
22
|
|
|
23
|
+
from lalamo.common import flatten_parameters
|
|
24
|
+
from lalamo.language_model import LanguageModel
|
|
25
|
+
from lalamo.message_processor import UserMessage
|
|
23
26
|
from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
|
|
27
|
+
from lalamo.model_import.common import (
|
|
28
|
+
DownloadingFileEvent,
|
|
29
|
+
FinishedDownloadingFileEvent,
|
|
30
|
+
FinishedInitializingModelEvent,
|
|
31
|
+
InitializingModelEvent,
|
|
32
|
+
StatusEvent,
|
|
33
|
+
)
|
|
24
34
|
from lalamo.modules import WeightLayout, config_converter
|
|
25
35
|
from lalamo.utils import jax_uint4_to_packed_uint8
|
|
26
36
|
|
|
@@ -91,6 +101,52 @@ def _pack_uint4_weights(weights: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarra
|
|
|
91
101
|
return packed_weights
|
|
92
102
|
|
|
93
103
|
|
|
104
|
+
@app.command(help="Chat with a converted model.")
|
|
105
|
+
def chat(
|
|
106
|
+
model_path: Annotated[
|
|
107
|
+
Path,
|
|
108
|
+
Argument(
|
|
109
|
+
help="Path to the model directory.",
|
|
110
|
+
metavar="MODEL_PATH",
|
|
111
|
+
),
|
|
112
|
+
],
|
|
113
|
+
weight_layout: Annotated[
|
|
114
|
+
WeightLayout | None,
|
|
115
|
+
Option(
|
|
116
|
+
help=(
|
|
117
|
+
"(EXPERIMENTAL) Order of dimensions in the weights of linear layers."
|
|
118
|
+
"\n\n\n\n"
|
|
119
|
+
"If set to AUTO, the layout will depend on the model."
|
|
120
|
+
),
|
|
121
|
+
show_default="auto",
|
|
122
|
+
),
|
|
123
|
+
] = None,
|
|
124
|
+
) -> None:
|
|
125
|
+
if weight_layout is None:
|
|
126
|
+
weight_layout = WeightLayout.AUTO
|
|
127
|
+
with Progress(
|
|
128
|
+
SpinnerColumn(),
|
|
129
|
+
TextColumn("[progress.description]{task.description}"),
|
|
130
|
+
transient=True,
|
|
131
|
+
) as progress:
|
|
132
|
+
progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
133
|
+
model = LanguageModel.load(model_path, weight_layout)
|
|
134
|
+
messages = []
|
|
135
|
+
while True:
|
|
136
|
+
user_text = console.input("[cyan]user> [/cyan]")
|
|
137
|
+
user_message = UserMessage(user_text)
|
|
138
|
+
messages.append(user_message)
|
|
139
|
+
|
|
140
|
+
console.print("[red]assistant> [/red]", end="")
|
|
141
|
+
model_response_tokens = []
|
|
142
|
+
for token in model.stream_reply_text(messages):
|
|
143
|
+
console.print(token, end="")
|
|
144
|
+
model_response_tokens.append(token)
|
|
145
|
+
console.print()
|
|
146
|
+
model_response_text = "".join(model_response_tokens)
|
|
147
|
+
messages.append(model.message_processor.parse_response(model_response_text))
|
|
148
|
+
|
|
149
|
+
|
|
94
150
|
@app.command(help="Convert the model for use with the Uzu inference engine.")
|
|
95
151
|
def convert(
|
|
96
152
|
model_repo: Annotated[
|
|
@@ -118,7 +174,7 @@ def convert(
|
|
|
118
174
|
WeightLayout | None,
|
|
119
175
|
Option(
|
|
120
176
|
help=(
|
|
121
|
-
"Order of dimensions in the weights of linear layers."
|
|
177
|
+
"(EXPERIMENTAL) Order of dimensions in the weights of linear layers."
|
|
122
178
|
"\n\n\n\n"
|
|
123
179
|
"If set to AUTO, the layout will depend on the model."
|
|
124
180
|
),
|
|
@@ -194,41 +250,58 @@ def convert(
|
|
|
194
250
|
TextColumn("[progress.description]{task.description}"),
|
|
195
251
|
transient=True,
|
|
196
252
|
) as progress:
|
|
197
|
-
|
|
198
|
-
|
|
253
|
+
event_to_task = {}
|
|
254
|
+
|
|
255
|
+
def progress_callback(event: StatusEvent) -> None:
|
|
256
|
+
match event:
|
|
257
|
+
case DownloadingFileEvent(file_spec):
|
|
258
|
+
event_to_task[event] = progress.add_task(f"Retrieving {file_spec.filename}...")
|
|
259
|
+
case FinishedDownloadingFileEvent(file_spec):
|
|
260
|
+
progress.remove_task(event_to_task[event])
|
|
261
|
+
case InitializingModelEvent():
|
|
262
|
+
event_to_task[event] = progress.add_task("Initializing model...")
|
|
263
|
+
case FinishedInitializingModelEvent():
|
|
264
|
+
progress.remove_task(event_to_task[event])
|
|
265
|
+
|
|
266
|
+
main_task = progress.add_task("👨🍳 Cooking...")
|
|
267
|
+
model, metadata = import_model(
|
|
199
268
|
model_repo,
|
|
200
269
|
precision=precision_dtype,
|
|
201
270
|
context_length=context_length,
|
|
271
|
+
progress_callback=progress_callback,
|
|
202
272
|
)
|
|
203
|
-
progress.add_task(f"💾 Saving the model to {output_dir}")
|
|
273
|
+
save_task = progress.add_task(f"💾 Saving the model to {output_dir}")
|
|
204
274
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
205
275
|
|
|
206
|
-
weights = dict(model.export_weights(weight_layout))
|
|
207
|
-
packed_weights = _pack_uint4_weights(weights)
|
|
208
|
-
save_file(packed_weights, output_dir / "model.safetensors")
|
|
209
|
-
|
|
210
|
-
config_json = config_converter.unstructure(metadata, ModelMetadata)
|
|
211
|
-
with open(output_dir / "config.json", "w") as file:
|
|
212
|
-
json.dump(config_json, file, indent=4)
|
|
213
|
-
|
|
214
|
-
for path in tokenizer_file_paths:
|
|
215
|
-
shutil.copy(path, output_dir / path.name)
|
|
216
|
-
|
|
217
276
|
if include_traces:
|
|
218
|
-
progress.add_task("🚁 Generating traces...")
|
|
277
|
+
trace_task = progress.add_task("🚁 Generating traces...")
|
|
219
278
|
|
|
220
279
|
num_tokens = 512
|
|
221
280
|
token_stride = 8
|
|
222
281
|
token_ids = jnp.arange(0, num_tokens, dtype=jnp.int32)
|
|
223
282
|
token_positions = jnp.arange(0, num_tokens * token_stride, token_stride, dtype=jnp.int32)
|
|
224
|
-
result = model(
|
|
283
|
+
result = model.decoder(
|
|
225
284
|
token_ids,
|
|
226
285
|
token_positions,
|
|
227
286
|
return_updated_kv_cache=True,
|
|
228
287
|
return_activation_trace=True,
|
|
229
288
|
)
|
|
230
|
-
traces =
|
|
289
|
+
traces = flatten_parameters(result.export())
|
|
231
290
|
save_file(traces, output_dir / "traces.safetensors")
|
|
291
|
+
progress.remove_task(trace_task)
|
|
292
|
+
progress.remove_task(main_task)
|
|
293
|
+
|
|
294
|
+
model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
|
|
295
|
+
weights = flatten_parameters(model.export_weights(weight_layout))
|
|
296
|
+
del model
|
|
297
|
+
|
|
298
|
+
packed_weights = _pack_uint4_weights(weights)
|
|
299
|
+
save_file(packed_weights, output_dir / "model.safetensors")
|
|
300
|
+
|
|
301
|
+
config_json = config_converter.unstructure(metadata, ModelMetadata)
|
|
302
|
+
with open(output_dir / "config.json", "w") as file:
|
|
303
|
+
json.dump(config_json, file, indent=4)
|
|
304
|
+
progress.remove_task(save_task)
|
|
232
305
|
|
|
233
306
|
console.print(f"🧑🍳 Model successfully cooked and saved to [cyan]`{output_dir}`[/cyan]!")
|
|
234
307
|
|