lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,16 @@
1
- from collections.abc import Iterable
2
- from dataclasses import dataclass
1
+ from collections.abc import Callable, Iterable, Iterator
2
+ from dataclasses import dataclass, replace
3
+ from itertools import batched
3
4
  from pathlib import Path
4
5
  from typing import NamedTuple
5
6
 
6
7
  import equinox as eqx
7
8
  import jax
8
9
  import jax.numpy as jnp
9
- from einops import rearrange
10
+ import numpy as np
11
+ from einops import rearrange, repeat
10
12
  from jax import vmap
11
- from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
13
+ from jaxtyping import Array, Bool, Float, Int, Key, PRNGKeyArray
12
14
 
13
15
  from lalamo.message_processor import AssistantMessage, Message, MessageProcessor
14
16
  from lalamo.modules import (
@@ -21,7 +23,15 @@ from lalamo.modules import (
21
23
  )
22
24
  from lalamo.sampling import SamplingPolicy, make_policy
23
25
 
24
- from .common import TextModel, TextModelConfig
26
+ from .common import BatchSizeInfo, BatchSizesComputedEvent, InferenceConfig, TextModel, TextModelConfig
27
+ from .compile_helpers import compile_generate_tokens
28
+ from .lm_helpers import (
29
+ decrease_batchsize_on_oom,
30
+ estimate_batchsizes_from_vram,
31
+ merge_small_buckets,
32
+ pad_keys_to_size,
33
+ pad_sequences,
34
+ )
25
35
 
26
36
  __all__ = [
27
37
  "ForwardPassConfig",
@@ -31,7 +41,7 @@ __all__ = [
31
41
  ]
32
42
 
33
43
 
34
- _COMPILED_PROMPT_LENGTHS = [512 * 2**i for i in range(10)]
44
+ _COMPILED_PROMPT_LENGTHS = [256 * 2**i for i in range(12)]
35
45
 
36
46
 
37
47
  type ForwardPassConfig = DecoderForwardPassConfig
@@ -94,6 +104,13 @@ class LanguageModelConfig(TextModelConfig[DecoderConfig]):
94
104
  return result
95
105
 
96
106
 
107
+ class Chunk(eqx.Module):
108
+ tokens: Int[Array, "num_chunks batch chunk_size"]
109
+ indices: Int[Array, "num_chunks batch chunk_size"]
110
+ sequence_ends: Int[Array, "num_chunks batch"]
111
+ is_last_token_inside: Bool[Array, "num_chunks batch"]
112
+
113
+
97
114
  class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
98
115
  @property
99
116
  def stop_token_ids(self) -> tuple[int, ...]:
@@ -102,68 +119,143 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
102
119
  def default_sampling_policy(self) -> SamplingPolicy:
103
120
  return self.config.generation_config.default_policy()
104
121
 
122
+ @eqx.filter_jit
123
+ def _make_chunks(
124
+ self,
125
+ token_ids: Int[Array, "batch tokens"],
126
+ lengths_without_padding: Int[Array, " batch"] | None,
127
+ chunk_size: int,
128
+ ) -> Chunk:
129
+ batch_size, sequence_length = token_ids.shape
130
+ if lengths_without_padding is None:
131
+ lengths_without_padding = jnp.full((batch_size,), sequence_length, dtype=jnp.int32)
132
+
133
+ # If all sequences fit in a single chunk, use sequence_length as the chunk size
134
+ chunk_size = min(chunk_size, sequence_length)
135
+
136
+ n_chunks = (sequence_length + chunk_size - 1) // chunk_size
137
+ padded_length = n_chunks * chunk_size
138
+
139
+ token_ids = jnp.pad(token_ids, [(0, 0), (0, padded_length - sequence_length)])
140
+
141
+ # Reshape tokens to (num_chunks, batch, chunk_size)
142
+ tokens = rearrange(
143
+ token_ids,
144
+ "batch (num_chunks chunk_size) -> num_chunks batch chunk_size",
145
+ chunk_size=chunk_size,
146
+ )
147
+
148
+ # Create position indices (num_chunks, batch, chunk_size)
149
+ indices = jnp.arange(padded_length, dtype=jnp.int32)
150
+ indices = repeat(indices, "token_idx -> batch token_idx", batch=batch_size)
151
+ indices = rearrange(
152
+ indices,
153
+ "batch (num_chunks chunk_size) -> num_chunks batch chunk_size",
154
+ chunk_size=chunk_size,
155
+ )
156
+
157
+ # sequence_ends: for each chunk, how many valid tokens per batch item
158
+ chunk_starts = jnp.arange(n_chunks, dtype=jnp.int32) * chunk_size
159
+ sequence_ends = jnp.clip(
160
+ lengths_without_padding[None, :] - chunk_starts[:, None],
161
+ 0,
162
+ chunk_size,
163
+ )
164
+
165
+ # last_token_inside: whether the last valid token (at index length-1) is in this chunk
166
+ last_token_idx = lengths_without_padding - 1
167
+ chunk_ends = chunk_starts + chunk_size
168
+ is_last_token_inside = (last_token_idx[None, :] >= chunk_starts[:, None]) & (
169
+ last_token_idx[None, :] < chunk_ends[:, None]
170
+ )
171
+
172
+ return Chunk(
173
+ tokens=tokens,
174
+ indices=indices,
175
+ sequence_ends=sequence_ends,
176
+ is_last_token_inside=is_last_token_inside,
177
+ )
178
+
105
179
  @eqx.filter_jit
106
180
  def _prefill(
107
181
  self,
108
182
  token_ids: Int[Array, "batch tokens"],
183
+ state_capacity: int,
109
184
  lengths_without_padding: Int[Array, " batch"] | None = None,
110
- state_capacity: int | None = None,
111
185
  forward_pass_config: ForwardPassConfig | None = None,
186
+ chunk_size: int = 512, # vllm default
112
187
  ) -> PrefillResults:
113
188
  batch_size, sequence_length = token_ids.shape
114
- token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
115
- if state_capacity is not None:
116
- state = self.model.init_static_state(batch_size, state_capacity)
117
- else:
118
- state = None
119
189
 
120
- decoder_outputs = self.model(
121
- token_ids,
122
- token_positions,
123
- state,
124
- return_updated_state=True,
125
- lengths_without_padding=lengths_without_padding,
126
- forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
127
- forward_pass_config=forward_pass_config,
128
- )
190
+ if lengths_without_padding is None:
191
+ lengths_without_padding = jnp.full((batch_size,), sequence_length, dtype=jnp.int32)
129
192
 
130
- if lengths_without_padding is not None:
131
- last_logits_indices = lengths_without_padding - 1
132
- else:
133
- last_logits_indices = jnp.array([sequence_length - 1] * batch_size, dtype=jnp.int32)
193
+ chunks = self._make_chunks(token_ids, lengths_without_padding, chunk_size)
194
+
195
+ num_chunks, _, chunk_size = chunks.tokens.shape
196
+ state_capacity = max(state_capacity, num_chunks * chunk_size)
134
197
 
135
- last_token_logits = vmap(lambda logits, index: logits[index])(decoder_outputs.logits, last_logits_indices)
198
+ state = self.model.init_static_state(batch_size, state_capacity)
199
+ logits_like = jnp.zeros((batch_size, self.model.vocab_size), dtype=jnp.float32)
200
+
201
+ def apply_chunk(state_and_logits: tuple, chunk: Chunk) -> tuple:
202
+ state, prev_logits = state_and_logits
203
+ decoder_outputs = self.model(
204
+ chunk.tokens,
205
+ chunk.indices,
206
+ state,
207
+ return_updated_state=True,
208
+ lengths_without_padding=chunk.sequence_ends,
209
+ forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
210
+ forward_pass_config=forward_pass_config,
211
+ )
212
+ assert decoder_outputs.updated_state is not None
213
+
214
+ chunk_logits = decoder_outputs.logits[jnp.arange(batch_size), chunk.sequence_ends - 1, :]
215
+ new_logits = jnp.where(chunk.is_last_token_inside[:, None], chunk_logits, prev_logits)
216
+
217
+ return (decoder_outputs.updated_state, new_logits), None
218
+
219
+ (final_state, final_logits), _ = jax.lax.scan(apply_chunk, (state, logits_like), chunks)
136
220
 
137
- assert decoder_outputs.updated_state is not None
138
221
  return PrefillResults(
139
- last_token_logits=last_token_logits,
140
- last_token_indices=last_logits_indices,
141
- state=decoder_outputs.updated_state,
222
+ last_token_logits=final_logits,
223
+ last_token_indices=jnp.maximum(lengths_without_padding - 1, 0),
224
+ state=final_state,
142
225
  )
143
226
 
144
- @eqx.filter_jit
145
227
  def generate_tokens(
146
228
  self,
147
229
  prompt_token_ids: Int[Array, "batch prompt_tokens"],
148
- sampling_policy: SamplingPolicy | None = None,
230
+ generation_config: GenerationConfig | None = None,
149
231
  prompt_lengths_without_padding: Int[Array, " batch"] | None = None,
150
232
  max_output_length: int = 8192,
151
233
  eos_token_ids: Int[Array, " eos_tokens"] | None = None,
152
234
  forward_pass_config: ForwardPassConfig | None = None,
153
235
  num_top_logits_to_return: int | None = None,
154
236
  *,
155
- key: PRNGKeyArray | None = None,
237
+ keys: Key[Array, " batch"] | None = None,
156
238
  ) -> GenerationResults:
157
- if sampling_policy is None:
158
- sampling_policy = self.default_sampling_policy()
239
+ batch_size, sequence_length = prompt_token_ids.shape
240
+
241
+ sampling_policy = self.default_sampling_policy()
242
+ if generation_config is not None:
243
+ sampling_policy = generation_config.default_policy()
244
+
159
245
  if eos_token_ids is None:
160
246
  eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
247
+ if keys is None:
248
+ keys = jax.random.split(jax.random.key(0), num=batch_size)
249
+
250
+ if len(keys) != batch_size:
251
+ raise ValueError(
252
+ f"Length of 'keys' should be equal to the batch size, or keys should be None; got {len(keys)}",
253
+ )
161
254
 
162
- batch_size, sequence_length = prompt_token_ids.shape
163
255
  prefill_results = self._prefill(
164
256
  prompt_token_ids,
165
- prompt_lengths_without_padding,
166
257
  sequence_length + max_output_length,
258
+ prompt_lengths_without_padding,
167
259
  forward_pass_config=forward_pass_config,
168
260
  )
169
261
 
@@ -174,18 +266,14 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
174
266
  jnp.zeros(batch_size, dtype=jnp.bool),
175
267
  )
176
268
 
177
- if key is None:
178
- key = jax.random.PRNGKey(0)
179
- keys = jax.random.split(key, num=max_output_length)
180
-
181
269
  def loop_iteration(
182
270
  state: DecodingState,
183
- key: PRNGKeyArray,
271
+ keys: Key[Array, " batch"],
184
272
  ) -> tuple[DecodingState, GenerationStepResults]:
185
273
  def sample_and_update() -> tuple[DecodingState, GenerationStepResults]:
186
274
  upcasted_logits = state.last_token_logits.astype(jnp.float32)
187
275
  processed_logits = vmap(sampling_policy.process_logits)(upcasted_logits)
188
- next_token_ids = jax.random.categorical(key, processed_logits)
276
+ next_token_ids = jax.vmap(lambda k, logits: jax.random.categorical(k, logits))(keys, processed_logits)
189
277
  next_token_ids = jnp.where(state.stop_flags, jnp.zeros(batch_size, dtype=jnp.int32), next_token_ids)
190
278
  if num_top_logits_to_return is not None:
191
279
  next_top_k_token_logits, next_top_k_token_ids = jax.lax.top_k(
@@ -214,7 +302,7 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
214
302
  )
215
303
  assert decoder_outputs.updated_state is not None, "updated_state should not be None"
216
304
  new_state = DecodingState(
217
- decoder_outputs.logits.squeeze(1),
305
+ decoder_outputs.logits.squeeze(1).astype(jnp.float32),
218
306
  next_token_indices,
219
307
  decoder_outputs.updated_state,
220
308
  stop_flags,
@@ -234,7 +322,9 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
234
322
 
235
323
  return jax.lax.cond(jnp.all(state.stop_flags), pad_and_repeat_state, sample_and_update)
236
324
 
237
- _, generated = jax.lax.scan(loop_iteration, initial_state, keys)
325
+ per_step_keys: Key[Array, "batch max_len"] = jax.vmap(lambda k: jax.random.split(k, max_output_length))(keys)
326
+ per_step_keys: Key[Array, "max_len batch"] = jnp.swapaxes(per_step_keys, 0, 1)
327
+ _, generated = jax.lax.scan(loop_iteration, initial_state, per_step_keys)
238
328
 
239
329
  token_ids = rearrange(generated.token_ids, "iteration batch -> batch iteration")
240
330
 
@@ -247,29 +337,222 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
247
337
 
248
338
  return GenerationResults(token_ids, top_k_token_ids, top_k_token_logits)
249
339
 
340
+ def _generate_tokens_batch(
341
+ self,
342
+ batch: tuple[list[int], ...],
343
+ batch_keys: tuple[Key[Array, ""], ...],
344
+ *,
345
+ generation_config: GenerationConfig | None,
346
+ inference_config: InferenceConfig,
347
+ forward_pass_config: ForwardPassConfig | None,
348
+ ) -> Iterator[GenerationResults]:
349
+ assert inference_config.batch_size is not None
350
+ batch_size = inference_config.batch_size
351
+
352
+ padded_token_ids = pad_sequences(batch, (batch_size, inference_config.padded_length), dtype=jnp.int32)
353
+
354
+ lengths = jnp.array([len(tokens) for tokens in batch], dtype=jnp.int32)
355
+ padded_lengths = jnp.pad(lengths, (0, batch_size - len(batch)))
356
+
357
+ padded_keys = pad_keys_to_size(batch_keys, batch_size)
358
+
359
+ generate_tokens_fn = compile_generate_tokens(
360
+ self,
361
+ generation_config,
362
+ inference_config,
363
+ forward_pass_config=forward_pass_config,
364
+ )
365
+ results = generate_tokens_fn(
366
+ self,
367
+ prompt_token_ids=padded_token_ids,
368
+ prompt_lengths_without_padding=padded_lengths,
369
+ keys=padded_keys,
370
+ )
371
+ for i in range(len(batch)):
372
+ yield GenerationResults(
373
+ token_ids=results.token_ids[i],
374
+ top_k_token_ids=results.top_k_token_ids[i] if results.top_k_token_ids is not None else None,
375
+ top_k_token_logits=results.top_k_token_logits[i] if results.top_k_token_logits is not None else None,
376
+ )
377
+
378
+ def generate_tokens_many(
379
+ self,
380
+ tokenized: Iterable[list[int]],
381
+ generation_config: GenerationConfig | None = None,
382
+ inference_config: InferenceConfig = InferenceConfig(), # noqa: B008
383
+ *,
384
+ forward_pass_config: ForwardPassConfig | None = None,
385
+ keys: Key[Array, " num_sequences"] | None = None,
386
+ ) -> Iterator[GenerationResults]:
387
+ tokenized = list(tokenized) # load eagerly to RAM
388
+
389
+ if keys is None:
390
+ keys = jax.random.split(jax.random.key(0), num=len(tokenized))
391
+
392
+ if len(keys) != len(tokenized):
393
+ raise ValueError(
394
+ f"Length of 'keys' should be equal to the number of sequences passed or None; got {len(keys)}",
395
+ )
396
+
397
+ def process_batches(batch_size: int) -> Iterator[tuple[int, GenerationResults]]:
398
+ new_inference_config = replace(inference_config, batch_size=batch_size)
399
+
400
+ for batch_items in batched(zip(tokenized, keys, strict=True), batch_size):
401
+ real_batch, batch_keys = zip(*batch_items, strict=True)
402
+ yield from self._generate_tokens_batch(
403
+ real_batch,
404
+ batch_keys,
405
+ generation_config=generation_config,
406
+ inference_config=new_inference_config,
407
+ forward_pass_config=forward_pass_config,
408
+ )
409
+
410
+ assert inference_config.batch_size is not None
411
+ yield from decrease_batchsize_on_oom(
412
+ process_batches,
413
+ starting_batch_size=inference_config.batch_size,
414
+ )
415
+
416
+ def estimate_memory_consumption(
417
+ self,
418
+ generation_config: GenerationConfig | None = None,
419
+ inference_config: InferenceConfig = InferenceConfig(), # noqa: B008
420
+ *,
421
+ forward_pass_config: ForwardPassConfig | None = None,
422
+ ) -> int:
423
+ memory_analysis = compile_generate_tokens(
424
+ self,
425
+ generation_config=generation_config,
426
+ inference_config=inference_config,
427
+ forward_pass_config=forward_pass_config,
428
+ ).memory_analysis()
429
+
430
+ assert hasattr(memory_analysis, "argument_size_in_bytes")
431
+ assert hasattr(memory_analysis, "output_size_in_bytes")
432
+ assert hasattr(memory_analysis, "temp_size_in_bytes")
433
+
434
+ return (
435
+ memory_analysis.argument_size_in_bytes
436
+ + memory_analysis.output_size_in_bytes
437
+ + memory_analysis.temp_size_in_bytes
438
+ )
439
+
440
+ def _trim_at_eos(self, token_ids: list[int]) -> list[int]:
441
+ if not self.stop_token_ids:
442
+ return token_ids
443
+ stop_set = set(self.stop_token_ids)
444
+ end = next((i for i, token_id in enumerate(token_ids) if token_id in stop_set), len(token_ids))
445
+ return token_ids[: end + 1]
446
+
250
447
  def reply(
251
448
  self,
252
449
  messages: Iterable[Message],
253
- sampling_policy: SamplingPolicy | None = None,
254
- forward_pass_config: ForwardPassConfig | None = None,
450
+ generation_config: GenerationConfig | None = None,
255
451
  *,
452
+ forward_pass_config: ForwardPassConfig | None = None,
256
453
  key: PRNGKeyArray | None = None,
257
454
  ) -> AssistantMessage:
258
455
  formatted_messages = self.message_processor.render_request(messages)
259
456
  token_ids = jnp.array(self.message_processor.tokenize_text(formatted_messages), dtype=jnp.int32)[None, :]
260
457
  response_ids = self.generate_tokens(
261
458
  token_ids,
262
- sampling_policy,
459
+ generation_config,
263
460
  forward_pass_config=forward_pass_config,
264
- key=key,
461
+ keys=key[None, ...] if key is not None else None,
265
462
  ).token_ids.squeeze(0)
266
- response_text = self.message_processor.detokenize(response_ids.tolist())
463
+ trimmed_ids = self._trim_at_eos(response_ids.tolist())
464
+ response_text = self.message_processor.detokenize(trimmed_ids)
267
465
  return self.message_processor.parse_response(response_text)
268
466
 
467
+ def reply_many(
468
+ self,
469
+ messages: Iterable[Iterable[Message]],
470
+ generation_config: GenerationConfig | None = None,
471
+ inference_config: InferenceConfig = InferenceConfig(), # noqa: B008
472
+ *,
473
+ forward_pass_config: ForwardPassConfig | None = None,
474
+ keys: Key[Array, " num_sequences"] | None = None,
475
+ vram_bytes: int | None = None,
476
+ batch_sizes_callback: Callable[[BatchSizesComputedEvent], None] | None = None,
477
+ ) -> Iterator[tuple[int, AssistantMessage]]:
478
+ messages = list(messages) # eagerly load the dataset into RAM
479
+
480
+ if keys is None:
481
+ keys = jax.random.split(jax.random.key(0), num=len(messages))
482
+
483
+ if len(keys) != len(messages):
484
+ raise ValueError(
485
+ f"Length of 'keys' should be equal to the number of sequences passed or None; got {len(keys)}",
486
+ )
487
+
488
+ if vram_bytes is not None and inference_config.batch_size is not None:
489
+ raise ValueError("You have to specify only one of batch_size and vram_gb, not both.")
490
+
491
+ if vram_bytes is None and inference_config.batch_size is None:
492
+ raise ValueError("You have to specify either batch_size or vram_gb, but you provided neither.")
493
+
494
+ tokenized: list[list[int]] = self.message_processor.tokenize_requests(messages)
495
+
496
+ buckets: dict[int, list[tuple[int, list[int]]]] = {}
497
+ max_prompt_length = max(_COMPILED_PROMPT_LENGTHS)
498
+ for idx, sequence in enumerate(tokenized):
499
+ assert len(sequence) <= max_prompt_length, (
500
+ f"Sequence length {len(sequence)} exceeds largest bucket {max_prompt_length}"
501
+ )
502
+ # we choose the smallest size from precomputed ones that is longer or equal to the current sequence
503
+ padded_len = min(length for length in _COMPILED_PROMPT_LENGTHS if length >= len(sequence))
504
+ buckets.setdefault(padded_len, []).append((idx, sequence))
505
+ sorted_lengths = sorted(buckets.keys())
506
+
507
+ if inference_config.batch_size is not None:
508
+ batch_size_per_bucket = dict.fromkeys(sorted_lengths, inference_config.batch_size)
509
+ else:
510
+ batch_size_per_bucket = estimate_batchsizes_from_vram(
511
+ lambda config: self.estimate_memory_consumption(inference_config=config),
512
+ sorted_lengths,
513
+ vram_bytes, # type: ignore
514
+ inference_config,
515
+ )
516
+
517
+ buckets = merge_small_buckets(buckets, batch_size_per_bucket, min_batches=2)
518
+ assert sum(len(bucket) for bucket in buckets.values()) == len(tokenized)
519
+
520
+ if batch_sizes_callback is not None:
521
+ batch_sizes = tuple(
522
+ BatchSizeInfo(
523
+ prefix_length=padded_length,
524
+ num_elements=len(buckets[padded_length]),
525
+ batch_size=batch_size_per_bucket.get(padded_length, 1),
526
+ )
527
+ for padded_length in sorted(buckets.keys())
528
+ )
529
+ batch_sizes_callback(BatchSizesComputedEvent(batch_sizes=batch_sizes))
530
+
531
+ # Process longest sequences first so batchsize=1 OOM happens as early as possible, if it does happen
532
+ for padded_length in sorted(buckets.keys(), reverse=True):
533
+ sequence_ids, sequence_tokenized = zip(*buckets[padded_length], strict=True)
534
+ sequence_ids = list(sequence_ids)
535
+ batch_size = batch_size_per_bucket[padded_length]
536
+
537
+ bucket_inference_config = replace(inference_config, batch_size=batch_size, padded_length=padded_length)
538
+
539
+ all_results = self.generate_tokens_many(
540
+ sequence_tokenized,
541
+ generation_config=generation_config,
542
+ inference_config=bucket_inference_config,
543
+ forward_pass_config=forward_pass_config,
544
+ keys=keys[np.array(sequence_ids)],
545
+ )
546
+
547
+ for idx, result in zip(sequence_ids, all_results, strict=True):
548
+ trimmed_ids = self._trim_at_eos(result.token_ids.tolist())
549
+ response = self.message_processor.parse_tokenized_response(trimmed_ids)
550
+ yield (idx, response)
551
+
269
552
  def stream_reply_text(
270
553
  self,
271
554
  messages: Iterable[Message],
272
- sampling_policy: SamplingPolicy | None = None,
555
+ generation_config: GenerationConfig | None = None,
273
556
  max_output_length: int = 8192,
274
557
  forward_pass_config: ForwardPassConfig | None = None,
275
558
  *,
@@ -279,7 +562,7 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
279
562
  token_ids = jnp.array(self.message_processor.tokenize_text(formatted_messages), dtype=jnp.int32)
280
563
  for token_id in self.stream_tokens(
281
564
  token_ids,
282
- sampling_policy,
565
+ generation_config,
283
566
  max_output_length,
284
567
  forward_pass_config=forward_pass_config,
285
568
  key=key,
@@ -289,15 +572,17 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
289
572
  def stream_tokens(
290
573
  self,
291
574
  prompt_token_ids: Int[Array, " prompt_tokens"],
292
- sampling_policy: SamplingPolicy | None = None,
575
+ generation_config: GenerationConfig | None = None,
293
576
  max_output_length: int = 8192,
294
577
  eos_token_ids: Int[Array, " eos_tokens"] | None = None,
295
578
  forward_pass_config: ForwardPassConfig | None = None,
296
579
  *,
297
580
  key: PRNGKeyArray | None = None,
298
581
  ) -> Iterable[Int[Array, ""]]:
299
- if sampling_policy is None:
300
- sampling_policy = self.default_sampling_policy()
582
+ sampling_policy = self.default_sampling_policy()
583
+ if generation_config is not None:
584
+ sampling_policy = generation_config.default_policy()
585
+
301
586
  if eos_token_ids is None:
302
587
  eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
303
588
 
@@ -309,8 +594,8 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
309
594
 
310
595
  prefill_results = self._prefill(
311
596
  padded_token_ids[None, :],
312
- jnp.array([input_length], dtype=jnp.int32),
313
597
  padded_input_length + max_output_length,
598
+ lengths_without_padding=jnp.array([input_length], dtype=jnp.int32),
314
599
  forward_pass_config=forward_pass_config,
315
600
  )
316
601
 
@@ -341,6 +626,7 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
341
626
  next_token_indices.reshape(1, 1),
342
627
  state.state,
343
628
  return_updated_state=True,
629
+ forward_pass_mode=ForwardPassMode.SINGLE_TOKEN,
344
630
  forward_pass_config=forward_pass_config,
345
631
  )
346
632
  assert decoder_outputs.updated_state is not None, "updated_state should not be None"