lalamo 0.3.4__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.3.4 → lalamo-0.4.0}/PKG-INFO +11 -4
- lalamo-0.4.0/lalamo/__init__.py +26 -0
- lalamo-0.4.0/lalamo/data/__init__.py +8 -0
- lalamo-0.4.0/lalamo/data/huggingface_message.py +38 -0
- lalamo-0.4.0/lalamo/data/lalamo_completions.py +43 -0
- lalamo-0.4.0/lalamo/data/utils.py +8 -0
- lalamo-0.4.0/lalamo/language_model.py +369 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/main.py +271 -43
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/message_processor.py +11 -1
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/common.py +10 -6
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/__init__.py +3 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/executorch.py +12 -6
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo-0.4.0/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/huggingface_tokenizer_config.py +1 -3
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/loaders/executorch.py +10 -9
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo-0.4.0/lalamo/model_import/loaders/utils.py +92 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/__init__.py +4 -1
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/common.py +15 -12
- lalamo-0.4.0/lalamo/model_import/model_specs/gpt_oss.py +21 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/__init__.py +35 -7
- lalamo-0.4.0/lalamo/modules/activations.py +40 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/attention.py +73 -20
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/common.py +8 -57
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/decoder.py +48 -34
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/decoder_layer.py +57 -43
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/embedding.py +13 -19
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/kv_cache.py +53 -16
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/linear.py +260 -79
- lalamo-0.4.0/lalamo/modules/mlp.py +484 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/normalization.py +2 -3
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/rope.py +32 -21
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/utils.py +10 -0
- lalamo-0.4.0/lalamo/speculator/__init__.py +11 -0
- lalamo-0.4.0/lalamo/speculator/common.py +22 -0
- lalamo-0.4.0/lalamo/speculator/inference.py +75 -0
- lalamo-0.4.0/lalamo/speculator/ngram.py +154 -0
- lalamo-0.4.0/lalamo/speculator/utils.py +52 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/utils.py +27 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo.egg-info/PKG-INFO +11 -4
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo.egg-info/SOURCES.txt +13 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo.egg-info/requires.txt +11 -5
- {lalamo-0.3.4 → lalamo-0.4.0}/pyproject.toml +31 -20
- {lalamo-0.3.4 → lalamo-0.4.0}/tests/test_generation.py +48 -22
- {lalamo-0.3.4 → lalamo-0.4.0}/tests/test_huggingface_models.py +20 -14
- lalamo-0.4.0/tests/test_moe.py +58 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/tests/test_registry_abc.py +2 -3
- lalamo-0.3.4/lalamo/__init__.py +0 -11
- lalamo-0.3.4/lalamo/language_model.py +0 -286
- lalamo-0.3.4/lalamo/modules/activations.py +0 -30
- lalamo-0.3.4/lalamo/modules/mlp.py +0 -112
- {lalamo-0.3.4 → lalamo-0.4.0}/LICENSE +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/README.md +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/common.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/common.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/huggingface.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/quantization.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/registry_abc.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo/sampling.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/setup.cfg +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/tests/test_model_spec.py +0 -0
- {lalamo-0.3.4 → lalamo-0.4.0}/tests/test_parameter_tree.py +0 -0
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lalamo
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: JAX library for optimization and export of models for use with the UZU inference engine.
|
|
5
5
|
Requires-Python: <4,>=3.12
|
|
6
6
|
Description-Content-Type: text/markdown
|
|
7
7
|
License-File: LICENSE
|
|
8
|
-
Requires-Dist: cattrs>=24.1.2
|
|
8
|
+
Requires-Dist: cattrs[msgpack]>=24.1.2
|
|
9
9
|
Requires-Dist: click>=8.1.8
|
|
10
10
|
Requires-Dist: einops>=0.8.0
|
|
11
11
|
Requires-Dist: equinox>=0.11.11
|
|
12
12
|
Requires-Dist: huggingface-hub[hf-transfer]>=0.27.1
|
|
13
|
-
Requires-Dist: jax>=0.
|
|
14
|
-
Requires-Dist: jax[cuda]>=0.4.38; sys_platform == "linux"
|
|
13
|
+
Requires-Dist: jax>=0.7.2
|
|
15
14
|
Requires-Dist: jaxtyping>=0.2.36
|
|
16
15
|
Requires-Dist: jinja2>=3.1.6
|
|
17
16
|
Requires-Dist: ml-dtypes>=0.5.1
|
|
@@ -21,6 +20,14 @@ Requires-Dist: thefuzz>=0.22.1
|
|
|
21
20
|
Requires-Dist: tokenizers>=0.21.2
|
|
22
21
|
Requires-Dist: typer>=0.15.1
|
|
23
22
|
Requires-Dist: safetensors>=0.6.2
|
|
23
|
+
Requires-Dist: polars>=1.33.1
|
|
24
|
+
Requires-Dist: xxhash>=3.5.0
|
|
25
|
+
Provides-Extra: cpu
|
|
26
|
+
Requires-Dist: jax[cpu]>=0.7.2; extra == "cpu"
|
|
27
|
+
Provides-Extra: cuda
|
|
28
|
+
Requires-Dist: jax[cuda]>=0.7.2; extra == "cuda"
|
|
29
|
+
Provides-Extra: tpu
|
|
30
|
+
Requires-Dist: jax[tpu]>=0.7.2; extra == "tpu"
|
|
24
31
|
Dynamic: license-file
|
|
25
32
|
|
|
26
33
|
<p align="center">
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from lalamo.language_model import LanguageModel
|
|
2
|
+
from lalamo.message_processor import (
|
|
3
|
+
AssistantMessage,
|
|
4
|
+
ContentBlock,
|
|
5
|
+
Image,
|
|
6
|
+
Message,
|
|
7
|
+
SystemMessage,
|
|
8
|
+
ToolSchema,
|
|
9
|
+
UserMessage,
|
|
10
|
+
)
|
|
11
|
+
from lalamo.model_import import ModelSpec, import_model
|
|
12
|
+
|
|
13
|
+
__version__ = "0.4.0"
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"AssistantMessage",
|
|
17
|
+
"ContentBlock",
|
|
18
|
+
"Image",
|
|
19
|
+
"LanguageModel",
|
|
20
|
+
"Message",
|
|
21
|
+
"ModelSpec",
|
|
22
|
+
"SystemMessage",
|
|
23
|
+
"ToolSchema",
|
|
24
|
+
"UserMessage",
|
|
25
|
+
"import_model",
|
|
26
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import ClassVar, Self
|
|
5
|
+
|
|
6
|
+
import cattrs
|
|
7
|
+
import polars as pl
|
|
8
|
+
|
|
9
|
+
from lalamo.message_processor import AssistantMessage, Message, UserMessage
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class HFMessage:
|
|
14
|
+
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
15
|
+
|
|
16
|
+
role: str
|
|
17
|
+
content: str
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_dict(cls, obj: dict) -> Self:
|
|
21
|
+
return cls._converter.structure(obj, cls)
|
|
22
|
+
|
|
23
|
+
def as_message(self) -> Message:
|
|
24
|
+
match self.role:
|
|
25
|
+
case "user":
|
|
26
|
+
return UserMessage(self.content)
|
|
27
|
+
case "assistant":
|
|
28
|
+
return AssistantMessage(None, self.content)
|
|
29
|
+
case other:
|
|
30
|
+
raise ValueError(f"Cannot convert {other} message")
|
|
31
|
+
|
|
32
|
+
def import_hf_parquet(path: Path | str) -> Iterable[list[Message]]:
|
|
33
|
+
path = Path(path)
|
|
34
|
+
|
|
35
|
+
dataframe = pl.scan_parquet(path).collect()
|
|
36
|
+
|
|
37
|
+
for conversation in dataframe.get_column("conversation").shuffle(1337):
|
|
38
|
+
yield [HFMessage.from_dict(message).as_message() for message in conversation]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import IO, Any, ClassVar, Self
|
|
4
|
+
|
|
5
|
+
import msgpack
|
|
6
|
+
from cattrs.preconf.msgpack import MsgpackConverter
|
|
7
|
+
from cattrs.preconf.msgpack import make_converter as make_msgpack_converter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class LalamoCompletion:
|
|
12
|
+
_converter: ClassVar[MsgpackConverter] = make_msgpack_converter()
|
|
13
|
+
|
|
14
|
+
prefix_token_ids: list[int]
|
|
15
|
+
completion_token_ids: list[int]
|
|
16
|
+
completion_token_logits: list[dict[int, float]]
|
|
17
|
+
|
|
18
|
+
def __post_init__(self) -> None:
|
|
19
|
+
if len(self.completion_token_ids) != len(self.completion_token_logits):
|
|
20
|
+
raise ValueError(f"({len(self.completion_token_ids)=}) != ({len(self.completion_token_logits)=})")
|
|
21
|
+
|
|
22
|
+
def serialize(self) -> bytes:
|
|
23
|
+
return self._converter.dumps(self)
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def deserialize(cls, data: bytes | IO[bytes]) -> Self:
|
|
27
|
+
if isinstance(data, bytes):
|
|
28
|
+
obj: Any = msgpack.unpackb(data, strict_map_key=False)
|
|
29
|
+
else:
|
|
30
|
+
obj = msgpack.unpack(data, strict_map_key=False)
|
|
31
|
+
|
|
32
|
+
return cls._converter.structure(obj, cls)
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def deserialize_many(cls, data: bytes | IO[bytes]) -> Iterable[Self]:
|
|
36
|
+
if isinstance(data, bytes):
|
|
37
|
+
unpacker = msgpack.Unpacker(strict_map_key=False)
|
|
38
|
+
unpacker.feed(data)
|
|
39
|
+
else:
|
|
40
|
+
unpacker = msgpack.Unpacker(file_like=data, strict_map_key=False)
|
|
41
|
+
|
|
42
|
+
for obj in unpacker:
|
|
43
|
+
yield cls._converter.structure(obj, cls)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
from lalamo.message_processor import Message, UserMessage
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_prefixes_ending_in_user_message(conversation: Iterable[Message]) -> list[list[Message]]:
|
|
7
|
+
conversation = list(conversation)
|
|
8
|
+
return [conversation[: i + 1] for i, msg in enumerate(conversation) if isinstance(msg, UserMessage)]
|
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import NamedTuple, Self
|
|
6
|
+
|
|
7
|
+
import equinox as eqx
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from einops import rearrange
|
|
11
|
+
from jax import vmap
|
|
12
|
+
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
|
|
13
|
+
from tokenizers import Tokenizer
|
|
14
|
+
|
|
15
|
+
from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
|
|
16
|
+
from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
|
|
17
|
+
from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule, config_converter
|
|
18
|
+
from lalamo.modules.common import ForwardPassMode
|
|
19
|
+
from lalamo.modules.decoder import DecoderForwardPassConfig
|
|
20
|
+
from lalamo.sampling import SamplingPolicy, make_policy
|
|
21
|
+
from lalamo.utils import open_safetensors
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ForwardPassConfig",
|
|
25
|
+
"GenerationConfig",
|
|
26
|
+
"LanguageModel",
|
|
27
|
+
"LanguageModelConfig",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
_COMPILED_PROMPT_LENGTHS = [512 * 2**i for i in range(10)]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
type ForwardPassConfig = DecoderForwardPassConfig
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class PrefillResults(NamedTuple):
|
|
38
|
+
last_token_logits: Float[Array, "batch vocabulary"]
|
|
39
|
+
last_token_indices: Int[Array, " batch"]
|
|
40
|
+
kv_cache: KVCache
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DecodingState(NamedTuple):
|
|
44
|
+
last_token_logits: Float[Array, "batch vocabulary"]
|
|
45
|
+
last_token_indices: Int[Array, " batch"]
|
|
46
|
+
kv_cache: KVCache
|
|
47
|
+
stop_flags: Bool[Array, " batch"]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class GenerationStepResults(NamedTuple):
|
|
51
|
+
token_ids: Int[Array, " batch"]
|
|
52
|
+
top_k_token_ids: Int[Array, " batch k"] | None
|
|
53
|
+
top_k_token_logits: Float[Array, " batch k"] | None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class GenerationResults(NamedTuple):
|
|
57
|
+
token_ids: Int[Array, "batch response_tokens"]
|
|
58
|
+
top_k_token_ids: Int[Array, "batch response_tokens k"] | None
|
|
59
|
+
top_k_token_logits: Float[Array, "batch response_tokens k"] | None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass(frozen=True)
|
|
63
|
+
class GenerationConfig:
|
|
64
|
+
stop_token_ids: tuple[int, ...]
|
|
65
|
+
temperature: float | None
|
|
66
|
+
top_k: int | None
|
|
67
|
+
top_p: float | None
|
|
68
|
+
banned_tokens: tuple[int, ...] | None
|
|
69
|
+
|
|
70
|
+
def default_policy(self) -> SamplingPolicy:
|
|
71
|
+
return make_policy(self.temperature, self.top_k, self.top_p, self.banned_tokens)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(frozen=True)
|
|
75
|
+
class LanguageModelConfig:
|
|
76
|
+
decoder_config: DecoderConfig
|
|
77
|
+
message_processor_config: MessageProcessorConfig
|
|
78
|
+
generation_config: GenerationConfig
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
82
|
+
decoder: Decoder
|
|
83
|
+
message_processor: MessageProcessor = eqx.field(static=True)
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def load(cls, path: Path | str) -> Self:
|
|
87
|
+
if isinstance(path, str):
|
|
88
|
+
path = Path(path)
|
|
89
|
+
with open(path / "config.json") as config_file:
|
|
90
|
+
config_json = json.load(config_file)
|
|
91
|
+
config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
|
|
92
|
+
with open_safetensors(path / "model.safetensors") as weights_dict:
|
|
93
|
+
weights = unflatten_parameters(weights_dict)
|
|
94
|
+
decoder = config.decoder_config.empty().import_weights(weights)
|
|
95
|
+
tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
|
|
96
|
+
message_processor = MessageProcessor(config.message_processor_config, tokenizer)
|
|
97
|
+
return cls(config, decoder, message_processor)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def activation_precision(self) -> DTypeLike:
|
|
101
|
+
return self.decoder.activation_precision
|
|
102
|
+
|
|
103
|
+
def export_weights(self) -> ParameterTree:
|
|
104
|
+
return self.decoder.export_weights()
|
|
105
|
+
|
|
106
|
+
def import_weights(
|
|
107
|
+
self,
|
|
108
|
+
weights: ParameterTree[Array],
|
|
109
|
+
) -> Self:
|
|
110
|
+
return replace(
|
|
111
|
+
self,
|
|
112
|
+
decoder=self.decoder.import_weights(weights),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def stop_token_ids(self) -> tuple[int, ...]:
|
|
117
|
+
return self.config.generation_config.stop_token_ids
|
|
118
|
+
|
|
119
|
+
def default_sampling_policy(self) -> SamplingPolicy:
|
|
120
|
+
return self.config.generation_config.default_policy()
|
|
121
|
+
|
|
122
|
+
@eqx.filter_jit
|
|
123
|
+
def _prefill(
|
|
124
|
+
self,
|
|
125
|
+
token_ids: Int[Array, "batch tokens"],
|
|
126
|
+
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
127
|
+
kv_cache_capacity: int | None = None,
|
|
128
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
129
|
+
) -> PrefillResults:
|
|
130
|
+
batch_size, sequence_length = token_ids.shape
|
|
131
|
+
token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
|
|
132
|
+
if kv_cache_capacity is not None:
|
|
133
|
+
kv_cache = self.decoder.init_static_kv_cache(batch_size, kv_cache_capacity)
|
|
134
|
+
else:
|
|
135
|
+
kv_cache = None
|
|
136
|
+
|
|
137
|
+
decoder_outputs = self.decoder(
|
|
138
|
+
token_ids,
|
|
139
|
+
token_positions,
|
|
140
|
+
kv_cache,
|
|
141
|
+
return_updated_kv_cache=True,
|
|
142
|
+
lengths_without_padding=lengths_without_padding,
|
|
143
|
+
forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
|
|
144
|
+
forward_pass_config=forward_pass_config,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if lengths_without_padding is not None:
|
|
148
|
+
last_logits_indices = lengths_without_padding - 1
|
|
149
|
+
else:
|
|
150
|
+
last_logits_indices = jnp.array([sequence_length - 1] * batch_size, dtype=jnp.int32)
|
|
151
|
+
|
|
152
|
+
last_token_logits = vmap(lambda logits, index: logits[index])(decoder_outputs.logits, last_logits_indices)
|
|
153
|
+
|
|
154
|
+
assert decoder_outputs.updated_kv_cache is not None
|
|
155
|
+
return PrefillResults(
|
|
156
|
+
last_token_logits=last_token_logits,
|
|
157
|
+
last_token_indices=last_logits_indices,
|
|
158
|
+
kv_cache=decoder_outputs.updated_kv_cache,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
@eqx.filter_jit
|
|
162
|
+
def generate_tokens(
|
|
163
|
+
self,
|
|
164
|
+
prompt_token_ids: Int[Array, "batch prompt_tokens"],
|
|
165
|
+
sampling_policy: SamplingPolicy | None = None,
|
|
166
|
+
prompt_lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
167
|
+
max_output_length: int = 8192,
|
|
168
|
+
eos_token_ids: Int[Array, " eos_tokens"] | None = None,
|
|
169
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
170
|
+
num_top_logits_to_return: int | None = None,
|
|
171
|
+
*,
|
|
172
|
+
key: PRNGKeyArray | None = None,
|
|
173
|
+
) -> GenerationResults:
|
|
174
|
+
if sampling_policy is None:
|
|
175
|
+
sampling_policy = self.default_sampling_policy()
|
|
176
|
+
if eos_token_ids is None:
|
|
177
|
+
eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
|
|
178
|
+
|
|
179
|
+
batch_size, sequence_length = prompt_token_ids.shape
|
|
180
|
+
prefill_results = self._prefill(
|
|
181
|
+
prompt_token_ids,
|
|
182
|
+
prompt_lengths_without_padding,
|
|
183
|
+
sequence_length + max_output_length,
|
|
184
|
+
forward_pass_config=forward_pass_config,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
initial_state = DecodingState(
|
|
188
|
+
prefill_results.last_token_logits,
|
|
189
|
+
prefill_results.last_token_indices,
|
|
190
|
+
prefill_results.kv_cache,
|
|
191
|
+
jnp.zeros(batch_size, dtype=jnp.bool),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if key is None:
|
|
195
|
+
key = jax.random.PRNGKey(0)
|
|
196
|
+
keys = jax.random.split(key, num=max_output_length)
|
|
197
|
+
|
|
198
|
+
def loop_iteration(
|
|
199
|
+
state: DecodingState,
|
|
200
|
+
key: PRNGKeyArray,
|
|
201
|
+
) -> tuple[DecodingState, GenerationStepResults]:
|
|
202
|
+
def sample_and_update() -> tuple[DecodingState, GenerationStepResults]:
|
|
203
|
+
upcasted_logits = state.last_token_logits.astype(jnp.float32)
|
|
204
|
+
processed_logits = vmap(sampling_policy.process_logits)(upcasted_logits)
|
|
205
|
+
next_token_ids = jax.random.categorical(key, processed_logits)
|
|
206
|
+
next_token_ids = jnp.where(state.stop_flags, jnp.zeros(batch_size, dtype=jnp.int32), next_token_ids)
|
|
207
|
+
if num_top_logits_to_return is not None:
|
|
208
|
+
next_top_k_token_logits, next_top_k_token_ids = jax.lax.top_k(
|
|
209
|
+
processed_logits,
|
|
210
|
+
num_top_logits_to_return,
|
|
211
|
+
)
|
|
212
|
+
else:
|
|
213
|
+
next_top_k_token_ids = None
|
|
214
|
+
next_top_k_token_logits = None
|
|
215
|
+
next_token_indices = state.last_token_indices + 1
|
|
216
|
+
|
|
217
|
+
stop_flags = state.stop_flags | jnp.any(next_token_ids[:, None] == eos_token_ids[None, :], axis=-1)
|
|
218
|
+
|
|
219
|
+
if batch_size == 1:
|
|
220
|
+
forward_pass_mode = ForwardPassMode.SINGLE_TOKEN
|
|
221
|
+
else:
|
|
222
|
+
forward_pass_mode = ForwardPassMode.MULTI_TOKEN
|
|
223
|
+
|
|
224
|
+
decoder_outputs = self.decoder(
|
|
225
|
+
next_token_ids[:, None],
|
|
226
|
+
next_token_indices[:, None],
|
|
227
|
+
state.kv_cache,
|
|
228
|
+
return_updated_kv_cache=True,
|
|
229
|
+
forward_pass_mode=forward_pass_mode,
|
|
230
|
+
forward_pass_config=forward_pass_config,
|
|
231
|
+
)
|
|
232
|
+
assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
|
|
233
|
+
new_state = DecodingState(
|
|
234
|
+
decoder_outputs.logits.squeeze(1),
|
|
235
|
+
next_token_indices,
|
|
236
|
+
decoder_outputs.updated_kv_cache,
|
|
237
|
+
stop_flags,
|
|
238
|
+
)
|
|
239
|
+
return new_state, GenerationStepResults(next_token_ids, next_top_k_token_ids, next_top_k_token_logits)
|
|
240
|
+
|
|
241
|
+
def pad_and_repeat_state() -> tuple[DecodingState, GenerationStepResults]:
|
|
242
|
+
(batch_size,) = state.stop_flags.shape
|
|
243
|
+
pad_token = jnp.zeros(batch_size, dtype=jnp.int32)
|
|
244
|
+
if num_top_logits_to_return is not None:
|
|
245
|
+
top_k_token_ids = jnp.zeros((batch_size, num_top_logits_to_return), dtype=jnp.int32)
|
|
246
|
+
top_k_token_logits = jnp.zeros((batch_size, num_top_logits_to_return), dtype=jnp.float32)
|
|
247
|
+
else:
|
|
248
|
+
top_k_token_ids = None
|
|
249
|
+
top_k_token_logits = None
|
|
250
|
+
return state, GenerationStepResults(pad_token, top_k_token_ids, top_k_token_logits)
|
|
251
|
+
|
|
252
|
+
return jax.lax.cond(jnp.all(state.stop_flags), pad_and_repeat_state, sample_and_update)
|
|
253
|
+
|
|
254
|
+
_, generated = jax.lax.scan(loop_iteration, initial_state, keys)
|
|
255
|
+
|
|
256
|
+
token_ids = rearrange(generated.token_ids, "iteration batch -> batch iteration")
|
|
257
|
+
|
|
258
|
+
if num_top_logits_to_return is not None:
|
|
259
|
+
top_k_token_ids = rearrange(generated.top_k_token_ids, "iteration batch k -> batch iteration k")
|
|
260
|
+
top_k_token_logits = rearrange(generated.top_k_token_logits, "iteration batch k -> batch iteration k")
|
|
261
|
+
else:
|
|
262
|
+
top_k_token_ids = None
|
|
263
|
+
top_k_token_logits = None
|
|
264
|
+
|
|
265
|
+
return GenerationResults(token_ids, top_k_token_ids, top_k_token_logits)
|
|
266
|
+
|
|
267
|
+
def reply(
|
|
268
|
+
self,
|
|
269
|
+
messages: Iterable[Message],
|
|
270
|
+
sampling_policy: SamplingPolicy | None = None,
|
|
271
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
272
|
+
*,
|
|
273
|
+
key: PRNGKeyArray | None = None,
|
|
274
|
+
) -> AssistantMessage:
|
|
275
|
+
formatted_messages = self.message_processor.render_request(messages)
|
|
276
|
+
token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)[None, :]
|
|
277
|
+
response_ids = self.generate_tokens(
|
|
278
|
+
token_ids,
|
|
279
|
+
sampling_policy,
|
|
280
|
+
forward_pass_config=forward_pass_config,
|
|
281
|
+
key=key,
|
|
282
|
+
).token_ids.squeeze(0)
|
|
283
|
+
response_text = self.message_processor.detokenize(response_ids.tolist())
|
|
284
|
+
return self.message_processor.parse_response(response_text)
|
|
285
|
+
|
|
286
|
+
def stream_reply_text(
|
|
287
|
+
self,
|
|
288
|
+
messages: Iterable[Message],
|
|
289
|
+
sampling_policy: SamplingPolicy | None = None,
|
|
290
|
+
max_output_length: int = 8192,
|
|
291
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
292
|
+
*,
|
|
293
|
+
key: PRNGKeyArray | None = None,
|
|
294
|
+
) -> Iterable[str]:
|
|
295
|
+
formatted_messages = self.message_processor.render_request(messages)
|
|
296
|
+
token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
|
|
297
|
+
for token_id in self.stream_tokens(
|
|
298
|
+
token_ids,
|
|
299
|
+
sampling_policy,
|
|
300
|
+
max_output_length,
|
|
301
|
+
forward_pass_config=forward_pass_config,
|
|
302
|
+
key=key,
|
|
303
|
+
):
|
|
304
|
+
yield self.message_processor.detokenize([token_id.item()])
|
|
305
|
+
|
|
306
|
+
def stream_tokens(
|
|
307
|
+
self,
|
|
308
|
+
prompt_token_ids: Int[Array, " prompt_tokens"],
|
|
309
|
+
sampling_policy: SamplingPolicy | None = None,
|
|
310
|
+
max_output_length: int = 8192,
|
|
311
|
+
eos_token_ids: Int[Array, " eos_tokens"] | None = None,
|
|
312
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
313
|
+
*,
|
|
314
|
+
key: PRNGKeyArray | None = None,
|
|
315
|
+
) -> Iterable[Int[Array, ""]]:
|
|
316
|
+
if sampling_policy is None:
|
|
317
|
+
sampling_policy = self.default_sampling_policy()
|
|
318
|
+
if eos_token_ids is None:
|
|
319
|
+
eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
|
|
320
|
+
|
|
321
|
+
(input_length,) = prompt_token_ids.shape
|
|
322
|
+
|
|
323
|
+
padded_input_length = min(length for length in _COMPILED_PROMPT_LENGTHS if length >= input_length)
|
|
324
|
+
padded_token_ids = jnp.zeros((padded_input_length,), dtype=jnp.int32)
|
|
325
|
+
padded_token_ids = padded_token_ids.at[:input_length].set(prompt_token_ids)
|
|
326
|
+
|
|
327
|
+
prefill_results = self._prefill(
|
|
328
|
+
padded_token_ids[None, :],
|
|
329
|
+
jnp.array([input_length], dtype=jnp.int32),
|
|
330
|
+
padded_input_length + max_output_length,
|
|
331
|
+
forward_pass_config=forward_pass_config,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
if key is None:
|
|
335
|
+
key = jax.random.PRNGKey(0)
|
|
336
|
+
keys = jax.random.split(key, num=max_output_length)
|
|
337
|
+
|
|
338
|
+
state = DecodingState(
|
|
339
|
+
prefill_results.last_token_logits,
|
|
340
|
+
prefill_results.last_token_indices,
|
|
341
|
+
prefill_results.kv_cache,
|
|
342
|
+
jnp.array([0], dtype=jnp.bool),
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
for iter_key in keys:
|
|
346
|
+
upcasted_logits = state.last_token_logits.astype(jnp.float32)
|
|
347
|
+
processed_logits = sampling_policy.process_logits(upcasted_logits.squeeze(0))
|
|
348
|
+
next_token_id = jax.random.categorical(iter_key, processed_logits)
|
|
349
|
+
|
|
350
|
+
yield next_token_id
|
|
351
|
+
|
|
352
|
+
if jnp.any(next_token_id == eos_token_ids):
|
|
353
|
+
return
|
|
354
|
+
|
|
355
|
+
next_token_indices = state.last_token_indices + 1
|
|
356
|
+
decoder_outputs = self.decoder(
|
|
357
|
+
next_token_id.reshape(1, 1),
|
|
358
|
+
next_token_indices.reshape(1, 1),
|
|
359
|
+
state.kv_cache,
|
|
360
|
+
return_updated_kv_cache=True,
|
|
361
|
+
forward_pass_config=forward_pass_config,
|
|
362
|
+
)
|
|
363
|
+
assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
|
|
364
|
+
state = DecodingState(
|
|
365
|
+
decoder_outputs.logits.squeeze(1),
|
|
366
|
+
next_token_indices,
|
|
367
|
+
decoder_outputs.updated_kv_cache,
|
|
368
|
+
state.stop_flags,
|
|
369
|
+
)
|