lalamo 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +20 -5
- lalamo/data/__init__.py +8 -0
- lalamo/data/huggingface_message.py +38 -0
- lalamo/data/lalamo_completions.py +43 -0
- lalamo/data/utils.py +8 -0
- lalamo/language_model.py +152 -69
- lalamo/main.py +273 -45
- lalamo/message_processor.py +11 -1
- lalamo/model_import/common.py +10 -6
- lalamo/model_import/decoder_configs/__init__.py +3 -0
- lalamo/model_import/decoder_configs/executorch.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- lalamo/model_import/huggingface_tokenizer_config.py +1 -3
- lalamo/model_import/loaders/executorch.py +10 -9
- lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo/model_import/loaders/utils.py +92 -0
- lalamo/model_import/model_specs/__init__.py +4 -1
- lalamo/model_import/model_specs/common.py +15 -12
- lalamo/model_import/model_specs/gpt_oss.py +21 -0
- lalamo/modules/__init__.py +35 -7
- lalamo/modules/activations.py +24 -14
- lalamo/modules/attention.py +73 -20
- lalamo/modules/common.py +8 -57
- lalamo/modules/decoder.py +48 -34
- lalamo/modules/decoder_layer.py +57 -43
- lalamo/modules/embedding.py +13 -19
- lalamo/modules/kv_cache.py +53 -16
- lalamo/modules/linear.py +260 -79
- lalamo/modules/mlp.py +395 -23
- lalamo/modules/normalization.py +2 -3
- lalamo/modules/rope.py +32 -21
- lalamo/modules/utils.py +10 -0
- lalamo/speculator/__init__.py +11 -0
- lalamo/speculator/common.py +22 -0
- lalamo/speculator/inference.py +75 -0
- lalamo/speculator/ngram.py +154 -0
- lalamo/speculator/utils.py +52 -0
- lalamo/utils.py +27 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/METADATA +11 -4
- lalamo-0.4.1.dist-info/RECORD +71 -0
- lalamo-0.3.4.dist-info/RECORD +0 -59
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/WHEEL +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/entry_points.txt +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -1,11 +1,26 @@
|
|
|
1
|
-
from lalamo.
|
|
2
|
-
from lalamo.
|
|
1
|
+
from lalamo.language_model import LanguageModel
|
|
2
|
+
from lalamo.message_processor import (
|
|
3
|
+
AssistantMessage,
|
|
4
|
+
ContentBlock,
|
|
5
|
+
Image,
|
|
6
|
+
Message,
|
|
7
|
+
SystemMessage,
|
|
8
|
+
ToolSchema,
|
|
9
|
+
UserMessage,
|
|
10
|
+
)
|
|
11
|
+
from lalamo.model_import import ModelSpec, import_model
|
|
3
12
|
|
|
4
|
-
__version__ = "0.
|
|
13
|
+
__version__ = "0.4.1"
|
|
5
14
|
|
|
6
15
|
__all__ = [
|
|
7
|
-
"
|
|
8
|
-
"
|
|
16
|
+
"AssistantMessage",
|
|
17
|
+
"ContentBlock",
|
|
18
|
+
"Image",
|
|
19
|
+
"LanguageModel",
|
|
20
|
+
"Message",
|
|
9
21
|
"ModelSpec",
|
|
22
|
+
"SystemMessage",
|
|
23
|
+
"ToolSchema",
|
|
24
|
+
"UserMessage",
|
|
10
25
|
"import_model",
|
|
11
26
|
]
|
lalamo/data/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import ClassVar, Self
|
|
5
|
+
|
|
6
|
+
import cattrs
|
|
7
|
+
import polars as pl
|
|
8
|
+
|
|
9
|
+
from lalamo.message_processor import AssistantMessage, Message, UserMessage
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class HFMessage:
|
|
14
|
+
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
15
|
+
|
|
16
|
+
role: str
|
|
17
|
+
content: str
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_dict(cls, obj: dict) -> Self:
|
|
21
|
+
return cls._converter.structure(obj, cls)
|
|
22
|
+
|
|
23
|
+
def as_message(self) -> Message:
|
|
24
|
+
match self.role:
|
|
25
|
+
case "user":
|
|
26
|
+
return UserMessage(self.content)
|
|
27
|
+
case "assistant":
|
|
28
|
+
return AssistantMessage(None, self.content)
|
|
29
|
+
case other:
|
|
30
|
+
raise ValueError(f"Cannot convert {other} message")
|
|
31
|
+
|
|
32
|
+
def import_hf_parquet(path: Path | str) -> Iterable[list[Message]]:
|
|
33
|
+
path = Path(path)
|
|
34
|
+
|
|
35
|
+
dataframe = pl.scan_parquet(path).collect()
|
|
36
|
+
|
|
37
|
+
for conversation in dataframe.get_column("conversation").shuffle(1337):
|
|
38
|
+
yield [HFMessage.from_dict(message).as_message() for message in conversation]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import IO, Any, ClassVar, Self
|
|
4
|
+
|
|
5
|
+
import msgpack
|
|
6
|
+
from cattrs.preconf.msgpack import MsgpackConverter
|
|
7
|
+
from cattrs.preconf.msgpack import make_converter as make_msgpack_converter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class LalamoCompletion:
|
|
12
|
+
_converter: ClassVar[MsgpackConverter] = make_msgpack_converter()
|
|
13
|
+
|
|
14
|
+
prefix_token_ids: list[int]
|
|
15
|
+
completion_token_ids: list[int]
|
|
16
|
+
completion_token_logits: list[dict[int, float]]
|
|
17
|
+
|
|
18
|
+
def __post_init__(self) -> None:
|
|
19
|
+
if len(self.completion_token_ids) != len(self.completion_token_logits):
|
|
20
|
+
raise ValueError(f"({len(self.completion_token_ids)=}) != ({len(self.completion_token_logits)=})")
|
|
21
|
+
|
|
22
|
+
def serialize(self) -> bytes:
|
|
23
|
+
return self._converter.dumps(self)
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def deserialize(cls, data: bytes | IO[bytes]) -> Self:
|
|
27
|
+
if isinstance(data, bytes):
|
|
28
|
+
obj: Any = msgpack.unpackb(data, strict_map_key=False)
|
|
29
|
+
else:
|
|
30
|
+
obj = msgpack.unpack(data, strict_map_key=False)
|
|
31
|
+
|
|
32
|
+
return cls._converter.structure(obj, cls)
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def deserialize_many(cls, data: bytes | IO[bytes]) -> Iterable[Self]:
|
|
36
|
+
if isinstance(data, bytes):
|
|
37
|
+
unpacker = msgpack.Unpacker(strict_map_key=False)
|
|
38
|
+
unpacker.feed(data)
|
|
39
|
+
else:
|
|
40
|
+
unpacker = msgpack.Unpacker(file_like=data, strict_map_key=False)
|
|
41
|
+
|
|
42
|
+
for obj in unpacker:
|
|
43
|
+
yield cls._converter.structure(obj, cls)
|
lalamo/data/utils.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
from lalamo.message_processor import Message, UserMessage
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_prefixes_ending_in_user_message(conversation: Iterable[Message]) -> list[list[Message]]:
|
|
7
|
+
conversation = list(conversation)
|
|
8
|
+
return [conversation[: i + 1] for i, msg in enumerate(conversation) if isinstance(msg, UserMessage)]
|
lalamo/language_model.py
CHANGED
|
@@ -7,33 +7,56 @@ from typing import NamedTuple, Self
|
|
|
7
7
|
import equinox as eqx
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
|
+
from einops import rearrange
|
|
11
|
+
from jax import vmap
|
|
10
12
|
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
|
|
11
|
-
from safetensors.flax import load_file
|
|
12
13
|
from tokenizers import Tokenizer
|
|
13
14
|
|
|
14
15
|
from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
|
|
15
16
|
from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
|
|
16
|
-
from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule,
|
|
17
|
+
from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule, config_converter
|
|
18
|
+
from lalamo.modules.common import ForwardPassMode
|
|
19
|
+
from lalamo.modules.decoder import DecoderForwardPassConfig
|
|
17
20
|
from lalamo.sampling import SamplingPolicy, make_policy
|
|
21
|
+
from lalamo.utils import open_safetensors
|
|
18
22
|
|
|
19
23
|
__all__ = [
|
|
24
|
+
"ForwardPassConfig",
|
|
20
25
|
"GenerationConfig",
|
|
21
26
|
"LanguageModel",
|
|
22
27
|
"LanguageModelConfig",
|
|
23
28
|
]
|
|
24
29
|
|
|
25
30
|
|
|
31
|
+
_COMPILED_PROMPT_LENGTHS = [512 * 2**i for i in range(10)]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
type ForwardPassConfig = DecoderForwardPassConfig
|
|
35
|
+
|
|
36
|
+
|
|
26
37
|
class PrefillResults(NamedTuple):
|
|
27
|
-
last_token_logits: Float[Array, " vocabulary"]
|
|
28
|
-
|
|
38
|
+
last_token_logits: Float[Array, "batch vocabulary"]
|
|
39
|
+
last_token_indices: Int[Array, " batch"]
|
|
29
40
|
kv_cache: KVCache
|
|
30
41
|
|
|
31
42
|
|
|
32
43
|
class DecodingState(NamedTuple):
|
|
33
|
-
last_token_logits: Float[Array, " vocabulary"]
|
|
34
|
-
|
|
44
|
+
last_token_logits: Float[Array, "batch vocabulary"]
|
|
45
|
+
last_token_indices: Int[Array, " batch"]
|
|
35
46
|
kv_cache: KVCache
|
|
36
|
-
|
|
47
|
+
stop_flags: Bool[Array, " batch"]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class GenerationStepResults(NamedTuple):
|
|
51
|
+
token_ids: Int[Array, " batch"]
|
|
52
|
+
top_k_token_ids: Int[Array, " batch k"] | None
|
|
53
|
+
top_k_token_logits: Float[Array, " batch k"] | None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class GenerationResults(NamedTuple):
|
|
57
|
+
token_ids: Int[Array, "batch response_tokens"]
|
|
58
|
+
top_k_token_ids: Int[Array, "batch response_tokens k"] | None
|
|
59
|
+
top_k_token_logits: Float[Array, "batch response_tokens k"] | None
|
|
37
60
|
|
|
38
61
|
|
|
39
62
|
@dataclass(frozen=True)
|
|
@@ -60,14 +83,15 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
60
83
|
message_processor: MessageProcessor = eqx.field(static=True)
|
|
61
84
|
|
|
62
85
|
@classmethod
|
|
63
|
-
def load(cls, path: Path | str
|
|
86
|
+
def load(cls, path: Path | str) -> Self:
|
|
64
87
|
if isinstance(path, str):
|
|
65
88
|
path = Path(path)
|
|
66
89
|
with open(path / "config.json") as config_file:
|
|
67
90
|
config_json = json.load(config_file)
|
|
68
91
|
config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
|
|
69
|
-
|
|
70
|
-
|
|
92
|
+
with open_safetensors(path / "model.safetensors") as weights_dict:
|
|
93
|
+
weights = unflatten_parameters(weights_dict)
|
|
94
|
+
decoder = config.decoder_config.empty().import_weights(weights)
|
|
71
95
|
tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
|
|
72
96
|
message_processor = MessageProcessor(config.message_processor_config, tokenizer)
|
|
73
97
|
return cls(config, decoder, message_processor)
|
|
@@ -76,17 +100,16 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
76
100
|
def activation_precision(self) -> DTypeLike:
|
|
77
101
|
return self.decoder.activation_precision
|
|
78
102
|
|
|
79
|
-
def export_weights(self
|
|
80
|
-
return self.decoder.export_weights(
|
|
103
|
+
def export_weights(self) -> ParameterTree:
|
|
104
|
+
return self.decoder.export_weights()
|
|
81
105
|
|
|
82
106
|
def import_weights(
|
|
83
107
|
self,
|
|
84
108
|
weights: ParameterTree[Array],
|
|
85
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
86
109
|
) -> Self:
|
|
87
110
|
return replace(
|
|
88
111
|
self,
|
|
89
|
-
decoder=self.decoder.import_weights(weights
|
|
112
|
+
decoder=self.decoder.import_weights(weights),
|
|
90
113
|
)
|
|
91
114
|
|
|
92
115
|
@property
|
|
@@ -99,14 +122,15 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
99
122
|
@eqx.filter_jit
|
|
100
123
|
def _prefill(
|
|
101
124
|
self,
|
|
102
|
-
token_ids: Int[Array, " tokens"],
|
|
103
|
-
|
|
125
|
+
token_ids: Int[Array, "batch tokens"],
|
|
126
|
+
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
104
127
|
kv_cache_capacity: int | None = None,
|
|
128
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
105
129
|
) -> PrefillResults:
|
|
106
|
-
|
|
107
|
-
token_positions = jnp.arange(
|
|
130
|
+
batch_size, sequence_length = token_ids.shape
|
|
131
|
+
token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
|
|
108
132
|
if kv_cache_capacity is not None:
|
|
109
|
-
kv_cache = self.decoder.init_static_kv_cache(kv_cache_capacity)
|
|
133
|
+
kv_cache = self.decoder.init_static_kv_cache(batch_size, kv_cache_capacity)
|
|
110
134
|
else:
|
|
111
135
|
kv_cache = None
|
|
112
136
|
|
|
@@ -115,52 +139,56 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
115
139
|
token_positions,
|
|
116
140
|
kv_cache,
|
|
117
141
|
return_updated_kv_cache=True,
|
|
118
|
-
|
|
142
|
+
lengths_without_padding=lengths_without_padding,
|
|
143
|
+
forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
|
|
144
|
+
forward_pass_config=forward_pass_config,
|
|
119
145
|
)
|
|
120
146
|
|
|
121
|
-
if
|
|
122
|
-
|
|
147
|
+
if lengths_without_padding is not None:
|
|
148
|
+
last_logits_indices = lengths_without_padding - 1
|
|
123
149
|
else:
|
|
124
|
-
|
|
150
|
+
last_logits_indices = jnp.array([sequence_length - 1] * batch_size, dtype=jnp.int32)
|
|
125
151
|
|
|
126
|
-
last_token_logits = decoder_outputs.logits
|
|
127
|
-
last_token_position = jnp.array(last_logits_index, dtype=jnp.int32)
|
|
152
|
+
last_token_logits = vmap(lambda logits, index: logits[index])(decoder_outputs.logits, last_logits_indices)
|
|
128
153
|
|
|
129
154
|
assert decoder_outputs.updated_kv_cache is not None
|
|
130
155
|
return PrefillResults(
|
|
131
156
|
last_token_logits=last_token_logits,
|
|
132
|
-
|
|
157
|
+
last_token_indices=last_logits_indices,
|
|
133
158
|
kv_cache=decoder_outputs.updated_kv_cache,
|
|
134
159
|
)
|
|
135
160
|
|
|
136
161
|
@eqx.filter_jit
|
|
137
162
|
def generate_tokens(
|
|
138
163
|
self,
|
|
139
|
-
prompt_token_ids: Int[Array, " prompt_tokens"],
|
|
164
|
+
prompt_token_ids: Int[Array, "batch prompt_tokens"],
|
|
140
165
|
sampling_policy: SamplingPolicy | None = None,
|
|
141
|
-
|
|
166
|
+
prompt_lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
142
167
|
max_output_length: int = 8192,
|
|
143
168
|
eos_token_ids: Int[Array, " eos_tokens"] | None = None,
|
|
169
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
170
|
+
num_top_logits_to_return: int | None = None,
|
|
144
171
|
*,
|
|
145
172
|
key: PRNGKeyArray | None = None,
|
|
146
|
-
) ->
|
|
173
|
+
) -> GenerationResults:
|
|
147
174
|
if sampling_policy is None:
|
|
148
175
|
sampling_policy = self.default_sampling_policy()
|
|
149
176
|
if eos_token_ids is None:
|
|
150
177
|
eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
|
|
151
178
|
|
|
152
|
-
|
|
179
|
+
batch_size, sequence_length = prompt_token_ids.shape
|
|
153
180
|
prefill_results = self._prefill(
|
|
154
181
|
prompt_token_ids,
|
|
155
|
-
|
|
156
|
-
|
|
182
|
+
prompt_lengths_without_padding,
|
|
183
|
+
sequence_length + max_output_length,
|
|
184
|
+
forward_pass_config=forward_pass_config,
|
|
157
185
|
)
|
|
158
186
|
|
|
159
187
|
initial_state = DecodingState(
|
|
160
188
|
prefill_results.last_token_logits,
|
|
161
|
-
prefill_results.
|
|
189
|
+
prefill_results.last_token_indices,
|
|
162
190
|
prefill_results.kv_cache,
|
|
163
|
-
jnp.
|
|
191
|
+
jnp.zeros(batch_size, dtype=jnp.bool),
|
|
164
192
|
)
|
|
165
193
|
|
|
166
194
|
if key is None:
|
|
@@ -170,49 +198,88 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
170
198
|
def loop_iteration(
|
|
171
199
|
state: DecodingState,
|
|
172
200
|
key: PRNGKeyArray,
|
|
173
|
-
) -> tuple[DecodingState,
|
|
174
|
-
def sample_and_update() -> tuple[DecodingState,
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
201
|
+
) -> tuple[DecodingState, GenerationStepResults]:
|
|
202
|
+
def sample_and_update() -> tuple[DecodingState, GenerationStepResults]:
|
|
203
|
+
upcasted_logits = state.last_token_logits.astype(jnp.float32)
|
|
204
|
+
processed_logits = vmap(sampling_policy.process_logits)(upcasted_logits)
|
|
205
|
+
next_token_ids = jax.random.categorical(key, processed_logits)
|
|
206
|
+
next_token_ids = jnp.where(state.stop_flags, jnp.zeros(batch_size, dtype=jnp.int32), next_token_ids)
|
|
207
|
+
if num_top_logits_to_return is not None:
|
|
208
|
+
next_top_k_token_logits, next_top_k_token_ids = jax.lax.top_k(
|
|
209
|
+
processed_logits,
|
|
210
|
+
num_top_logits_to_return,
|
|
211
|
+
)
|
|
212
|
+
else:
|
|
213
|
+
next_top_k_token_ids = None
|
|
214
|
+
next_top_k_token_logits = None
|
|
215
|
+
next_token_indices = state.last_token_indices + 1
|
|
216
|
+
|
|
217
|
+
stop_flags = state.stop_flags | jnp.any(next_token_ids[:, None] == eos_token_ids[None, :], axis=-1)
|
|
218
|
+
|
|
219
|
+
if batch_size == 1:
|
|
220
|
+
forward_pass_mode = ForwardPassMode.SINGLE_TOKEN
|
|
221
|
+
else:
|
|
222
|
+
forward_pass_mode = ForwardPassMode.MULTI_TOKEN
|
|
180
223
|
|
|
181
224
|
decoder_outputs = self.decoder(
|
|
182
|
-
|
|
183
|
-
|
|
225
|
+
next_token_ids[:, None],
|
|
226
|
+
next_token_indices[:, None],
|
|
184
227
|
state.kv_cache,
|
|
185
228
|
return_updated_kv_cache=True,
|
|
229
|
+
forward_pass_mode=forward_pass_mode,
|
|
230
|
+
forward_pass_config=forward_pass_config,
|
|
186
231
|
)
|
|
187
232
|
assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
|
|
188
233
|
new_state = DecodingState(
|
|
189
|
-
decoder_outputs.logits.squeeze(),
|
|
190
|
-
|
|
234
|
+
decoder_outputs.logits.squeeze(1),
|
|
235
|
+
next_token_indices,
|
|
191
236
|
decoder_outputs.updated_kv_cache,
|
|
192
|
-
|
|
237
|
+
stop_flags,
|
|
193
238
|
)
|
|
194
|
-
return new_state,
|
|
239
|
+
return new_state, GenerationStepResults(next_token_ids, next_top_k_token_ids, next_top_k_token_logits)
|
|
195
240
|
|
|
196
|
-
def pad_and_repeat_state() -> tuple[DecodingState,
|
|
197
|
-
|
|
198
|
-
|
|
241
|
+
def pad_and_repeat_state() -> tuple[DecodingState, GenerationStepResults]:
|
|
242
|
+
(batch_size,) = state.stop_flags.shape
|
|
243
|
+
pad_token = jnp.zeros(batch_size, dtype=jnp.int32)
|
|
244
|
+
if num_top_logits_to_return is not None:
|
|
245
|
+
top_k_token_ids = jnp.zeros((batch_size, num_top_logits_to_return), dtype=jnp.int32)
|
|
246
|
+
top_k_token_logits = jnp.zeros((batch_size, num_top_logits_to_return), dtype=jnp.float32)
|
|
247
|
+
else:
|
|
248
|
+
top_k_token_ids = None
|
|
249
|
+
top_k_token_logits = None
|
|
250
|
+
return state, GenerationStepResults(pad_token, top_k_token_ids, top_k_token_logits)
|
|
199
251
|
|
|
200
|
-
return jax.lax.cond(state.
|
|
252
|
+
return jax.lax.cond(jnp.all(state.stop_flags), pad_and_repeat_state, sample_and_update)
|
|
201
253
|
|
|
202
|
-
_,
|
|
254
|
+
_, generated = jax.lax.scan(loop_iteration, initial_state, keys)
|
|
203
255
|
|
|
204
|
-
|
|
256
|
+
token_ids = rearrange(generated.token_ids, "iteration batch -> batch iteration")
|
|
257
|
+
|
|
258
|
+
if num_top_logits_to_return is not None:
|
|
259
|
+
top_k_token_ids = rearrange(generated.top_k_token_ids, "iteration batch k -> batch iteration k")
|
|
260
|
+
top_k_token_logits = rearrange(generated.top_k_token_logits, "iteration batch k -> batch iteration k")
|
|
261
|
+
else:
|
|
262
|
+
top_k_token_ids = None
|
|
263
|
+
top_k_token_logits = None
|
|
264
|
+
|
|
265
|
+
return GenerationResults(token_ids, top_k_token_ids, top_k_token_logits)
|
|
205
266
|
|
|
206
267
|
def reply(
|
|
207
268
|
self,
|
|
208
269
|
messages: Iterable[Message],
|
|
209
270
|
sampling_policy: SamplingPolicy | None = None,
|
|
271
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
210
272
|
*,
|
|
211
273
|
key: PRNGKeyArray | None = None,
|
|
212
274
|
) -> AssistantMessage:
|
|
213
275
|
formatted_messages = self.message_processor.render_request(messages)
|
|
214
|
-
token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
|
|
215
|
-
response_ids = self.generate_tokens(
|
|
276
|
+
token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)[None, :]
|
|
277
|
+
response_ids = self.generate_tokens(
|
|
278
|
+
token_ids,
|
|
279
|
+
sampling_policy,
|
|
280
|
+
forward_pass_config=forward_pass_config,
|
|
281
|
+
key=key,
|
|
282
|
+
).token_ids.squeeze(0)
|
|
216
283
|
response_text = self.message_processor.detokenize(response_ids.tolist())
|
|
217
284
|
return self.message_processor.parse_response(response_text)
|
|
218
285
|
|
|
@@ -220,21 +287,29 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
220
287
|
self,
|
|
221
288
|
messages: Iterable[Message],
|
|
222
289
|
sampling_policy: SamplingPolicy | None = None,
|
|
290
|
+
max_output_length: int = 8192,
|
|
291
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
223
292
|
*,
|
|
224
293
|
key: PRNGKeyArray | None = None,
|
|
225
294
|
) -> Iterable[str]:
|
|
226
295
|
formatted_messages = self.message_processor.render_request(messages)
|
|
227
296
|
token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
|
|
228
|
-
for token_id in self.stream_tokens(
|
|
297
|
+
for token_id in self.stream_tokens(
|
|
298
|
+
token_ids,
|
|
299
|
+
sampling_policy,
|
|
300
|
+
max_output_length,
|
|
301
|
+
forward_pass_config=forward_pass_config,
|
|
302
|
+
key=key,
|
|
303
|
+
):
|
|
229
304
|
yield self.message_processor.detokenize([token_id.item()])
|
|
230
305
|
|
|
231
306
|
def stream_tokens(
|
|
232
307
|
self,
|
|
233
308
|
prompt_token_ids: Int[Array, " prompt_tokens"],
|
|
234
309
|
sampling_policy: SamplingPolicy | None = None,
|
|
235
|
-
prompt_length_without_padding: Int[Array, ""] | int | None = None,
|
|
236
310
|
max_output_length: int = 8192,
|
|
237
311
|
eos_token_ids: Int[Array, " eos_tokens"] | None = None,
|
|
312
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
238
313
|
*,
|
|
239
314
|
key: PRNGKeyArray | None = None,
|
|
240
315
|
) -> Iterable[Int[Array, ""]]:
|
|
@@ -244,10 +319,16 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
244
319
|
eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
|
|
245
320
|
|
|
246
321
|
(input_length,) = prompt_token_ids.shape
|
|
322
|
+
|
|
323
|
+
padded_input_length = min(length for length in _COMPILED_PROMPT_LENGTHS if length >= input_length)
|
|
324
|
+
padded_token_ids = jnp.zeros((padded_input_length,), dtype=jnp.int32)
|
|
325
|
+
padded_token_ids = padded_token_ids.at[:input_length].set(prompt_token_ids)
|
|
326
|
+
|
|
247
327
|
prefill_results = self._prefill(
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
328
|
+
padded_token_ids[None, :],
|
|
329
|
+
jnp.array([input_length], dtype=jnp.int32),
|
|
330
|
+
padded_input_length + max_output_length,
|
|
331
|
+
forward_pass_config=forward_pass_config,
|
|
251
332
|
)
|
|
252
333
|
|
|
253
334
|
if key is None:
|
|
@@ -256,13 +337,14 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
256
337
|
|
|
257
338
|
state = DecodingState(
|
|
258
339
|
prefill_results.last_token_logits,
|
|
259
|
-
prefill_results.
|
|
340
|
+
prefill_results.last_token_indices,
|
|
260
341
|
prefill_results.kv_cache,
|
|
261
|
-
jnp.array(0, dtype=jnp.bool),
|
|
342
|
+
jnp.array([0], dtype=jnp.bool),
|
|
262
343
|
)
|
|
263
344
|
|
|
264
345
|
for iter_key in keys:
|
|
265
|
-
|
|
346
|
+
upcasted_logits = state.last_token_logits.astype(jnp.float32)
|
|
347
|
+
processed_logits = sampling_policy.process_logits(upcasted_logits.squeeze(0))
|
|
266
348
|
next_token_id = jax.random.categorical(iter_key, processed_logits)
|
|
267
349
|
|
|
268
350
|
yield next_token_id
|
|
@@ -270,17 +352,18 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
270
352
|
if jnp.any(next_token_id == eos_token_ids):
|
|
271
353
|
return
|
|
272
354
|
|
|
273
|
-
|
|
355
|
+
next_token_indices = state.last_token_indices + 1
|
|
274
356
|
decoder_outputs = self.decoder(
|
|
275
|
-
next_token_id.reshape(1),
|
|
276
|
-
|
|
357
|
+
next_token_id.reshape(1, 1),
|
|
358
|
+
next_token_indices.reshape(1, 1),
|
|
277
359
|
state.kv_cache,
|
|
278
360
|
return_updated_kv_cache=True,
|
|
361
|
+
forward_pass_config=forward_pass_config,
|
|
279
362
|
)
|
|
280
363
|
assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
|
|
281
364
|
state = DecodingState(
|
|
282
|
-
decoder_outputs.logits.squeeze(),
|
|
283
|
-
|
|
365
|
+
decoder_outputs.logits.squeeze(1),
|
|
366
|
+
next_token_indices,
|
|
284
367
|
decoder_outputs.updated_kv_cache,
|
|
285
|
-
state.
|
|
368
|
+
state.stop_flags,
|
|
286
369
|
)
|