lalamo 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo-0.2.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Mirai Tech Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
lalamo-0.2.1/PKG-INFO ADDED
@@ -0,0 +1,74 @@
1
+ Metadata-Version: 2.4
2
+ Name: lalamo
3
+ Version: 0.2.1
4
+ Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
+ Requires-Python: <4,>=3.12
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: cattrs>=24.1.2
9
+ Requires-Dist: click>=8.1.8
10
+ Requires-Dist: einops>=0.8.0
11
+ Requires-Dist: equinox>=0.11.11
12
+ Requires-Dist: huggingface-hub[hf-transfer]>=0.27.1
13
+ Requires-Dist: jax>=0.4.38; sys_platform == "darwin"
14
+ Requires-Dist: jax[cuda]>=0.4.38; sys_platform == "linux"
15
+ Requires-Dist: jaxtyping>=0.2.36
16
+ Requires-Dist: ml-dtypes>=0.5.1
17
+ Requires-Dist: optax>=0.2.4
18
+ Requires-Dist: rich>=14.0.0
19
+ Requires-Dist: thefuzz>=0.22.1
20
+ Requires-Dist: typer>=0.15.1
21
+ Dynamic: license-file
22
+
23
+ <p align="center">
24
+ <picture>
25
+ <img alt="Mirai" src="https://artifacts.trymirai.com/social/github/lalamo-header.jpg" style="max-width: 100%;">
26
+ </picture>
27
+ </p>
28
+
29
+ <a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
30
+ <a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
31
+ <a href="mailto:alexey@getmirai.co,dima@getmirai.co,aleksei@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
32
+ <a href="https://docs.trymirai.com/components/models"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
33
+ [![License](https://img.shields.io/badge/License-MIT-blue)](LICENSE)
34
+
35
+ # lalamo
36
+
37
+ A set of tools for adapting Large Language Models to on-device inference using the [uzu](https://github.com/trymirai/uzu) inference engine.
38
+
39
+ ## Quick Start
40
+
41
+ To get the list of [supported models](https://trymirai.com/models), run:
42
+
43
+ ```bash
44
+ uv run lalamo list-models
45
+ ```
46
+
47
+ To convert a model, run:
48
+
49
+ ```bash
50
+ uv run lalamo convert MODEL_REPO --precision float16
51
+ ```
52
+
53
+ After that, you can find the converted model in the `models` folder. For more options see `uv run lalamo convert --help`.
54
+
55
+ ## Model Support
56
+
57
+ To add support for a new model, write the corresponding [ModelSpec](lalamo/model_import/model_specs), as shown in the example below:
58
+
59
+ ```python
60
+ ModelSpec(
61
+ vendor="Google",
62
+ family="Gemma-3",
63
+ name="Gemma-3-1B-Instruct",
64
+ size="1B",
65
+ quantization=None,
66
+ repo="google/gemma-3-1b-it",
67
+ config_type=HFGemma3TextConfig,
68
+ config_file_name="config.json",
69
+ weights_file_names=huggingface_weight_files(1),
70
+ weights_type=WeightsType.SAFETENSORS,
71
+ tokenizer_files=HUGGINGFACE_TOKENIZER_FILES,
72
+ use_cases=tuple(),
73
+ )
74
+ ```
lalamo-0.2.1/README.md ADDED
@@ -0,0 +1,52 @@
1
+ <p align="center">
2
+ <picture>
3
+ <img alt="Mirai" src="https://artifacts.trymirai.com/social/github/lalamo-header.jpg" style="max-width: 100%;">
4
+ </picture>
5
+ </p>
6
+
7
+ <a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
8
+ <a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
9
+ <a href="mailto:alexey@getmirai.co,dima@getmirai.co,aleksei@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
10
+ <a href="https://docs.trymirai.com/components/models"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
11
+ [![License](https://img.shields.io/badge/License-MIT-blue)](LICENSE)
12
+
13
+ # lalamo
14
+
15
+ A set of tools for adapting Large Language Models to on-device inference using the [uzu](https://github.com/trymirai/uzu) inference engine.
16
+
17
+ ## Quick Start
18
+
19
+ To get the list of [supported models](https://trymirai.com/models), run:
20
+
21
+ ```bash
22
+ uv run lalamo list-models
23
+ ```
24
+
25
+ To convert a model, run:
26
+
27
+ ```bash
28
+ uv run lalamo convert MODEL_REPO --precision float16
29
+ ```
30
+
31
+ After that, you can find the converted model in the `models` folder. For more options see `uv run lalamo convert --help`.
32
+
33
+ ## Model Support
34
+
35
+ To add support for a new model, write the corresponding [ModelSpec](lalamo/model_import/model_specs), as shown in the example below:
36
+
37
+ ```python
38
+ ModelSpec(
39
+ vendor="Google",
40
+ family="Gemma-3",
41
+ name="Gemma-3-1B-Instruct",
42
+ size="1B",
43
+ quantization=None,
44
+ repo="google/gemma-3-1b-it",
45
+ config_type=HFGemma3TextConfig,
46
+ config_file_name="config.json",
47
+ weights_file_names=huggingface_weight_files(1),
48
+ weights_type=WeightsType.SAFETENSORS,
49
+ tokenizer_files=HUGGINGFACE_TOKENIZER_FILES,
50
+ use_cases=tuple(),
51
+ )
52
+ ```
@@ -0,0 +1,11 @@
1
+ from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
+ from lalamo.modules import Decoder
3
+
4
+ __version__ = "0.2.1"
5
+
6
+ __all__ = [
7
+ "REPO_TO_MODEL",
8
+ "Decoder",
9
+ "ModelSpec",
10
+ "import_model",
11
+ ]
@@ -0,0 +1,60 @@
1
+ from collections.abc import Iterable, Mapping
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array, DTypeLike
5
+
6
+ __all__ = [
7
+ "DEFAULT_PRECISION",
8
+ "ParameterDict",
9
+ "ParameterPath",
10
+ ]
11
+
12
+ DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
13
+
14
+
15
+ type NestedParameters = Mapping[str, Array | NestedParameters] | Iterable[Array | NestedParameters]
16
+
17
+
18
+ class ParameterDict(dict[str, Array]):
19
+ def __init__(self, **kwargs: Array | NestedParameters | Iterable[Array | NestedParameters]) -> None:
20
+ super().__init__(self._flatten(kwargs))
21
+
22
+ def __setitem__(
23
+ self,
24
+ key: str,
25
+ value: Array | NestedParameters | Iterable[Array | NestedParameters],
26
+ ) -> None:
27
+ key = ParameterPath(key)
28
+
29
+ if isinstance(value, Array):
30
+ super().__setitem__(key, value)
31
+ return
32
+
33
+ for subkey, subvalue in self._flatten(value).items():
34
+ super().__setitem__(key / subkey, subvalue)
35
+
36
+ @classmethod
37
+ def _flatten(cls, nested_parameters: NestedParameters) -> dict[str, Array]:
38
+ result: dict[str, Array] = {}
39
+ if not isinstance(nested_parameters, Mapping):
40
+ nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
41
+ for key, value in nested_parameters.items():
42
+ key_path = ParameterPath(key)
43
+ if isinstance(value, Array):
44
+ result[key_path] = value
45
+ else:
46
+ result.update({key_path / subkey: subvalue for subkey, subvalue in cls._flatten(value).items()})
47
+ return result
48
+
49
+
50
+ class ParameterPath(str):
51
+ __slots__ = ()
52
+
53
+ @property
54
+ def components(self) -> tuple[str, ...]:
55
+ return tuple(self.split("."))
56
+
57
+ def __truediv__(self, other: str | int) -> "ParameterPath":
58
+ if not self:
59
+ return ParameterPath(str(other))
60
+ return ParameterPath(self + "." + str(other))
@@ -0,0 +1,263 @@
1
+ from abc import abstractmethod
2
+ from collections.abc import Iterable
3
+ from dataclasses import dataclass
4
+ from typing import NamedTuple
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
10
+
11
+ from lalamo.modules import Decoder, KVCache
12
+
13
+ __all__ = [
14
+ "BanTokensPolicy",
15
+ "CompositePolicy",
16
+ "GreedyPolicy",
17
+ "LanguageModel",
18
+ "SamplingPolicy",
19
+ "TemperaturePolicy",
20
+ "TopKPolicy",
21
+ "TopPPolicy",
22
+ ]
23
+
24
+
25
+ class SamplingPolicy(eqx.Module):
26
+ @abstractmethod
27
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]: ...
28
+
29
+ def __call__(self, logits: Float[Array, " vocabulary"], *, key: PRNGKeyArray) -> Int[Array, ""]:
30
+ return jax.random.categorical(key, self.process_logits(logits))
31
+
32
+
33
+ class GreedyPolicy(SamplingPolicy):
34
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
35
+ max_logit_value = jnp.max(logits)
36
+ return jnp.where(logits == max_logit_value, 1.0, -jnp.inf)
37
+
38
+
39
+ class TemperaturePolicy(SamplingPolicy):
40
+ temperature: float = eqx.field(static=True)
41
+
42
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
43
+ return logits / self.temperature
44
+
45
+
46
+ class TopKPolicy(SamplingPolicy):
47
+ k: int = eqx.field(static=True)
48
+
49
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
50
+ top_k_logits, _ = jax.lax.top_k(logits, self.k)
51
+ min_logit_val = jnp.min(top_k_logits)
52
+ return jnp.where(logits >= min_logit_val, logits, -jnp.inf)
53
+
54
+
55
+ class TopPPolicy(SamplingPolicy):
56
+ p: float = eqx.field(static=True)
57
+
58
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
59
+ sorted_indices = jnp.argsort(logits, descending=True)
60
+ sorted_logits = logits[sorted_indices]
61
+ cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits))
62
+
63
+ to_remove = cumulative_probs > self.p
64
+ to_remove = jnp.roll(to_remove, 1)
65
+ to_remove = to_remove.at[0].set(False)
66
+
67
+ return jnp.where(to_remove, -jnp.inf, logits)
68
+
69
+
70
+ class BanTokensPolicy(SamplingPolicy):
71
+ banned_tokens: list[int] = eqx.field(static=True)
72
+
73
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
74
+ banned_tokens_indices = jnp.asarray(self.banned_tokens, dtype=jnp.int32)
75
+ return logits.at[banned_tokens_indices].set(-jnp.inf)
76
+
77
+
78
+ class CompositePolicy(SamplingPolicy):
79
+ policies: list[SamplingPolicy] = eqx.field(static=True)
80
+
81
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
82
+ for policy in self.policies:
83
+ logits = policy.process_logits(logits)
84
+ return logits
85
+
86
+
87
+ class PrefillResults(NamedTuple):
88
+ last_token_logits: Float[Array, " vocabulary"]
89
+ last_token_position: Int[Array, ""]
90
+ kv_cache: KVCache
91
+
92
+
93
+ class DecodingState(NamedTuple):
94
+ last_token_logits: Float[Array, " vocabulary"]
95
+ last_token_position: Int[Array, ""]
96
+ kv_cache: KVCache
97
+ stop_flag: Bool[Array, ""]
98
+
99
+
100
+ @dataclass(frozen=True)
101
+ class LanguageModel:
102
+ decoder: Decoder
103
+
104
+ def _prefill(
105
+ self,
106
+ token_ids: Int[Array, " tokens"],
107
+ length_without_padding: Int[Array, ""] | int | None = None,
108
+ kv_cache_capacity: int | None = None,
109
+ ) -> PrefillResults:
110
+ (num_tokens,) = token_ids.shape
111
+ token_positions = jnp.arange(num_tokens, dtype=jnp.int32)
112
+ if kv_cache_capacity is not None:
113
+ kv_cache = self.decoder.init_static_kv_cache(kv_cache_capacity)
114
+ else:
115
+ kv_cache = None
116
+
117
+ decoder_outputs = self.decoder(
118
+ token_ids,
119
+ token_positions,
120
+ kv_cache,
121
+ return_updated_kv_cache=True,
122
+ length_without_padding=length_without_padding,
123
+ )
124
+
125
+ if length_without_padding is not None:
126
+ last_logits_index = length_without_padding - 1
127
+ else:
128
+ last_logits_index = num_tokens - 1
129
+
130
+ last_token_logits = decoder_outputs.logits[last_logits_index, :]
131
+ last_token_position = jnp.array(last_logits_index, dtype=jnp.int32)
132
+
133
+ assert decoder_outputs.updated_kv_cache is not None
134
+ return PrefillResults(
135
+ last_token_logits=last_token_logits,
136
+ last_token_position=last_token_position,
137
+ kv_cache=decoder_outputs.updated_kv_cache,
138
+ )
139
+
140
+ def generate(
141
+ self,
142
+ prompt_token_ids: Int[Array, " prompt_tokens"],
143
+ sampling_policy: SamplingPolicy | None = None,
144
+ prompt_length_without_padding: Int[Array, ""] | int | None = None,
145
+ max_output_length: int = 8192,
146
+ eos_token_ids: Int[Array, " eos_tokens"] | None = None,
147
+ *,
148
+ key: PRNGKeyArray | None = None,
149
+ ) -> Int[Array, " response_tokens"]:
150
+ if sampling_policy is None:
151
+ sampling_policy = TemperaturePolicy(temperature=1.0)
152
+
153
+ (input_length,) = prompt_token_ids.shape
154
+ prefill_results = self._prefill(
155
+ prompt_token_ids,
156
+ prompt_length_without_padding,
157
+ input_length + max_output_length,
158
+ )
159
+
160
+ initial_state = DecodingState(
161
+ prefill_results.last_token_logits,
162
+ prefill_results.last_token_position,
163
+ prefill_results.kv_cache,
164
+ jnp.array(0, dtype=jnp.bool),
165
+ )
166
+
167
+ if key is None:
168
+ key = jax.random.PRNGKey(0)
169
+ keys = jax.random.split(key, num=max_output_length)
170
+
171
+ def loop_iteration(
172
+ state: DecodingState,
173
+ key: PRNGKeyArray,
174
+ ) -> tuple[DecodingState, Int[Array, ""]]:
175
+ def sample_and_update() -> tuple[DecodingState, Int[Array, ""]]:
176
+ processed_logits = sampling_policy.process_logits(state.last_token_logits)
177
+ next_token_id = jax.random.categorical(key, processed_logits)
178
+ next_token_position = state.last_token_position + 1
179
+
180
+ if eos_token_ids is not None:
181
+ stop_flag = state.stop_flag | jnp.any(next_token_id == eos_token_ids)
182
+ else:
183
+ stop_flag = state.stop_flag
184
+
185
+ decoder_outputs = self.decoder(
186
+ next_token_id.reshape(1),
187
+ next_token_position.reshape(1),
188
+ state.kv_cache,
189
+ return_updated_kv_cache=True,
190
+ )
191
+ assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
192
+ new_state = DecodingState(
193
+ decoder_outputs.logits.squeeze(),
194
+ next_token_position,
195
+ decoder_outputs.updated_kv_cache,
196
+ stop_flag,
197
+ )
198
+ return new_state, next_token_id
199
+
200
+ def pad_and_repeat_state() -> tuple[DecodingState, Int[Array, ""]]:
201
+ pad_token = jnp.array(0, dtype=jnp.int32)
202
+ return state, pad_token
203
+
204
+ return jax.lax.cond(state.stop_flag, pad_and_repeat_state, sample_and_update)
205
+
206
+ _, tokens = jax.lax.scan(loop_iteration, initial_state, keys)
207
+
208
+ return tokens
209
+
210
+ def stream(
211
+ self,
212
+ prompt_token_ids: Int[Array, " prompt_tokens"],
213
+ sampling_policy: SamplingPolicy | None = None,
214
+ prompt_length_without_padding: Int[Array, ""] | int | None = None,
215
+ max_output_length: int = 8192,
216
+ eos_token_ids: Int[Array, " eos_tokens"] | None = None,
217
+ *,
218
+ key: PRNGKeyArray | None = None,
219
+ ) -> Iterable[Int[Array, ""]]:
220
+ if sampling_policy is None:
221
+ sampling_policy = TemperaturePolicy(temperature=1.0)
222
+
223
+ (input_length,) = prompt_token_ids.shape
224
+ prefill_results = self._prefill(
225
+ prompt_token_ids,
226
+ prompt_length_without_padding,
227
+ input_length + max_output_length,
228
+ )
229
+
230
+ if key is None:
231
+ key = jax.random.PRNGKey(0)
232
+ keys = jax.random.split(key, num=max_output_length)
233
+
234
+ state = DecodingState(
235
+ prefill_results.last_token_logits,
236
+ prefill_results.last_token_position,
237
+ prefill_results.kv_cache,
238
+ jnp.array(0, dtype=jnp.bool),
239
+ )
240
+
241
+ for iter_key in keys:
242
+ processed_logits = sampling_policy.process_logits(state.last_token_logits)
243
+ next_token_id = jax.random.categorical(iter_key, processed_logits)
244
+
245
+ yield next_token_id
246
+
247
+ if eos_token_ids is not None and jnp.any(next_token_id == eos_token_ids):
248
+ return
249
+
250
+ next_token_position = state.last_token_position + 1
251
+ decoder_outputs = self.decoder(
252
+ next_token_id.reshape(1),
253
+ next_token_position.reshape(1),
254
+ state.kv_cache,
255
+ return_updated_kv_cache=True,
256
+ )
257
+ assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
258
+ state = DecodingState(
259
+ decoder_outputs.logits.squeeze(),
260
+ next_token_position,
261
+ decoder_outputs.updated_kv_cache,
262
+ state.stop_flag,
263
+ )