lalamo 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lalamo/__init__.py +20 -5
  2. lalamo/data/__init__.py +8 -0
  3. lalamo/data/huggingface_message.py +38 -0
  4. lalamo/data/lalamo_completions.py +43 -0
  5. lalamo/data/utils.py +8 -0
  6. lalamo/language_model.py +152 -69
  7. lalamo/main.py +271 -43
  8. lalamo/message_processor.py +11 -1
  9. lalamo/model_import/common.py +10 -6
  10. lalamo/model_import/decoder_configs/__init__.py +3 -0
  11. lalamo/model_import/decoder_configs/executorch.py +12 -6
  12. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  13. lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  14. lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  15. lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  16. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  17. lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  18. lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  21. lalamo/model_import/huggingface_tokenizer_config.py +1 -3
  22. lalamo/model_import/loaders/executorch.py +10 -9
  23. lalamo/model_import/loaders/huggingface.py +104 -9
  24. lalamo/model_import/loaders/utils.py +92 -0
  25. lalamo/model_import/model_specs/__init__.py +4 -1
  26. lalamo/model_import/model_specs/common.py +15 -12
  27. lalamo/model_import/model_specs/gpt_oss.py +21 -0
  28. lalamo/modules/__init__.py +35 -7
  29. lalamo/modules/activations.py +24 -14
  30. lalamo/modules/attention.py +73 -20
  31. lalamo/modules/common.py +8 -57
  32. lalamo/modules/decoder.py +48 -34
  33. lalamo/modules/decoder_layer.py +57 -43
  34. lalamo/modules/embedding.py +13 -19
  35. lalamo/modules/kv_cache.py +53 -16
  36. lalamo/modules/linear.py +260 -79
  37. lalamo/modules/mlp.py +395 -23
  38. lalamo/modules/normalization.py +2 -3
  39. lalamo/modules/rope.py +32 -21
  40. lalamo/modules/utils.py +10 -0
  41. lalamo/speculator/__init__.py +11 -0
  42. lalamo/speculator/common.py +22 -0
  43. lalamo/speculator/inference.py +75 -0
  44. lalamo/speculator/ngram.py +154 -0
  45. lalamo/speculator/utils.py +52 -0
  46. lalamo/utils.py +27 -0
  47. {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
  48. lalamo-0.4.0.dist-info/RECORD +71 -0
  49. lalamo-0.3.4.dist-info/RECORD +0 -59
  50. {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
  51. {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
  52. {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
  53. {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
lalamo/__init__.py CHANGED
@@ -1,11 +1,26 @@
1
- from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
- from lalamo.modules import Decoder
1
+ from lalamo.language_model import LanguageModel
2
+ from lalamo.message_processor import (
3
+ AssistantMessage,
4
+ ContentBlock,
5
+ Image,
6
+ Message,
7
+ SystemMessage,
8
+ ToolSchema,
9
+ UserMessage,
10
+ )
11
+ from lalamo.model_import import ModelSpec, import_model
3
12
 
4
- __version__ = "0.3.4"
13
+ __version__ = "0.4.0"
5
14
 
6
15
  __all__ = [
7
- "REPO_TO_MODEL",
8
- "Decoder",
16
+ "AssistantMessage",
17
+ "ContentBlock",
18
+ "Image",
19
+ "LanguageModel",
20
+ "Message",
9
21
  "ModelSpec",
22
+ "SystemMessage",
23
+ "ToolSchema",
24
+ "UserMessage",
10
25
  "import_model",
11
26
  ]
@@ -0,0 +1,8 @@
1
+ from .huggingface_message import import_hf_parquet
2
+ from .utils import get_prefixes_ending_in_user_message
3
+
4
+ __all__ = [
5
+ "get_prefixes_ending_in_user_message",
6
+ "import_hf_parquet",
7
+ ]
8
+
@@ -0,0 +1,38 @@
1
+ from collections.abc import Iterable
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import ClassVar, Self
5
+
6
+ import cattrs
7
+ import polars as pl
8
+
9
+ from lalamo.message_processor import AssistantMessage, Message, UserMessage
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class HFMessage:
14
+ _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
15
+
16
+ role: str
17
+ content: str
18
+
19
+ @classmethod
20
+ def from_dict(cls, obj: dict) -> Self:
21
+ return cls._converter.structure(obj, cls)
22
+
23
+ def as_message(self) -> Message:
24
+ match self.role:
25
+ case "user":
26
+ return UserMessage(self.content)
27
+ case "assistant":
28
+ return AssistantMessage(None, self.content)
29
+ case other:
30
+ raise ValueError(f"Cannot convert {other} message")
31
+
32
+ def import_hf_parquet(path: Path | str) -> Iterable[list[Message]]:
33
+ path = Path(path)
34
+
35
+ dataframe = pl.scan_parquet(path).collect()
36
+
37
+ for conversation in dataframe.get_column("conversation").shuffle(1337):
38
+ yield [HFMessage.from_dict(message).as_message() for message in conversation]
@@ -0,0 +1,43 @@
1
+ from collections.abc import Iterable
2
+ from dataclasses import dataclass
3
+ from typing import IO, Any, ClassVar, Self
4
+
5
+ import msgpack
6
+ from cattrs.preconf.msgpack import MsgpackConverter
7
+ from cattrs.preconf.msgpack import make_converter as make_msgpack_converter
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class LalamoCompletion:
12
+ _converter: ClassVar[MsgpackConverter] = make_msgpack_converter()
13
+
14
+ prefix_token_ids: list[int]
15
+ completion_token_ids: list[int]
16
+ completion_token_logits: list[dict[int, float]]
17
+
18
+ def __post_init__(self) -> None:
19
+ if len(self.completion_token_ids) != len(self.completion_token_logits):
20
+ raise ValueError(f"({len(self.completion_token_ids)=}) != ({len(self.completion_token_logits)=})")
21
+
22
+ def serialize(self) -> bytes:
23
+ return self._converter.dumps(self)
24
+
25
+ @classmethod
26
+ def deserialize(cls, data: bytes | IO[bytes]) -> Self:
27
+ if isinstance(data, bytes):
28
+ obj: Any = msgpack.unpackb(data, strict_map_key=False)
29
+ else:
30
+ obj = msgpack.unpack(data, strict_map_key=False)
31
+
32
+ return cls._converter.structure(obj, cls)
33
+
34
+ @classmethod
35
+ def deserialize_many(cls, data: bytes | IO[bytes]) -> Iterable[Self]:
36
+ if isinstance(data, bytes):
37
+ unpacker = msgpack.Unpacker(strict_map_key=False)
38
+ unpacker.feed(data)
39
+ else:
40
+ unpacker = msgpack.Unpacker(file_like=data, strict_map_key=False)
41
+
42
+ for obj in unpacker:
43
+ yield cls._converter.structure(obj, cls)
lalamo/data/utils.py ADDED
@@ -0,0 +1,8 @@
1
+ from collections.abc import Iterable
2
+
3
+ from lalamo.message_processor import Message, UserMessage
4
+
5
+
6
+ def get_prefixes_ending_in_user_message(conversation: Iterable[Message]) -> list[list[Message]]:
7
+ conversation = list(conversation)
8
+ return [conversation[: i + 1] for i, msg in enumerate(conversation) if isinstance(msg, UserMessage)]
lalamo/language_model.py CHANGED
@@ -7,33 +7,56 @@ from typing import NamedTuple, Self
7
7
  import equinox as eqx
8
8
  import jax
9
9
  import jax.numpy as jnp
10
+ from einops import rearrange
11
+ from jax import vmap
10
12
  from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
11
- from safetensors.flax import load_file
12
13
  from tokenizers import Tokenizer
13
14
 
14
15
  from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
15
16
  from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
16
- from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule, WeightLayout, config_converter
17
+ from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule, config_converter
18
+ from lalamo.modules.common import ForwardPassMode
19
+ from lalamo.modules.decoder import DecoderForwardPassConfig
17
20
  from lalamo.sampling import SamplingPolicy, make_policy
21
+ from lalamo.utils import open_safetensors
18
22
 
19
23
  __all__ = [
24
+ "ForwardPassConfig",
20
25
  "GenerationConfig",
21
26
  "LanguageModel",
22
27
  "LanguageModelConfig",
23
28
  ]
24
29
 
25
30
 
31
+ _COMPILED_PROMPT_LENGTHS = [512 * 2**i for i in range(10)]
32
+
33
+
34
+ type ForwardPassConfig = DecoderForwardPassConfig
35
+
36
+
26
37
  class PrefillResults(NamedTuple):
27
- last_token_logits: Float[Array, " vocabulary"]
28
- last_token_position: Int[Array, ""]
38
+ last_token_logits: Float[Array, "batch vocabulary"]
39
+ last_token_indices: Int[Array, " batch"]
29
40
  kv_cache: KVCache
30
41
 
31
42
 
32
43
  class DecodingState(NamedTuple):
33
- last_token_logits: Float[Array, " vocabulary"]
34
- last_token_position: Int[Array, ""]
44
+ last_token_logits: Float[Array, "batch vocabulary"]
45
+ last_token_indices: Int[Array, " batch"]
35
46
  kv_cache: KVCache
36
- stop_flag: Bool[Array, ""]
47
+ stop_flags: Bool[Array, " batch"]
48
+
49
+
50
+ class GenerationStepResults(NamedTuple):
51
+ token_ids: Int[Array, " batch"]
52
+ top_k_token_ids: Int[Array, " batch k"] | None
53
+ top_k_token_logits: Float[Array, " batch k"] | None
54
+
55
+
56
+ class GenerationResults(NamedTuple):
57
+ token_ids: Int[Array, "batch response_tokens"]
58
+ top_k_token_ids: Int[Array, "batch response_tokens k"] | None
59
+ top_k_token_logits: Float[Array, "batch response_tokens k"] | None
37
60
 
38
61
 
39
62
  @dataclass(frozen=True)
@@ -60,14 +83,15 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
60
83
  message_processor: MessageProcessor = eqx.field(static=True)
61
84
 
62
85
  @classmethod
63
- def load(cls, path: Path | str, weight_layout: WeightLayout = WeightLayout.AUTO) -> Self:
86
+ def load(cls, path: Path | str) -> Self:
64
87
  if isinstance(path, str):
65
88
  path = Path(path)
66
89
  with open(path / "config.json") as config_file:
67
90
  config_json = json.load(config_file)
68
91
  config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
69
- weights = unflatten_parameters(load_file(path / "model.safetensors"))
70
- decoder = config.decoder_config.empty().import_weights(weights, weight_layout)
92
+ with open_safetensors(path / "model.safetensors") as weights_dict:
93
+ weights = unflatten_parameters(weights_dict)
94
+ decoder = config.decoder_config.empty().import_weights(weights)
71
95
  tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
72
96
  message_processor = MessageProcessor(config.message_processor_config, tokenizer)
73
97
  return cls(config, decoder, message_processor)
@@ -76,17 +100,16 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
76
100
  def activation_precision(self) -> DTypeLike:
77
101
  return self.decoder.activation_precision
78
102
 
79
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
80
- return self.decoder.export_weights(weight_layout)
103
+ def export_weights(self) -> ParameterTree:
104
+ return self.decoder.export_weights()
81
105
 
82
106
  def import_weights(
83
107
  self,
84
108
  weights: ParameterTree[Array],
85
- weight_layout: WeightLayout = WeightLayout.AUTO,
86
109
  ) -> Self:
87
110
  return replace(
88
111
  self,
89
- decoder=self.decoder.import_weights(weights, weight_layout),
112
+ decoder=self.decoder.import_weights(weights),
90
113
  )
91
114
 
92
115
  @property
@@ -99,14 +122,15 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
99
122
  @eqx.filter_jit
100
123
  def _prefill(
101
124
  self,
102
- token_ids: Int[Array, " tokens"],
103
- length_without_padding: Int[Array, ""] | int | None = None,
125
+ token_ids: Int[Array, "batch tokens"],
126
+ lengths_without_padding: Int[Array, " batch"] | None = None,
104
127
  kv_cache_capacity: int | None = None,
128
+ forward_pass_config: ForwardPassConfig | None = None,
105
129
  ) -> PrefillResults:
106
- (num_tokens,) = token_ids.shape
107
- token_positions = jnp.arange(num_tokens, dtype=jnp.int32)
130
+ batch_size, sequence_length = token_ids.shape
131
+ token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
108
132
  if kv_cache_capacity is not None:
109
- kv_cache = self.decoder.init_static_kv_cache(kv_cache_capacity)
133
+ kv_cache = self.decoder.init_static_kv_cache(batch_size, kv_cache_capacity)
110
134
  else:
111
135
  kv_cache = None
112
136
 
@@ -115,52 +139,56 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
115
139
  token_positions,
116
140
  kv_cache,
117
141
  return_updated_kv_cache=True,
118
- length_without_padding=length_without_padding,
142
+ lengths_without_padding=lengths_without_padding,
143
+ forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
144
+ forward_pass_config=forward_pass_config,
119
145
  )
120
146
 
121
- if length_without_padding is not None:
122
- last_logits_index = length_without_padding - 1
147
+ if lengths_without_padding is not None:
148
+ last_logits_indices = lengths_without_padding - 1
123
149
  else:
124
- last_logits_index = num_tokens - 1
150
+ last_logits_indices = jnp.array([sequence_length - 1] * batch_size, dtype=jnp.int32)
125
151
 
126
- last_token_logits = decoder_outputs.logits[last_logits_index, :]
127
- last_token_position = jnp.array(last_logits_index, dtype=jnp.int32)
152
+ last_token_logits = vmap(lambda logits, index: logits[index])(decoder_outputs.logits, last_logits_indices)
128
153
 
129
154
  assert decoder_outputs.updated_kv_cache is not None
130
155
  return PrefillResults(
131
156
  last_token_logits=last_token_logits,
132
- last_token_position=last_token_position,
157
+ last_token_indices=last_logits_indices,
133
158
  kv_cache=decoder_outputs.updated_kv_cache,
134
159
  )
135
160
 
136
161
  @eqx.filter_jit
137
162
  def generate_tokens(
138
163
  self,
139
- prompt_token_ids: Int[Array, " prompt_tokens"],
164
+ prompt_token_ids: Int[Array, "batch prompt_tokens"],
140
165
  sampling_policy: SamplingPolicy | None = None,
141
- prompt_length_without_padding: Int[Array, ""] | int | None = None,
166
+ prompt_lengths_without_padding: Int[Array, " batch"] | None = None,
142
167
  max_output_length: int = 8192,
143
168
  eos_token_ids: Int[Array, " eos_tokens"] | None = None,
169
+ forward_pass_config: ForwardPassConfig | None = None,
170
+ num_top_logits_to_return: int | None = None,
144
171
  *,
145
172
  key: PRNGKeyArray | None = None,
146
- ) -> Int[Array, " response_tokens"]:
173
+ ) -> GenerationResults:
147
174
  if sampling_policy is None:
148
175
  sampling_policy = self.default_sampling_policy()
149
176
  if eos_token_ids is None:
150
177
  eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
151
178
 
152
- (input_length,) = prompt_token_ids.shape
179
+ batch_size, sequence_length = prompt_token_ids.shape
153
180
  prefill_results = self._prefill(
154
181
  prompt_token_ids,
155
- prompt_length_without_padding,
156
- input_length + max_output_length,
182
+ prompt_lengths_without_padding,
183
+ sequence_length + max_output_length,
184
+ forward_pass_config=forward_pass_config,
157
185
  )
158
186
 
159
187
  initial_state = DecodingState(
160
188
  prefill_results.last_token_logits,
161
- prefill_results.last_token_position,
189
+ prefill_results.last_token_indices,
162
190
  prefill_results.kv_cache,
163
- jnp.array(0, dtype=jnp.bool),
191
+ jnp.zeros(batch_size, dtype=jnp.bool),
164
192
  )
165
193
 
166
194
  if key is None:
@@ -170,49 +198,88 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
170
198
  def loop_iteration(
171
199
  state: DecodingState,
172
200
  key: PRNGKeyArray,
173
- ) -> tuple[DecodingState, Int[Array, ""]]:
174
- def sample_and_update() -> tuple[DecodingState, Int[Array, ""]]:
175
- processed_logits = sampling_policy.process_logits(state.last_token_logits)
176
- next_token_id = jax.random.categorical(key, processed_logits)
177
- next_token_position = state.last_token_position + 1
178
-
179
- stop_flag = state.stop_flag | jnp.any(next_token_id == eos_token_ids)
201
+ ) -> tuple[DecodingState, GenerationStepResults]:
202
+ def sample_and_update() -> tuple[DecodingState, GenerationStepResults]:
203
+ upcasted_logits = state.last_token_logits.astype(jnp.float32)
204
+ processed_logits = vmap(sampling_policy.process_logits)(upcasted_logits)
205
+ next_token_ids = jax.random.categorical(key, processed_logits)
206
+ next_token_ids = jnp.where(state.stop_flags, jnp.zeros(batch_size, dtype=jnp.int32), next_token_ids)
207
+ if num_top_logits_to_return is not None:
208
+ next_top_k_token_logits, next_top_k_token_ids = jax.lax.top_k(
209
+ processed_logits,
210
+ num_top_logits_to_return,
211
+ )
212
+ else:
213
+ next_top_k_token_ids = None
214
+ next_top_k_token_logits = None
215
+ next_token_indices = state.last_token_indices + 1
216
+
217
+ stop_flags = state.stop_flags | jnp.any(next_token_ids[:, None] == eos_token_ids[None, :], axis=-1)
218
+
219
+ if batch_size == 1:
220
+ forward_pass_mode = ForwardPassMode.SINGLE_TOKEN
221
+ else:
222
+ forward_pass_mode = ForwardPassMode.MULTI_TOKEN
180
223
 
181
224
  decoder_outputs = self.decoder(
182
- next_token_id.reshape(1),
183
- next_token_position.reshape(1),
225
+ next_token_ids[:, None],
226
+ next_token_indices[:, None],
184
227
  state.kv_cache,
185
228
  return_updated_kv_cache=True,
229
+ forward_pass_mode=forward_pass_mode,
230
+ forward_pass_config=forward_pass_config,
186
231
  )
187
232
  assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
188
233
  new_state = DecodingState(
189
- decoder_outputs.logits.squeeze(),
190
- next_token_position,
234
+ decoder_outputs.logits.squeeze(1),
235
+ next_token_indices,
191
236
  decoder_outputs.updated_kv_cache,
192
- stop_flag,
237
+ stop_flags,
193
238
  )
194
- return new_state, next_token_id
239
+ return new_state, GenerationStepResults(next_token_ids, next_top_k_token_ids, next_top_k_token_logits)
195
240
 
196
- def pad_and_repeat_state() -> tuple[DecodingState, Int[Array, ""]]:
197
- pad_token = jnp.array(0, dtype=jnp.int32)
198
- return state, pad_token
241
+ def pad_and_repeat_state() -> tuple[DecodingState, GenerationStepResults]:
242
+ (batch_size,) = state.stop_flags.shape
243
+ pad_token = jnp.zeros(batch_size, dtype=jnp.int32)
244
+ if num_top_logits_to_return is not None:
245
+ top_k_token_ids = jnp.zeros((batch_size, num_top_logits_to_return), dtype=jnp.int32)
246
+ top_k_token_logits = jnp.zeros((batch_size, num_top_logits_to_return), dtype=jnp.float32)
247
+ else:
248
+ top_k_token_ids = None
249
+ top_k_token_logits = None
250
+ return state, GenerationStepResults(pad_token, top_k_token_ids, top_k_token_logits)
199
251
 
200
- return jax.lax.cond(state.stop_flag, pad_and_repeat_state, sample_and_update)
252
+ return jax.lax.cond(jnp.all(state.stop_flags), pad_and_repeat_state, sample_and_update)
201
253
 
202
- _, tokens = jax.lax.scan(loop_iteration, initial_state, keys)
254
+ _, generated = jax.lax.scan(loop_iteration, initial_state, keys)
203
255
 
204
- return tokens
256
+ token_ids = rearrange(generated.token_ids, "iteration batch -> batch iteration")
257
+
258
+ if num_top_logits_to_return is not None:
259
+ top_k_token_ids = rearrange(generated.top_k_token_ids, "iteration batch k -> batch iteration k")
260
+ top_k_token_logits = rearrange(generated.top_k_token_logits, "iteration batch k -> batch iteration k")
261
+ else:
262
+ top_k_token_ids = None
263
+ top_k_token_logits = None
264
+
265
+ return GenerationResults(token_ids, top_k_token_ids, top_k_token_logits)
205
266
 
206
267
  def reply(
207
268
  self,
208
269
  messages: Iterable[Message],
209
270
  sampling_policy: SamplingPolicy | None = None,
271
+ forward_pass_config: ForwardPassConfig | None = None,
210
272
  *,
211
273
  key: PRNGKeyArray | None = None,
212
274
  ) -> AssistantMessage:
213
275
  formatted_messages = self.message_processor.render_request(messages)
214
- token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
215
- response_ids = self.generate_tokens(token_ids, sampling_policy, key=key)
276
+ token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)[None, :]
277
+ response_ids = self.generate_tokens(
278
+ token_ids,
279
+ sampling_policy,
280
+ forward_pass_config=forward_pass_config,
281
+ key=key,
282
+ ).token_ids.squeeze(0)
216
283
  response_text = self.message_processor.detokenize(response_ids.tolist())
217
284
  return self.message_processor.parse_response(response_text)
218
285
 
@@ -220,21 +287,29 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
220
287
  self,
221
288
  messages: Iterable[Message],
222
289
  sampling_policy: SamplingPolicy | None = None,
290
+ max_output_length: int = 8192,
291
+ forward_pass_config: ForwardPassConfig | None = None,
223
292
  *,
224
293
  key: PRNGKeyArray | None = None,
225
294
  ) -> Iterable[str]:
