lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/commands.py +247 -14
- lalamo/common.py +33 -0
- lalamo/data/__init__.py +3 -2
- lalamo/data/huggingface_message.py +4 -5
- lalamo/main.py +274 -9
- lalamo/message_processor.py +19 -1
- lalamo/model_import/common.py +17 -1
- lalamo/model_import/model_specs/mistral.py +5 -0
- lalamo/model_import/remote_registry.py +44 -0
- lalamo/models/__init__.py +3 -0
- lalamo/models/common.py +22 -0
- lalamo/models/compile_helpers.py +58 -0
- lalamo/models/language_model.py +342 -56
- lalamo/models/lm_helpers.py +198 -0
- lalamo/modules/decoder.py +4 -0
- lalamo/modules/token_mixers/mamba.py +345 -105
- lalamo/speculator/__init__.py +0 -2
- lalamo/speculator/inference.py +35 -61
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/METADATA +1 -1
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/RECORD +25 -23
- lalamo/speculator/estimator.py +0 -127
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/WHEEL +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/top_level.txt +0 -0
lalamo/models/language_model.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
|
-
from collections.abc import Iterable
|
|
2
|
-
from dataclasses import dataclass
|
|
1
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
3
|
+
from itertools import batched
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import NamedTuple
|
|
5
6
|
|
|
6
7
|
import equinox as eqx
|
|
7
8
|
import jax
|
|
8
9
|
import jax.numpy as jnp
|
|
9
|
-
|
|
10
|
+
import numpy as np
|
|
11
|
+
from einops import rearrange, repeat
|
|
10
12
|
from jax import vmap
|
|
11
|
-
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
|
|
13
|
+
from jaxtyping import Array, Bool, Float, Int, Key, PRNGKeyArray
|
|
12
14
|
|
|
13
15
|
from lalamo.message_processor import AssistantMessage, Message, MessageProcessor
|
|
14
16
|
from lalamo.modules import (
|
|
@@ -21,7 +23,15 @@ from lalamo.modules import (
|
|
|
21
23
|
)
|
|
22
24
|
from lalamo.sampling import SamplingPolicy, make_policy
|
|
23
25
|
|
|
24
|
-
from .common import TextModel, TextModelConfig
|
|
26
|
+
from .common import BatchSizeInfo, BatchSizesComputedEvent, InferenceConfig, TextModel, TextModelConfig
|
|
27
|
+
from .compile_helpers import compile_generate_tokens
|
|
28
|
+
from .lm_helpers import (
|
|
29
|
+
decrease_batchsize_on_oom,
|
|
30
|
+
estimate_batchsizes_from_vram,
|
|
31
|
+
merge_small_buckets,
|
|
32
|
+
pad_keys_to_size,
|
|
33
|
+
pad_sequences,
|
|
34
|
+
)
|
|
25
35
|
|
|
26
36
|
__all__ = [
|
|
27
37
|
"ForwardPassConfig",
|
|
@@ -31,7 +41,7 @@ __all__ = [
|
|
|
31
41
|
]
|
|
32
42
|
|
|
33
43
|
|
|
34
|
-
_COMPILED_PROMPT_LENGTHS = [
|
|
44
|
+
_COMPILED_PROMPT_LENGTHS = [256 * 2**i for i in range(12)]
|
|
35
45
|
|
|
36
46
|
|
|
37
47
|
type ForwardPassConfig = DecoderForwardPassConfig
|
|
@@ -94,6 +104,13 @@ class LanguageModelConfig(TextModelConfig[DecoderConfig]):
|
|
|
94
104
|
return result
|
|
95
105
|
|
|
96
106
|
|
|
107
|
+
class Chunk(eqx.Module):
|
|
108
|
+
tokens: Int[Array, "num_chunks batch chunk_size"]
|
|
109
|
+
indices: Int[Array, "num_chunks batch chunk_size"]
|
|
110
|
+
sequence_ends: Int[Array, "num_chunks batch"]
|
|
111
|
+
is_last_token_inside: Bool[Array, "num_chunks batch"]
|
|
112
|
+
|
|
113
|
+
|
|
97
114
|
class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
98
115
|
@property
|
|
99
116
|
def stop_token_ids(self) -> tuple[int, ...]:
|
|
@@ -102,68 +119,143 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
102
119
|
def default_sampling_policy(self) -> SamplingPolicy:
|
|
103
120
|
return self.config.generation_config.default_policy()
|
|
104
121
|
|
|
122
|
+
@eqx.filter_jit
|
|
123
|
+
def _make_chunks(
|
|
124
|
+
self,
|
|
125
|
+
token_ids: Int[Array, "batch tokens"],
|
|
126
|
+
lengths_without_padding: Int[Array, " batch"] | None,
|
|
127
|
+
chunk_size: int,
|
|
128
|
+
) -> Chunk:
|
|
129
|
+
batch_size, sequence_length = token_ids.shape
|
|
130
|
+
if lengths_without_padding is None:
|
|
131
|
+
lengths_without_padding = jnp.full((batch_size,), sequence_length, dtype=jnp.int32)
|
|
132
|
+
|
|
133
|
+
# If all sequences fit in a single chunk, use sequence_length as the chunk size
|
|
134
|
+
chunk_size = min(chunk_size, sequence_length)
|
|
135
|
+
|
|
136
|
+
n_chunks = (sequence_length + chunk_size - 1) // chunk_size
|
|
137
|
+
padded_length = n_chunks * chunk_size
|
|
138
|
+
|
|
139
|
+
token_ids = jnp.pad(token_ids, [(0, 0), (0, padded_length - sequence_length)])
|
|
140
|
+
|
|
141
|
+
# Reshape tokens to (num_chunks, batch, chunk_size)
|
|
142
|
+
tokens = rearrange(
|
|
143
|
+
token_ids,
|
|
144
|
+
"batch (num_chunks chunk_size) -> num_chunks batch chunk_size",
|
|
145
|
+
chunk_size=chunk_size,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Create position indices (num_chunks, batch, chunk_size)
|
|
149
|
+
indices = jnp.arange(padded_length, dtype=jnp.int32)
|
|
150
|
+
indices = repeat(indices, "token_idx -> batch token_idx", batch=batch_size)
|
|
151
|
+
indices = rearrange(
|
|
152
|
+
indices,
|
|
153
|
+
"batch (num_chunks chunk_size) -> num_chunks batch chunk_size",
|
|
154
|
+
chunk_size=chunk_size,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# sequence_ends: for each chunk, how many valid tokens per batch item
|
|
158
|
+
chunk_starts = jnp.arange(n_chunks, dtype=jnp.int32) * chunk_size
|
|
159
|
+
sequence_ends = jnp.clip(
|
|
160
|
+
lengths_without_padding[None, :] - chunk_starts[:, None],
|
|
161
|
+
0,
|
|
162
|
+
chunk_size,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# last_token_inside: whether the last valid token (at index length-1) is in this chunk
|
|
166
|
+
last_token_idx = lengths_without_padding - 1
|
|
167
|
+
chunk_ends = chunk_starts + chunk_size
|
|
168
|
+
is_last_token_inside = (last_token_idx[None, :] >= chunk_starts[:, None]) & (
|
|
169
|
+
last_token_idx[None, :] < chunk_ends[:, None]
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return Chunk(
|
|
173
|
+
tokens=tokens,
|
|
174
|
+
indices=indices,
|
|
175
|
+
sequence_ends=sequence_ends,
|
|
176
|
+
is_last_token_inside=is_last_token_inside,
|
|
177
|
+
)
|
|
178
|
+
|
|
105
179
|
@eqx.filter_jit
|
|
106
180
|
def _prefill(
|
|
107
181
|
self,
|
|
108
182
|
token_ids: Int[Array, "batch tokens"],
|
|
183
|
+
state_capacity: int,
|
|
109
184
|
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
110
|
-
state_capacity: int | None = None,
|
|
111
185
|
forward_pass_config: ForwardPassConfig | None = None,
|
|
186
|
+
chunk_size: int = 512, # vllm default
|
|
112
187
|
) -> PrefillResults:
|
|
113
188
|
batch_size, sequence_length = token_ids.shape
|
|
114
|
-
token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
|
|
115
|
-
if state_capacity is not None:
|
|
116
|
-
state = self.model.init_static_state(batch_size, state_capacity)
|
|
117
|
-
else:
|
|
118
|
-
state = None
|
|
119
189
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
token_positions,
|
|
123
|
-
state,
|
|
124
|
-
return_updated_state=True,
|
|
125
|
-
lengths_without_padding=lengths_without_padding,
|
|
126
|
-
forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
|
|
127
|
-
forward_pass_config=forward_pass_config,
|
|
128
|
-
)
|
|
190
|
+
if lengths_without_padding is None:
|
|
191
|
+
lengths_without_padding = jnp.full((batch_size,), sequence_length, dtype=jnp.int32)
|
|
129
192
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
193
|
+
chunks = self._make_chunks(token_ids, lengths_without_padding, chunk_size)
|
|
194
|
+
|
|
195
|
+
num_chunks, _, chunk_size = chunks.tokens.shape
|
|
196
|
+
state_capacity = max(state_capacity, num_chunks * chunk_size)
|
|
134
197
|
|
|
135
|
-
|
|
198
|
+
state = self.model.init_static_state(batch_size, state_capacity)
|
|
199
|
+
logits_like = jnp.zeros((batch_size, self.model.vocab_size), dtype=jnp.float32)
|
|
200
|
+
|
|
201
|
+
def apply_chunk(state_and_logits: tuple, chunk: Chunk) -> tuple:
|
|
202
|
+
state, prev_logits = state_and_logits
|
|
203
|
+
decoder_outputs = self.model(
|
|
204
|
+
chunk.tokens,
|
|
205
|
+
chunk.indices,
|
|
206
|
+
state,
|
|
207
|
+
return_updated_state=True,
|
|
208
|
+
lengths_without_padding=chunk.sequence_ends,
|
|
209
|
+
forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
|
|
210
|
+
forward_pass_config=forward_pass_config,
|
|
211
|
+
)
|
|
212
|
+
assert decoder_outputs.updated_state is not None
|
|
213
|
+
|
|
214
|
+
chunk_logits = decoder_outputs.logits[jnp.arange(batch_size), chunk.sequence_ends - 1, :]
|
|
215
|
+
new_logits = jnp.where(chunk.is_last_token_inside[:, None], chunk_logits, prev_logits)
|
|
216
|
+
|
|
217
|
+
return (decoder_outputs.updated_state, new_logits), None
|
|
218
|
+
|
|
219
|
+
(final_state, final_logits), _ = jax.lax.scan(apply_chunk, (state, logits_like), chunks)
|
|
136
220
|
|
|
137
|
-
assert decoder_outputs.updated_state is not None
|
|
138
221
|
return PrefillResults(
|
|
139
|
-
last_token_logits=
|
|
140
|
-
last_token_indices=
|
|
141
|
-
state=
|
|
222
|
+
last_token_logits=final_logits,
|
|
223
|
+
last_token_indices=jnp.maximum(lengths_without_padding - 1, 0),
|
|
224
|
+
state=final_state,
|
|
142
225
|
)
|
|
143
226
|
|
|
144
|
-
@eqx.filter_jit
|
|
145
227
|
def generate_tokens(
|
|
146
228
|
self,
|
|
147
229
|
prompt_token_ids: Int[Array, "batch prompt_tokens"],
|
|
148
|
-
|
|
230
|
+
generation_config: GenerationConfig | None = None,
|
|
149
231
|
prompt_lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
150
232
|
max_output_length: int = 8192,
|
|
151
233
|
eos_token_ids: Int[Array, " eos_tokens"] | None = None,
|
|
152
234
|
forward_pass_config: ForwardPassConfig | None = None,
|
|
153
235
|
num_top_logits_to_return: int | None = None,
|
|
154
236
|
*,
|
|
155
|
-
|
|
237
|
+
keys: Key[Array, " batch"] | None = None,
|
|
156
238
|
) -> GenerationResults:
|
|
157
|
-
|
|
158
|
-
|
|
239
|
+
batch_size, sequence_length = prompt_token_ids.shape
|
|
240
|
+
|
|
241
|
+
sampling_policy = self.default_sampling_policy()
|
|
242
|
+
if generation_config is not None:
|
|
243
|
+
sampling_policy = generation_config.default_policy()
|
|
244
|
+
|
|
159
245
|
if eos_token_ids is None:
|
|
160
246
|
eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
|
|
247
|
+
if keys is None:
|
|
248
|
+
keys = jax.random.split(jax.random.key(0), num=batch_size)
|
|
249
|
+
|
|
250
|
+
if len(keys) != batch_size:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
f"Length of 'keys' should be equal to the batch size, or keys should be None; got {len(keys)}",
|
|
253
|
+
)
|
|
161
254
|
|
|
162
|
-
batch_size, sequence_length = prompt_token_ids.shape
|
|
163
255
|
prefill_results = self._prefill(
|
|
164
256
|
prompt_token_ids,
|
|
165
|
-
prompt_lengths_without_padding,
|
|
166
257
|
sequence_length + max_output_length,
|
|
258
|
+
prompt_lengths_without_padding,
|
|
167
259
|
forward_pass_config=forward_pass_config,
|
|
168
260
|
)
|
|
169
261
|
|
|
@@ -174,18 +266,14 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
174
266
|
jnp.zeros(batch_size, dtype=jnp.bool),
|
|
175
267
|
)
|
|
176
268
|
|
|
177
|
-
if key is None:
|
|
178
|
-
key = jax.random.PRNGKey(0)
|
|
179
|
-
keys = jax.random.split(key, num=max_output_length)
|
|
180
|
-
|
|
181
269
|
def loop_iteration(
|
|
182
270
|
state: DecodingState,
|
|
183
|
-
|
|
271
|
+
keys: Key[Array, " batch"],
|
|
184
272
|
) -> tuple[DecodingState, GenerationStepResults]:
|
|
185
273
|
def sample_and_update() -> tuple[DecodingState, GenerationStepResults]:
|
|
186
274
|
upcasted_logits = state.last_token_logits.astype(jnp.float32)
|
|
187
275
|
processed_logits = vmap(sampling_policy.process_logits)(upcasted_logits)
|
|
188
|
-
next_token_ids = jax.random.categorical(
|
|
276
|
+
next_token_ids = jax.vmap(lambda k, logits: jax.random.categorical(k, logits))(keys, processed_logits)
|
|
189
277
|
next_token_ids = jnp.where(state.stop_flags, jnp.zeros(batch_size, dtype=jnp.int32), next_token_ids)
|
|
190
278
|
if num_top_logits_to_return is not None:
|
|
191
279
|
next_top_k_token_logits, next_top_k_token_ids = jax.lax.top_k(
|
|
@@ -214,7 +302,7 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
214
302
|
)
|
|
215
303
|
assert decoder_outputs.updated_state is not None, "updated_state should not be None"
|
|
216
304
|
new_state = DecodingState(
|
|
217
|
-
decoder_outputs.logits.squeeze(1),
|
|
305
|
+
decoder_outputs.logits.squeeze(1).astype(jnp.float32),
|
|
218
306
|
next_token_indices,
|
|
219
307
|
decoder_outputs.updated_state,
|
|
220
308
|
stop_flags,
|
|
@@ -234,7 +322,9 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
234
322
|
|
|
235
323
|
return jax.lax.cond(jnp.all(state.stop_flags), pad_and_repeat_state, sample_and_update)
|
|
236
324
|
|
|
237
|
-
|
|
325
|
+
per_step_keys: Key[Array, "batch max_len"] = jax.vmap(lambda k: jax.random.split(k, max_output_length))(keys)
|
|
326
|
+
per_step_keys: Key[Array, "max_len batch"] = jnp.swapaxes(per_step_keys, 0, 1)
|
|
327
|
+
_, generated = jax.lax.scan(loop_iteration, initial_state, per_step_keys)
|
|
238
328
|
|
|
239
329
|
token_ids = rearrange(generated.token_ids, "iteration batch -> batch iteration")
|
|
240
330
|
|
|
@@ -247,29 +337,222 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
247
337
|
|
|
248
338
|
return GenerationResults(token_ids, top_k_token_ids, top_k_token_logits)
|
|
249
339
|
|
|
340
|
+
def _generate_tokens_batch(
|
|
341
|
+
self,
|
|
342
|
+
batch: tuple[list[int], ...],
|
|
343
|
+
batch_keys: tuple[Key[Array, ""], ...],
|
|
344
|
+
*,
|
|
345
|
+
generation_config: GenerationConfig | None,
|
|
346
|
+
inference_config: InferenceConfig,
|
|
347
|
+
forward_pass_config: ForwardPassConfig | None,
|
|
348
|
+
) -> Iterator[GenerationResults]:
|
|
349
|
+
assert inference_config.batch_size is not None
|
|
350
|
+
batch_size = inference_config.batch_size
|
|
351
|
+
|
|
352
|
+
padded_token_ids = pad_sequences(batch, (batch_size, inference_config.padded_length), dtype=jnp.int32)
|
|
353
|
+
|
|
354
|
+
lengths = jnp.array([len(tokens) for tokens in batch], dtype=jnp.int32)
|
|
355
|
+
padded_lengths = jnp.pad(lengths, (0, batch_size - len(batch)))
|
|
356
|
+
|
|
357
|
+
padded_keys = pad_keys_to_size(batch_keys, batch_size)
|
|
358
|
+
|
|
359
|
+
generate_tokens_fn = compile_generate_tokens(
|
|
360
|
+
self,
|
|
361
|
+
generation_config,
|
|
362
|
+
inference_config,
|
|
363
|
+
forward_pass_config=forward_pass_config,
|
|
364
|
+
)
|
|
365
|
+
results = generate_tokens_fn(
|
|
366
|
+
self,
|
|
367
|
+
prompt_token_ids=padded_token_ids,
|
|
368
|
+
prompt_lengths_without_padding=padded_lengths,
|
|
369
|
+
keys=padded_keys,
|
|
370
|
+
)
|
|
371
|
+
for i in range(len(batch)):
|
|
372
|
+
yield GenerationResults(
|
|
373
|
+
token_ids=results.token_ids[i],
|
|
374
|
+
top_k_token_ids=results.top_k_token_ids[i] if results.top_k_token_ids is not None else None,
|
|
375
|
+
top_k_token_logits=results.top_k_token_logits[i] if results.top_k_token_logits is not None else None,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
def generate_tokens_many(
|
|
379
|
+
self,
|
|
380
|
+
tokenized: Iterable[list[int]],
|
|
381
|
+
generation_config: GenerationConfig | None = None,
|
|
382
|
+
inference_config: InferenceConfig = InferenceConfig(), # noqa: B008
|
|
383
|
+
*,
|
|
384
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
385
|
+
keys: Key[Array, " num_sequences"] | None = None,
|
|
386
|
+
) -> Iterator[GenerationResults]:
|
|
387
|
+
tokenized = list(tokenized) # load eagerly to RAM
|
|
388
|
+
|
|
389
|
+
if keys is None:
|
|
390
|
+
keys = jax.random.split(jax.random.key(0), num=len(tokenized))
|
|
391
|
+
|
|
392
|
+
if len(keys) != len(tokenized):
|
|
393
|
+
raise ValueError(
|
|
394
|
+
f"Length of 'keys' should be equal to the number of sequences passed or None; got {len(keys)}",
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
def process_batches(batch_size: int) -> Iterator[tuple[int, GenerationResults]]:
|
|
398
|
+
new_inference_config = replace(inference_config, batch_size=batch_size)
|
|
399
|
+
|
|
400
|
+
for batch_items in batched(zip(tokenized, keys, strict=True), batch_size):
|
|
401
|
+
real_batch, batch_keys = zip(*batch_items, strict=True)
|
|
402
|
+
yield from self._generate_tokens_batch(
|
|
403
|
+
real_batch,
|
|
404
|
+
batch_keys,
|
|
405
|
+
generation_config=generation_config,
|
|
406
|
+
inference_config=new_inference_config,
|
|
407
|
+
forward_pass_config=forward_pass_config,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
assert inference_config.batch_size is not None
|
|
411
|
+
yield from decrease_batchsize_on_oom(
|
|
412
|
+
process_batches,
|
|
413
|
+
starting_batch_size=inference_config.batch_size,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
def estimate_memory_consumption(
|
|
417
|
+
self,
|
|
418
|
+
generation_config: GenerationConfig | None = None,
|
|
419
|
+
inference_config: InferenceConfig = InferenceConfig(), # noqa: B008
|
|
420
|
+
*,
|
|
421
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
422
|
+
) -> int:
|
|
423
|
+
memory_analysis = compile_generate_tokens(
|
|
424
|
+
self,
|
|
425
|
+
generation_config=generation_config,
|
|
426
|
+
inference_config=inference_config,
|
|
427
|
+
forward_pass_config=forward_pass_config,
|
|
428
|
+
).memory_analysis()
|
|
429
|
+
|
|
430
|
+
assert hasattr(memory_analysis, "argument_size_in_bytes")
|
|
431
|
+
assert hasattr(memory_analysis, "output_size_in_bytes")
|
|
432
|
+
assert hasattr(memory_analysis, "temp_size_in_bytes")
|
|
433
|
+
|
|
434
|
+
return (
|
|
435
|
+
memory_analysis.argument_size_in_bytes
|
|
436
|
+
+ memory_analysis.output_size_in_bytes
|
|
437
|
+
+ memory_analysis.temp_size_in_bytes
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def _trim_at_eos(self, token_ids: list[int]) -> list[int]:
|
|
441
|
+
if not self.stop_token_ids:
|
|
442
|
+
return token_ids
|
|
443
|
+
stop_set = set(self.stop_token_ids)
|
|
444
|
+
end = next((i for i, token_id in enumerate(token_ids) if token_id in stop_set), len(token_ids))
|
|
445
|
+
return token_ids[: end + 1]
|
|
446
|
+
|
|
250
447
|
def reply(
|
|
251
448
|
self,
|
|
252
449
|
messages: Iterable[Message],
|
|
253
|
-
|
|
254
|
-
forward_pass_config: ForwardPassConfig | None = None,
|
|
450
|
+
generation_config: GenerationConfig | None = None,
|
|
255
451
|
*,
|
|
452
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
256
453
|
key: PRNGKeyArray | None = None,
|
|
257
454
|
) -> AssistantMessage:
|
|
258
455
|
formatted_messages = self.message_processor.render_request(messages)
|
|
259
456
|
token_ids = jnp.array(self.message_processor.tokenize_text(formatted_messages), dtype=jnp.int32)[None, :]
|
|
260
457
|
response_ids = self.generate_tokens(
|
|
261
458
|
token_ids,
|
|
262
|
-
|
|
459
|
+
generation_config,
|
|
263
460
|
forward_pass_config=forward_pass_config,
|
|
264
|
-
|
|
461
|
+
keys=key[None, ...] if key is not None else None,
|
|
265
462
|
).token_ids.squeeze(0)
|
|
266
|
-
|
|
463
|
+
trimmed_ids = self._trim_at_eos(response_ids.tolist())
|
|
464
|
+
response_text = self.message_processor.detokenize(trimmed_ids)
|
|
267
465
|
return self.message_processor.parse_response(response_text)
|
|
268
466
|
|
|
467
|
+
def reply_many(
|
|
468
|
+
self,
|
|
469
|
+
messages: Iterable[Iterable[Message]],
|
|
470
|
+
generation_config: GenerationConfig | None = None,
|
|
471
|
+
inference_config: InferenceConfig = InferenceConfig(), # noqa: B008
|
|
472
|
+
*,
|
|
473
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
474
|
+
keys: Key[Array, " num_sequences"] | None = None,
|
|
475
|
+
vram_bytes: int | None = None,
|
|
476
|
+
batch_sizes_callback: Callable[[BatchSizesComputedEvent], None] | None = None,
|
|
477
|
+
) -> Iterator[tuple[int, AssistantMessage]]:
|
|
478
|
+
messages = list(messages) # eagerly load the dataset into RAM
|
|
479
|
+
|
|
480
|
+
if keys is None:
|
|
481
|
+
keys = jax.random.split(jax.random.key(0), num=len(messages))
|
|
482
|
+
|
|
483
|
+
if len(keys) != len(messages):
|
|
484
|
+
raise ValueError(
|
|
485
|
+
f"Length of 'keys' should be equal to the number of sequences passed or None; got {len(keys)}",
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if vram_bytes is not None and inference_config.batch_size is not None:
|
|
489
|
+
raise ValueError("You have to specify only one of batch_size and vram_gb, not both.")
|
|
490
|
+
|
|
491
|
+
if vram_bytes is None and inference_config.batch_size is None:
|
|
492
|
+
raise ValueError("You have to specify either batch_size or vram_gb, but you provided neither.")
|
|
493
|
+
|
|
494
|
+
tokenized: list[list[int]] = self.message_processor.tokenize_requests(messages)
|
|
495
|
+
|
|
496
|
+
buckets: dict[int, list[tuple[int, list[int]]]] = {}
|
|
497
|
+
max_prompt_length = max(_COMPILED_PROMPT_LENGTHS)
|
|
498
|
+
for idx, sequence in enumerate(tokenized):
|
|
499
|
+
assert len(sequence) <= max_prompt_length, (
|
|
500
|
+
f"Sequence length {len(sequence)} exceeds largest bucket {max_prompt_length}"
|
|
501
|
+
)
|
|
502
|
+
# we choose the smallest size from precomputed ones that is longer or equal to the current sequence
|
|
503
|
+
padded_len = min(length for length in _COMPILED_PROMPT_LENGTHS if length >= len(sequence))
|
|
504
|
+
buckets.setdefault(padded_len, []).append((idx, sequence))
|
|
505
|
+
sorted_lengths = sorted(buckets.keys())
|
|
506
|
+
|
|
507
|
+
if inference_config.batch_size is not None:
|
|
508
|
+
batch_size_per_bucket = dict.fromkeys(sorted_lengths, inference_config.batch_size)
|
|
509
|
+
else:
|
|
510
|
+
batch_size_per_bucket = estimate_batchsizes_from_vram(
|
|
511
|
+
lambda config: self.estimate_memory_consumption(inference_config=config),
|
|
512
|
+
sorted_lengths,
|
|
513
|
+
vram_bytes, # type: ignore
|
|
514
|
+
inference_config,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
buckets = merge_small_buckets(buckets, batch_size_per_bucket, min_batches=2)
|
|
518
|
+
assert sum(len(bucket) for bucket in buckets.values()) == len(tokenized)
|
|
519
|
+
|
|
520
|
+
if batch_sizes_callback is not None:
|
|
521
|
+
batch_sizes = tuple(
|
|
522
|
+
BatchSizeInfo(
|
|
523
|
+
prefix_length=padded_length,
|
|
524
|
+
num_elements=len(buckets[padded_length]),
|
|
525
|
+
batch_size=batch_size_per_bucket.get(padded_length, 1),
|
|
526
|
+
)
|
|
527
|
+
for padded_length in sorted(buckets.keys())
|
|
528
|
+
)
|
|
529
|
+
batch_sizes_callback(BatchSizesComputedEvent(batch_sizes=batch_sizes))
|
|
530
|
+
|
|
531
|
+
# Process longest sequences first so batchsize=1 OOM happens as early as possible, if it does happen
|
|
532
|
+
for padded_length in sorted(buckets.keys(), reverse=True):
|
|
533
|
+
sequence_ids, sequence_tokenized = zip(*buckets[padded_length], strict=True)
|
|
534
|
+
sequence_ids = list(sequence_ids)
|
|
535
|
+
batch_size = batch_size_per_bucket[padded_length]
|
|
536
|
+
|
|
537
|
+
bucket_inference_config = replace(inference_config, batch_size=batch_size, padded_length=padded_length)
|
|
538
|
+
|
|
539
|
+
all_results = self.generate_tokens_many(
|
|
540
|
+
sequence_tokenized,
|
|
541
|
+
generation_config=generation_config,
|
|
542
|
+
inference_config=bucket_inference_config,
|
|
543
|
+
forward_pass_config=forward_pass_config,
|
|
544
|
+
keys=keys[np.array(sequence_ids)],
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
for idx, result in zip(sequence_ids, all_results, strict=True):
|
|
548
|
+
trimmed_ids = self._trim_at_eos(result.token_ids.tolist())
|
|
549
|
+
response = self.message_processor.parse_tokenized_response(trimmed_ids)
|
|
550
|
+
yield (idx, response)
|
|
551
|
+
|
|
269
552
|
def stream_reply_text(
|
|
270
553
|
self,
|
|
271
554
|
messages: Iterable[Message],
|
|
272
|
-
|
|
555
|
+
generation_config: GenerationConfig | None = None,
|
|
273
556
|
max_output_length: int = 8192,
|
|
274
557
|
forward_pass_config: ForwardPassConfig | None = None,
|
|
275
558
|
*,
|
|
@@ -279,7 +562,7 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
279
562
|
token_ids = jnp.array(self.message_processor.tokenize_text(formatted_messages), dtype=jnp.int32)
|
|
280
563
|
for token_id in self.stream_tokens(
|
|
281
564
|
token_ids,
|
|
282
|
-
|
|
565
|
+
generation_config,
|
|
283
566
|
max_output_length,
|
|
284
567
|
forward_pass_config=forward_pass_config,
|
|
285
568
|
key=key,
|
|
@@ -289,15 +572,17 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
289
572
|
def stream_tokens(
|
|
290
573
|
self,
|
|
291
574
|
prompt_token_ids: Int[Array, " prompt_tokens"],
|
|
292
|
-
|
|
575
|
+
generation_config: GenerationConfig | None = None,
|
|
293
576
|
max_output_length: int = 8192,
|
|
294
577
|
eos_token_ids: Int[Array, " eos_tokens"] | None = None,
|
|
295
578
|
forward_pass_config: ForwardPassConfig | None = None,
|
|
296
579
|
*,
|
|
297
580
|
key: PRNGKeyArray | None = None,
|
|
298
581
|
) -> Iterable[Int[Array, ""]]:
|
|
299
|
-
|
|
300
|
-
|
|
582
|
+
sampling_policy = self.default_sampling_policy()
|
|
583
|
+
if generation_config is not None:
|
|
584
|
+
sampling_policy = generation_config.default_policy()
|
|
585
|
+
|
|
301
586
|
if eos_token_ids is None:
|
|
302
587
|
eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
|
|
303
588
|
|
|
@@ -309,8 +594,8 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
309
594
|
|
|
310
595
|
prefill_results = self._prefill(
|
|
311
596
|
padded_token_ids[None, :],
|
|
312
|
-
jnp.array([input_length], dtype=jnp.int32),
|
|
313
597
|
padded_input_length + max_output_length,
|
|
598
|
+
lengths_without_padding=jnp.array([input_length], dtype=jnp.int32),
|
|
314
599
|
forward_pass_config=forward_pass_config,
|
|
315
600
|
)
|
|
316
601
|
|
|
@@ -341,6 +626,7 @@ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
|
341
626
|
next_token_indices.reshape(1, 1),
|
|
342
627
|
state.state,
|
|
343
628
|
return_updated_state=True,
|
|
629
|
+
forward_pass_mode=ForwardPassMode.SINGLE_TOKEN,
|
|
344
630
|
forward_pass_config=forward_pass_config,
|
|
345
631
|
)
|
|
346
632
|
assert decoder_outputs.updated_state is not None, "updated_state should not be None"
|