lalamo 0.3.3__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {lalamo-0.3.3 → lalamo-0.4.0}/PKG-INFO +11 -4
  2. lalamo-0.4.0/lalamo/__init__.py +26 -0
  3. lalamo-0.4.0/lalamo/data/__init__.py +8 -0
  4. lalamo-0.4.0/lalamo/data/huggingface_message.py +38 -0
  5. lalamo-0.4.0/lalamo/data/lalamo_completions.py +43 -0
  6. lalamo-0.4.0/lalamo/data/utils.py +8 -0
  7. lalamo-0.4.0/lalamo/language_model.py +369 -0
  8. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/main.py +271 -43
  9. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/message_processor.py +11 -1
  10. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/common.py +17 -7
  11. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/__init__.py +3 -0
  12. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/executorch.py +12 -6
  13. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  14. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  15. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  16. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  17. lalamo-0.4.0/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  18. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  19. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  20. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  21. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  22. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/huggingface_tokenizer_config.py +1 -4
  23. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/loaders/executorch.py +10 -9
  24. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/loaders/huggingface.py +104 -9
  25. lalamo-0.4.0/lalamo/model_import/loaders/utils.py +92 -0
  26. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/__init__.py +4 -1
  27. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/common.py +15 -12
  28. lalamo-0.4.0/lalamo/model_import/model_specs/gpt_oss.py +21 -0
  29. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/__init__.py +35 -7
  30. lalamo-0.4.0/lalamo/modules/activations.py +40 -0
  31. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/attention.py +73 -20
  32. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/common.py +8 -57
  33. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/decoder.py +48 -34
  34. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/decoder_layer.py +57 -43
  35. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/embedding.py +13 -19
  36. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/kv_cache.py +53 -16
  37. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/linear.py +260 -79
  38. lalamo-0.4.0/lalamo/modules/mlp.py +484 -0
  39. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/normalization.py +2 -3
  40. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/rope.py +32 -21
  41. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/utils.py +10 -0
  42. lalamo-0.4.0/lalamo/speculator/__init__.py +11 -0
  43. lalamo-0.4.0/lalamo/speculator/common.py +22 -0
  44. lalamo-0.4.0/lalamo/speculator/inference.py +75 -0
  45. lalamo-0.4.0/lalamo/speculator/ngram.py +154 -0
  46. lalamo-0.4.0/lalamo/speculator/utils.py +52 -0
  47. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/utils.py +27 -0
  48. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo.egg-info/PKG-INFO +11 -4
  49. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo.egg-info/SOURCES.txt +13 -0
  50. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo.egg-info/requires.txt +11 -5
  51. {lalamo-0.3.3 → lalamo-0.4.0}/pyproject.toml +31 -20
  52. {lalamo-0.3.3 → lalamo-0.4.0}/tests/test_generation.py +48 -22
  53. {lalamo-0.3.3 → lalamo-0.4.0}/tests/test_huggingface_models.py +20 -14
  54. lalamo-0.4.0/tests/test_moe.py +58 -0
  55. {lalamo-0.3.3 → lalamo-0.4.0}/tests/test_registry_abc.py +2 -3
  56. lalamo-0.3.3/lalamo/__init__.py +0 -11
  57. lalamo-0.3.3/lalamo/language_model.py +0 -286
  58. lalamo-0.3.3/lalamo/modules/activations.py +0 -30
  59. lalamo-0.3.3/lalamo/modules/mlp.py +0 -112
  60. {lalamo-0.3.3 → lalamo-0.4.0}/LICENSE +0 -0
  61. {lalamo-0.3.3 → lalamo-0.4.0}/README.md +0 -0
  62. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/common.py +0 -0
  63. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/__init__.py +0 -0
  64. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/decoder_configs/common.py +0 -0
  65. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/huggingface_generation_config.py +0 -0
  66. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/loaders/__init__.py +0 -0
  67. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/loaders/common.py +0 -0
  68. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/deepseek.py +0 -0
  69. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/gemma.py +0 -0
  70. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/huggingface.py +0 -0
  71. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/llama.py +0 -0
  72. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/mistral.py +0 -0
  73. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/pleias.py +0 -0
  74. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/polaris.py +0 -0
  75. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/qwen.py +0 -0
  76. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/model_import/model_specs/reka.py +0 -0
  77. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/modules/torch_interop.py +0 -0
  78. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/quantization.py +0 -0
  79. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/registry_abc.py +0 -0
  80. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo/sampling.py +0 -0
  81. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo.egg-info/dependency_links.txt +0 -0
  82. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo.egg-info/entry_points.txt +0 -0
  83. {lalamo-0.3.3 → lalamo-0.4.0}/lalamo.egg-info/top_level.txt +0 -0
  84. {lalamo-0.3.3 → lalamo-0.4.0}/setup.cfg +0 -0
  85. {lalamo-0.3.3 → lalamo-0.4.0}/tests/test_model_spec.py +0 -0
  86. {lalamo-0.3.3 → lalamo-0.4.0}/tests/test_parameter_tree.py +0 -0