226
295
  formatted_messages = self.message_processor.render_request(messages)
227
296
  token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
228
- for token_id in self.stream_tokens(token_ids, sampling_policy, key=key):
297
+ for token_id in self.stream_tokens(
298
+ token_ids,
299
+ sampling_policy,
300
+ max_output_length,
301
+ forward_pass_config=forward_pass_config,
302
+ key=key,
303
+ ):
229
304
  yield self.message_processor.detokenize([token_id.item()])
230
305
 
231
306
  def stream_tokens(
232
307
  self,
233
308
  prompt_token_ids: Int[Array, " prompt_tokens"],
234
309
  sampling_policy: SamplingPolicy | None = None,
235
- prompt_length_without_padding: Int[Array, ""] | int | None = None,
236
310
  max_output_length: int = 8192,
237
311
  eos_token_ids: Int[Array, " eos_tokens"] | None = None,
312
+ forward_pass_config: ForwardPassConfig | None = None,
238
313
  *,
239
314
  key: PRNGKeyArray | None = None,
240
315
  ) -> Iterable[Int[Array, ""]]:
@@ -244,10 +319,16 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
244
319
  eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
245
320
 
246
321
  (input_length,) = prompt_token_ids.shape
322
+
323
+ padded_input_length = min(length for length in _COMPILED_PROMPT_LENGTHS if length >= input_length)
324
+ padded_token_ids = jnp.zeros((padded_input_length,), dtype=jnp.int32)
325
+ padded_token_ids = padded_token_ids.at[:input_length].set(prompt_token_ids)
326
+
247
327
  prefill_results = self._prefill(
248
- prompt_token_ids,
249
- prompt_length_without_padding,
250
- input_length + max_output_length,
328
+ padded_token_ids[None, :],
329
+ jnp.array([input_length], dtype=jnp.int32),
330
+ padded_input_length + max_output_length,
331
+ forward_pass_config=forward_pass_config,
251
332
  )
