lalamo 0.2.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/common.py +79 -29
  3. lalamo/language_model.py +106 -83
  4. lalamo/main.py +91 -18
  5. lalamo/message_processor.py +170 -0
  6. lalamo/model_import/common.py +159 -43
  7. lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
  8. lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
  9. lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
  10. lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
  11. lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
  12. lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
  13. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
  14. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
  15. lalamo/model_import/huggingface_generation_config.py +44 -0
  16. lalamo/model_import/huggingface_tokenizer_config.py +85 -0
  17. lalamo/model_import/loaders/common.py +2 -1
  18. lalamo/model_import/loaders/huggingface.py +12 -10
  19. lalamo/model_import/model_specs/__init__.py +3 -2
  20. lalamo/model_import/model_specs/common.py +31 -32
  21. lalamo/model_import/model_specs/deepseek.py +1 -10
  22. lalamo/model_import/model_specs/gemma.py +2 -25
  23. lalamo/model_import/model_specs/huggingface.py +2 -12
  24. lalamo/model_import/model_specs/llama.py +2 -58
  25. lalamo/model_import/model_specs/mistral.py +9 -19
  26. lalamo/model_import/model_specs/pleias.py +3 -13
  27. lalamo/model_import/model_specs/polaris.py +5 -7
  28. lalamo/model_import/model_specs/qwen.py +12 -111
  29. lalamo/model_import/model_specs/reka.py +4 -13
  30. lalamo/modules/__init__.py +2 -1
  31. lalamo/modules/attention.py +90 -10
  32. lalamo/modules/common.py +51 -4
  33. lalamo/modules/decoder.py +90 -8
  34. lalamo/modules/decoder_layer.py +85 -8
  35. lalamo/modules/embedding.py +95 -29
  36. lalamo/modules/kv_cache.py +3 -3
  37. lalamo/modules/linear.py +170 -130
  38. lalamo/modules/mlp.py +40 -7
  39. lalamo/modules/normalization.py +24 -6
  40. lalamo/modules/rope.py +24 -6
  41. lalamo/sampling.py +99 -0
  42. lalamo/utils.py +86 -1
  43. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/METADATA +6 -6
  44. lalamo-0.3.1.dist-info/RECORD +58 -0
  45. lalamo-0.2.7.dist-info/RECORD +0 -54
  46. /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
  47. /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
  48. /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
  49. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/WHEEL +0 -0
  50. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/entry_points.txt +0 -0
  51. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/licenses/LICENSE +0 -0
  52. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/top_level.txt +0 -0
lalamo/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
2
  from lalamo.modules import Decoder
3
3
 
4
- __version__ = "0.2.7"
4
+ __version__ = "0.3.1"
5
5
 
