litert-torch-nightly 0.9.0.dev20260202__py3-none-any.whl → 0.9.0.dev20260203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- litert_torch/generative/export_hf/core/attention.py +86 -8
- litert_torch/generative/export_hf/core/attention_test.py +7 -2
- litert_torch/generative/export_hf/core/cache.py +112 -64
- litert_torch/generative/export_hf/core/cache_base.py +19 -2
- litert_torch/generative/export_hf/core/export_lib.py +55 -6
- litert_torch/generative/export_hf/core/exportable_module.py +30 -34
- litert_torch/generative/export_hf/core/exportable_module_config.py +39 -0
- litert_torch/generative/export_hf/core/split_cache/attention.py +28 -5
- litert_torch/generative/export_hf/core/split_cache/cache.py +113 -33
- litert_torch/generative/export_hf/core/split_cache/exportable_module.py +21 -14
- litert_torch/generative/export_hf/export.py +35 -2
- litert_torch/version.py +1 -1
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/METADATA +1 -1
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/RECORD +18 -17
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/WHEEL +0 -0
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/entry_points.txt +0 -0
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/licenses/LICENSE +0 -0
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/top_level.txt +0 -0
|
@@ -15,38 +15,35 @@
|
|
|
15
15
|
"""Exportable modules."""
|
|
16
16
|
|
|
17
17
|
import abc
|
|
18
|
-
import dataclasses
|
|
19
18
|
from litert_torch.generative.export_hf.core import cache as _
|
|
20
19
|
from litert_torch.generative.export_hf.core import cache_base as kv_cache_lib
|
|
20
|
+
from litert_torch.generative.export_hf.core import exportable_module_config
|
|
21
21
|
from litert_torch.generative.export_hf.core import utils
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
class ExportableModuleConfig:
|
|
27
|
-
"""Config for exportable modules."""
|
|
25
|
+
ExportableModuleConfig = exportable_module_config.ExportableModuleConfig
|
|
28
26
|
|
|
29
|
-
batch_size: int = 1
|
|
30
|
-
cache_length: int = 1280
|
|
31
|
-
prefill_lengths: list[int] = dataclasses.field(default_factory=lambda: [128])
|
|
32
|
-
# For dynamic shape
|
|
33
|
-
cache_length_dim: torch.export.Dim | None = None
|
|
34
|
-
prefill_length_dim: torch.export.Dim | None = None
|
|
35
27
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
externalize_rope: bool = False
|
|
39
|
-
split_cache: bool = False
|
|
28
|
+
class ExportableModuleBase(torch.nn.Module, abc.ABC):
|
|
29
|
+
"""Base class for exportable modules."""
|
|
40
30
|
|
|
41
|
-
|
|
31
|
+
def __init__(self, export_config: ExportableModuleConfig):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self._export_config = export_config
|
|
42
34
|
|
|
35
|
+
@property
|
|
36
|
+
def export_config(self) -> ExportableModuleConfig:
|
|
37
|
+
return self._export_config
|
|
43
38
|
|
|
44
|
-
|
|
45
|
-
|
|
39
|
+
def attention_kwargs(self):
|
|
40
|
+
k_ts_idx = self.export_config.k_ts_idx
|
|
41
|
+
v_ts_idx = self.export_config.v_ts_idx
|
|
42
|
+
return {"k_ts_idx": k_ts_idx, "v_ts_idx": v_ts_idx}
|
|
46
43
|
|
|
47
44
|
@abc.abstractmethod
|
|
48
45
|
def get_sample_inputs(
|
|
49
|
-
self, model_config
|
|
46
|
+
self, model_config
|
|
50
47
|
) -> dict[str, tuple[dict[str, torch.Tensor], dict[str, torch.export.Dim]]]:
|
|
51
48
|
"""Returns the sample inputs for the model."""
|
|
52
49
|
...
|
|
@@ -55,8 +52,10 @@ class ExportableModuleBase(torch.nn.Module, abc.ABC):
|
|
|
55
52
|
class LiteRTExportableModuleForDecoderOnlyLM(ExportableModuleBase):
|
|
56
53
|
"""Base class for exportable modules for decoder-only LM."""
|
|
57
54
|
|
|
58
|
-
def __init__(
|
|
59
|
-
|
|
55
|
+
def __init__(
|
|
56
|
+
self, model: torch.nn.Module, export_config: ExportableModuleConfig
|
|
57
|
+
):
|
|
58
|
+
super().__init__(export_config)
|
|
60
59
|
self.model = model
|
|
61
60
|
|
|
62
61
|
def adapt_inputs(
|
|
@@ -108,16 +107,13 @@ class LiteRTExportableModuleForDecoderOnlyLM(ExportableModuleBase):
|
|
|
108
107
|
})
|
|
109
108
|
return ret
|
|
110
109
|
|
|
111
|
-
def get_sample_kv_cache(
|
|
112
|
-
self, model_config, export_config: ExportableModuleConfig
|
|
113
|
-
):
|
|
110
|
+
def get_sample_kv_cache(self, model_config):
|
|
114
111
|
"""Returns the input sample KV cache for the model."""
|
|
112
|
+
export_config = self.export_config
|
|
115
113
|
num_layers = model_config.num_hidden_layers
|
|
116
|
-
batch_size = export_config.batch_size
|
|
117
|
-
cache_length = export_config.cache_length
|
|
118
114
|
kv_cache = kv_cache_lib.CACHE_REGISTRY[
|
|
119
115
|
export_config.cache_implementation
|
|
120
|
-
].create_from_config(model_config,
|
|
116
|
+
].create_from_config(model_config, export_config)
|
|
121
117
|
inputs = {"kv_cache": kv_cache}
|
|
122
118
|
if export_config.cache_length_dim is not None:
|
|
123
119
|
all_k_shapes = tuple(
|
|
@@ -150,6 +146,7 @@ class LiteRTExportableModuleForDecoderOnlyLMPrefill(
|
|
|
150
146
|
mask,
|
|
151
147
|
):
|
|
152
148
|
inputs = self.adapt_inputs(tokens, None, input_pos, kv_cache, mask)
|
|
149
|
+
inputs |= self.attention_kwargs()
|
|
153
150
|
output = self.model(**inputs)
|
|
154
151
|
return {"kv_cache": output.past_key_values}
|
|
155
152
|
|
|
@@ -165,11 +162,10 @@ class LiteRTExportableModuleForDecoderOnlyLMPrefill(
|
|
|
165
162
|
)
|
|
166
163
|
return tokens, tokens_dynamic_shape
|
|
167
164
|
|
|
168
|
-
def get_sample_inputs(
|
|
169
|
-
|
|
170
|
-
):
|
|
165
|
+
def get_sample_inputs(self, model_config):
|
|
166
|
+
export_config = self.export_config
|
|
171
167
|
kv_cache_inputs, kv_cache_dynamic_shapes = self.get_sample_kv_cache(
|
|
172
|
-
model_config
|
|
168
|
+
model_config
|
|
173
169
|
)
|
|
174
170
|
batch_size = export_config.batch_size
|
|
175
171
|
cache_length = export_config.cache_length
|
|
@@ -218,6 +214,7 @@ class LiteRTExportableModuleForDecoderOnlyLMGenerate(
|
|
|
218
214
|
mask,
|
|
219
215
|
):
|
|
220
216
|
inputs = self.adapt_inputs(tokens, None, input_pos, kv_cache, mask)
|
|
217
|
+
inputs |= self.attention_kwargs()
|
|
221
218
|
output = self.model(**inputs)
|
|
222
219
|
return {"kv_cache": output.past_key_values, "logits": output.logits}
|
|
223
220
|
|
|
@@ -231,11 +228,10 @@ class LiteRTExportableModuleForDecoderOnlyLMGenerate(
|
|
|
231
228
|
tokens_dynamic_shape = {"tokens": None} if decode_length_dim else {}
|
|
232
229
|
return tokens, tokens_dynamic_shape
|
|
233
230
|
|
|
234
|
-
def get_sample_inputs(
|
|
235
|
-
|
|
236
|
-
):
|
|
231
|
+
def get_sample_inputs(self, model_config):
|
|
232
|
+
export_config = self.export_config
|
|
237
233
|
kv_cache_inputs, kv_cache_dynamic_shapes = self.get_sample_kv_cache(
|
|
238
|
-
model_config
|
|
234
|
+
model_config
|
|
239
235
|
)
|
|
240
236
|
batch_size = export_config.batch_size
|
|
241
237
|
cache_length = export_config.cache_length
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright 2025 The LiteRT Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Exportable modules."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclasses.dataclass
|
|
22
|
+
class ExportableModuleConfig:
|
|
23
|
+
"""Config for exportable modules."""
|
|
24
|
+
|
|
25
|
+
batch_size: int = 1
|
|
26
|
+
cache_length: int = 1280
|
|
27
|
+
prefill_lengths: list[int] = dataclasses.field(default_factory=lambda: [128])
|
|
28
|
+
# For dynamic shape
|
|
29
|
+
cache_length_dim: torch.export.Dim | None = None
|
|
30
|
+
prefill_length_dim: torch.export.Dim | None = None
|
|
31
|
+
|
|
32
|
+
# Export configs
|
|
33
|
+
externalize_embedder: bool = False
|
|
34
|
+
externalize_rope: bool = False
|
|
35
|
+
|
|
36
|
+
split_cache: bool = False
|
|
37
|
+
cache_implementation: str = "LiteRTLMCache"
|
|
38
|
+
k_ts_idx: int = 2
|
|
39
|
+
v_ts_idx: int = 3
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
import math
|
|
18
18
|
from typing import Optional
|
|
19
19
|
|
|
20
|
+
from litert_torch.generative.custom_ops import bmm_4d as bmm_lib
|
|
20
21
|
from litert_torch.generative.export_hf.core.split_cache import cache as kv_cache_lib
|
|
21
22
|
import torch
|
|
22
23
|
import torch.nn.functional as F
|
|
@@ -28,6 +29,8 @@ def _scaled_dot_product_attention(
|
|
|
28
29
|
key_cache: kv_cache_lib.KeyCacheEntry,
|
|
29
30
|
value_cache: kv_cache_lib.ValueCacheEntry,
|
|
30
31
|
head_size: int,
|
|
32
|
+
k_ts_idx: int,
|
|
33
|
+
v_ts_idx: int,
|
|
31
34
|
mask: Optional[torch.Tensor] = None,
|
|
32
35
|
scale: Optional[float] = None,
|
|
33
36
|
softcap: Optional[float] = None,
|
|
@@ -40,6 +43,8 @@ def _scaled_dot_product_attention(
|
|
|
40
43
|
value_cache: A tuple of Value tensor. 1(bk)sh
|
|
41
44
|
head_size (int): head dimension.
|
|
42
45
|
mask (torch.Tensor): the optional mask tensor.
|
|
46
|
+
k_ts_idx (int): the timestamp index of the key tensor.
|
|
47
|
+
v_ts_idx (int): the timestamp index of the value tensor.
|
|
43
48
|
scale (float): the optional scale factor.
|
|
44
49
|
softcap (float): the optional softcap for the logits.
|
|
45
50
|
|
|
@@ -60,8 +65,13 @@ def _scaled_dot_product_attention(
|
|
|
60
65
|
assert mask is not None, "Mask should not be None!"
|
|
61
66
|
t = mask.shape[2]
|
|
62
67
|
|
|
63
|
-
|
|
64
|
-
|
|
68
|
+
if k_ts_idx == 2:
|
|
69
|
+
bmm_fn = bmm_lib.bmm_4d
|
|
70
|
+
else:
|
|
71
|
+
assert k_ts_idx == 3, "k_ts_idx must be 2 or 3."
|
|
72
|
+
bmm_fn = lambda x, y: torch.einsum("abth,abhs->abts", x, y)
|
|
73
|
+
logits0 = bmm_fn(query, key_past)
|
|
74
|
+
logits1 = bmm_fn(query, key)
|
|
65
75
|
logits = torch.cat([logits0, logits1], dim=-1)
|
|
66
76
|
|
|
67
77
|
_, bk, gt, s = logits.shape
|
|
@@ -76,8 +86,13 @@ def _scaled_dot_product_attention(
|
|
|
76
86
|
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
|
77
87
|
probs0, probs1 = probs[..., :-t], probs[..., -t:]
|
|
78
88
|
|
|
79
|
-
|
|
80
|
-
|
|
89
|
+
if v_ts_idx == 3:
|
|
90
|
+
bmm_fn = bmm_lib.bmm_4d
|
|
91
|
+
else:
|
|
92
|
+
assert v_ts_idx == 2, "v_ts_idx must be 2 or 3."
|
|
93
|
+
bmm_fn = lambda x, y: torch.einsum("abts,absh->abth", x, y)
|
|
94
|
+
encoded0 = bmm_fn(probs0, value_past)
|
|
95
|
+
encoded1 = bmm_fn(probs1, value)
|
|
81
96
|
encoded = encoded0 + encoded1
|
|
82
97
|
|
|
83
98
|
return encoded # 1, bk, gt, h
|
|
@@ -94,7 +109,6 @@ def split_cache_attention(
|
|
|
94
109
|
**kwargs, # You need to accept **kwargs as models will pass other args
|
|
95
110
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
96
111
|
"""ODML transposed attention implementation for NPU."""
|
|
97
|
-
del kwargs
|
|
98
112
|
|
|
99
113
|
b, n, seq_len, h = query.shape
|
|
100
114
|
if hasattr(module, "num_key_value_groups"):
|
|
@@ -102,6 +116,13 @@ def split_cache_attention(
|
|
|
102
116
|
else:
|
|
103
117
|
g = 1
|
|
104
118
|
num_query_groups = n // g
|
|
119
|
+
k_ts_idx: int | None = kwargs.get("k_ts_idx", None)
|
|
120
|
+
v_ts_idx: int | None = kwargs.get("v_ts_idx", None)
|
|
121
|
+
if k_ts_idx is None or v_ts_idx is None:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"Timestamp indices not passed to attention module. The model is not"
|
|
124
|
+
" passing the kwargs correctly."
|
|
125
|
+
)
|
|
105
126
|
# bnth -> b(kg)th -> 1(bk)(gt)h
|
|
106
127
|
query = query.reshape(1, b * num_query_groups, g * seq_len, h)
|
|
107
128
|
|
|
@@ -113,6 +134,8 @@ def split_cache_attention(
|
|
|
113
134
|
mask=attention_mask,
|
|
114
135
|
scale=scaling,
|
|
115
136
|
softcap=softcap,
|
|
137
|
+
k_ts_idx=k_ts_idx,
|
|
138
|
+
v_ts_idx=v_ts_idx,
|
|
116
139
|
) # 1, bk, gt, h
|
|
117
140
|
sdpa_out = sdpa_out.reshape(b, -1, seq_len, h).permute(0, 2, 1, 3)
|
|
118
141
|
return sdpa_out, None
|
|
@@ -25,16 +25,26 @@ Shape annotations used here:
|
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
27
|
from typing import Any, List, Optional, Self, Tuple
|
|
28
|
-
import litert_torch.generative.export_hf.core.cache_base as cache_base_lib
|
|
29
28
|
import jaxtyping as jt
|
|
29
|
+
from litert_torch.generative.export_hf.core import exportable_module_config
|
|
30
|
+
import litert_torch.generative.export_hf.core.cache_base as cache_base_lib
|
|
30
31
|
import torch
|
|
31
32
|
import torch.utils._pytree as pytree
|
|
32
33
|
|
|
34
|
+
ExportableModuleConfig = exportable_module_config.ExportableModuleConfig
|
|
33
35
|
|
|
34
|
-
KeyCache =
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
KeyCache = (
|
|
37
|
+
jt.Shaped[torch.Tensor, "1 BK H S"] | jt.Shaped[torch.Tensor, "1 BK S H"]
|
|
38
|
+
)
|
|
39
|
+
KeySlice = (
|
|
40
|
+
jt.Shaped[torch.Tensor, "1 BK H T"] | jt.Shaped[torch.Tensor, "1 BK T H"]
|
|
41
|
+
)
|
|
42
|
+
ValueCache = (
|
|
43
|
+
jt.Shaped[torch.Tensor, "1 BK S H"] | jt.Shaped[torch.Tensor, "1 BK H S"]
|
|
44
|
+
)
|
|
45
|
+
ValueSlice = (
|
|
46
|
+
jt.Shaped[torch.Tensor, "1 BK T H"] | jt.Shaped[torch.Tensor, "1 BK H T"]
|
|
47
|
+
)
|
|
38
48
|
|
|
39
49
|
KeyCacheEntry = Tuple[KeyCache, KeySlice | None]
|
|
40
50
|
ValueCacheEntry = Tuple[ValueCache, ValueSlice | None]
|
|
@@ -51,6 +61,8 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
51
61
|
key_cache: KeyCacheEntry,
|
|
52
62
|
value_cache: ValueCacheEntry,
|
|
53
63
|
batch_size: int = 1,
|
|
64
|
+
k_ts_idx: int = 2,
|
|
65
|
+
v_ts_idx: int = 3,
|
|
54
66
|
**kwargs,
|
|
55
67
|
):
|
|
56
68
|
super().__init__()
|
|
@@ -62,12 +74,16 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
62
74
|
self.values = value_cache
|
|
63
75
|
self.is_initialized = True
|
|
64
76
|
|
|
77
|
+
self.k_ts_idx = k_ts_idx
|
|
78
|
+
self.v_ts_idx = v_ts_idx
|
|
79
|
+
|
|
65
80
|
self.k_cache_shape = self.keys[0].shape
|
|
66
81
|
self.v_cache_shape = self.values[0].shape
|
|
67
|
-
self.max_cache_len = self.k_cache_shape[
|
|
82
|
+
self.max_cache_len = self.k_cache_shape[self.k_ts_idx]
|
|
68
83
|
self.batch_size = batch_size
|
|
69
|
-
self.
|
|
70
|
-
|
|
84
|
+
self.head_dim = (
|
|
85
|
+
self.k_cache_shape[2] if self.k_ts_idx == 3 else self.k_cache_shape[3]
|
|
86
|
+
)
|
|
71
87
|
self.additional_states = kwargs.get("additional_states", None)
|
|
72
88
|
|
|
73
89
|
self.cumulative_length = 0
|
|
@@ -75,6 +91,12 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
75
91
|
def get_batch_size(self) -> int:
|
|
76
92
|
return self.batch_size
|
|
77
93
|
|
|
94
|
+
def get_k_ts_idx(self) -> int:
|
|
95
|
+
return self.k_ts_idx
|
|
96
|
+
|
|
97
|
+
def get_v_ts_idx(self) -> int:
|
|
98
|
+
return self.v_ts_idx
|
|
99
|
+
|
|
78
100
|
def lazy_initialization(self, key_states: torch.Tensor):
|
|
79
101
|
# Since we don't support real lazy initialization, this function could only
|
|
80
102
|
# be called by Cache.early_initialization, where uses a standard cache
|
|
@@ -97,12 +119,25 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
97
119
|
|
|
98
120
|
value_states = value_states.to(self.values[0].dtype)
|
|
99
121
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
122
|
+
if self.k_ts_idx == 2:
|
|
123
|
+
key_states = key_states.reshape(
|
|
124
|
+
1, -1, seq_len, self.head_dim
|
|
125
|
+
) # 1, bk, s, h
|
|
126
|
+
else:
|
|
127
|
+
assert self.k_ts_idx == 3, "k_ts_idx must be 2 or 3."
|
|
128
|
+
key_states = key_states.permute(0, 1, 3, 2).reshape(
|
|
129
|
+
1, -1, self.head_dim, seq_len
|
|
130
|
+
) # 1, bk, h, s
|
|
131
|
+
|
|
132
|
+
if self.v_ts_idx == 2:
|
|
133
|
+
value_states = value_states.reshape(
|
|
134
|
+
1, -1, seq_len, self.head_dim
|
|
135
|
+
) # 1, bk, s, h
|
|
136
|
+
else:
|
|
137
|
+
assert self.v_ts_idx == 3, "v_ts_idx must be 2 or 3."
|
|
138
|
+
value_states = value_states.permute(0, 1, 3, 2).reshape(
|
|
139
|
+
1, -1, self.head_dim, seq_len
|
|
140
|
+
) # 1, bk, h, s
|
|
106
141
|
|
|
107
142
|
self.keys = (self.keys[0], key_states)
|
|
108
143
|
self.values = (self.values[0], value_states)
|
|
@@ -123,37 +158,68 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
123
158
|
|
|
124
159
|
@classmethod
|
|
125
160
|
def _infer_cache_shape_from_config(
|
|
126
|
-
cls,
|
|
161
|
+
cls,
|
|
162
|
+
model_config,
|
|
163
|
+
layer_index,
|
|
164
|
+
export_config: ExportableModuleConfig,
|
|
165
|
+
**kwargs,
|
|
127
166
|
):
|
|
128
167
|
"""Infers the KV cache shape from the model config."""
|
|
129
168
|
del layer_index # Unused.
|
|
169
|
+
del kwargs # Unused.
|
|
170
|
+
cache_length = export_config.cache_length
|
|
171
|
+
batch_size = export_config.batch_size
|
|
172
|
+
k_ts_idx = export_config.k_ts_idx
|
|
173
|
+
v_ts_idx = export_config.v_ts_idx
|
|
130
174
|
num_kv_heads = model_config.num_key_value_heads
|
|
131
175
|
embed_size_per_head = (
|
|
132
176
|
getattr(model_config, "head_dim", None)
|
|
133
177
|
or model_config.hidden_size // model_config.num_attention_heads
|
|
134
178
|
)
|
|
135
179
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
180
|
+
if k_ts_idx == 2:
|
|
181
|
+
k_cache_shape = (
|
|
182
|
+
1,
|
|
183
|
+
batch_size * num_kv_heads,
|
|
184
|
+
cache_length,
|
|
185
|
+
embed_size_per_head,
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
assert k_ts_idx == 3, "k_ts_idx must be 2 or 3."
|
|
189
|
+
k_cache_shape = (
|
|
190
|
+
1,
|
|
191
|
+
batch_size * num_kv_heads,
|
|
192
|
+
embed_size_per_head,
|
|
193
|
+
cache_length,
|
|
194
|
+
)
|
|
195
|
+
if v_ts_idx == 2:
|
|
196
|
+
v_cache_shape = (
|
|
197
|
+
1,
|
|
198
|
+
batch_size * num_kv_heads,
|
|
199
|
+
cache_length,
|
|
200
|
+
embed_size_per_head,
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
assert v_ts_idx == 3, "v_ts_idx must be 2 or 3."
|
|
204
|
+
v_cache_shape = (
|
|
205
|
+
1,
|
|
206
|
+
batch_size * num_kv_heads,
|
|
207
|
+
embed_size_per_head,
|
|
208
|
+
cache_length,
|
|
209
|
+
)
|
|
148
210
|
return k_cache_shape, v_cache_shape
|
|
149
211
|
|
|
150
212
|
@classmethod
|
|
151
213
|
def create_from_config(
|
|
152
|
-
cls,
|
|
214
|
+
cls,
|
|
215
|
+
model_config,
|
|
216
|
+
layer_index,
|
|
217
|
+
export_config: ExportableModuleConfig,
|
|
218
|
+
**kwargs,
|
|
153
219
|
) -> Self:
|
|
154
220
|
"""Creates a KV cache from the model config."""
|
|
155
221
|
k_cache_shape, v_cache_shape = cls._infer_cache_shape_from_config(
|
|
156
|
-
model_config, layer_index,
|
|
222
|
+
model_config, layer_index, export_config, **kwargs
|
|
157
223
|
)
|
|
158
224
|
keys = torch.zeros(k_cache_shape, dtype=torch.float32)
|
|
159
225
|
values = torch.zeros(v_cache_shape, dtype=torch.float32)
|
|
@@ -165,14 +231,22 @@ class LiteRTLMSplitCache(cache_base_lib.LiteRTLMCacheMixin):
|
|
|
165
231
|
"""Optimized Cache class for HuggingFace integration."""
|
|
166
232
|
|
|
167
233
|
@classmethod
|
|
168
|
-
def create_from_config(
|
|
234
|
+
def create_from_config(
|
|
235
|
+
cls,
|
|
236
|
+
model_config,
|
|
237
|
+
export_config: ExportableModuleConfig,
|
|
238
|
+
**kwargs,
|
|
239
|
+
) -> Self:
|
|
169
240
|
"""Creates a KV cache from the model config."""
|
|
170
241
|
num_layers = model_config.num_hidden_layers
|
|
171
242
|
layers = []
|
|
172
243
|
for layer_index in range(num_layers):
|
|
173
244
|
layers.append(
|
|
174
245
|
LiteRTLMSplitCacheLayer.create_from_config(
|
|
175
|
-
model_config,
|
|
246
|
+
model_config,
|
|
247
|
+
layer_index,
|
|
248
|
+
export_config,
|
|
249
|
+
**kwargs,
|
|
176
250
|
)
|
|
177
251
|
)
|
|
178
252
|
return cls(layers)
|
|
@@ -188,6 +262,8 @@ def _flatten_kvc_t(
|
|
|
188
262
|
layer_0 = kvc.layers[0]
|
|
189
263
|
assert isinstance(layer_0, cache_base_lib.LiteRTLMCacheLayerMixin)
|
|
190
264
|
batch_size = layer_0.get_batch_size()
|
|
265
|
+
k_ts_idx = layer_0.get_k_ts_idx()
|
|
266
|
+
v_ts_idx = layer_0.get_v_ts_idx()
|
|
191
267
|
for i, cache_layer in enumerate(kvc.layers):
|
|
192
268
|
flattened.append(cache_layer.keys[0])
|
|
193
269
|
flat_names.append(f"k_{i}")
|
|
@@ -199,16 +275,18 @@ def _flatten_kvc_t(
|
|
|
199
275
|
flat_names.append(f"k_{i}_slice")
|
|
200
276
|
flattened.append(cache_layer.values[1])
|
|
201
277
|
flat_names.append(f"v_{i}_slice")
|
|
202
|
-
return flattened, [flat_names, (batch_size, num_layers)]
|
|
278
|
+
return flattened, [flat_names, (batch_size, num_layers, k_ts_idx, v_ts_idx)]
|
|
203
279
|
|
|
204
280
|
|
|
205
281
|
def _unflatten_kvc_t(
|
|
206
282
|
values: List[torch.Tensor],
|
|
207
|
-
context: Tuple[List[str], Tuple[int, int]],
|
|
283
|
+
context: Tuple[List[str], Tuple[int, int, int, int]],
|
|
208
284
|
) -> LiteRTLMSplitCache:
|
|
209
285
|
"""Unflattens the KV cache from a list of tensors."""
|
|
210
286
|
flat_names = context[0]
|
|
211
287
|
batch_size = context[1][0]
|
|
288
|
+
k_ts_idx = context[1][2]
|
|
289
|
+
v_ts_idx = context[1][3]
|
|
212
290
|
num_layers = context[1][1]
|
|
213
291
|
kv_entries = []
|
|
214
292
|
for i in range(num_layers):
|
|
@@ -231,6 +309,8 @@ def _unflatten_kvc_t(
|
|
|
231
309
|
key_cache=(k_cache, k_cache_update),
|
|
232
310
|
value_cache=(v_cache, v_cache_update),
|
|
233
311
|
batch_size=batch_size,
|
|
312
|
+
k_ts_idx=k_ts_idx,
|
|
313
|
+
v_ts_idx=v_ts_idx,
|
|
234
314
|
)
|
|
235
315
|
)
|
|
236
316
|
obj = LiteRTLMSplitCache(kv_entries)
|
|
@@ -14,12 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Exportable module for split cache attention models."""
|
|
16
16
|
|
|
17
|
+
import copy
|
|
17
18
|
from litert_torch.generative.export_hf.core import cache as base_cache_lib
|
|
18
19
|
from litert_torch.generative.export_hf.core import exportable_module as base_exportable_module
|
|
19
20
|
from litert_torch.generative.export_hf.core import utils
|
|
20
21
|
from litert_torch.generative.export_hf.core.split_cache import attention_mask
|
|
21
22
|
from litert_torch.generative.export_hf.core.split_cache import cache as kv_cache_lib
|
|
22
|
-
import numpy as np
|
|
23
23
|
import torch
|
|
24
24
|
from torch import nn
|
|
25
25
|
|
|
@@ -64,9 +64,9 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLM(
|
|
|
64
64
|
ret['inputs_embeds'] = embeddings
|
|
65
65
|
|
|
66
66
|
ret.update({
|
|
67
|
-
'position_ids':
|
|
67
|
+
'position_ids': torch.arange(embeddings.shape[1])[None, :],
|
|
68
68
|
'past_key_values': kv_cache,
|
|
69
|
-
'cache_position':
|
|
69
|
+
'cache_position': torch.arange(embeddings.shape[1]),
|
|
70
70
|
'attention_mask': masks,
|
|
71
71
|
# Other common settings
|
|
72
72
|
'use_cache': True,
|
|
@@ -164,6 +164,7 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLMPrefill(
|
|
|
164
164
|
mask,
|
|
165
165
|
kv_cache,
|
|
166
166
|
)
|
|
167
|
+
inputs |= self.attention_kwargs()
|
|
167
168
|
output = self.model(**inputs)
|
|
168
169
|
output_cache = output.past_key_values
|
|
169
170
|
return self.post_process_kv_cache(output_cache)
|
|
@@ -171,9 +172,9 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLMPrefill(
|
|
|
171
172
|
def get_sample_inputs(
|
|
172
173
|
self,
|
|
173
174
|
model_config,
|
|
174
|
-
export_config: base_exportable_module.ExportableModuleConfig,
|
|
175
175
|
):
|
|
176
|
-
|
|
176
|
+
export_config = self.export_config
|
|
177
|
+
kv_cache_inputs, _ = self.get_sample_kv_cache(model_config)
|
|
177
178
|
|
|
178
179
|
sample_inputs = {}
|
|
179
180
|
for prefill_length in export_config.prefill_lengths:
|
|
@@ -207,6 +208,7 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLMGenerate(
|
|
|
207
208
|
mask,
|
|
208
209
|
kv_cache,
|
|
209
210
|
)
|
|
211
|
+
inputs |= self.attention_kwargs()
|
|
210
212
|
output = self.model(**inputs)
|
|
211
213
|
output_cache = output.past_key_values
|
|
212
214
|
ret = self.post_process_kv_cache(output_cache)
|
|
@@ -216,9 +218,9 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLMGenerate(
|
|
|
216
218
|
def get_sample_inputs(
|
|
217
219
|
self,
|
|
218
220
|
model_config,
|
|
219
|
-
export_config: base_exportable_module.ExportableModuleConfig,
|
|
220
221
|
):
|
|
221
|
-
|
|
222
|
+
export_config = self.export_config
|
|
223
|
+
kv_cache_inputs, _ = self.get_sample_kv_cache(model_config)
|
|
222
224
|
sample_inputs = {
|
|
223
225
|
**kv_cache_inputs,
|
|
224
226
|
**self._get_input(
|
|
@@ -322,13 +324,20 @@ class CacheUpdate(torch.nn.Module):
|
|
|
322
324
|
return {'kv_cache': kv_cache}
|
|
323
325
|
|
|
324
326
|
@classmethod
|
|
325
|
-
def _get_input(
|
|
327
|
+
def _get_input(
|
|
328
|
+
cls,
|
|
329
|
+
model_config,
|
|
330
|
+
input_length,
|
|
331
|
+
export_config: base_exportable_module.ExportableModuleConfig,
|
|
332
|
+
):
|
|
326
333
|
"""Gets sample inputs for the model."""
|
|
327
334
|
kv_cache = base_cache_lib.LiteRTLMCache.create_from_config(
|
|
328
|
-
model_config,
|
|
335
|
+
model_config, export_config
|
|
329
336
|
)
|
|
337
|
+
slice_export_config = copy.deepcopy(export_config)
|
|
338
|
+
slice_export_config.cache_length = input_length
|
|
330
339
|
kv_slice = base_cache_lib.LiteRTLMCache.create_from_config(
|
|
331
|
-
model_config,
|
|
340
|
+
model_config, slice_export_config
|
|
332
341
|
)
|
|
333
342
|
return {
|
|
334
343
|
'kv_cache': kv_cache,
|
|
@@ -348,15 +357,13 @@ class CacheUpdate(torch.nn.Module):
|
|
|
348
357
|
inputs = cls._get_input(
|
|
349
358
|
model_config,
|
|
350
359
|
prefill_length,
|
|
351
|
-
export_config
|
|
352
|
-
export_config.batch_size,
|
|
360
|
+
export_config,
|
|
353
361
|
)
|
|
354
362
|
sample_inputs[f'prefill_cache_update_{prefill_length}'] = (inputs, {})
|
|
355
363
|
decode_inputs = cls._get_input(
|
|
356
364
|
model_config,
|
|
357
365
|
1,
|
|
358
|
-
export_config
|
|
359
|
-
export_config.batch_size,
|
|
366
|
+
export_config,
|
|
360
367
|
)
|
|
361
368
|
sample_inputs['decode_cache_update'] = (decode_inputs, {})
|
|
362
369
|
return sample_inputs
|
|
@@ -30,7 +30,10 @@ def export(
|
|
|
30
30
|
cache_length=4096,
|
|
31
31
|
quantization_recipe: str = 'dynamic_wi8_afp32',
|
|
32
32
|
enable_dynamic_shape: bool = False,
|
|
33
|
-
|
|
33
|
+
externalize_embedder: bool = False,
|
|
34
|
+
key_ts_idx: int = 2,
|
|
35
|
+
value_ts_idx: int = 3,
|
|
36
|
+
split_cache: bool = False,
|
|
34
37
|
auto_model_override: str | None = None,
|
|
35
38
|
# target_accelerator: str | None = None,
|
|
36
39
|
trust_remote_code: bool = False,
|
|
@@ -46,6 +49,8 @@ def export(
|
|
|
46
49
|
auto_model_override=auto_model_override,
|
|
47
50
|
)
|
|
48
51
|
del config # Unused.
|
|
52
|
+
if split_cache and not externalize_embedder:
|
|
53
|
+
raise ValueError('Split cache requires externalize embedder to be enabled.')
|
|
49
54
|
export_config = exportable_module.ExportableModuleConfig(
|
|
50
55
|
batch_size=1,
|
|
51
56
|
prefill_lengths=prefill_lengths,
|
|
@@ -56,17 +61,45 @@ def export(
|
|
|
56
61
|
cache_length_dim=torch.export.Dim('cache_length')
|
|
57
62
|
if enable_dynamic_shape
|
|
58
63
|
else None,
|
|
59
|
-
externalize_embedder=
|
|
64
|
+
externalize_embedder=externalize_embedder,
|
|
65
|
+
k_ts_idx=key_ts_idx,
|
|
66
|
+
v_ts_idx=value_ts_idx,
|
|
67
|
+
split_cache=split_cache,
|
|
68
|
+
externalize_rope=split_cache,
|
|
69
|
+
cache_implementation='LiteRTLMSplitCache'
|
|
70
|
+
if split_cache
|
|
71
|
+
else 'LiteRTLMCache',
|
|
60
72
|
)
|
|
61
73
|
export_lib.export_text_prefill_decode_model(
|
|
62
74
|
pt_model, text_model_config, export_config, work_dir, quantization_recipe
|
|
63
75
|
)
|
|
64
76
|
gc.collect()
|
|
77
|
+
if externalize_embedder:
|
|
78
|
+
export_lib.export_embedder_model(
|
|
79
|
+
pt_model,
|
|
80
|
+
text_model_config,
|
|
81
|
+
export_config,
|
|
82
|
+
work_dir,
|
|
83
|
+
quantization_recipe,
|
|
84
|
+
)
|
|
85
|
+
gc.collect()
|
|
86
|
+
if split_cache:
|
|
87
|
+
export_lib.export_auxiliary_model(
|
|
88
|
+
pt_model,
|
|
89
|
+
text_model_config,
|
|
90
|
+
export_config,
|
|
91
|
+
work_dir,
|
|
92
|
+
quantization_recipe,
|
|
93
|
+
)
|
|
94
|
+
gc.collect()
|
|
65
95
|
tokenizer_model_path = export_lib.export_tokenizer(tokenizer, work_dir)
|
|
66
96
|
tflite_model_path = os.path.join(
|
|
67
97
|
work_dir,
|
|
68
98
|
'model_quantized.tflite' if quantization_recipe else 'model.tflite',
|
|
69
99
|
)
|
|
100
|
+
if externalize_embedder or split_cache:
|
|
101
|
+
# TODO(weiyiw): Add support for packaging models.
|
|
102
|
+
return
|
|
70
103
|
litert_lm_builder.package_model(
|
|
71
104
|
pt_model,
|
|
72
105
|
tokenizer,
|
litert_torch/version.py
CHANGED