@@ -1,17 +1,16 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.3.3
3
+ Version: 0.4.0
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
7
7
  License-File: LICENSE
8
- Requires-Dist: cattrs>=24.1.2
8
+ Requires-Dist: cattrs[msgpack]>=24.1.2
9
9
  Requires-Dist: click>=8.1.8
10
10
  Requires-Dist: einops>=0.8.0
11
11
  Requires-Dist: equinox>=0.11.11
12
12
  Requires-Dist: huggingface-hub[hf-transfer]>=0.27.1
13
- Requires-Dist: jax>=0.4.38; sys_platform == "darwin"
14
- Requires-Dist: jax[cuda]>=0.4.38; sys_platform == "linux"
13
+ Requires-Dist: jax>=0.7.2
15
14
  Requires-Dist: jaxtyping>=0.2.36
16
15
  Requires-Dist: jinja2>=3.1.6
17
16
  Requires-Dist: ml-dtypes>=0.5.1
@@ -21,6 +20,14 @@ Requires-Dist: thefuzz>=0.22.1
21
20
  Requires-Dist: tokenizers>=0.21.2
22
21
  Requires-Dist: typer>=0.15.1
23
22
  Requires-Dist: safetensors>=0.6.2
23
+ Requires-Dist: polars>=1.33.1
24
+ Requires-Dist: xxhash>=3.5.0
25
+ Provides-Extra: cpu
26
+ Requires-Dist: jax[cpu]>=0.7.2; extra == "cpu"
27
+ Provides-Extra: cuda
28
+ Requires-Dist: jax[cuda]>=0.7.2; extra == "cuda"
29
+ Provides-Extra: tpu
30
+ Requires-Dist: jax[tpu]>=0.7.2; extra == "tpu"
24
31
  Dynamic: license-file
25
32
 
26
33
  <p align="center">