6
6
  __all__ = [
7
7
  "REPO_TO_MODEL",
lalamo/common.py CHANGED
@@ -1,50 +1,100 @@
1
- from collections.abc import Iterable, Mapping
1
+ from collections import defaultdict
2
+ from collections.abc import Mapping, Sequence
3
+ from typing import cast
2
4
 
3
5
  import jax.numpy as jnp
6
+ from jax._src.api import ShapeDtypeStruct
4
7
  from jaxtyping import Array, DTypeLike
5
8
 
9
+ from lalamo.utils import MapDictValues, MapSequence
10
+
6
11
  __all__ = [
7
12
  "DEFAULT_PRECISION",
8
- "ParameterDict",
13
+ "ArrayLike",
9
14
  "ParameterPath",
15
+ "ParameterTree",
16
+ "dummy_array",
17
+ "flatten_parameters",
18
+ "unflatten_parameters",
10
19
  ]
11
20
 
12
21
  DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
13
22
 
14
23
 
15
- type NestedParameters = Mapping[str, Array | NestedParameters] | Iterable[Array | NestedParameters]
24
+ type ArrayLike = Array | ShapeDtypeStruct
25
+
26
+
27
+ type ParameterTree[ArrayType: ArrayLike] = (
28
+ Mapping[str, ArrayType | ParameterTree[ArrayType]] | Sequence[ArrayType | ParameterTree[ArrayType]]
29
+ )
30
+
31
+
32
+ def dummy_array(shape: int | tuple[int, ...], dtype: DTypeLike) -> Array:
33
+ if isinstance(shape, int):
34
+ shape = (shape,)
35
+ return cast("Array", ShapeDtypeStruct(shape=shape, dtype=dtype))
36
+
37
+
38
+ def flatten_parameters[ArrayType: ArrayLike](nested_parameters: ParameterTree[ArrayType]) -> dict[str, ArrayType]:
39
+ result: dict[str, ArrayType] = {}
40
+ if not isinstance(nested_parameters, Mapping):
41
+ nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
42
+ for key, value in nested_parameters.items():
43
+ key_path = ParameterPath(key)
44
+ if isinstance(value, (Array, ShapeDtypeStruct)):
45
+ result[key_path] = value
46
+ else:
47
+ update: dict[str, ArrayType] = {
48
+ str(key_path / subkey): subvalue for subkey, subvalue in flatten_parameters(value).items()
49
+ }
50
+ result.update(update)
51
+ return result
52
+
53
+
54
+ type KeyTree = Mapping[str, str | KeyTree] | Sequence[str | KeyTree]
55
+
56
+
57
+ def _unflatten_keys(flat_keys: Mapping[str, str]) -> KeyTree:
58
+ groups: dict[str, dict[str, str] | str] = defaultdict(dict)
59
+ for subkey, full_key in flat_keys.items():
60
+ match subkey.split(".", maxsplit=1):
61
+ case [head]:
62
+ groups[head] = full_key
63
+ case [head, tail]:
64
+ group = groups[head]
65
+ assert isinstance(group, dict)
66
+ group[tail] = full_key
16
67
 
68
+ unflattened_groups: dict[str, KeyTree] = {}
69
+ for subkey, group in groups.items():
70
+ if isinstance(group, str):
71
+ unflattened_groups[subkey] = group
72
+ else:
73
+ unflattened_groups[subkey] = _unflatten_keys(group)
17
74
 
18
- class ParameterDict(dict[str, Array]):
19
- def __init__(self, **kwargs: Array | NestedParameters | Iterable[Array | NestedParameters]) -> None:
20
- super().__init__(self._flatten(kwargs))
75
+ if any(key.isnumeric() for key in unflattened_groups):
76
+ assert set(unflattened_groups.keys()) == set(map(str, range(len(unflattened_groups))))
77
+ return [v for k, v in sorted(unflattened_groups.items(), key=lambda item: int(item[0]))]
78
+ return unflattened_groups
21
79
 
22
- def __setitem__(
23
- self,
24
- key: str,
25
- value: Array | NestedParameters | Iterable[Array | NestedParameters],
26
- ) -> None:
27
- key = ParameterPath(key)
28
80
 
29
- if isinstance(value, Array):
30
- super().__setitem__(key, value)
31
- return
81
+ def _recursive_map_dict[ArrayType: ArrayLike](
82
+ key_tree: KeyTree | str,
83
+ root_collection: Mapping[str, ArrayType],
84
+ ) -> ParameterTree[ArrayType] | ArrayType:
85
+ if isinstance(key_tree, str):
86
+ return root_collection[key_tree]
87
+ if isinstance(key_tree, Mapping):
88
+ return MapDictValues(lambda subtree: _recursive_map_dict(subtree, root_collection), key_tree)
89
+ if isinstance(key_tree, Sequence):
90
+ return MapSequence(lambda subtree: _recursive_map_dict(subtree, root_collection), key_tree)
32
91
 
33
- for subkey, subvalue in self._flatten(value).items():
34
- super().__setitem__(key / subkey, subvalue)
35
92
 
36
- @classmethod
37
- def _flatten(cls, nested_parameters: NestedParameters) -> dict[str, Array]:
38
- result: dict[str, Array] = {}
39
- if not isinstance(nested_parameters, Mapping):
40
- nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
41
- for key, value in nested_parameters.items():
42
- key_path = ParameterPath(key)
43
- if isinstance(value, Array):
44
- result[key_path] = value
45
- else:
46
- result.update({key_path / subkey: subvalue for subkey, subvalue in cls._flatten(value).items()})
47
- return result
93
+ def unflatten_parameters[ArrayType: ArrayLike](flat_parameters: Mapping[str, ArrayType]) -> ParameterTree[ArrayType]:
94
+ unflattened_keys = _unflatten_keys({k: k for k in flat_parameters})
95
+ result = _recursive_map_dict(unflattened_keys, flat_parameters)
96
+ assert not isinstance(result, (Array, ShapeDtypeStruct))
97
+ return result
48
98
 
49
99
 
50
100
  class ParameterPath(str):
lalamo/language_model.py CHANGED
@@ -1,89 +1,28 @@
1
- from abc import abstractmethod
1
+ import json
2
2
  from collections.abc import Iterable
3
- from dataclasses import dataclass
4
- from typing import NamedTuple
3
+ from dataclasses import dataclass, replace
4
+ from pathlib import Path
5
+ from typing import NamedTuple, Self
5
6
 
6
7
  import equinox as eqx
7
8
  import jax
8
9
  import jax.numpy as jnp
9
10
  from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
11
+ from safetensors.flax import load_file
12
+ from tokenizers import Tokenizer
10
13
 
11
- from lalamo.modules import Decoder, KVCache
14
+ from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
15
+ from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
16
+ from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule, WeightLayout, config_converter
17
+ from lalamo.sampling import SamplingPolicy, make_policy
12
18
 
13
19
  __all__ = [
14
- "BanTokensPolicy",
15
- "CompositePolicy",
16
- "GreedyPolicy",
20
+ "GenerationConfig",
17
21
  "LanguageModel",
18
- "SamplingPolicy",
19
- "TemperaturePolicy",
20
- "TopKPolicy",
21
- "TopPPolicy",
22
+ "LanguageModelConfig",
22
23
  ]
23
24
 
24
25
 
25
- class SamplingPolicy(eqx.Module):
26
- @abstractmethod
27
- def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]: ...
28
-
29
- def __call__(self, logits: Float[Array, " vocabulary"], *, key: PRNGKeyArray) -> Int[Array, ""]:
30
- return jax.random.categorical(key, self.process_logits(logits))
31
-
32
-
33
- class GreedyPolicy(SamplingPolicy):
34
- def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
35
- max_logit_value = jnp.max(logits)
36
- return jnp.where(logits == max_logit_value, 1.0, -jnp.inf)
37
-
38
-
39
- class TemperaturePolicy(SamplingPolicy):
40
- temperature: float = eqx.field(static=True)
41
-
42
- def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
43
- return logits / self.temperature
44
-
45
-
46
- class TopKPolicy(SamplingPolicy):
47
- k: int = eqx.field(static=True)
48
-
49
- def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
50
- top_k_logits, _ = jax.lax.top_k(logits, self.k)
51
- min_logit_val = jnp.min(top_k_logits)
52
- return jnp.where(logits >= min_logit_val, logits, -jnp.inf)
53
-
54
-
55
- class TopPPolicy(SamplingPolicy):
56
- p: float = eqx.field(static=True)
57
-
58
- def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
59
- sorted_indices = jnp.argsort(logits, descending=True)
60
- sorted_logits = logits[sorted_indices]
61
- cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits))
62
-
63
- to_remove = cumulative_probs > self.p
64
- to_remove = jnp.roll(to_remove, 1)
65
- to_remove = to_remove.at[0].set(False)
66
-
67
- return jnp.where(to_remove, -jnp.inf, logits)
68
-
69
-
70
- class BanTokensPolicy(SamplingPolicy):
71
- banned_tokens: list[int] = eqx.field(static=True)
72
-
73
- def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
74
- banned_tokens_indices = jnp.asarray(self.banned_tokens, dtype=jnp.int32)
75
- return logits.at[banned_tokens_indices].set(-jnp.inf)
76
-
77
-
78
- class CompositePolicy(SamplingPolicy):
79
- policies: list[SamplingPolicy] = eqx.field(static=True)
80
-
81
- def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
82
- for policy in self.policies:
83
- logits = policy.process_logits(logits)
84
- return logits
85
-
86
-
87
26
  class PrefillResults(NamedTuple):
