lalamo 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
+ from lalamo.modules import Decoder
3
+
4
+ __version__ = "0.2.1"
5
+
6
+ __all__ = [
7
+ "REPO_TO_MODEL",
8
+ "Decoder",
9
+ "ModelSpec",
10
+ "import_model",
11
+ ]
lalamo/common.py ADDED
@@ -0,0 +1,60 @@
1
+ from collections.abc import Iterable, Mapping
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array, DTypeLike
5
+
6
+ __all__ = [
7
+ "DEFAULT_PRECISION",
8
+ "ParameterDict",
9
+ "ParameterPath",
10
+ ]
11
+
12
+ DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
13
+
14
+
15
+ type NestedParameters = Mapping[str, Array | NestedParameters] | Iterable[Array | NestedParameters]
16
+
17
+
18
+ class ParameterDict(dict[str, Array]):
19
+ def __init__(self, **kwargs: Array | NestedParameters | Iterable[Array | NestedParameters]) -> None:
20
+ super().__init__(self._flatten(kwargs))
21
+
22
+ def __setitem__(
23
+ self,
24
+ key: str,
25
+ value: Array | NestedParameters | Iterable[Array | NestedParameters],
26
+ ) -> None:
27
+ key = ParameterPath(key)
28
+
29
+ if isinstance(value, Array):
30
+ super().__setitem__(key, value)
31
+ return
32
+
33
+ for subkey, subvalue in self._flatten(value).items():
34
+ super().__setitem__(key / subkey, subvalue)
35
+
36
+ @classmethod
37
+ def _flatten(cls, nested_parameters: NestedParameters) -> dict[str, Array]:
38
+ result: dict[str, Array] = {}
39
+ if not isinstance(nested_parameters, Mapping):
40
+ nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
41
+ for key, value in nested_parameters.items():
42
+ key_path = ParameterPath(key)
43
+ if isinstance(value, Array):
44
+ result[key_path] = value
45
+ else:
46
+ result.update({key_path / subkey: subvalue for subkey, subvalue in cls._flatten(value).items()})
47
+ return result
48
+
49
+
50
+ class ParameterPath(str):
51
+ __slots__ = ()
52
+
53
+ @property
54
+ def components(self) -> tuple[str, ...]:
55
+ return tuple(self.split("."))
56
+
57
+ def __truediv__(self, other: str | int) -> "ParameterPath":
58
+ if not self:
59
+ return ParameterPath(str(other))
60
+ return ParameterPath(self + "." + str(other))
@@ -0,0 +1,263 @@
1
+ from abc import abstractmethod
2
+ from collections.abc import Iterable
3
+ from dataclasses import dataclass
4
+ from typing import NamedTuple
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
10
+
11
+ from lalamo.modules import Decoder, KVCache
12
+
13
+ __all__ = [
14
+ "BanTokensPolicy",
15
+ "CompositePolicy",
16
+ "GreedyPolicy",
17
+ "LanguageModel",
18
+ "SamplingPolicy",
19
+ "TemperaturePolicy",
20
+ "TopKPolicy",
21
+ "TopPPolicy",
22
+ ]
23
+
24
+
25
+ class SamplingPolicy(eqx.Module):
26
+ @abstractmethod
27
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]: ...
28
+
29
+ def __call__(self, logits: Float[Array, " vocabulary"], *, key: PRNGKeyArray) -> Int[Array, ""]:
30
+ return jax.random.categorical(key, self.process_logits(logits))
31
+
32
+
33
+ class GreedyPolicy(SamplingPolicy):
34
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
35
+ max_logit_value = jnp.max(logits)
36
+ return jnp.where(logits == max_logit_value, 1.0, -jnp.inf)
37
+
38
+
39
+ class TemperaturePolicy(SamplingPolicy):
40
+ temperature: float = eqx.field(static=True)
41
+
42
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
43
+ return logits / self.temperature
44
+
45
+
46
+ class TopKPolicy(SamplingPolicy):
47
+ k: int = eqx.field(static=True)
48
+
49
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
50
+ top_k_logits, _ = jax.lax.top_k(logits, self.k)
51
+ min_logit_val = jnp.min(top_k_logits)
52
+ return jnp.where(logits >= min_logit_val, logits, -jnp.inf)
53
+
54
+
55
+ class TopPPolicy(SamplingPolicy):
56
+ p: float = eqx.field(static=True)
57
+
58
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
59
+ sorted_indices = jnp.argsort(logits, descending=True)
60
+ sorted_logits = logits[sorted_indices]
61
+ cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits))
62
+
63
+ to_remove = cumulative_probs > self.p
64
+ to_remove = jnp.roll(to_remove, 1)
65
+ to_remove = to_remove.at[0].set(False)
66
+
67
+ return jnp.where(to_remove, -jnp.inf, logits)
68
+
69
+
70
+ class BanTokensPolicy(SamplingPolicy):
71
+ banned_tokens: list[int] = eqx.field(static=True)
72
+
73
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
74
+ banned_tokens_indices = jnp.asarray(self.banned_tokens, dtype=jnp.int32)
75
+ return logits.at[banned_tokens_indices].set(-jnp.inf)
76
+
77
+
78
+ class CompositePolicy(SamplingPolicy):
79
+ policies: list[SamplingPolicy] = eqx.field(static=True)
80
+
81
+ def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
82
+ for policy in self.policies:
83
+ logits = policy.process_logits(logits)
84
+ return logits
85
+
86
+
87
+ class PrefillResults(NamedTuple):
88
+ last_token_logits: Float[Array, " vocabulary"]
89
+ last_token_position: Int[Array, ""]
90
+ kv_cache: KVCache
91
+
92
+
93
+ class DecodingState(NamedTuple):
94
+ last_token_logits: Float[Array, " vocabulary"]
95
+ last_token_position: Int[Array, ""]
96
+ kv_cache: KVCache
97
+ stop_flag: Bool[Array, ""]
98
+
99
+
100
+ @dataclass(frozen=True)
101
+ class LanguageModel:
102
+ decoder: Decoder
103
+
104
+ def _prefill(
105
+ self,
106
+ token_ids: Int[Array, " tokens"],
107
+ length_without_padding: Int[Array, ""] | int | None = None,
108
+ kv_cache_capacity: int | None = None,
109
+ ) -> PrefillResults:
110
+ (num_tokens,) = token_ids.shape
111
+ token_positions = jnp.arange(num_tokens, dtype=jnp.int32)
112
+ if kv_cache_capacity is not None:
113
+ kv_cache = self.decoder.init_static_kv_cache(kv_cache_capacity)
114
+ else:
115
+ kv_cache = None
116
+
117
+ decoder_outputs = self.decoder(
118
+ token_ids,
119
+ token_positions,
120
+ kv_cache,
121
+ return_updated_kv_cache=True,
122
+ length_without_padding=length_without_padding,
123
+ )
124
+
125
+ if length_without_padding is not None:
126
+ last_logits_index = length_without_padding - 1
127
+ else:
128
+ last_logits_index = num_tokens - 1
129
+
130
+ last_token_logits = decoder_outputs.logits[last_logits_index, :]
131
+ last_token_position = jnp.array(last_logits_index, dtype=jnp.int32)
132
+
133
+ assert decoder_outputs.updated_kv_cache is not None
134
+ return PrefillResults(
135
+ last_token_logits=last_token_logits,
136
+ last_token_position=last_token_position,
137
+ kv_cache=decoder_outputs.updated_kv_cache,
138
+ )
139
+
140
+ def generate(
141
+ self,
142
+ prompt_token_ids: Int[Array, " prompt_tokens"],
143
+ sampling_policy: SamplingPolicy | None = None,
144
+ prompt_length_without_padding: Int[Array, ""] | int | None = None,
145
+ max_output_length: int = 8192,
146
+ eos_token_ids: Int[Array, " eos_tokens"] | None = None,
147
+ *,
148
+ key: PRNGKeyArray | None = None,
149
+ ) -> Int[Array, " response_tokens"]:
150
+ if sampling_policy is None:
151
+ sampling_policy = TemperaturePolicy(temperature=1.0)
152
+
153
+ (input_length,) = prompt_token_ids.shape
154
+ prefill_results = self._prefill(
155
+ prompt_token_ids,
156
+ prompt_length_without_padding,
157
+ input_length + max_output_length,
158
+ )
159
+
160
+ initial_state = DecodingState(
161
+ prefill_results.last_token_logits,
162
+ prefill_results.last_token_position,
163
+ prefill_results.kv_cache,
164
+ jnp.array(0, dtype=jnp.bool),
165
+ )
166
+
167
+ if key is None:
168
+ key = jax.random.PRNGKey(0)
169
+ keys = jax.random.split(key, num=max_output_length)
170
+
171
+ def loop_iteration(
172
+ state: DecodingState,
173
+ key: PRNGKeyArray,
174
+ ) -> tuple[DecodingState, Int[Array, ""]]:
175
+ def sample_and_update() -> tuple[DecodingState, Int[Array, ""]]:
176
+ processed_logits = sampling_policy.process_logits(state.last_token_logits)
177
+ next_token_id = jax.random.categorical(key, processed_logits)
178
+ next_token_position = state.last_token_position + 1
179
+
180
+ if eos_token_ids is not None:
181
+ stop_flag = state.stop_flag | jnp.any(next_token_id == eos_token_ids)
182
+ else:
183
+ stop_flag = state.stop_flag
184
+
185
+ decoder_outputs = self.decoder(
186
+ next_token_id.reshape(1),
187
+ next_token_position.reshape(1),
188
+ state.kv_cache,
189
+ return_updated_kv_cache=True,
190
+ )
191
+ assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
192
+ new_state = DecodingState(
193
+ decoder_outputs.logits.squeeze(),
194
+ next_token_position,
195
+ decoder_outputs.updated_kv_cache,
196
+ stop_flag,
197
+ )
198
+ return new_state, next_token_id
199
+
200
+ def pad_and_repeat_state() -> tuple[DecodingState, Int[Array, ""]]:
201
+ pad_token = jnp.array(0, dtype=jnp.int32)
202
+ return state, pad_token
203
+
204
+ return jax.lax.cond(state.stop_flag, pad_and_repeat_state, sample_and_update)
205
+
206
+ _, tokens = jax.lax.scan(loop_iteration, initial_state, keys)
207
+
208
+ return tokens
209
+
210
+ def stream(
211
+ self,
212
+ prompt_token_ids: Int[Array, " prompt_tokens"],
213
+ sampling_policy: SamplingPolicy | None = None,
214
+ prompt_length_without_padding: Int[Array, ""] | int | None = None,
215
+ max_output_length: int = 8192,
216
+ eos_token_ids: Int[Array, " eos_tokens"] | None = None,
217
+ *,
218
+ key: PRNGKeyArray | None = None,
219
+ ) -> Iterable[Int[Array, ""]]:
220
+ if sampling_policy is None:
221
+ sampling_policy = TemperaturePolicy(temperature=1.0)
222
+
223
+ (input_length,) = prompt_token_ids.shape
224
+ prefill_results = self._prefill(
225
+ prompt_token_ids,
226
+ prompt_length_without_padding,
227
+ input_length + max_output_length,
228
+ )
229
+
230
+ if key is None:
231
+ key = jax.random.PRNGKey(0)
232
+ keys = jax.random.split(key, num=max_output_length)
233
+
234
+ state = DecodingState(
235
+ prefill_results.last_token_logits,
236
+ prefill_results.last_token_position,
237
+ prefill_results.kv_cache,
238
+ jnp.array(0, dtype=jnp.bool),
239
+ )
240
+
241
+ for iter_key in keys:
242
+ processed_logits = sampling_policy.process_logits(state.last_token_logits)
243
+ next_token_id = jax.random.categorical(iter_key, processed_logits)
244
+
245
+ yield next_token_id
246
+
247
+ if eos_token_ids is not None and jnp.any(next_token_id == eos_token_ids):
248
+ return
249
+
250
+ next_token_position = state.last_token_position + 1
251
+ decoder_outputs = self.decoder(
252
+ next_token_id.reshape(1),
253
+ next_token_position.reshape(1),
254
+ state.kv_cache,
255
+ return_updated_kv_cache=True,
256
+ )
257
+ assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
258
+ state = DecodingState(
259
+ decoder_outputs.logits.squeeze(),
260
+ next_token_position,
261
+ decoder_outputs.updated_kv_cache,
262
+ state.stop_flag,
263
+ )
lalamo/main.py ADDED
@@ -0,0 +1,299 @@
1
+ import json
2
+ import re
3
+ import shutil
4
+ import sys
5
+ from enum import Enum
6
+ from pathlib import Path
7
+ from typing import Annotated
8
+
9
+ import jax.numpy as jnp
10
+ import thefuzz.process
11
+ from click import Context as ClickContext
12
+ from click import Parameter as ClickParameter
13
+ from click import ParamType
14
+ from jaxtyping import DTypeLike
15
+ from rich import box
16
+ from rich.console import Console
17
+ from rich.panel import Panel
18
+ from rich.progress import Progress, SpinnerColumn, TextColumn
19
+ from rich.table import Table
20
+ from safetensors.flax import save_file
21
+ from typer import Argument, Exit, Option, Typer
22
+
23
+ from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
24
+ from lalamo.modules import WeightLayout, config_converter
25
+ from lalamo.utils import jax_uint4_to_packed_uint8
26
+
27
+ SCRIPT_NAME = Path(sys.argv[0]).name
28
+
29
+ DEFAULT_OUTPUT_DIR = Path("models")
30
+
31
+
32
+ class Precision(Enum):
33
+ FLOAT32 = "float32"
34
+ FLOAT16 = "float16"
35
+ BFLOAT16 = "bfloat16"
36
+
37
+
38
+ console = Console()
39
+ err_console = Console(stderr=True)
40
+ app = Typer(
41
+ rich_markup_mode="rich",
42
+ add_completion=False,
43
+ pretty_exceptions_show_locals=False,
44
+ )
45
+
46
+
47
+ class ModelParser(ParamType):
48
+ name: str = "Huggingface Model Repo"
49
+
50
+ def convert(self, value: str, param: ClickParameter | None, ctx: ClickContext | None) -> ModelSpec:
51
+ result = REPO_TO_MODEL.get(value)
52
+ if result is None:
53
+ closest_repo = _closest_repo(value)
54
+ error_message_parts = [
55
+ f'"{value}".',
56
+ ]
57
+ if closest_repo:
58
+ error_message_parts.append(
59
+ f' Perhaps you meant "{closest_repo}"?',
60
+ )
61
+ error_message_parts.append(
62
+ f"\n\nUse the `{SCRIPT_NAME} list-models` command to see the list of currently supported models.",
63
+ )
64
+ error_message = "".join(error_message_parts)
65
+ self.fail(error_message, param, ctx)
66
+ return result
67
+
68
+
69
+ def _closest_repo(query: str, min_score: float = 80) -> str | None:
70
+ if not REPO_TO_MODEL:
71
+ return None
72
+ (closest_match, score), *_ = thefuzz.process.extract(query, list(REPO_TO_MODEL))
73
+ if closest_match and score >= min_score:
74
+ return closest_match
75
+ return None
76
+
77
+
78
+ def _error(message: str) -> None:
79
+ panel = Panel(message, box=box.ROUNDED, title="Error", title_align="left", border_style="red")
80
+ err_console.print(panel)
81
+ raise Exit(1)
82
+
83
+
84
+ def _pack_uint4_weights(weights: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]:
85
+ packed_weights = {}
86
+ for key, value in weights.items():
87
+ if value.dtype == jnp.uint4:
88
+ packed_weights[key] = jax_uint4_to_packed_uint8(value)
89
+ else:
90
+ packed_weights[key] = value
91
+ return packed_weights
92
+
93
+
94
+ @app.command(help="Convert the model for use with the Uzu inference engine.")
95
+ def convert(
96
+ model_repo: Annotated[
97
+ ModelSpec,
98
+ Argument(
99
+ help=(
100
+ "Huggingface model repo. Example: [cyan]'meta-llama/Llama-3.2-1B-Instruct'[/cyan]."
101
+ "\n\n\n\n"
102
+ f"You can use the [cyan]`{SCRIPT_NAME} list-models`[/cyan] command to get a list of supported models."
103
+ ),
104
+ click_type=ModelParser(),
105
+ show_default=False,
106
+ metavar="MODEL_REPO",
107
+ autocompletion=lambda: list(REPO_TO_MODEL),
108
+ ),
109
+ ],
110
+ precision: Annotated[
111
+ Precision | None,
112
+ Option(
113
+ help="Precision to use for activations and non-quantized weights.",
114
+ show_default="Native precision of the model",
115
+ ),
116
+ ] = None,
117
+ weight_layout: Annotated[
118
+ WeightLayout | None,
119
+ Option(
120
+ help=(
121
+ "Order of dimensions in the weights of linear layers."
122
+ "\n\n\n\n"
123
+ "If set to AUTO, the layout will depend on the model."
124
+ ),
125
+ show_default="auto",
126
+ ),
127
+ ] = None,
128
+ output_dir: Annotated[
129
+ Path | None,
130
+ Option(
131
+ help="Directory to save the converted model to.",
132
+ show_default="Saves the converted model in the `models/<model_name>` directory",
133
+ ),
134
+ ] = None,
135
+ context_length: Annotated[
136
+ int | None,
137
+ Option(
138
+ help="Maximum supported context length. Used to precompute positional embeddings.",
139
+ show_default="Model's native maximum context length.",
140
+ ),
141
+ ] = None,
142
+ include_traces: Annotated[
143
+ bool,
144
+ Option(
145
+ help="Export activation traces for debugging purposes.",
146
+ ),
147
+ ] = False,
148
+ overwrite: Annotated[
149
+ bool,
150
+ Option(
151
+ help="Overwrite existing model files.",
152
+ ),
153
+ ] = False,
154
+ ) -> None:
155
+ if precision is not None:
156
+ precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
157
+ else:
158
+ precision_dtype = None
159
+
160
+ if weight_layout is not None:
161
+ weight_layout = WeightLayout(weight_layout)
162
+ else:
163
+ weight_layout = WeightLayout.AUTO
164
+
165
+ if output_dir is None:
166
+ output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
167
+
168
+ console.print(f"🚀 Converting [cyan]{model_repo.name}[/cyan] by [cyan]{model_repo.vendor}[/cyan].")
169
+ conversion_strs = [
170
+ f"⚙️ Using weight layout [cyan]{weight_layout}[/cyan]",
171
+ ]
172
+ if precision is not None:
173
+ conversion_strs.append(
174
+ f" and converting floating-point weights into [cyan]{precision.name.lower()}[/cyan] precision",
175
+ )
176
+ conversion_strs.append(".")
177
+ console.print("".join(conversion_strs))
178
+
179
+ if output_dir.exists() and not overwrite:
180
+ answer = console.input(
181
+ rf"⚠️ Output directory [cyan]{output_dir}[/cyan] already exists."
182
+ r" Do you want to overwrite it? [cyan]\[y/n][/cyan]: ",
183
+ )
184
+ while answer.lower() not in ["y", "n", "yes", "no"]:
185
+ answer = console.input("Please enter 'y' or 'n': ")
186
+ if answer.lower() in ["y", "yes"]:
187
+ shutil.rmtree(output_dir)
188
+ else:
189
+ console.print("Exiting...")
190
+ raise Exit
191
+
192
+ with Progress(
193
+ SpinnerColumn(),
194
+ TextColumn("[progress.description]{task.description}"),
195
+ transient=True,
196
+ ) as progress:
197
+ progress.add_task("👨‍🍳 Cooking...")
198
+ model, metadata, tokenizer_file_paths = import_model(
199
+ model_repo,
200
+ precision=precision_dtype,
201
+ context_length=context_length,
202
+ )
203
+ progress.add_task(f"💾 Saving the model to {output_dir}")
204
+ output_dir.mkdir(parents=True, exist_ok=True)
205
+
206
+ weights = dict(model.export_weights(weight_layout))
207
+ packed_weights = _pack_uint4_weights(weights)
208
+ save_file(packed_weights, output_dir / "model.safetensors")
209
+
210
+ config_json = config_converter.unstructure(metadata, ModelMetadata)
211
+ with open(output_dir / "config.json", "w") as file:
212
+ json.dump(config_json, file, indent=4)
213
+
214
+ for path in tokenizer_file_paths:
215
+ shutil.copy(path, output_dir / path.name)
216
+
217
+ if include_traces:
218
+ progress.add_task("🚁 Generating traces...")
219
+
220
+ num_tokens = 512
221
+ token_stride = 8
222
+ token_ids = jnp.arange(0, num_tokens, dtype=jnp.int32)
223
+ token_positions = jnp.arange(0, num_tokens * token_stride, token_stride, dtype=jnp.int32)
224
+ result = model(
225
+ token_ids,
226
+ token_positions,
227
+ return_updated_kv_cache=True,
228
+ return_activation_trace=True,
229
+ )
230
+ traces = dict(result.export())
231
+ save_file(traces, output_dir / "traces.safetensors")
232
+
233
+ console.print(f"🧑‍🍳 Model successfully cooked and saved to [cyan]`{output_dir}`[/cyan]!")
234
+
235
+
236
+ def _model_size_string_to_int(
237
+ size_str: str,
238
+ _regex: re.Pattern = re.compile(r"(?P<number>(\d+)(\.\d*)?)(?P<suffix>[KMBT])"),
239
+ ) -> float:
240
+ match = _regex.match(size_str)
241
+ factors = {
242
+ "K": 1024**1,
243
+ "M": 1024**2,
244
+ "B": 1024**3,
245
+ "T": 1024**4,
246
+ }
247
+ if match:
248
+ return float(match.group("number")) * factors[match.group("suffix")]
249
+ raise ValueError(f"Invalid size string: {size_str}")
250
+
251
+
252
+ @app.command(help="List the supported models.")
253
+ def list_models(
254
+ plain: Annotated[
255
+ bool,
256
+ Option(
257
+ help="Only list repo names without fancy formatting.",
258
+ ),
259
+ ] = False,
260
+ ) -> None:
261
+ sorted_specs = sorted(
262
+ REPO_TO_MODEL.values(),
263
+ key=lambda spec: (
264
+ spec.vendor.lower(),
265
+ spec.family.lower(),
266
+ _model_size_string_to_int(spec.size),
267
+ spec.name.lower(),
268
+ ),
269
+ )
270
+
271
+ if plain:
272
+ for spec in sorted_specs:
273
+ console.print(spec.repo)
274
+ return
275
+
276
+ table = Table(
277
+ show_header=True,
278
+ header_style="bold",
279
+ show_lines=True,
280
+ box=box.ROUNDED,
281
+ )
282
+ table.add_column("Vendor", justify="left", style="magenta")
283
+ table.add_column("Family", justify="left", style="magenta", no_wrap=True)
284
+ table.add_column("Size", justify="right", style="magenta")
285
+ table.add_column("Quant", justify="left", style="magenta")
286
+ table.add_column("Repo", justify="left", style="cyan", no_wrap=True)
287
+ for spec in sorted_specs:
288
+ table.add_row(
289
+ spec.vendor,
290
+ spec.family,
291
+ spec.size,
292
+ str(spec.quantization),
293
+ spec.repo,
294
+ )
295
+ console.print(table)
296
+
297
+
298
+ if __name__ == "__main__":
299
+ app()
lalamo/quantization.py ADDED
@@ -0,0 +1,92 @@
1
+ from enum import Enum
2
+
3
+ from jax import numpy as jnp
4
+ from jaxtyping import Array, DTypeLike, Float
5
+
6
+ __all__ = ["QuantizationMode", "quantize_weights"]
7
+
8
+
9
+ class QuantizationMode(Enum):
10
+ UINT4 = "uint4"
11
+ INT8 = "int8"
12
+ UINT8 = "uint8"
13
+
14
+ @classmethod
15
+ def from_num_bits(cls, num_bits: int) -> "QuantizationMode":
16
+ bit_to_mode = {
17
+ 4: cls.UINT4,
18
+ 8: cls.UINT8,
19
+ }
20
+ if num_bits not in bit_to_mode:
21
+ raise ValueError(f"No quantization mode defined for {num_bits} bits")
22
+ return bit_to_mode[num_bits]
23
+
24
+ @property
25
+ def range(self) -> tuple[int, int]:
26
+ return MODE_TO_RANGE[self]
27
+
28
+ @property
29
+ def dtype(self) -> DTypeLike:
30
+ value_to_dtype = {
31
+ QuantizationMode.UINT4: jnp.uint4,
32
+ QuantizationMode.INT8: jnp.int8,
33
+ QuantizationMode.UINT8: jnp.uint8,
34
+ }
35
+ return value_to_dtype[self]
36
+
37
+ @property
38
+ def bits(self) -> int:
39
+ value_to_bits = {
40
+ QuantizationMode.UINT4: 4,
41
+ QuantizationMode.INT8: 8,
42
+ QuantizationMode.UINT8: 8,
43
+ }
44
+ return value_to_bits[self]
45
+
46
+ def __str__(self) -> str:
47
+ return self.value
48
+
49
+
50
+ MODE_TO_RANGE = {
51
+ QuantizationMode.UINT4: (0, 15),
52
+ QuantizationMode.INT8: (-128, 127),
53
+ QuantizationMode.UINT8: (0, 255),
54
+ }
55
+
56
+
57
+ def quantize_weights(x: Float[Array, "..."], mode: QuantizationMode) -> Float[Array, "..."]:
58
+ range_min, range_max = MODE_TO_RANGE[mode]
59
+ return jnp.clip(jnp.round(x), range_min, range_max)
60
+
61
+
62
+ def dynamically_quantize_activations(
63
+ x: Float[Array, " channels"],
64
+ mode: QuantizationMode,
65
+ ) -> Float[Array, " channels"]:
66
+ # Reference implementation: https://github.com/pytorch/pytorch/blob/2ccbacfa24cae724ec1ea3bc7de189e5bf948d46/torch/ao/quantization/fx/_decomposed.py#L790
67
+ range_min, range_max = mode.range
68
+ min_val = jnp.min(x)
69
+ max_val = jnp.max(x)
70
+ min_val_neg = jnp.minimum(min_val, 0)
71
+ max_val_pos = jnp.maximum(max_val, 0)
72
+
73
+ # scale
74
+ scale = (max_val_pos - min_val_neg) / (range_max - range_min)
75
+ scale = jnp.maximum(scale, jnp.finfo(x.dtype).eps)
76
+
77
+ # zero point
78
+ descaled_min = min_val_neg / scale
79
+ descaled_max = max_val_pos / scale
80
+ zero_point_from_min_error = range_min + descaled_min
81
+ zero_point_from_max_error = range_max + descaled_max
82
+ zero_point = jnp.where(
83
+ zero_point_from_min_error + zero_point_from_max_error > 0,
84
+ range_min - descaled_min,
85
+ range_max - descaled_max,
86
+ )
87
+ zero_point = jnp.round(jnp.clip(zero_point, range_min, range_max))
88
+
89
+ x_normalized = x / scale + zero_point
90
+ x_quantized = jnp.clip(jnp.round(x_normalized), range_min, range_max)
91
+
92
+ return (x_quantized - zero_point) * scale
lalamo/utils.py ADDED
@@ -0,0 +1,55 @@
1
+ import einops
2
+ import jax.numpy as jnp
3
+ import torch.utils.dlpack
4
+ from jaxtyping import Array
5
+
6
+ __all__ = [
7
+ "jax_to_torch",
8
+ "jax_uint4_to_packed_uint8",
9
+ "torch_to_jax",
10
+ ]
11
+
12
+
13
+ @torch.no_grad()
14
+ def _torch_to_jax_bfloat16(tensor: torch.Tensor) -> Array:
15
+ # Credit: https://github.com/jax-ml/ml_dtypes/issues/81#issuecomment-2399636232
16
+ if tensor.dtype != torch.bfloat16:
17
+ raise ValueError("Trying to convert non-bfloat16 tensor to bfloat16")
18
+ intermediate_tensor = tensor.view(torch.uint16)
19
+ return jnp.array(intermediate_tensor).view("bfloat16")
20
+
21
+
22
+ def torch_to_jax(array: torch.Tensor) -> Array:
23
+ array = array.detach().cpu()
24
+ if array.dtype == torch.bfloat16:
25
+ return _torch_to_jax_bfloat16(array)
26
+ return jnp.array(array.numpy())
27
+
28
+
29
+ def jax_to_torch(array: Array) -> torch.Tensor:
30
+ if array.dtype == jnp.bfloat16:
31
+ intermediate_array = array.view(jnp.uint16)
32
+ return torch.utils.dlpack.from_dlpack(intermediate_array).view(torch.bfloat16)
33
+ return torch.utils.dlpack.from_dlpack(array)
34
+
35
+
36
+ def jax_uint4_to_packed_uint8(array: Array) -> Array:
37
+ if array.dtype != jnp.uint4:
38
+ raise ValueError(f"Input array must have dtype jnp.uint4, but got {array.dtype}")
39
+
40
+ if not array.shape:
41
+ raise ValueError("Input array cannot be a scalar and must have at least one dimension.")
42
+
43
+ *_, last_dim = array.shape
44
+ if last_dim % 2 != 0:
45
+ raise ValueError(f"The last dimension of the input array must be even, but got shape {array.shape}")
46
+
47
+ low_nibbles, high_nibbles = einops.rearrange(
48
+ array.astype(jnp.uint8),
49
+ "... (dim_half two) -> two ... dim_half",
50
+ two=2,
51
+ )
52
+
53
+ packed = (high_nibbles << 4) | low_nibbles
54
+
55
+ return packed.astype(jnp.uint8)
@@ -0,0 +1,74 @@
1
+ Metadata-Version: 2.4
2
+ Name: lalamo
3
+ Version: 0.2.1
4
+ Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
+ Requires-Python: <4,>=3.12
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: cattrs>=24.1.2
9
+ Requires-Dist: click>=8.1.8
10
+ Requires-Dist: einops>=0.8.0
11
+ Requires-Dist: equinox>=0.11.11
12
+ Requires-Dist: huggingface-hub[hf-transfer]>=0.27.1
13
+ Requires-Dist: jax>=0.4.38; sys_platform == "darwin"
14
+ Requires-Dist: jax[cuda]>=0.4.38; sys_platform == "linux"
15
+ Requires-Dist: jaxtyping>=0.2.36
16
+ Requires-Dist: ml-dtypes>=0.5.1
17
+ Requires-Dist: optax>=0.2.4
18
+ Requires-Dist: rich>=14.0.0
19
+ Requires-Dist: thefuzz>=0.22.1
20
+ Requires-Dist: typer>=0.15.1
21
+ Dynamic: license-file
22
+
23
+ <p align="center">
24
+ <picture>
25
+ <img alt="Mirai" src="https://artifacts.trymirai.com/social/github/lalamo-header.jpg" style="max-width: 100%;">
26
+ </picture>
27
+ </p>
28
+
29
+ <a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
30
+ <a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
31
+ <a href="mailto:alexey@getmirai.co,dima@getmirai.co,aleksei@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
32
+ <a href="https://docs.trymirai.com/components/models"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
33
+ [![License](https://img.shields.io/badge/License-MIT-blue)](LICENSE)
34
+
35
+ # lalamo
36
+
37
+ A set of tools for adapting Large Language Models to on-device inference using the [uzu](https://github.com/trymirai/uzu) inference engine.
38
+
39
+ ## Quick Start
40
+
41
+ To get the list of [supported models](https://trymirai.com/models), run:
42
+
43
+ ```bash
44
+ uv run lalamo list-models
45
+ ```
46
+
47
+ To convert a model, run:
48
+
49
+ ```bash
50
+ uv run lalamo convert MODEL_REPO --precision float16
51
+ ```
52
+
53
+ After that, you can find the converted model in the `models` folder. For more options see `uv run lalamo convert --help`.
54
+
55
+ ## Model Support
56
+
57
+ To add support for a new model, write the corresponding [ModelSpec](lalamo/model_import/model_specs), as shown in the example below:
58
+
59
+ ```python
60
+ ModelSpec(
61
+ vendor="Google",
62
+ family="Gemma-3",
63
+ name="Gemma-3-1B-Instruct",
64
+ size="1B",
65
+ quantization=None,
66
+ repo="google/gemma-3-1b-it",
67
+ config_type=HFGemma3TextConfig,
68
+ config_file_name="config.json",
69
+ weights_file_names=huggingface_weight_files(1),
70
+ weights_type=WeightsType.SAFETENSORS,
71
+ tokenizer_files=HUGGINGFACE_TOKENIZER_FILES,
72
+ use_cases=tuple(),
73
+ )
74
+ ```
@@ -0,0 +1,12 @@
1
+ lalamo/__init__.py,sha256=uKBR6vAH2AmdpPqz1q2zVVwQyCpWRWUHAfm-uQg8DAM,217
2
+ lalamo/common.py,sha256=uYLw68V4AF3zlENG3KAIKRpOFXVHv8xX_n0cc3qJnj4,1877
3
+ lalamo/language_model.py,sha256=GiA_BDQuYCgVBFHljb_ltW_M7g3I1Siwm111M3Jc8MM,9286
4
+ lalamo/main.py,sha256=K2RLyTcxvBCP0teSsminssj_oUkuQAQ5y9ixa1uOqas,9546
5
+ lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
6
+ lalamo/utils.py,sha256=QzkT0_82nd9pS5p0e7yOOdL_ZeKQr_Ftj4kFrWF35R8,1754
7
+ lalamo-0.2.1.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
8
+ lalamo-0.2.1.dist-info/METADATA,sha256=1qDWPQiCYK_EIeff-oiaF7VeIksGNdZ4nCFikHXGJR4,2611
9
+ lalamo-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ lalamo-0.2.1.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
11
+ lalamo-0.2.1.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
12
+ lalamo-0.2.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ lalamo = lalamo.main:app
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Mirai Tech Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ lalamo