litert-torch-nightly 0.9.0.dev20260202__py3-none-any.whl → 0.9.0.dev20260203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- litert_torch/generative/export_hf/core/attention.py +86 -8
- litert_torch/generative/export_hf/core/attention_test.py +7 -2
- litert_torch/generative/export_hf/core/cache.py +112 -64
- litert_torch/generative/export_hf/core/cache_base.py +19 -2
- litert_torch/generative/export_hf/core/export_lib.py +55 -6
- litert_torch/generative/export_hf/core/exportable_module.py +30 -34
- litert_torch/generative/export_hf/core/exportable_module_config.py +39 -0
- litert_torch/generative/export_hf/core/split_cache/attention.py +28 -5
- litert_torch/generative/export_hf/core/split_cache/cache.py +113 -33
- litert_torch/generative/export_hf/core/split_cache/exportable_module.py +21 -14
- litert_torch/generative/export_hf/export.py +35 -2
- litert_torch/version.py +1 -1
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/METADATA +1 -1
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/RECORD +18 -17
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/WHEEL +0 -0
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/entry_points.txt +0 -0
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/licenses/LICENSE +0 -0
- {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/top_level.txt +0 -0
|
@@ -14,13 +14,83 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Optimized Attention layer for HuggingFace integration."""
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
from
|
|
17
|
+
import math
|
|
18
|
+
from typing import Optional
|
|
19
19
|
import jaxtyping as jt
|
|
20
|
+
from litert_torch.generative.custom_ops import bmm_4d as bmm_lib
|
|
20
21
|
import torch
|
|
22
|
+
import torch.nn.functional as F
|
|
21
23
|
import transformers
|
|
22
24
|
|
|
23
25
|
|
|
26
|
+
def scaled_dot_product_attention_transposed(
|
|
27
|
+
query: torch.Tensor,
|
|
28
|
+
key: torch.Tensor,
|
|
29
|
+
value: torch.Tensor,
|
|
30
|
+
head_size: int,
|
|
31
|
+
k_ts_idx: int,
|
|
32
|
+
v_ts_idx: int,
|
|
33
|
+
mask: Optional[torch.Tensor] = None,
|
|
34
|
+
scale: Optional[float] = None,
|
|
35
|
+
softcap: Optional[float] = None,
|
|
36
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
|
37
|
+
):
|
|
38
|
+
"""Scaled dot product attention with transposed key and value.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
query: Query tensor, with shape [B, T, N, H].
|
|
42
|
+
key: Key tensor, with shape [B, T, KV_LEN, H].
|
|
43
|
+
value: Value tensor, with shape [B, T, H, KV_LEN].
|
|
44
|
+
head_size (int): head dimension.
|
|
45
|
+
mask (torch.Tensor): the optional mask tensor.
|
|
46
|
+
scale (float): the optional scale factor.
|
|
47
|
+
softcap (float): the optional softcap for the logits.
|
|
48
|
+
alibi_bias (torch.Tensor): optional alibi bias tensor.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The output tensor of scaled_dot_product_attention_transposed.
|
|
52
|
+
"""
|
|
53
|
+
if scale is None:
|
|
54
|
+
scale = 1.0 / math.sqrt(head_size)
|
|
55
|
+
|
|
56
|
+
if alibi_bias is not None:
|
|
57
|
+
alibi_bias = alibi_bias * scale
|
|
58
|
+
if mask is None:
|
|
59
|
+
mask = alibi_bias
|
|
60
|
+
else:
|
|
61
|
+
mask = mask + alibi_bias
|
|
62
|
+
|
|
63
|
+
query = query * scale
|
|
64
|
+
|
|
65
|
+
assert mask is not None, "Mask should not be None!"
|
|
66
|
+
t = mask.shape[2]
|
|
67
|
+
if k_ts_idx == 2:
|
|
68
|
+
bmm_fn = bmm_lib.bmm_4d
|
|
69
|
+
else:
|
|
70
|
+
assert k_ts_idx == 3, "k_ts_idx must be 2 or 3."
|
|
71
|
+
bmm_fn = lambda x, y: torch.einsum("abth,abhs->abts", x, y)
|
|
72
|
+
logits = bmm_fn(query, key)
|
|
73
|
+
|
|
74
|
+
_, bk, gt, s = logits.shape
|
|
75
|
+
g = gt // t
|
|
76
|
+
logits = logits.reshape((bk, g, t, s))
|
|
77
|
+
if softcap is not None:
|
|
78
|
+
logits = torch.tanh(logits / softcap)
|
|
79
|
+
logits = logits * softcap
|
|
80
|
+
|
|
81
|
+
padded_logits = logits + mask
|
|
82
|
+
padded_logits = padded_logits.reshape(1, bk, gt, s)
|
|
83
|
+
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
|
84
|
+
if v_ts_idx == 3:
|
|
85
|
+
bmm_fn = bmm_lib.bmm_4d
|
|
86
|
+
else:
|
|
87
|
+
assert v_ts_idx == 2, "v_ts_idx must be 2 or 3."
|
|
88
|
+
bmm_fn = lambda x, y: torch.einsum("abts,absh->abth", x, y)
|
|
89
|
+
encoded = bmm_fn(probs, value)
|
|
90
|
+
|
|
91
|
+
return encoded # 1, bk, gt, h
|
|
92
|
+
|
|
93
|
+
|
|
24
94
|
def transposed_attention(
|
|
25
95
|
module: torch.nn.Module,
|
|
26
96
|
query: jt.Float[torch.Tensor, "b n t h"],
|
|
@@ -46,20 +116,28 @@ def transposed_attention(
|
|
|
46
116
|
Returns:
|
|
47
117
|
The attention output tensor.
|
|
48
118
|
"""
|
|
49
|
-
del kwargs # Unused in this implementation but required by the interface.
|
|
50
119
|
|
|
51
120
|
b, n, seq_len, h = query.shape
|
|
52
121
|
g = getattr(module, "num_key_value_groups", 1)
|
|
53
122
|
num_query_groups = n // g
|
|
54
123
|
# bnth -> b(kg)th -> 1(bk)(gt)h
|
|
55
124
|
query = query.reshape(1, b * num_query_groups, g * seq_len, h)
|
|
125
|
+
key_ts_idx: int | None = kwargs.get("k_ts_idx", None)
|
|
126
|
+
value_ts_idx: int | None = kwargs.get("v_ts_idx", None)
|
|
127
|
+
if key_ts_idx is None or value_ts_idx is None:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
"Timestamp indices not passed to attention module. The model is not"
|
|
130
|
+
" passing the kwargs correctly."
|
|
131
|
+
)
|
|
56
132
|
|
|
57
133
|
# 1, bk, gt, h
|
|
58
|
-
sdpa_out =
|
|
59
|
-
query,
|
|
60
|
-
key,
|
|
61
|
-
value,
|
|
62
|
-
h,
|
|
134
|
+
sdpa_out = scaled_dot_product_attention_transposed(
|
|
135
|
+
query=query,
|
|
136
|
+
key=key,
|
|
137
|
+
value=value,
|
|
138
|
+
head_size=h,
|
|
139
|
+
k_ts_idx=key_ts_idx,
|
|
140
|
+
v_ts_idx=value_ts_idx,
|
|
63
141
|
mask=attention_mask,
|
|
64
142
|
scale=scaling,
|
|
65
143
|
softcap=softcap,
|
|
@@ -71,7 +71,7 @@ class DummyAttentionModule(torch.nn.Module):
|
|
|
71
71
|
self.scaling = scaling
|
|
72
72
|
self.softcap = softcap
|
|
73
73
|
|
|
74
|
-
def forward(self, query, key, value, attention_mask):
|
|
74
|
+
def forward(self, query, key, value, attention_mask, **kwargs):
|
|
75
75
|
attention_interface = modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
76
76
|
self.attention_implementation
|
|
77
77
|
]
|
|
@@ -84,6 +84,7 @@ class DummyAttentionModule(torch.nn.Module):
|
|
|
84
84
|
attention_mask,
|
|
85
85
|
scaling=self.scaling,
|
|
86
86
|
softcap=self.softcap,
|
|
87
|
+
**kwargs,
|
|
87
88
|
)[0]
|
|
88
89
|
|
|
89
90
|
|
|
@@ -139,8 +140,12 @@ class AttentionTest(parameterized.TestCase):
|
|
|
139
140
|
scaling=scl,
|
|
140
141
|
softcap=scp,
|
|
141
142
|
)
|
|
143
|
+
attention_kwargs = {
|
|
144
|
+
'k_ts_idx': 2,
|
|
145
|
+
'v_ts_idx': 3,
|
|
146
|
+
}
|
|
142
147
|
expected = attn(query, key, value, mask)
|
|
143
|
-
actual = test_attn(query, key, value, mask)
|
|
148
|
+
actual = test_attn(query, key, value, mask, **attention_kwargs)
|
|
144
149
|
self.assertTrue(
|
|
145
150
|
torch.allclose(
|
|
146
151
|
expected, actual, rtol=1e-2, atol=1e-2, equal_nan=True
|
|
@@ -25,18 +25,30 @@ Shape annotations used here:
|
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
27
|
from typing import Any, List, Optional, Tuple
|
|
28
|
+
|
|
29
|
+
import jaxtyping as jt
|
|
28
30
|
import litert_torch.generative.custom_ops.dynamic_update_slice as tfl_dus
|
|
31
|
+
from litert_torch.generative.export_hf.core import exportable_module_config
|
|
29
32
|
import litert_torch.generative.export_hf.core.cache_base as cache_base_lib
|
|
30
|
-
import jaxtyping as jt
|
|
31
33
|
import torch
|
|
32
34
|
import torch.utils._pytree as pytree
|
|
33
35
|
|
|
36
|
+
ExportableModuleConfig = exportable_module_config.ExportableModuleConfig
|
|
37
|
+
|
|
34
38
|
|
|
35
39
|
# Shape annotations for the cache entries.
|
|
36
|
-
KeyCache =
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
+
KeyCache = (
|
|
41
|
+
jt.Shaped[torch.Tensor, "1 BK S H"] | jt.Shaped[torch.Tensor, "1 BK H S"]
|
|
42
|
+
)
|
|
43
|
+
KeySlice = (
|
|
44
|
+
jt.Shaped[torch.Tensor, "1 BK T H"] | jt.Shaped[torch.Tensor, "1 BK H T"]
|
|
45
|
+
)
|
|
46
|
+
ValueCache = (
|
|
47
|
+
jt.Shaped[torch.Tensor, "1 BK H S"] | jt.Shaped[torch.Tensor, "1 BK S H"]
|
|
48
|
+
)
|
|
49
|
+
ValueSlice = (
|
|
50
|
+
jt.Shaped[torch.Tensor, "1 BK H T"] | jt.Shaped[torch.Tensor, "1 BK T H"]
|
|
51
|
+
)
|
|
40
52
|
|
|
41
53
|
|
|
42
54
|
def _get_slice_indices(
|
|
@@ -77,15 +89,11 @@ def _update_kv_impl(
|
|
|
77
89
|
k_slice: KeySlice,
|
|
78
90
|
v_slice: ValueSlice,
|
|
79
91
|
cache_position: jt.Int32[torch.Tensor, "T"],
|
|
80
|
-
|
|
92
|
+
k_ts_idx: int,
|
|
93
|
+
v_ts_idx: int,
|
|
81
94
|
):
|
|
82
95
|
"""Updates the cache buffer using tfl.dynamic_update_slice."""
|
|
83
96
|
cache_dim = 4
|
|
84
|
-
k_ts_idx = 2 # K Cache shape is 1 BK S H
|
|
85
|
-
v_ts_idx = 3 # V Cache shape is 1 BK H S
|
|
86
|
-
if reverse_kv:
|
|
87
|
-
k_ts_idx = 3 # K Cache shape is 1 BK H S
|
|
88
|
-
v_ts_idx = 2 # V Cache shape is 1 BK S H
|
|
89
97
|
positions = cache_position[0] # The position of the first input token.
|
|
90
98
|
k_slice_indices = _get_slice_indices(positions.clone(), cache_dim, k_ts_idx)
|
|
91
99
|
v_slice_indices = _get_slice_indices(positions.clone(), cache_dim, v_ts_idx)
|
|
@@ -109,27 +117,26 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
109
117
|
key_cache: KeyCache,
|
|
110
118
|
value_cache: ValueCache,
|
|
111
119
|
batch_size: int = 1,
|
|
112
|
-
|
|
120
|
+
k_ts_idx: int = 2,
|
|
121
|
+
v_ts_idx: int = 3,
|
|
113
122
|
**kwargs,
|
|
114
123
|
):
|
|
115
124
|
super().__init__()
|
|
116
125
|
self.keys = key_cache
|
|
117
126
|
self.values = value_cache
|
|
118
|
-
self.
|
|
127
|
+
self.k_ts_idx = k_ts_idx # The index of the sequence dimension in K cache.
|
|
128
|
+
self.v_ts_idx = v_ts_idx # The index of the sequence dimension in V cache.
|
|
129
|
+
assert k_ts_idx in [2, 3]
|
|
130
|
+
assert v_ts_idx in [2, 3]
|
|
119
131
|
self.is_initialized = True
|
|
120
132
|
|
|
121
133
|
self.k_cache_shape = self.keys.shape
|
|
122
134
|
self.v_cache_shape = self.values.shape
|
|
123
|
-
self.max_cache_len =
|
|
124
|
-
self.v_cache_shape[2] if reverse_kv else self.k_cache_shape[2]
|
|
125
|
-
)
|
|
135
|
+
self.max_cache_len = self.v_cache_shape[self.v_ts_idx]
|
|
126
136
|
self.batch_size = batch_size
|
|
127
|
-
self.
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
self.head_dim = (
|
|
131
|
-
self.v_cache_shape[3] if reverse_kv else self.k_cache_shape[3]
|
|
132
|
-
)
|
|
137
|
+
v_head_dim_idx = 3 if self.v_ts_idx == 2 else 2
|
|
138
|
+
self.head_dim = self.v_cache_shape[v_head_dim_idx]
|
|
139
|
+
|
|
133
140
|
self.additional_states = kwargs.get("additional_states", None)
|
|
134
141
|
|
|
135
142
|
self.cumulative_length = 0
|
|
@@ -137,6 +144,12 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
137
144
|
def get_batch_size(self) -> int:
|
|
138
145
|
return self.batch_size
|
|
139
146
|
|
|
147
|
+
def get_k_ts_idx(self) -> int:
|
|
148
|
+
return self.k_ts_idx
|
|
149
|
+
|
|
150
|
+
def get_v_ts_idx(self) -> int:
|
|
151
|
+
return self.v_ts_idx
|
|
152
|
+
|
|
140
153
|
def lazy_initialization(self, key_states: torch.Tensor):
|
|
141
154
|
# Since we don't support real lazy initialization, this function could only
|
|
142
155
|
# be called by Cache.early_initialization, where uses a standard cache
|
|
@@ -162,13 +175,24 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
162
175
|
value_states = value_states.to(self.values.dtype)
|
|
163
176
|
|
|
164
177
|
if not cache_kwargs.get("kv_slice_preprocessed", False):
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
178
|
+
if self.k_ts_idx == 3:
|
|
179
|
+
key_target_shape = (1, -1, self.head_dim, seq_len)
|
|
180
|
+
key_states = key_states.permute(0, 1, 3, 2).reshape(*key_target_shape)
|
|
181
|
+
elif self.k_ts_idx == 2:
|
|
182
|
+
key_target_shape = (1, -1, seq_len, self.head_dim)
|
|
183
|
+
key_states = key_states.reshape(*key_target_shape)
|
|
184
|
+
else:
|
|
185
|
+
raise ValueError(f"Unsupported k_ts_idx: {self.k_ts_idx}")
|
|
186
|
+
if self.v_ts_idx == 3:
|
|
187
|
+
value_target_shape = (1, -1, self.head_dim, seq_len)
|
|
188
|
+
value_states = value_states.permute(0, 1, 3, 2).reshape(
|
|
189
|
+
*value_target_shape
|
|
190
|
+
)
|
|
191
|
+
elif self.v_ts_idx == 2:
|
|
192
|
+
value_target_shape = (1, -1, seq_len, self.head_dim)
|
|
193
|
+
value_states = value_states.reshape(*value_target_shape)
|
|
194
|
+
else:
|
|
195
|
+
raise ValueError(f"Unsupported v_ts_idx: {self.v_ts_idx}")
|
|
172
196
|
|
|
173
197
|
cache_position: jt.Int32[torch.Tensor, "T"] = cache_kwargs.get(
|
|
174
198
|
"cache_position"
|
|
@@ -182,7 +206,8 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
182
206
|
key_states,
|
|
183
207
|
value_states,
|
|
184
208
|
cache_position,
|
|
185
|
-
self.
|
|
209
|
+
self.k_ts_idx,
|
|
210
|
+
self.v_ts_idx,
|
|
186
211
|
)
|
|
187
212
|
return self.keys, self.values
|
|
188
213
|
|
|
@@ -203,32 +228,52 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
203
228
|
cls,
|
|
204
229
|
model_config,
|
|
205
230
|
layer_index,
|
|
206
|
-
|
|
207
|
-
batch_size=1,
|
|
208
|
-
reverse_kv=False,
|
|
231
|
+
export_config: ExportableModuleConfig,
|
|
209
232
|
):
|
|
210
233
|
"""Infers the KV cache shape from the model config."""
|
|
211
234
|
del layer_index # Unused.
|
|
235
|
+
cache_length = export_config.cache_length
|
|
236
|
+
batch_size = export_config.batch_size
|
|
237
|
+
k_ts_idx = export_config.k_ts_idx
|
|
238
|
+
v_ts_idx = export_config.v_ts_idx
|
|
212
239
|
num_kv_heads = model_config.num_key_value_heads
|
|
213
240
|
embed_size_per_head = (
|
|
214
241
|
getattr(model_config, "head_dim", None)
|
|
215
242
|
or model_config.hidden_size // model_config.num_attention_heads
|
|
216
243
|
)
|
|
217
244
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
245
|
+
if k_ts_idx == 2:
|
|
246
|
+
k_cache_shape = (
|
|
247
|
+
1,
|
|
248
|
+
batch_size * num_kv_heads,
|
|
249
|
+
cache_length,
|
|
250
|
+
embed_size_per_head,
|
|
251
|
+
)
|
|
252
|
+
elif k_ts_idx == 3:
|
|
253
|
+
k_cache_shape = (
|
|
254
|
+
1,
|
|
255
|
+
batch_size * num_kv_heads,
|
|
256
|
+
embed_size_per_head,
|
|
257
|
+
cache_length,
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
raise ValueError(f"Unsupported k_ts_idx: {k_ts_idx}")
|
|
261
|
+
if v_ts_idx == 2:
|
|
262
|
+
v_cache_shape = (
|
|
263
|
+
1,
|
|
264
|
+
batch_size * num_kv_heads,
|
|
265
|
+
cache_length,
|
|
266
|
+
embed_size_per_head,
|
|
267
|
+
)
|
|
268
|
+
elif v_ts_idx == 3:
|
|
269
|
+
v_cache_shape = (
|
|
270
|
+
1,
|
|
271
|
+
batch_size * num_kv_heads,
|
|
272
|
+
embed_size_per_head,
|
|
273
|
+
cache_length,
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError(f"Unsupported v_ts_idx: {v_ts_idx}")
|
|
232
277
|
return k_cache_shape, v_cache_shape
|
|
233
278
|
|
|
234
279
|
@classmethod
|
|
@@ -236,18 +281,22 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
|
|
|
236
281
|
cls,
|
|
237
282
|
model_config,
|
|
238
283
|
layer_index,
|
|
239
|
-
|
|
240
|
-
batch_size=1,
|
|
241
|
-
reverse_kv=False,
|
|
284
|
+
export_config: ExportableModuleConfig,
|
|
242
285
|
**kwargs,
|
|
243
286
|
) -> "LiteRTLMCacheLayer":
|
|
244
287
|
"""Creates a KV cache from the model config."""
|
|
245
288
|
k_cache_shape, v_cache_shape = cls._infer_cache_shape_from_config(
|
|
246
|
-
model_config, layer_index,
|
|
289
|
+
model_config, layer_index, export_config
|
|
247
290
|
)
|
|
248
291
|
keys = torch.zeros(k_cache_shape, dtype=torch.float32)
|
|
249
292
|
values = torch.zeros(v_cache_shape, dtype=torch.float32)
|
|
250
|
-
return cls(
|
|
293
|
+
return cls(
|
|
294
|
+
keys,
|
|
295
|
+
values,
|
|
296
|
+
k_ts_idx=export_config.k_ts_idx,
|
|
297
|
+
v_ts_idx=export_config.v_ts_idx,
|
|
298
|
+
**kwargs,
|
|
299
|
+
)
|
|
251
300
|
|
|
252
301
|
|
|
253
302
|
@cache_base_lib.register_cache_implementation
|
|
@@ -258,9 +307,7 @@ class LiteRTLMCache(cache_base_lib.LiteRTLMCacheMixin):
|
|
|
258
307
|
def create_from_config(
|
|
259
308
|
cls,
|
|
260
309
|
model_config,
|
|
261
|
-
|
|
262
|
-
batch_size=1,
|
|
263
|
-
reverse_kv=False,
|
|
310
|
+
export_config: ExportableModuleConfig,
|
|
264
311
|
**kwargs,
|
|
265
312
|
) -> "LiteRTLMCache":
|
|
266
313
|
"""Creates a KV cache from the model config."""
|
|
@@ -271,9 +318,8 @@ class LiteRTLMCache(cache_base_lib.LiteRTLMCacheMixin):
|
|
|
271
318
|
LiteRTLMCacheLayer.create_from_config(
|
|
272
319
|
model_config,
|
|
273
320
|
layer_index,
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
reverse_kv=reverse_kv,
|
|
321
|
+
export_config,
|
|
322
|
+
**kwargs,
|
|
277
323
|
)
|
|
278
324
|
)
|
|
279
325
|
return cls(layers)
|
|
@@ -281,7 +327,7 @@ class LiteRTLMCache(cache_base_lib.LiteRTLMCacheMixin):
|
|
|
281
327
|
|
|
282
328
|
def _flatten_kvc_t(
|
|
283
329
|
kvc: LiteRTLMCache,
|
|
284
|
-
) -> Tuple[List[torch.Tensor], Tuple[List[str], Tuple[int, int,
|
|
330
|
+
) -> Tuple[List[torch.Tensor], Tuple[List[str], Tuple[int, int, int, int]]]:
|
|
285
331
|
"""Flattens the cache into a list of tensors."""
|
|
286
332
|
flattened = []
|
|
287
333
|
flat_names = []
|
|
@@ -289,22 +335,23 @@ def _flatten_kvc_t(
|
|
|
289
335
|
layer_0 = kvc.layers[0]
|
|
290
336
|
assert isinstance(layer_0, cache_base_lib.LiteRTLMCacheLayerMixin)
|
|
291
337
|
batch_size = layer_0.get_batch_size()
|
|
292
|
-
|
|
338
|
+
k_ts_idx = layer_0.get_k_ts_idx()
|
|
339
|
+
v_ts_idx = layer_0.get_v_ts_idx()
|
|
293
340
|
for i, layer in enumerate(kvc.layers):
|
|
294
341
|
flattened.append(layer.keys)
|
|
295
342
|
flat_names.append(f"k_{i}")
|
|
296
343
|
flattened.append(layer.values)
|
|
297
344
|
flat_names.append(f"v_{i}")
|
|
298
|
-
return flattened, (flat_names, (batch_size, num_layers,
|
|
345
|
+
return flattened, (flat_names, (batch_size, num_layers, k_ts_idx, v_ts_idx))
|
|
299
346
|
|
|
300
347
|
|
|
301
348
|
def _unflatten_kvc_t(
|
|
302
349
|
values: List[torch.Tensor],
|
|
303
|
-
context: Tuple[List[str], Tuple[int, int,
|
|
350
|
+
context: Tuple[List[str], Tuple[int, int, int, int]],
|
|
304
351
|
) -> LiteRTLMCache:
|
|
305
352
|
"""Unflattens the cache from a list of tensors."""
|
|
306
353
|
flat_names = context[0]
|
|
307
|
-
batch_size, num_layers,
|
|
354
|
+
batch_size, num_layers, k_ts_idx, v_ts_idx = context[1]
|
|
308
355
|
layers = []
|
|
309
356
|
for i in range(num_layers):
|
|
310
357
|
k_cache_idx = flat_names.index(f"k_{i}")
|
|
@@ -314,7 +361,8 @@ def _unflatten_kvc_t(
|
|
|
314
361
|
key_cache=values[k_cache_idx],
|
|
315
362
|
value_cache=values[v_cache_idx],
|
|
316
363
|
batch_size=batch_size,
|
|
317
|
-
|
|
364
|
+
k_ts_idx=k_ts_idx,
|
|
365
|
+
v_ts_idx=v_ts_idx,
|
|
318
366
|
)
|
|
319
367
|
)
|
|
320
368
|
obj = LiteRTLMCache(layers)
|
|
@@ -15,8 +15,11 @@
|
|
|
15
15
|
"""Base class for cache."""
|
|
16
16
|
|
|
17
17
|
import abc
|
|
18
|
+
from litert_torch.generative.export_hf.core import exportable_module_config
|
|
18
19
|
from transformers import cache_utils
|
|
19
20
|
|
|
21
|
+
ExportableModuleConfig = exportable_module_config.ExportableModuleConfig
|
|
22
|
+
|
|
20
23
|
|
|
21
24
|
class LiteRTLMCacheLayerMixin(cache_utils.CacheLayerMixin, abc.ABC):
|
|
22
25
|
"""Optimized Cache layer class mixin for HuggingFace integration."""
|
|
@@ -26,10 +29,24 @@ class LiteRTLMCacheLayerMixin(cache_utils.CacheLayerMixin, abc.ABC):
|
|
|
26
29
|
"""Returns the batch size of the cache."""
|
|
27
30
|
...
|
|
28
31
|
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def get_k_ts_idx(self) -> int:
|
|
34
|
+
"""Returns the index of the sequence dimension in K cache."""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
@abc.abstractmethod
|
|
38
|
+
def get_v_ts_idx(self) -> int:
|
|
39
|
+
"""Returns the index of the sequence dimension in V cache."""
|
|
40
|
+
...
|
|
41
|
+
|
|
29
42
|
@classmethod
|
|
30
43
|
@abc.abstractmethod
|
|
31
44
|
def create_from_config(
|
|
32
|
-
cls,
|
|
45
|
+
cls,
|
|
46
|
+
model_config,
|
|
47
|
+
layer_index,
|
|
48
|
+
export_config: ExportableModuleConfig,
|
|
49
|
+
**kwargs
|
|
33
50
|
) -> "LiteRTLMCacheLayerMixin":
|
|
34
51
|
...
|
|
35
52
|
|
|
@@ -40,7 +57,7 @@ class LiteRTLMCacheMixin(cache_utils.Cache, abc.ABC):
|
|
|
40
57
|
@classmethod
|
|
41
58
|
@abc.abstractmethod
|
|
42
59
|
def create_from_config(
|
|
43
|
-
cls, model_config,
|
|
60
|
+
cls, model_config, export_config: ExportableModuleConfig, **kwargs
|
|
44
61
|
) -> "LiteRTLMCacheMixin":
|
|
45
62
|
"""Creates a KV cache from the model config."""
|
|
46
63
|
...
|
|
@@ -26,6 +26,7 @@ from litert_torch.generative.export_hf.core import exportable_module
|
|
|
26
26
|
from litert_torch.generative.export_hf.core import patches as _
|
|
27
27
|
from litert_torch.generative.export_hf.core import utils
|
|
28
28
|
from litert_torch.generative.export_hf.core.external_emb import exportable_module as external_emb_module
|
|
29
|
+
from litert_torch.generative.export_hf.core.external_rope import exportable_module as external_rope_module
|
|
29
30
|
from litert_torch.generative.export_hf.core.external_rope import preprocess_model as external_rope_preprocess_model
|
|
30
31
|
from litert_torch.generative.export_hf.core.mu import mu_pass_lib
|
|
31
32
|
from litert_torch.generative.export_hf.core.split_cache import attention as _
|
|
@@ -34,6 +35,7 @@ from litert_torch.generative.tools import tokenizer_to_sentencepiece_lib as toke
|
|
|
34
35
|
from litert_torch.odml_torch.experimental import torch_tfl
|
|
35
36
|
import torch
|
|
36
37
|
import transformers
|
|
38
|
+
|
|
37
39
|
from ai_edge_quantizer import quantizer as quantizer_lib
|
|
38
40
|
from ai_edge_quantizer import recipe as recipe_lib
|
|
39
41
|
|
|
@@ -174,12 +176,10 @@ def export_text_prefill_decode_model(
|
|
|
174
176
|
prefill_module_cls, decode_module_cls = get_prefill_decode_exportable_cls(
|
|
175
177
|
export_config
|
|
176
178
|
)
|
|
177
|
-
prefill_module = prefill_module_cls(model)
|
|
178
|
-
decode_module = decode_module_cls(model)
|
|
179
|
+
prefill_module = prefill_module_cls(model, export_config)
|
|
180
|
+
decode_module = decode_module_cls(model, export_config)
|
|
179
181
|
converter = converter_utils.Converter()
|
|
180
|
-
sample_prefill_inputs = prefill_module.get_sample_inputs(
|
|
181
|
-
text_model_config, export_config
|
|
182
|
-
)
|
|
182
|
+
sample_prefill_inputs = prefill_module.get_sample_inputs(text_model_config)
|
|
183
183
|
for signature_name, (
|
|
184
184
|
sample_prefill_inputs,
|
|
185
185
|
prefill_dynamic_shapes,
|
|
@@ -213,7 +213,7 @@ def export_text_prefill_decode_model(
|
|
|
213
213
|
sample_kwargs=sample_prefill_inputs,
|
|
214
214
|
)
|
|
215
215
|
sample_decode_inputs, decode_dynamic_shapes = decode_module.get_sample_inputs(
|
|
216
|
-
text_model_config
|
|
216
|
+
text_model_config
|
|
217
217
|
)['decode']
|
|
218
218
|
if has_dynamic_shape:
|
|
219
219
|
print('Exporting decode_module...')
|
|
@@ -337,6 +337,55 @@ def export_embedder_model(
|
|
|
337
337
|
return model_path
|
|
338
338
|
|
|
339
339
|
|
|
340
|
+
def export_auxiliary_model(
|
|
341
|
+
model,
|
|
342
|
+
text_model_config,
|
|
343
|
+
export_config: exportable_module.ExportableModuleConfig,
|
|
344
|
+
work_dir: str,
|
|
345
|
+
quantization_recipe: str | None = None,
|
|
346
|
+
):
|
|
347
|
+
"""Exports auxiliary model."""
|
|
348
|
+
del quantization_recipe # Unused.
|
|
349
|
+
converter = converter_utils.Converter()
|
|
350
|
+
# RoPE
|
|
351
|
+
rope_module = external_rope_module.RoPEEmbedder(model)
|
|
352
|
+
sample_inputs = rope_module.get_sample_inputs(
|
|
353
|
+
text_model_config, export_config
|
|
354
|
+
)
|
|
355
|
+
for signature_name, (sample_input, _) in sample_inputs.items():
|
|
356
|
+
converter.add_signature(
|
|
357
|
+
signature_name,
|
|
358
|
+
rope_module.eval(),
|
|
359
|
+
sample_kwargs=sample_input,
|
|
360
|
+
)
|
|
361
|
+
# Attention Mask
|
|
362
|
+
attention_mask_module = split_cache_module.SplitAttentionMaskBuilder(model)
|
|
363
|
+
sample_inputs = attention_mask_module.get_sample_inputs(
|
|
364
|
+
text_model_config, export_config
|
|
365
|
+
)
|
|
366
|
+
for signature_name, (sample_input, _) in sample_inputs.items():
|
|
367
|
+
converter.add_signature(
|
|
368
|
+
signature_name,
|
|
369
|
+
attention_mask_module.eval(),
|
|
370
|
+
sample_kwargs=sample_input,
|
|
371
|
+
)
|
|
372
|
+
# Cache Update
|
|
373
|
+
cache_update_module = split_cache_module.CacheUpdate(model)
|
|
374
|
+
sample_inputs = cache_update_module.get_sample_inputs(
|
|
375
|
+
text_model_config, export_config
|
|
376
|
+
)
|
|
377
|
+
for signature_name, (sample_input, _) in sample_inputs.items():
|
|
378
|
+
converter.add_signature(
|
|
379
|
+
signature_name,
|
|
380
|
+
cache_update_module.eval(),
|
|
381
|
+
sample_kwargs=sample_input,
|
|
382
|
+
)
|
|
383
|
+
lrt_model = converter.convert(strict_export=False)
|
|
384
|
+
model_path = os.path.join(work_dir, 'auxiliary.tflite')
|
|
385
|
+
lrt_model.export(model_path)
|
|
386
|
+
return model_path
|
|
387
|
+
|
|
388
|
+
|
|
340
389
|
def export_tokenizer(
|
|
341
390
|
tokenizer,
|
|
342
391
|
work_dir: str,
|