88
27
  last_token_logits: Float[Array, " vocabulary"]
89
28
  last_token_position: Int[Array, ""]
@@ -98,9 +37,66 @@ class DecodingState(NamedTuple):
98
37
 
99
38
 
100
39
  @dataclass(frozen=True)
101
- class LanguageModel:
40
+ class GenerationConfig:
41
+ stop_token_ids: tuple[int, ...]
42
+ temperature: float | None
43
+ top_k: int | None
44
+ top_p: float | None
45
+ banned_tokens: tuple[int, ...] | None
46
+
47
+ def default_policy(self) -> SamplingPolicy:
48
+ return make_policy(self.temperature, self.top_k, self.top_p, self.banned_tokens)
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class LanguageModelConfig:
53
+ decoder_config: DecoderConfig
54
+ message_processor_config: MessageProcessorConfig
55
+ generation_config: GenerationConfig
56
+
57
+
58
+ class LanguageModel(LalamoModule[LanguageModelConfig]):
102
59
  decoder: Decoder
60
+ message_processor: MessageProcessor = eqx.field(static=True)
61
+
62
+ @classmethod
63
+ def load(cls, path: Path | str, weight_layout: WeightLayout = WeightLayout.AUTO) -> Self:
64
+ if isinstance(path, str):
65
+ path = Path(path)
66
+ with open(path / "config.json") as config_file:
67
+ config_json = json.load(config_file)
68
+ config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
69
+ weights = unflatten_parameters(load_file(path / "model.safetensors"))
70
+ decoder = config.decoder_config.empty().import_weights(weights, weight_layout)
71
+ tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
72
+ message_processor = MessageProcessor(config.message_processor_config, tokenizer)
73
+ return cls(config, decoder, message_processor)
74
+
75
+ @property
76
+ def activation_precision(self) -> DTypeLike:
77
+ return self.decoder.activation_precision
78
+
79
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
80
+ return self.decoder.export_weights(weight_layout)
81
+
82
+ def import_weights(
83
+ self,
84
+ weights: ParameterTree[Array],
85
+ weight_layout: WeightLayout = WeightLayout.AUTO,
86
+ ) -> Self:
87
+ return replace(
88
+ self,
89
+ decoder=self.decoder.import_weights(weights, weight_layout),
90
+ )
91
+
92
+ @property
93
+ def stop_token_ids(self) -> tuple[int, ...]:
94
+ return self.config.generation_config.stop_token_ids
103
95
 