@@ -0,0 +1,26 @@
1
+ from lalamo.language_model import LanguageModel
2
+ from lalamo.message_processor import (
3
+ AssistantMessage,
4
+ ContentBlock,
5
+ Image,
6
+ Message,
7
+ SystemMessage,
8
+ ToolSchema,
9
+ UserMessage,
10
+ )
11
+ from lalamo.model_import import ModelSpec, import_model
12
+
13
+ __version__ = "0.4.0"
14
+
15
+ __all__ = [
16
+ "AssistantMessage",
17
+ "ContentBlock",
18
+ "Image",
19
+ "LanguageModel",
20
+ "Message",
21
+ "ModelSpec",
22
+ "SystemMessage",
23
+ "ToolSchema",
24
+ "UserMessage",
25
+ "import_model",
26
+ ]
@@ -0,0 +1,8 @@
1
+ from .huggingface_message import import_hf_parquet
2
+ from .utils import get_prefixes_ending_in_user_message
3
+
4
+ __all__ = [
5
+ "get_prefixes_ending_in_user_message",
6
+ "import_hf_parquet",
7
+ ]
8
+
@@ -0,0 +1,38 @@
1
+ from collections.abc import Iterable
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import ClassVar, Self
5
+
6
+ import cattrs
7
+ import polars as pl
8
+
9
+ from lalamo.message_processor import AssistantMessage, Message, UserMessage
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class HFMessage:
14
+ _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
15
+
16
+ role: str
17
+ content: str
18
+
19
+ @classmethod
20
+ def from_dict(cls, obj: dict) -> Self:
21
+ return cls._converter.structure(obj, cls)
22
+
23
+ def as_message(self) -> Message:
24
+ match self.role:
25
+ case "user":
26
+ return UserMessage(self.content)
27
+ case "assistant":
28
+ return AssistantMessage(None, self.content)
29
+ case other:
30
+ raise ValueError(f"Cannot convert {other} message")
31
+
32
+ def import_hf_parquet(path: Path | str) -> Iterable[list[Message]]:
33
+ path = Path(path)
34
+
35
+ dataframe = pl.scan_parquet(path).collect()
36
+
37
+ for conversation in dataframe.get_column("conversation").shuffle(1337):
38
+ yield [HFMessage.from_dict(message).as_message() for message in conversation]
@@ -0,0 +1,43 @@
1
+ from collections.abc import Iterable
2
+ from dataclasses import dataclass
3
+ from typing import IO, Any, ClassVar, Self
4
+
5
+ import msgpack
6
+ from cattrs.preconf.msgpack import MsgpackConverter
7
+ from cattrs.preconf.msgpack import make_converter as make_msgpack_converter
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class LalamoCompletion:
12
+ _converter: ClassVar[MsgpackConverter] = make_msgpack_converter()
13
+
14
+ prefix_token_ids: list[int]
15
+ completion_token_ids: list[int]
16
+ completion_token_logits: list[dict[int, float]]
17
+
18
+ def __post_init__(self) -> None:
19
+ if len(self.completion_token_ids) != len(self.completion_token_logits):
20
+ raise ValueError(f"({len(self.completion_token_ids)=}) != ({len(self.completion_token_logits)=})")
21
+
22
+ def serialize(self) -> bytes:
23
+ return self._converter.dumps(self)
24
+
25
+ @classmethod
26
+ def deserialize(cls, data: bytes | IO[bytes]) -> Self:
27
+ if isinstance(data, bytes):
28
+ obj: Any = msgpack.unpackb(data, strict_map_key=False)
29
+ else:
30
+ obj = msgpack.unpack(data, strict_map_key=False)
31
+
32
+ return cls._converter.structure(obj, cls)
33
+
34
+ @classmethod
35
+ def deserialize_many(cls, data: bytes | IO[bytes]) -> Iterable[Self]:
36
+ if isinstance(data, bytes):
37
+ unpacker = msgpack.Unpacker(strict_map_key=False)
38
+ unpacker.feed(data)
39
+ else:
40
+ unpacker = msgpack.Unpacker(file_like=data, strict_map_key=False)
41
+
42
+ for obj in unpacker:
43
+ yield cls._converter.structure(obj, cls)
@@ -0,0 +1,8 @@
1
+ from collections.abc import Iterable
2
+
3
+ from lalamo.message_processor import Message, UserMessage
4
+
5
+
6
+ def get_prefixes_ending_in_user_message(conversation: Iterable[Message]) -> list[list[Message]]:
7
+ conversation = list(conversation)
8
+ return [conversation[: i + 1] for i, msg in enumerate(conversation) if isinstance(msg, UserMessage)]
@@ -0,0 +1,369 @@
1
+ import json
2
+ from collections.abc import Iterable
3
+ from dataclasses import dataclass, replace
4
+ from pathlib import Path
5
+ from typing import NamedTuple, Self
6
+
7
+ import equinox as eqx
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from einops import rearrange
11
+ from jax import vmap
12
+ from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
13
+ from tokenizers import Tokenizer
14
+
15
+ from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
16
+ from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
17
+ from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule, config_converter
18
+ from lalamo.modules.common import ForwardPassMode
19
+ from lalamo.modules.decoder import DecoderForwardPassConfig
20
+ from lalamo.sampling import SamplingPolicy, make_policy
21
+ from lalamo.utils import open_safetensors
22
+
23
+ __all__ = [
24
+ "ForwardPassConfig",
25
+ "GenerationConfig",
26
+ "LanguageModel",
27
+ "LanguageModelConfig",
28
+ ]
29
+
30
+
31
+ _COMPILED_PROMPT_LENGTHS = [512 * 2**i for i in range(10)]
32
+
33
+
34
+ type ForwardPassConfig = DecoderForwardPassConfig
35
+
36
+
37
+ class PrefillResults(NamedTuple):
38
+ last_token_logits: Float[Array, "batch vocabulary"]
39
+ last_token_indices: Int[Array, " batch"]
40
+ kv_cache: KVCache
41
+
42
+
43
+ class DecodingState(NamedTuple):
44
+ last_token_logits: Float[Array, "batch vocabulary"]
45
+ last_token_indices: Int[Array, " batch"]
46
+ kv_cache: KVCache
47
+ stop_flags: Bool[Array, " batch"]
48
+
49
+
50
+ class GenerationStepResults(NamedTuple):
51
+ token_ids: Int[Array, " batch"]
52
+ top_k_token_ids: Int[Array, " batch k"] | None
53
+ top_k_token_logits: Float[Array, " batch k"] | None
54
+
55
+
56
+ class GenerationResults(NamedTuple):
57
+ token_ids: Int[Array, "batch response_tokens"]
58
+ top_k_token_ids: Int[Array, "batch response_tokens k"] | None
59
+ top_k_token_logits: Float[Array, "batch response_tokens k"] | None
60
+
61
+
62
+ @dataclass(frozen=True)
63
+ class GenerationConfig:
64
+ stop_token_ids: tuple[int, ...]
65
+ temperature: float | None
66
+ top_k: int | None
67
+ top_p: float | None
68
+ banned_tokens: tuple[int, ...] | None
69
+
70
+ def default_policy(self) -> SamplingPolicy:
71
+ return make_policy(self.temperature, self.top_k, self.top_p, self.banned_tokens)
72
+
73
+
74
+ @dataclass(frozen=True)
75
+ class LanguageModelConfig:
76
+ decoder_config: DecoderConfig
77
+ message_processor_config: MessageProcessorConfig
78
+ generation_config: GenerationConfig
79
+
80
+
81
+ class LanguageModel(LalamoModule[LanguageModelConfig]):
82
+ decoder: Decoder
83
+ message_processor: MessageProcessor = eqx.field(static=True)
84
+
85
+ @classmethod
86
+ def load(cls, path: Path | str) -> Self:
87
+ if isinstance(path, str):
88
+ path = Path(path)
89
+ with open(path / "config.json") as config_file:
90
+ config_json = json.load(config_file)
91
+ config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
92
+ with open_safetensors(path / "model.safetensors") as weights_dict:
93
+ weights = unflatten_parameters(weights_dict)
94
+ decoder = config.decoder_config.empty().import_weights(weights)
95
+ tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
96
+ message_processor = MessageProcessor(config.message_processor_config, tokenizer)
97
+ return cls(config, decoder, message_processor)
98
+
99
+ @property
100
+ def activation_precision(self) -> DTypeLike:
101
+ return self.decoder.activation_precision
102
+
103
+ def export_weights(self) -> ParameterTree:
104
+ return self.decoder.export_weights()
105
+
106
+ def import_weights(
107
+ self,
108
+ weights: ParameterTree[Array],
109
+ ) -> Self:
110
+ return replace(
111
+ self,
112
+ decoder=self.decoder.import_weights(weights),
113
+ )
114
+
115
+ @property
116
+ def stop_token_ids(self) -> tuple[int, ...]:
117
+ return self.config.generation_config.stop_token_ids
118
+
119
+ def default_sampling_policy(self) -> SamplingPolicy:
120
+ return self.config.generation_config.default_policy()
121
+
122
+ @eqx.filter_jit
123
+ def _prefill(
124
+ self,
125
+ token_ids: Int[Array, "batch tokens"],
126
+ lengths_without_padding: Int[Array, " batch"] | None = None,
127
+ kv_cache_capacity: int | None = None,
128
+ forward_pass_config: ForwardPassConfig | None = None,
129
+ ) -> PrefillResults:
130
+ batch_size, sequence_length = token_ids.shape
131
+ token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
132
+ if kv_cache_capacity is not None:
133
+ kv_cache = self.decoder.init_static_kv_cache(batch_size, kv_cache_capacity)
134
+ else:
135
+ kv_cache = None
136
+
137
+ decoder_outputs = self.decoder(
138
+ token_ids,
139
+ token_positions,
140
+ kv_cache,
141
+ return_updated_kv_cache=True,
142
+ lengths_without_padding=lengths_without_padding,
143
+ forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
144
+ forward_pass_config=forward_pass_config,
145
+ )
146
+
147
+ if lengths_without_padding is not None:
148
+ last_logits_indices = lengths_without_padding - 1
149
+ else:
150
+ last_logits_indices = jnp.array([sequence_length - 1] * batch_size, dtype=jnp.int32)
151
+
152
+ last_token_logits = vmap(lambda logits, index: logits[index])(decoder_outputs.logits, last_logits_indices)
153
+
154
+ assert decoder_outputs.updated_kv_cache is not None
155
+ return PrefillResults(
156
+ last_token_logits=last_token_logits,
157
+ last_token_indices=last_logits_indices,
158
+ kv_cache=decoder_outputs.updated_kv_cache,
159
+ )
160
+
161
+ @eqx.filter_jit
162
+ def generate_tokens(
163
+ self,
164
+ prompt_token_ids: Int[Array, "batch prompt_tokens"],
165
+ sampling_policy: SamplingPolicy | None = None,
166
+ prompt_lengths_without_padding: Int[Array, " batch"] | None = None,
167
+ max_output_length: int = 8192,
168
+ eos_token_ids: Int[Array, " eos_tokens"] | None = None,
169
+ forward_pass_config: ForwardPassConfig | None = None,
170
+ num_top_logits_to_return: int | None = None,
171
+ *,
172
+ key: PRNGKeyArray | None = None,
173
+ ) -> GenerationResults:
174
+ if sampling_policy is None:
175
+ sampling_policy = self.default_sampling_policy()
176
+ if eos_token_ids is None:
177
+ eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
178
+
179
+ batch_size, sequence_length = prompt_token_ids.shape
180
+ prefill_results = self._prefill(
181
+ prompt_token_ids,
182
+ prompt_lengths_without_padding,
183
+ sequence_length + max_output_length,
184
+ forward_pass_config=forward_pass_config,
185
+ )
186
+
187
+ initial_state = DecodingState(
188
+ prefill_results.last_token_logits,
189
+ prefill_results.last_token_indices,
190
+ prefill_results.kv_cache,
191
+ jnp.zeros(batch_size, dtype=jnp.bool),
192
+ )
193
+
194
+ if key is None:
195
+ key = jax.random.PRNGKey(0)
196
+ keys = jax.random.split(key, num=max_output_length)
197
+
198
+ def loop_iteration(
199
+ state: DecodingState,
200
+ key: PRNGKeyArray,
201
+ ) -> tuple[DecodingState, GenerationStepResults]:
202
+ def sample_and_update() -> tuple[DecodingState, GenerationStepResults]:
203
+ upcasted_logits = state.last_token_logits.astype(jnp.float32)
204
+ processed_logits = vmap(sampling_policy.process_logits)(upcasted_logits)
205
+ next_token_ids = jax.random.categorical(key, processed_logits)
206
+ next_token_ids = jnp.where(state.stop_flags, jnp.zeros(batch_size, dtype=jnp.int32), next_token_ids)
207
+ if num_top_logits_to_return is not None:
208
+ next_top_k_token_logits, next_top_k_token_ids = jax.lax.top_k(
209
+ processed_logits,
210
+ num_top_logits_to_return,
211
+ )
212
+ else:
213
+ next_top_k_token_ids = None
214
+ next_top_k_token_logits = None
215
+ next_token_indices = state.last_token_indices + 1
216
+
217
+ stop_flags = state.stop_flags | jnp.any(next_token_ids[:, None] == eos_token_ids[None, :], axis=-1)
218
+
219
+ if batch_size == 1:
220
+ forward_pass_mode = ForwardPassMode.SINGLE_TOKEN
221
+ else:
222
+ forward_pass_mode = ForwardPassMode.MULTI_TOKEN
223
+
224
+ decoder_outputs = self.decoder(
225
+ next_token_ids[:, None],
226
+ next_token_indices[:, None],
227
+ state.kv_cache,
228
+ return_updated_kv_cache=True,
229
+ forward_pass_mode=forward_pass_mode,
230
+ forward_pass_config=forward_pass_config,
231
+ )
232
+ assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
233
+ new_state = DecodingState(
234
+ decoder_outputs.logits.squeeze(1),
235
+ next_token_indices,
236
+ decoder_outputs.updated_kv_cache,
237
+ stop_flags,
238
+ )
239
+ return new_state, GenerationStepResults(next_token_ids, next_top_k_token_ids, next_top_k_token_logits)
240
+
241
+ def pad_and_repeat_state() -> tuple[DecodingState, GenerationStepResults]:
242
+ (batch_size,) = state.stop_flags.shape
243
+ pad_token = jnp.zeros(batch_size, dtype=jnp.int32)
244
+ if num_top_logits_to_return is not None:
245
+ top_k_token_ids = jnp.zeros((batch_size, num_top_logits_to_return), dtype=jnp.int32)
246
+ top_k_token_logits = jnp.zeros((batch_size, num_top_logits_to_return), dtype=jnp.float32)
247
+ else:
248
+ top_k_token_ids = None
249
+ top_k_token_logits = None
250
+ return state, GenerationStepResults(pad_token, top_k_token_ids, top_k_token_logits)
251
+
252
+ return jax.lax.cond(jnp.all(state.stop_flags), pad_and_repeat_state, sample_and_update)
253
+
254
+ _, generated = jax.lax.scan(loop_iteration, initial_state, keys)
255
+
256
+ token_ids = rearrange(generated.token_ids, "iteration batch -> batch iteration")
257
+
258
+ if num_top_logits_to_return is not None:
259
+ top_k_token_ids = rearrange(generated.top_k_token_ids, "iteration batch k -> batch iteration k")
260
+ top_k_token_logits = rearrange(generated.top_k_token_logits, "iteration batch k -> batch iteration k")
261
+ else:
262
+ top_k_token_ids = None
263
+ top_k_token_logits = None
264
+
265
+ return GenerationResults(token_ids, top_k_token_ids, top_k_token_logits)
266
+
267
+ def reply(
268
+ self,
269
+ messages: Iterable[Message],
270
+ sampling_policy: SamplingPolicy | None = None,
271
+ forward_pass_config: ForwardPassConfig | None = None,
272
+ *,
273
+ key: PRNGKeyArray | None = None,
274
+ ) -> AssistantMessage:
275
+ formatted_messages = self.message_processor.render_request(messages)
276
+ token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)[None, :]
277
+ response_ids = self.generate_tokens(
278
+ token_ids,
279
+ sampling_policy,
280
+ forward_pass_config=forward_pass_config,
281
+ key=key,
282
+ ).token_ids.squeeze(0)
283
+ response_text = self.message_processor.detokenize(response_ids.tolist())
284
+ return self.message_processor.parse_response(response_text)
285
+
286
+ def stream_reply_text(
287
+ self,
288
+ messages: Iterable[Message],
289
+ sampling_policy: SamplingPolicy | None = None,
290
+ max_output_length: int = 8192,
291
+ forward_pass_config: ForwardPassConfig | None = None,
292
+ *,
293
+ key: PRNGKeyArray | None = None,
294
+ ) -> Iterable[str]:
295
+ formatted_messages = self.message_processor.render_request(messages)
296
+ token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
297
+ for token_id in self.stream_tokens(
298
+ token_ids,
299
+ sampling_policy,
300
+ max_output_length,
301
+ forward_pass_config=forward_pass_config,
302
+ key=key,
303
+ ):
304
+ yield self.message_processor.detokenize([token_id.item()])
305
+
306
+ def stream_tokens(
307
+ self,
308
+ prompt_token_ids: Int[Array, " prompt_tokens"],
309
+ sampling_policy: SamplingPolicy | None = None,
310
+ max_output_length: int = 8192,
311
+ eos_token_ids: Int[Array, " eos_tokens"] | None = None,
312
+ forward_pass_config: ForwardPassConfig | None = None,
313
+ *,
314
+ key: PRNGKeyArray | None = None,
315
+ ) -> Iterable[Int[Array, ""]]:
316
+ if sampling_policy is None:
317
+ sampling_policy = self.default_sampling_policy()
318
+ if eos_token_ids is None:
319
+ eos_token_ids = jnp.array(self.stop_token_ids, dtype=jnp.int32)
320
+
321
+ (input_length,) = prompt_token_ids.shape
322
+
323
+ padded_input_length = min(length for length in _COMPILED_PROMPT_LENGTHS if length >= input_length)
324
+ padded_token_ids = jnp.zeros((padded_input_length,), dtype=jnp.int32)
325
+ padded_token_ids = padded_token_ids.at[:input_length].set(prompt_token_ids)
326
+
327
+ prefill_results = self._prefill(
328
+ padded_token_ids[None, :],
329
+ jnp.array([input_length], dtype=jnp.int32),
330
+ padded_input_length + max_output_length,
331
+ forward_pass_config=forward_pass_config,
332
+ )
333
+
334
+ if key is None:
335
+ key = jax.random.PRNGKey(0)
336
+ keys = jax.random.split(key, num=max_output_length)
337
+
338
+ state = DecodingState(
339
+ prefill_results.last_token_logits,
340
+ prefill_results.last_token_indices,
341
+ prefill_results.kv_cache,
342
+ jnp.array([0], dtype=jnp.bool),
343
+ )
344
+
345
+ for iter_key in keys:
346
+ upcasted_logits = state.last_token_logits.astype(jnp.float32)
347
+ processed_logits = sampling_policy.process_logits(upcasted_logits.squeeze(0))
348
+ next_token_id = jax.random.categorical(iter_key, processed_logits)
349
+
350
+ yield next_token_id
351
+
352
+ if jnp.any(next_token_id == eos_token_ids):
353
+ return
354
+
355
+ next_token_indices = state.last_token_indices + 1
356
+ decoder_outputs = self.decoder(
357
+ next_token_id.reshape(1, 1),
358
+ next_token_indices.reshape(1, 1),
359
+ state.kv_cache,
360
+ return_updated_kv_cache=True,
361
+ forward_pass_config=forward_pass_config,
362
+ )
363
+ assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
364
+ state = DecodingState(
365
+ decoder_outputs.logits.squeeze(1),
366
+ next_token_indices,
367
+ decoder_outputs.updated_kv_cache,
368
+ state.stop_flags,
369
+ )