lalamo 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo-0.2.1/LICENSE +21 -0
- lalamo-0.2.1/PKG-INFO +74 -0
- lalamo-0.2.1/README.md +52 -0
- lalamo-0.2.1/lalamo/__init__.py +11 -0
- lalamo-0.2.1/lalamo/common.py +60 -0
- lalamo-0.2.1/lalamo/language_model.py +263 -0
- lalamo-0.2.1/lalamo/main.py +299 -0
- lalamo-0.2.1/lalamo/quantization.py +92 -0
- lalamo-0.2.1/lalamo/utils.py +55 -0
- lalamo-0.2.1/lalamo.egg-info/PKG-INFO +74 -0
- lalamo-0.2.1/lalamo.egg-info/SOURCES.txt +17 -0
- lalamo-0.2.1/lalamo.egg-info/dependency_links.txt +1 -0
- lalamo-0.2.1/lalamo.egg-info/entry_points.txt +2 -0
- lalamo-0.2.1/lalamo.egg-info/requires.txt +17 -0
- lalamo-0.2.1/lalamo.egg-info/top_level.txt +1 -0
- lalamo-0.2.1/pyproject.toml +119 -0
- lalamo-0.2.1/setup.cfg +4 -0
- lalamo-0.2.1/tests/test_generation.py +178 -0
- lalamo-0.2.1/tests/test_huggingface_models.py +81 -0
lalamo-0.2.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Mirai Tech Inc.
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
lalamo-0.2.1/PKG-INFO
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lalamo
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: JAX library for optimization and export of models for use with the UZU inference engine.
|
|
5
|
+
Requires-Python: <4,>=3.12
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: cattrs>=24.1.2
|
|
9
|
+
Requires-Dist: click>=8.1.8
|
|
10
|
+
Requires-Dist: einops>=0.8.0
|
|
11
|
+
Requires-Dist: equinox>=0.11.11
|
|
12
|
+
Requires-Dist: huggingface-hub[hf-transfer]>=0.27.1
|
|
13
|
+
Requires-Dist: jax>=0.4.38; sys_platform == "darwin"
|
|
14
|
+
Requires-Dist: jax[cuda]>=0.4.38; sys_platform == "linux"
|
|
15
|
+
Requires-Dist: jaxtyping>=0.2.36
|
|
16
|
+
Requires-Dist: ml-dtypes>=0.5.1
|
|
17
|
+
Requires-Dist: optax>=0.2.4
|
|
18
|
+
Requires-Dist: rich>=14.0.0
|
|
19
|
+
Requires-Dist: thefuzz>=0.22.1
|
|
20
|
+
Requires-Dist: typer>=0.15.1
|
|
21
|
+
Dynamic: license-file
|
|
22
|
+
|
|
23
|
+
<p align="center">
|
|
24
|
+
<picture>
|
|
25
|
+
<img alt="Mirai" src="https://artifacts.trymirai.com/social/github/lalamo-header.jpg" style="max-width: 100%;">
|
|
26
|
+
</picture>
|
|
27
|
+
</p>
|
|
28
|
+
|
|
29
|
+
<a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
|
|
30
|
+
<a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
|
|
31
|
+
<a href="mailto:alexey@getmirai.co,dima@getmirai.co,aleksei@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
|
|
32
|
+
<a href="https://docs.trymirai.com/components/models"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
|
|
33
|
+
[](LICENSE)
|
|
34
|
+
|
|
35
|
+
# lalamo
|
|
36
|
+
|
|
37
|
+
A set of tools for adapting Large Language Models to on-device inference using the [uzu](https://github.com/trymirai/uzu) inference engine.
|
|
38
|
+
|
|
39
|
+
## Quick Start
|
|
40
|
+
|
|
41
|
+
To get the list of [supported models](https://trymirai.com/models), run:
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
uv run lalamo list-models
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
To convert a model, run:
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
uv run lalamo convert MODEL_REPO --precision float16
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
After that, you can find the converted model in the `models` folder. For more options see `uv run lalamo convert --help`.
|
|
54
|
+
|
|
55
|
+
## Model Support
|
|
56
|
+
|
|
57
|
+
To add support for a new model, write the corresponding [ModelSpec](lalamo/model_import/model_specs), as shown in the example below:
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
ModelSpec(
|
|
61
|
+
vendor="Google",
|
|
62
|
+
family="Gemma-3",
|
|
63
|
+
name="Gemma-3-1B-Instruct",
|
|
64
|
+
size="1B",
|
|
65
|
+
quantization=None,
|
|
66
|
+
repo="google/gemma-3-1b-it",
|
|
67
|
+
config_type=HFGemma3TextConfig,
|
|
68
|
+
config_file_name="config.json",
|
|
69
|
+
weights_file_names=huggingface_weight_files(1),
|
|
70
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
71
|
+
tokenizer_files=HUGGINGFACE_TOKENIZER_FILES,
|
|
72
|
+
use_cases=tuple(),
|
|
73
|
+
)
|
|
74
|
+
```
|
lalamo-0.2.1/README.md
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<picture>
|
|
3
|
+
<img alt="Mirai" src="https://artifacts.trymirai.com/social/github/lalamo-header.jpg" style="max-width: 100%;">
|
|
4
|
+
</picture>
|
|
5
|
+
</p>
|
|
6
|
+
|
|
7
|
+
<a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
|
|
8
|
+
<a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
|
|
9
|
+
<a href="mailto:alexey@getmirai.co,dima@getmirai.co,aleksei@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
|
|
10
|
+
<a href="https://docs.trymirai.com/components/models"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
|
|
11
|
+
[](LICENSE)
|
|
12
|
+
|
|
13
|
+
# lalamo
|
|
14
|
+
|
|
15
|
+
A set of tools for adapting Large Language Models to on-device inference using the [uzu](https://github.com/trymirai/uzu) inference engine.
|
|
16
|
+
|
|
17
|
+
## Quick Start
|
|
18
|
+
|
|
19
|
+
To get the list of [supported models](https://trymirai.com/models), run:
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
uv run lalamo list-models
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
To convert a model, run:
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
uv run lalamo convert MODEL_REPO --precision float16
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
After that, you can find the converted model in the `models` folder. For more options see `uv run lalamo convert --help`.
|
|
32
|
+
|
|
33
|
+
## Model Support
|
|
34
|
+
|
|
35
|
+
To add support for a new model, write the corresponding [ModelSpec](lalamo/model_import/model_specs), as shown in the example below:
|
|
36
|
+
|
|
37
|
+
```python
|
|
38
|
+
ModelSpec(
|
|
39
|
+
vendor="Google",
|
|
40
|
+
family="Gemma-3",
|
|
41
|
+
name="Gemma-3-1B-Instruct",
|
|
42
|
+
size="1B",
|
|
43
|
+
quantization=None,
|
|
44
|
+
repo="google/gemma-3-1b-it",
|
|
45
|
+
config_type=HFGemma3TextConfig,
|
|
46
|
+
config_file_name="config.json",
|
|
47
|
+
weights_file_names=huggingface_weight_files(1),
|
|
48
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
49
|
+
tokenizer_files=HUGGINGFACE_TOKENIZER_FILES,
|
|
50
|
+
use_cases=tuple(),
|
|
51
|
+
)
|
|
52
|
+
```
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from collections.abc import Iterable, Mapping
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import Array, DTypeLike
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"DEFAULT_PRECISION",
|
|
8
|
+
"ParameterDict",
|
|
9
|
+
"ParameterPath",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
type NestedParameters = Mapping[str, Array | NestedParameters] | Iterable[Array | NestedParameters]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ParameterDict(dict[str, Array]):
|
|
19
|
+
def __init__(self, **kwargs: Array | NestedParameters | Iterable[Array | NestedParameters]) -> None:
|
|
20
|
+
super().__init__(self._flatten(kwargs))
|
|
21
|
+
|
|
22
|
+
def __setitem__(
|
|
23
|
+
self,
|
|
24
|
+
key: str,
|
|
25
|
+
value: Array | NestedParameters | Iterable[Array | NestedParameters],
|
|
26
|
+
) -> None:
|
|
27
|
+
key = ParameterPath(key)
|
|
28
|
+
|
|
29
|
+
if isinstance(value, Array):
|
|
30
|
+
super().__setitem__(key, value)
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
for subkey, subvalue in self._flatten(value).items():
|
|
34
|
+
super().__setitem__(key / subkey, subvalue)
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def _flatten(cls, nested_parameters: NestedParameters) -> dict[str, Array]:
|
|
38
|
+
result: dict[str, Array] = {}
|
|
39
|
+
if not isinstance(nested_parameters, Mapping):
|
|
40
|
+
nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
|
|
41
|
+
for key, value in nested_parameters.items():
|
|
42
|
+
key_path = ParameterPath(key)
|
|
43
|
+
if isinstance(value, Array):
|
|
44
|
+
result[key_path] = value
|
|
45
|
+
else:
|
|
46
|
+
result.update({key_path / subkey: subvalue for subkey, subvalue in cls._flatten(value).items()})
|
|
47
|
+
return result
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class ParameterPath(str):
|
|
51
|
+
__slots__ = ()
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def components(self) -> tuple[str, ...]:
|
|
55
|
+
return tuple(self.split("."))
|
|
56
|
+
|
|
57
|
+
def __truediv__(self, other: str | int) -> "ParameterPath":
|
|
58
|
+
if not self:
|
|
59
|
+
return ParameterPath(str(other))
|
|
60
|
+
return ParameterPath(self + "." + str(other))
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
|
|
10
|
+
|
|
11
|
+
from lalamo.modules import Decoder, KVCache
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"BanTokensPolicy",
|
|
15
|
+
"CompositePolicy",
|
|
16
|
+
"GreedyPolicy",
|
|
17
|
+
"LanguageModel",
|
|
18
|
+
"SamplingPolicy",
|
|
19
|
+
"TemperaturePolicy",
|
|
20
|
+
"TopKPolicy",
|
|
21
|
+
"TopPPolicy",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SamplingPolicy(eqx.Module):
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]: ...
|
|
28
|
+
|
|
29
|
+
def __call__(self, logits: Float[Array, " vocabulary"], *, key: PRNGKeyArray) -> Int[Array, ""]:
|
|
30
|
+
return jax.random.categorical(key, self.process_logits(logits))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GreedyPolicy(SamplingPolicy):
|
|
34
|
+
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
35
|
+
max_logit_value = jnp.max(logits)
|
|
36
|
+
return jnp.where(logits == max_logit_value, 1.0, -jnp.inf)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TemperaturePolicy(SamplingPolicy):
|
|
40
|
+
temperature: float = eqx.field(static=True)
|
|
41
|
+
|
|
42
|
+
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
43
|
+
return logits / self.temperature
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TopKPolicy(SamplingPolicy):
|
|
47
|
+
k: int = eqx.field(static=True)
|
|
48
|
+
|
|
49
|
+
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
50
|
+
top_k_logits, _ = jax.lax.top_k(logits, self.k)
|
|
51
|
+
min_logit_val = jnp.min(top_k_logits)
|
|
52
|
+
return jnp.where(logits >= min_logit_val, logits, -jnp.inf)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TopPPolicy(SamplingPolicy):
|
|
56
|
+
p: float = eqx.field(static=True)
|
|
57
|
+
|
|
58
|
+
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
59
|
+
sorted_indices = jnp.argsort(logits, descending=True)
|
|
60
|
+
sorted_logits = logits[sorted_indices]
|
|
61
|
+
cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits))
|
|
62
|
+
|
|
63
|
+
to_remove = cumulative_probs > self.p
|
|
64
|
+
to_remove = jnp.roll(to_remove, 1)
|
|
65
|
+
to_remove = to_remove.at[0].set(False)
|
|
66
|
+
|
|
67
|
+
return jnp.where(to_remove, -jnp.inf, logits)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class BanTokensPolicy(SamplingPolicy):
|
|
71
|
+
banned_tokens: list[int] = eqx.field(static=True)
|
|
72
|
+
|
|
73
|
+
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
74
|
+
banned_tokens_indices = jnp.asarray(self.banned_tokens, dtype=jnp.int32)
|
|
75
|
+
return logits.at[banned_tokens_indices].set(-jnp.inf)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class CompositePolicy(SamplingPolicy):
|
|
79
|
+
policies: list[SamplingPolicy] = eqx.field(static=True)
|
|
80
|
+
|
|
81
|
+
def process_logits(self, logits: Float[Array, " vocabulary"]) -> Float[Array, " vocabulary"]:
|
|
82
|
+
for policy in self.policies:
|
|
83
|
+
logits = policy.process_logits(logits)
|
|
84
|
+
return logits
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class PrefillResults(NamedTuple):
|
|
88
|
+
last_token_logits: Float[Array, " vocabulary"]
|
|
89
|
+
last_token_position: Int[Array, ""]
|
|
90
|
+
kv_cache: KVCache
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class DecodingState(NamedTuple):
|
|
94
|
+
last_token_logits: Float[Array, " vocabulary"]
|
|
95
|
+
last_token_position: Int[Array, ""]
|
|
96
|
+
kv_cache: KVCache
|
|
97
|
+
stop_flag: Bool[Array, ""]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass(frozen=True)
|
|
101
|
+
class LanguageModel:
|
|
102
|
+
decoder: Decoder
|
|
103
|
+
|
|
104
|
+
def _prefill(
|
|
105
|
+
self,
|
|
106
|
+
token_ids: Int[Array, " tokens"],
|
|
107
|
+
length_without_padding: Int[Array, ""] | int | None = None,
|
|
108
|
+
kv_cache_capacity: int | None = None,
|
|
109
|
+
) -> PrefillResults:
|
|
110
|
+
(num_tokens,) = token_ids.shape
|
|
111
|
+
token_positions = jnp.arange(num_tokens, dtype=jnp.int32)
|
|
112
|
+
if kv_cache_capacity is not None:
|
|
113
|
+
kv_cache = self.decoder.init_static_kv_cache(kv_cache_capacity)
|
|
114
|
+
else:
|
|
115
|
+
kv_cache = None
|
|
116
|
+
|
|
117
|
+
decoder_outputs = self.decoder(
|
|
118
|
+
token_ids,
|
|
119
|
+
token_positions,
|
|
120
|
+
kv_cache,
|
|
121
|
+
return_updated_kv_cache=True,
|
|
122
|
+
length_without_padding=length_without_padding,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if length_without_padding is not None:
|
|
126
|
+
last_logits_index = length_without_padding - 1
|
|
127
|
+
else:
|
|
128
|
+
last_logits_index = num_tokens - 1
|
|
129
|
+
|
|
130
|
+
last_token_logits = decoder_outputs.logits[last_logits_index, :]
|
|
131
|
+
last_token_position = jnp.array(last_logits_index, dtype=jnp.int32)
|
|
132
|
+
|
|
133
|
+
assert decoder_outputs.updated_kv_cache is not None
|
|
134
|
+
return PrefillResults(
|
|
135
|
+
last_token_logits=last_token_logits,
|
|
136
|
+
last_token_position=last_token_position,
|
|
137
|
+
kv_cache=decoder_outputs.updated_kv_cache,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def generate(
|
|
141
|
+
self,
|
|
142
|
+
prompt_token_ids: Int[Array, " prompt_tokens"],
|
|
143
|
+
sampling_policy: SamplingPolicy | None = None,
|
|
144
|
+
prompt_length_without_padding: Int[Array, ""] | int | None = None,
|
|
145
|
+
max_output_length: int = 8192,
|
|
146
|
+
eos_token_ids: Int[Array, " eos_tokens"] | None = None,
|
|
147
|
+
*,
|
|
148
|
+
key: PRNGKeyArray | None = None,
|
|
149
|
+
) -> Int[Array, " response_tokens"]:
|
|
150
|
+
if sampling_policy is None:
|
|
151
|
+
sampling_policy = TemperaturePolicy(temperature=1.0)
|
|
152
|
+
|
|
153
|
+
(input_length,) = prompt_token_ids.shape
|
|
154
|
+
prefill_results = self._prefill(
|
|
155
|
+
prompt_token_ids,
|
|
156
|
+
prompt_length_without_padding,
|
|
157
|
+
input_length + max_output_length,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
initial_state = DecodingState(
|
|
161
|
+
prefill_results.last_token_logits,
|
|
162
|
+
prefill_results.last_token_position,
|
|
163
|
+
prefill_results.kv_cache,
|
|
164
|
+
jnp.array(0, dtype=jnp.bool),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if key is None:
|
|
168
|
+
key = jax.random.PRNGKey(0)
|
|
169
|
+
keys = jax.random.split(key, num=max_output_length)
|
|
170
|
+
|
|
171
|
+
def loop_iteration(
|
|
172
|
+
state: DecodingState,
|
|
173
|
+
key: PRNGKeyArray,
|
|
174
|
+
) -> tuple[DecodingState, Int[Array, ""]]:
|
|
175
|
+
def sample_and_update() -> tuple[DecodingState, Int[Array, ""]]:
|
|
176
|
+
processed_logits = sampling_policy.process_logits(state.last_token_logits)
|
|
177
|
+
next_token_id = jax.random.categorical(key, processed_logits)
|
|
178
|
+
next_token_position = state.last_token_position + 1
|
|
179
|
+
|
|
180
|
+
if eos_token_ids is not None:
|
|
181
|
+
stop_flag = state.stop_flag | jnp.any(next_token_id == eos_token_ids)
|
|
182
|
+
else:
|
|
183
|
+
stop_flag = state.stop_flag
|
|
184
|
+
|
|
185
|
+
decoder_outputs = self.decoder(
|
|
186
|
+
next_token_id.reshape(1),
|
|
187
|
+
next_token_position.reshape(1),
|
|
188
|
+
state.kv_cache,
|
|
189
|
+
return_updated_kv_cache=True,
|
|
190
|
+
)
|
|
191
|
+
assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
|
|
192
|
+
new_state = DecodingState(
|
|
193
|
+
decoder_outputs.logits.squeeze(),
|
|
194
|
+
next_token_position,
|
|
195
|
+
decoder_outputs.updated_kv_cache,
|
|
196
|
+
stop_flag,
|
|
197
|
+
)
|
|
198
|
+
return new_state, next_token_id
|
|
199
|
+
|
|
200
|
+
def pad_and_repeat_state() -> tuple[DecodingState, Int[Array, ""]]:
|
|
201
|
+
pad_token = jnp.array(0, dtype=jnp.int32)
|
|
202
|
+
return state, pad_token
|
|
203
|
+
|
|
204
|
+
return jax.lax.cond(state.stop_flag, pad_and_repeat_state, sample_and_update)
|
|
205
|
+
|
|
206
|
+
_, tokens = jax.lax.scan(loop_iteration, initial_state, keys)
|
|
207
|
+
|
|
208
|
+
return tokens
|
|
209
|
+
|
|
210
|
+
def stream(
|
|
211
|
+
self,
|
|
212
|
+
prompt_token_ids: Int[Array, " prompt_tokens"],
|
|
213
|
+
sampling_policy: SamplingPolicy | None = None,
|
|
214
|
+
prompt_length_without_padding: Int[Array, ""] | int | None = None,
|
|
215
|
+
max_output_length: int = 8192,
|
|
216
|
+
eos_token_ids: Int[Array, " eos_tokens"] | None = None,
|
|
217
|
+
*,
|
|
218
|
+
key: PRNGKeyArray | None = None,
|
|
219
|
+
) -> Iterable[Int[Array, ""]]:
|
|
220
|
+
if sampling_policy is None:
|
|
221
|
+
sampling_policy = TemperaturePolicy(temperature=1.0)
|
|
222
|
+
|
|
223
|
+
(input_length,) = prompt_token_ids.shape
|
|
224
|
+
prefill_results = self._prefill(
|
|
225
|
+
prompt_token_ids,
|
|
226
|
+
prompt_length_without_padding,
|
|
227
|
+
input_length + max_output_length,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
if key is None:
|
|
231
|
+
key = jax.random.PRNGKey(0)
|
|
232
|
+
keys = jax.random.split(key, num=max_output_length)
|
|
233
|
+
|
|
234
|
+
state = DecodingState(
|
|
235
|
+
prefill_results.last_token_logits,
|
|
236
|
+
prefill_results.last_token_position,
|
|
237
|
+
prefill_results.kv_cache,
|
|
238
|
+
jnp.array(0, dtype=jnp.bool),
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
for iter_key in keys:
|
|
242
|
+
processed_logits = sampling_policy.process_logits(state.last_token_logits)
|
|
243
|
+
next_token_id = jax.random.categorical(iter_key, processed_logits)
|
|
244
|
+
|
|
245
|
+
yield next_token_id
|
|
246
|
+
|
|
247
|
+
if eos_token_ids is not None and jnp.any(next_token_id == eos_token_ids):
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
next_token_position = state.last_token_position + 1
|
|
251
|
+
decoder_outputs = self.decoder(
|
|
252
|
+
next_token_id.reshape(1),
|
|
253
|
+
next_token_position.reshape(1),
|
|
254
|
+
state.kv_cache,
|
|
255
|
+
return_updated_kv_cache=True,
|
|
256
|
+
)
|
|
257
|
+
assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
|
|
258
|
+
state = DecodingState(
|
|
259
|
+
decoder_outputs.logits.squeeze(),
|
|
260
|
+
next_token_position,
|
|
261
|
+
decoder_outputs.updated_kv_cache,
|
|
262
|
+
state.stop_flag,
|
|
263
|
+
)
|