252
333
 
253
334
  if key is None:
@@ -256,13 +337,14 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
256
337
 
257
338
  state = DecodingState(
258
339
  prefill_results.last_token_logits,
259
- prefill_results.last_token_position,
340
+ prefill_results.last_token_indices,
260
341
  prefill_results.kv_cache,
261
- jnp.array(0, dtype=jnp.bool),
342
+ jnp.array([0], dtype=jnp.bool),
262
343
  )
263
344
 
264
345
  for iter_key in keys:
265
- processed_logits = sampling_policy.process_logits(state.last_token_logits)
346
+ upcasted_logits = state.last_token_logits.astype(jnp.float32)
347
+ processed_logits = sampling_policy.process_logits(upcasted_logits.squeeze(0))
266
348
  next_token_id = jax.random.categorical(iter_key, processed_logits)
267
349
 
268
350
  yield next_token_id
@@ -270,17 +352,18 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
270
352
  if jnp.any(next_token_id == eos_token_ids):
271
353
  return
272
354
 
273
- next_token_position = state.last_token_position + 1
355
+ next_token_indices = state.last_token_indices + 1
274
356
  decoder_outputs = self.decoder(
275
- next_token_id.reshape(1),
276
- next_token_position.reshape(1),
357
+ next_token_id.reshape(1, 1),
358
+ next_token_indices.reshape(1, 1),
277
359
  state.kv_cache,
278
360
  return_updated_kv_cache=True,
361
+ forward_pass_config=forward_pass_config,
279
362
  )
280
363
  assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
281
364
  state = DecodingState(
282
- decoder_outputs.logits.squeeze(),
283
- next_token_position,
365
+ decoder_outputs.logits.squeeze(1),
366
+ next_token_indices,
284
367
  decoder_outputs.updated_kv_cache,
285
- state.stop_flag,
368
+ state.stop_flags,
286
369
  )