96
+ def default_sampling_policy(self) -> SamplingPolicy:
97
+ return self.config.generation_config.default_policy()
98
+
99
+ @eqx.filter_jit
104
100
  def _prefill(
105
101
  self,
106
102
  token_ids: Int[Array, " tokens"],
@@ -137,7 +133,8 @@ class LanguageModel:
137
133
  kv_cache=decoder_outputs.updated_kv_cache,
138
134
  )
139
135
 
140
- def generate(
136
+ @eqx.filter_jit
137
+ def generate_tokens(
141
138
  self,
142
139
  prompt_token_ids: Int[Array, " prompt_tokens"],
143
140
  sampling_policy: SamplingPolicy | None = None,
@@ -148,7 +145,9 @@ class LanguageModel:
148
145
  key: PRNGKeyArray | None = None,
149
146
  ) -> Int[Array, " response_tokens"]:
150
147
  if sampling_policy is None:
151
- sampling_policy = TemperaturePolicy(temperature=1.0)
148
+ sampling_policy = self.default_sampling_policy()
149
+ if eos_token_ids is None:
150
+ eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
152
151
 
153
152
  (input_length,) = prompt_token_ids.shape
154
153
  prefill_results = self._prefill(
@@ -177,10 +176,7 @@ class LanguageModel:
177
176
  next_token_id = jax.random.categorical(key, processed_logits)
178
177
  next_token_position = state.last_token_position + 1
179
178
 
180
- if eos_token_ids is not None:
181
- stop_flag = state.stop_flag | jnp.any(next_token_id == eos_token_ids)
182
- else:
183
- stop_flag = state.stop_flag
179
+ stop_flag = state.stop_flag | jnp.any(next_token_id == eos_token_ids)
184
180
 
185
181
  decoder_outputs = self.decoder(
186
182
  next_token_id.reshape(1),
@@ -207,7 +203,32 @@ class LanguageModel:
207
203
 
208
204
  return tokens
209
205
 
210
- def stream(
206
+ def reply(
207
+ self,
208
+ messages: Iterable[Message],
209
+ sampling_policy: SamplingPolicy | None = None,
210
+ *,
211
+ key: PRNGKeyArray | None = None,
212
+ ) -> AssistantMessage:
213
+ formatted_messages = self.message_processor.render_request(messages)
214
+ token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
215
+ response_ids = self.generate_tokens(token_ids, sampling_policy, key=key)
216
+ response_text = self.message_processor.detokenize(response_ids.tolist())
217
+ return self.message_processor.parse_response(response_text)
218
+
219
+ def stream_reply_text(
220
+ self,
221
+ messages: Iterable[Message],
222
+ sampling_policy: SamplingPolicy | None = None,
223
+ *,
224
+ key: PRNGKeyArray | None = None,
225
+ ) -> Iterable[str]:
226
+ formatted_messages = self.message_processor.render_request(messages)
227
+ token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
228
+ for token_id in self.stream_tokens(token_ids, sampling_policy, key=key):
229
+ yield self.message_processor.detokenize([token_id.item()])
230
+
231
+ def stream_tokens(
211
232
  self,
212
233
  prompt_token_ids: Int[Array, " prompt_tokens"],
213
234
  sampling_policy: SamplingPolicy | None = None,
@@ -218,7 +239,9 @@ class LanguageModel:
218
239
  key: PRNGKeyArray | None = None,
219
240
  ) -> Iterable[Int[Array, ""]]:
220
241
  if sampling_policy is None:
221
- sampling_policy = TemperaturePolicy(temperature=1.0)
242
+ sampling_policy = self.default_sampling_policy()
243
+ if eos_token_ids is None:
244
+ eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
222
245
 
223
246
  (input_length,) = prompt_token_ids.shape
224
247
  prefill_results = self._prefill(
@@ -244,7 +267,7 @@ class LanguageModel:
244
267
 
245
268
  yield next_token_id
246
269
 
247
- if eos_token_ids is not None and jnp.any(next_token_id == eos_token_ids):
270
+ if jnp.any(next_token_id == eos_token_ids):
248
271
  return
249
272
 
250
273
  next_token_position = state.last_token_position + 1
lalamo/main.py CHANGED
@@ -20,7 +20,17 @@ from rich.table import Table
20
20
  from safetensors.flax import save_file
21
21
  from typer import Argument, Exit, Option, Typer
22
22
 
23
+ from lalamo.common import flatten_parameters
24
+ from lalamo.language_model import LanguageModel
25
+ from lalamo.message_processor import UserMessage
23
26
  from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
27
+ from lalamo.model_import.common import (
28
+ DownloadingFileEvent,
29
+ FinishedDownloadingFileEvent,
30
+ FinishedInitializingModelEvent,
31
+ InitializingModelEvent,
32
+ StatusEvent,
33
+ )
24
34
  from lalamo.modules import WeightLayout, config_converter
25
35
  from lalamo.utils import jax_uint4_to_packed_uint8
26
36
 
@@ -91,6 +101,52 @@ def _pack_uint4_weights(weights: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarra
91
101
  return packed_weights
92
102
 
93
103
 
104
+ @app.command(help="Chat with a converted model.")
105
+ def chat(
106
+ model_path: Annotated[
107
+ Path,
108
+ Argument(
109
+ help="Path to the model directory.",
110
+ metavar="MODEL_PATH",
111
+ ),
112
+ ],
113
+ weight_layout: Annotated[
114
+ WeightLayout | None,
115
+ Option(
116
+ help=(
117
+ "(EXPERIMENTAL) Order of dimensions in the weights of linear layers."
118
+ "\n\n\n\n"
119
+ "If set to AUTO, the layout will depend on the model."
120
+ ),
121
+ show_default="auto",
122
+ ),
123
+ ] = None,
124
+ ) -> None:
125
+ if weight_layout is None:
126
+ weight_layout = WeightLayout.AUTO
127
+ with Progress(
128
+ SpinnerColumn(),
129
+ TextColumn("[progress.description]{task.description}"),
130
+ transient=True,
131
+ ) as progress:
132
+ progress.add_task("🚀 [cyan]Loading model...[/cyan]")
133
+ model = LanguageModel.load(model_path, weight_layout)
134
+ messages = []
135
+ while True:
136
+ user_text = console.input("[cyan]user> [/cyan]")
137
+ user_message = UserMessage(user_text)
138
+ messages.append(user_message)
139
+
140
+ console.print("[red]assistant> [/red]", end="")
141
+ model_response_tokens = []
142
+ for token in model.stream_reply_text(messages):
143
+ console.print(token, end="")
144
+ model_response_tokens.append(token)
145
+ console.print()
146
+ model_response_text = "".join(model_response_tokens)
147
+ messages.append(model.message_processor.parse_response(model_response_text))
148
+
149
+
94
150
  @app.command(help="Convert the model for use with the Uzu inference engine.")
95
151
  def convert(
96
152
  model_repo: Annotated[
@@ -118,7 +174,7 @@ def convert(
118
174
  WeightLayout | None,
119
175
  Option(
120
176
  help=(
121
- "Order of dimensions in the weights of linear layers."
177
+ "(EXPERIMENTAL) Order of dimensions in the weights of linear layers."
122
178
  "\n\n\n\n"
123
179
  "If set to AUTO, the layout will depend on the model."
124
180
  ),
@@ -194,41 +250,58 @@ def convert(
194
250
  TextColumn("[progress.description]{task.description}"),
195
251
  transient=True,
196
252
  ) as progress:
197
- progress.add_task("👨‍🍳 Cooking...")
198
- model, metadata, tokenizer_file_paths = import_model(
253
+ event_to_task = {}
254
+
255
+ def progress_callback(event: StatusEvent) -> None:
256
+ match event:
257
+ case DownloadingFileEvent(file_spec):
258
+ event_to_task[event] = progress.add_task(f"Retrieving {file_spec.filename}...")
259
+ case FinishedDownloadingFileEvent(file_spec):
260
+ progress.remove_task(event_to_task[event])
261
+ case InitializingModelEvent():
262
+ event_to_task[event] = progress.add_task("Initializing model...")
263
+ case FinishedInitializingModelEvent():
264
+ progress.remove_task(event_to_task[event])
265
+
266
+ main_task = progress.add_task("👨‍🍳 Cooking...")
267
+ model, metadata = import_model(
199
268
  model_repo,
200
269
  precision=precision_dtype,
201
270
  context_length=context_length,
271
+ progress_callback=progress_callback,
202
272
  )
203
- progress.add_task(f"💾 Saving the model to {output_dir}")
273
+ save_task = progress.add_task(f"💾 Saving the model to {output_dir}")
204
274
  output_dir.mkdir(parents=True, exist_ok=True)
205
275
 
206
- weights = dict(model.export_weights(weight_layout))
207
- packed_weights = _pack_uint4_weights(weights)
208
- save_file(packed_weights, output_dir / "model.safetensors")
209
-
210
- config_json = config_converter.unstructure(metadata, ModelMetadata)
211
- with open(output_dir / "config.json", "w") as file:
212
- json.dump(config_json, file, indent=4)
213
-
214
- for path in tokenizer_file_paths:
215
- shutil.copy(path, output_dir / path.name)
216
-
217
276
  if include_traces:
218
- progress.add_task("🚁 Generating traces...")
277
+ trace_task = progress.add_task("🚁 Generating traces...")
219
278
 
220
279
  num_tokens = 512
221
280
  token_stride = 8
222
281
  token_ids = jnp.arange(0, num_tokens, dtype=jnp.int32)
223
282
  token_positions = jnp.arange(0, num_tokens * token_stride, token_stride, dtype=jnp.int32)
224
- result = model(
283
+ result = model.decoder(
225
284
  token_ids,
226
285
  token_positions,
227
286
  return_updated_kv_cache=True,
228
287
  return_activation_trace=True,
229
288
  )
230
- traces = dict(result.export())
289
+ traces = flatten_parameters(result.export())
231
290
  save_file(traces, output_dir / "traces.safetensors")
291
+ progress.remove_task(trace_task)
292
+ progress.remove_task(main_task)
293
+
294
+ model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
295
+ weights = flatten_parameters(model.export_weights(weight_layout))
296
+ del model
297
+
298
+ packed_weights = _pack_uint4_weights(weights)
299
+ save_file(packed_weights, output_dir / "model.safetensors")
300
+
301
+ config_json = config_converter.unstructure(metadata, ModelMetadata)
302
+ with open(output_dir / "config.json", "w") as file:
303
+ json.dump(config_json, file, indent=4)
304
+ progress.remove_task(save_task)
232
305
 
233
306
  console.print(f"🧑‍🍳 Model successfully cooked and saved to [cyan]`{output_dir}`[/cyan]!")
234
307