ai-edge-torch-nightly 0.4.0.dev20250407__py3-none-any.whl → 0.5.0.dev20250409__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/gemma3/decoder.py +8 -9
- ai_edge_torch/generative/examples/gemma3/verify_util.py +4 -2
- ai_edge_torch/generative/layers/experimental/attention.py +9 -9
- ai_edge_torch/generative/layers/experimental/kv_cache.py +13 -284
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +6 -9
- ai_edge_torch/generative/layers/experimental/types.py +3 -0
- ai_edge_torch/generative/layers/kv_cache.py +81 -14
- ai_edge_torch/generative/test/test_kv_cache.py +12 -19
- ai_edge_torch/generative/utilities/converter.py +8 -3
- ai_edge_torch/generative/utilities/export_config.py +3 -1
- ai_edge_torch/odml_torch/export.py +30 -6
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250407.dist-info → ai_edge_torch_nightly-0.5.0.dev20250409.dist-info}/METADATA +4 -1
- {ai_edge_torch_nightly-0.4.0.dev20250407.dist-info → ai_edge_torch_nightly-0.5.0.dev20250409.dist-info}/RECORD +18 -18
- {ai_edge_torch_nightly-0.4.0.dev20250407.dist-info → ai_edge_torch_nightly-0.5.0.dev20250409.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250407.dist-info → ai_edge_torch_nightly-0.5.0.dev20250409.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250407.dist-info → ai_edge_torch_nightly-0.5.0.dev20250409.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
20
|
-
from ai_edge_torch.generative.layers
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
21
21
|
from ai_edge_torch.generative.utilities import converter
|
22
22
|
from ai_edge_torch.generative.utilities import export_config
|
23
23
|
import torch
|
@@ -58,7 +58,7 @@ def _create_export_config(
|
|
58
58
|
)
|
59
59
|
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
60
60
|
export_config.decode_mask = decode_mask
|
61
|
-
export_config.
|
61
|
+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
62
62
|
return export_config
|
63
63
|
|
64
64
|
|
@@ -18,9 +18,9 @@
|
|
18
18
|
from typing import List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import builder
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
22
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
23
|
from ai_edge_torch.generative.layers.experimental import attention
|
23
|
-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
26
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
@@ -81,8 +81,8 @@ class DecoderBlock(attention.TransformerBlock):
|
|
81
81
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
82
82
|
mask: Optional[torch.Tensor] = None,
|
83
83
|
input_pos: Optional[torch.Tensor] = None,
|
84
|
-
kv_cache: kv_utils.
|
85
|
-
) -> Tuple[torch.Tensor, Optional[kv_utils.
|
84
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
85
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
86
86
|
"""Forward function of the Gemma3Block.
|
87
87
|
|
88
88
|
Exactly the same as TransformerBlock but we call the post-attention norm
|
@@ -241,13 +241,12 @@ class Decoder(nn.Module):
|
|
241
241
|
self,
|
242
242
|
tokens: torch.Tensor,
|
243
243
|
input_pos: torch.Tensor,
|
244
|
-
kv_cache: kv_utils.
|
244
|
+
kv_cache: kv_utils.KVCache,
|
245
245
|
input_embeds: Optional[torch.Tensor] = None,
|
246
246
|
mask: Optional[torch.Tensor] = None,
|
247
247
|
image_indices: Optional[torch.Tensor] = None,
|
248
248
|
export_config: Optional[export_cfg.ExportConfig] = None,
|
249
|
-
) -> dict[torch.Tensor, kv_utils.
|
250
|
-
|
249
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
251
250
|
pixel_mask = None
|
252
251
|
if input_embeds is None:
|
253
252
|
# token embeddings of shape (b, t, n_embd)
|
@@ -287,10 +286,10 @@ class Decoder(nn.Module):
|
|
287
286
|
rope: List[Tuple[torch.Tensor, torch.Tensor]],
|
288
287
|
mask: torch.Tensor | List[torch.Tensor],
|
289
288
|
input_pos: torch.Tensor,
|
290
|
-
kv_cache: kv_utils.
|
289
|
+
kv_cache: kv_utils.KVCache,
|
291
290
|
pixel_mask: Optional[torch.Tensor] = None,
|
292
291
|
export_config: Optional[export_cfg.ExportConfig] = None,
|
293
|
-
) -> dict[torch.Tensor, kv_utils.
|
292
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
294
293
|
"""Forwards the model with input embeddings."""
|
295
294
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
296
295
|
"The number of transformer blocks and the number of KV cache entries"
|
@@ -326,7 +325,7 @@ class Decoder(nn.Module):
|
|
326
325
|
x, kv_entry = block(x, rope[i], mask_entry, input_pos, kv_entry)
|
327
326
|
if kv_entry:
|
328
327
|
updated_kv_entries.append(kv_entry)
|
329
|
-
updated_kv_cache = kv_utils.
|
328
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
330
329
|
if export_config is not None:
|
331
330
|
if (
|
332
331
|
torch.numel(input_pos) > 1
|
@@ -20,8 +20,8 @@ import os
|
|
20
20
|
from typing import List, Optional, Tuple
|
21
21
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
|
-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
|
25
25
|
from ai_edge_torch.generative.utilities.experimental import verifier
|
26
26
|
from gemma import config as gemma_config
|
27
27
|
from gemma import model as gemma_model
|
@@ -94,7 +94,9 @@ class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
|
|
94
94
|
|
95
95
|
def _init_kv_cache(self):
|
96
96
|
"""Returns an initialized KV cache."""
|
97
|
-
return kv_utils.
|
97
|
+
return kv_utils.KVCache.from_model_config(
|
98
|
+
self.model.model.config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
|
99
|
+
)
|
98
100
|
|
99
101
|
def forward(
|
100
102
|
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
@@ -22,8 +22,9 @@ at any time.
|
|
22
22
|
from typing import Optional, Tuple, Union
|
23
23
|
|
24
24
|
from ai_edge_torch.generative.layers import builder
|
25
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
26
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import kv_cache as
|
27
|
+
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
27
28
|
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
28
29
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
30
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
@@ -69,9 +70,9 @@ class TransformerBlock(nn.Module):
|
|
69
70
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
70
71
|
mask: Optional[torch.Tensor] = None,
|
71
72
|
input_pos: Optional[torch.Tensor] = None,
|
72
|
-
kv_cache: kv_utils.
|
73
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
73
74
|
lora: Optional[lora_utils.LoRAEntry] = None,
|
74
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.
|
75
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
75
76
|
"""Forward function of the TransformerBlock.
|
76
77
|
|
77
78
|
Args:
|
@@ -79,7 +80,7 @@ class TransformerBlock(nn.Module):
|
|
79
80
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
80
81
|
mask (torch.Tensor): the optional mask tensor.
|
81
82
|
input_pos (torch.Tensor): the optional input position tensor.
|
82
|
-
kv_cache (
|
83
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
83
84
|
lora (LoRAEntry): the optional lora entry.
|
84
85
|
|
85
86
|
Returns:
|
@@ -154,9 +155,9 @@ class CausalSelfAttention(nn.Module):
|
|
154
155
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
155
156
|
mask: Optional[torch.Tensor] = None,
|
156
157
|
input_pos: Optional[torch.Tensor] = None,
|
157
|
-
kv_cache: Optional[kv_utils.
|
158
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
158
159
|
lora: Optional[lora_utils.LoRAEntry] = None,
|
159
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.
|
160
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
160
161
|
"""Forward function of the CausalSelfAttention layer, which can support
|
161
162
|
|
162
163
|
MQA, GQA and MHA.
|
@@ -166,8 +167,7 @@ class CausalSelfAttention(nn.Module):
|
|
166
167
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
167
168
|
mask (torch.Tensor): the optional mask tensor.
|
168
169
|
input_pos (torch.Tensor): the optional input position tensor.
|
169
|
-
kv_cache (
|
170
|
-
module.
|
170
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
171
171
|
lora (LoRAEntry): the optional lora entry.
|
172
172
|
|
173
173
|
Returns:
|
@@ -237,7 +237,7 @@ class CausalSelfAttention(nn.Module):
|
|
237
237
|
) # 1, bk, h, s
|
238
238
|
|
239
239
|
if kv_cache is not None:
|
240
|
-
kv_cache =
|
240
|
+
kv_cache = kv_utils_experimental.update(kv_cache, input_pos, k, v)
|
241
241
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
242
242
|
|
243
243
|
sdpa_out = self.sdpa_func(
|
@@ -18,304 +18,33 @@
|
|
18
18
|
This is an experimental implementation and is subject to change at any time.
|
19
19
|
"""
|
20
20
|
|
21
|
-
import dataclasses
|
22
|
-
import functools
|
23
|
-
from typing import Any, List, Tuple, Type
|
24
|
-
from ai_edge_torch.generative.layers import model_config
|
25
|
-
from ai_edge_torch.generative.layers.experimental import types
|
26
21
|
from ai_edge_torch.generative.custom_ops import dynamic_update_slice as dus_utils
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
27
23
|
import torch
|
28
|
-
import torch.utils._pytree as pytree
|
29
|
-
|
30
|
-
|
31
|
-
@dataclasses.dataclass
|
32
|
-
class KVCacheEntryBase:
|
33
|
-
"""A single cache entry that includes K and V caches.
|
34
|
-
|
35
|
-
The chaches are built based on the provided config with the shape of
|
36
|
-
(batch_size, kv_cache_max, num_query_groups, head_dim).
|
37
|
-
"""
|
38
|
-
|
39
|
-
k_cache: torch.Tensor
|
40
|
-
v_cache: torch.Tensor
|
41
|
-
|
42
|
-
@classmethod
|
43
|
-
def _from_model_config(
|
44
|
-
cls,
|
45
|
-
k_shape: Tuple[int, ...],
|
46
|
-
v_shape: Tuple[int, ...],
|
47
|
-
dtype: torch.dtype = torch.float32,
|
48
|
-
device: torch.device = None,
|
49
|
-
):
|
50
|
-
"""Build an instance of the class based on model config."""
|
51
|
-
k = torch.zeros(k_shape, dtype=dtype, device=device)
|
52
|
-
v = torch.zeros(v_shape, dtype=dtype, device=device)
|
53
|
-
obj = cls(k_cache=k, v_cache=v)
|
54
|
-
return obj
|
55
|
-
|
56
|
-
@classmethod
|
57
|
-
def from_model_config(
|
58
|
-
cls,
|
59
|
-
kv_cache_max: int,
|
60
|
-
config: model_config.AttentionConfig,
|
61
|
-
dtype: torch.dtype = torch.float32,
|
62
|
-
device: torch.device = None,
|
63
|
-
batch_size: int = 1,
|
64
|
-
):
|
65
|
-
"""Build an instance of the class based on model config."""
|
66
|
-
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
|
67
|
-
return cls._from_model_config(shape, shape, dtype, device)
|
68
|
-
|
69
|
-
|
70
|
-
@dataclasses.dataclass
|
71
|
-
class KVCacheEntryBTNH(KVCacheEntryBase):
|
72
|
-
k_type = types.BTNH()
|
73
|
-
v_type = types.BTNH()
|
74
|
-
|
75
|
-
|
76
|
-
@dataclasses.dataclass
|
77
|
-
class KVCacheEntryTransposed(KVCacheEntryBase):
|
78
|
-
|
79
|
-
k_type = types.BNTH()
|
80
|
-
v_type = types.BNHT()
|
81
|
-
|
82
|
-
@classmethod
|
83
|
-
def from_model_config(
|
84
|
-
cls,
|
85
|
-
kv_cache_max: int,
|
86
|
-
config: model_config.AttentionConfig,
|
87
|
-
dtype: torch.dtype = torch.float32,
|
88
|
-
device: torch.device = None,
|
89
|
-
batch_size: int = 1,
|
90
|
-
):
|
91
|
-
"""Build an instance of the class based on model config."""
|
92
|
-
k_shape = (
|
93
|
-
batch_size,
|
94
|
-
config.num_query_groups,
|
95
|
-
kv_cache_max,
|
96
|
-
config.head_dim,
|
97
|
-
) # b, k, s, h
|
98
|
-
v_shape = (
|
99
|
-
batch_size,
|
100
|
-
config.num_query_groups,
|
101
|
-
config.head_dim,
|
102
|
-
kv_cache_max,
|
103
|
-
) # b, k, h, s
|
104
|
-
return cls._from_model_config(k_shape, v_shape, dtype, device)
|
105
|
-
|
106
|
-
|
107
|
-
def _flatten_kv_entry(
|
108
|
-
kv_e: KVCacheEntryBase,
|
109
|
-
) -> Tuple[List[torch.Tensor], Any]:
|
110
|
-
return ([kv_e.k_cache, kv_e.v_cache], None)
|
111
|
-
|
112
|
-
|
113
|
-
def _unflatten_kv_entry(
|
114
|
-
kv_entry_ty: Type[KVCacheEntryBase],
|
115
|
-
values: List[torch.Tensor],
|
116
|
-
unused_context: Any,
|
117
|
-
) -> KVCacheEntryBase:
|
118
|
-
return kv_entry_ty(*values)
|
119
|
-
|
120
|
-
|
121
|
-
pytree.register_pytree_node(
|
122
|
-
KVCacheEntryTransposed,
|
123
|
-
_flatten_kv_entry,
|
124
|
-
functools.partial(_unflatten_kv_entry, KVCacheEntryTransposed),
|
125
|
-
serialized_type_name="",
|
126
|
-
)
|
127
|
-
|
128
|
-
pytree.register_pytree_node(
|
129
|
-
KVCacheEntryBase,
|
130
|
-
_flatten_kv_entry,
|
131
|
-
functools.partial(_unflatten_kv_entry, KVCacheEntryBase),
|
132
|
-
serialized_type_name="",
|
133
|
-
)
|
134
|
-
|
135
|
-
|
136
|
-
@dataclasses.dataclass
|
137
|
-
class KVCacheBase:
|
138
|
-
"""A utility class for holding KV cache entries per layer."""
|
139
|
-
|
140
|
-
caches: Tuple[KVCacheEntryBase, ...]
|
141
|
-
|
142
|
-
@classmethod
|
143
|
-
def _from_model_config(
|
144
|
-
cls,
|
145
|
-
kv_entry_cls,
|
146
|
-
config: model_config.ModelConfig,
|
147
|
-
dtype: torch.dtype = torch.float32,
|
148
|
-
device: torch.device = None,
|
149
|
-
batch_size: int = 1,
|
150
|
-
):
|
151
|
-
caches = [
|
152
|
-
kv_entry_cls.from_model_config(
|
153
|
-
config.kv_cache_max,
|
154
|
-
config.block_config(idx).attn_config,
|
155
|
-
dtype,
|
156
|
-
device,
|
157
|
-
batch_size,
|
158
|
-
)
|
159
|
-
for idx in range(config.num_layers)
|
160
|
-
]
|
161
|
-
obj = cls(caches=tuple(caches))
|
162
|
-
return obj
|
163
|
-
|
164
|
-
@classmethod
|
165
|
-
def from_model_config(
|
166
|
-
cls,
|
167
|
-
config: model_config.ModelConfig,
|
168
|
-
dtype: torch.dtype = torch.float32,
|
169
|
-
device: torch.device = None,
|
170
|
-
batch_size: int = 1,
|
171
|
-
):
|
172
|
-
"""Build an instance of the class based on model config.
|
173
|
-
|
174
|
-
Args:
|
175
|
-
config (ModelConfig): Model config used for building the cache.
|
176
|
-
dtype (torch.dtype, optional): The data type of the cache tensor.
|
177
|
-
Defaults to torch.float32.
|
178
|
-
device (torch.device, optional): The device placement of the cache
|
179
|
-
tensors. Defaults to None.
|
180
|
-
batch_size (int, optional): The batch size of the cache tensors.
|
181
|
-
Defaults to 1.
|
182
|
-
|
183
|
-
Returns:
|
184
|
-
KVCacheBase: The created cache object.
|
185
|
-
"""
|
186
|
-
assert batch_size == 1, "Batch size must be 1 for KV Cache."
|
187
|
-
return cls._from_model_config(
|
188
|
-
KVCacheEntryBase,
|
189
|
-
config=config,
|
190
|
-
dtype=dtype,
|
191
|
-
device=device,
|
192
|
-
batch_size=batch_size,
|
193
|
-
)
|
194
|
-
|
195
|
-
def flatten(self) -> List[torch.Tensor]:
|
196
|
-
"""Flatten the cache entries into a list of tensors with order k_i, v_i."""
|
197
|
-
flattened, _ = _flatten_kvc(self)
|
198
|
-
return flattened
|
199
|
-
|
200
|
-
|
201
|
-
@dataclasses.dataclass
|
202
|
-
class KVCacheBTNH(KVCacheBase):
|
203
|
-
|
204
|
-
@classmethod
|
205
|
-
def from_model_config(
|
206
|
-
cls,
|
207
|
-
config: model_config.ModelConfig,
|
208
|
-
dtype: torch.dtype = torch.float32,
|
209
|
-
device: torch.device = None,
|
210
|
-
batch_size: int = 1,
|
211
|
-
):
|
212
|
-
return cls._from_model_config(
|
213
|
-
KVCacheEntryBTNH,
|
214
|
-
config=config,
|
215
|
-
dtype=dtype,
|
216
|
-
device=device,
|
217
|
-
batch_size=batch_size,
|
218
|
-
)
|
219
|
-
|
220
|
-
|
221
|
-
@dataclasses.dataclass
|
222
|
-
class KVCacheTransposed(KVCacheBase):
|
223
|
-
|
224
|
-
@classmethod
|
225
|
-
def from_model_config(
|
226
|
-
cls,
|
227
|
-
config: model_config.ModelConfig,
|
228
|
-
dtype: torch.dtype = torch.float32,
|
229
|
-
device: torch.device = None,
|
230
|
-
batch_size: int = 1,
|
231
|
-
):
|
232
|
-
return cls._from_model_config(
|
233
|
-
KVCacheEntryTransposed,
|
234
|
-
config=config,
|
235
|
-
dtype=dtype,
|
236
|
-
device=device,
|
237
|
-
batch_size=batch_size,
|
238
|
-
)
|
239
|
-
|
240
|
-
|
241
|
-
def _flatten_kvc(kvc: KVCacheBase) -> Tuple[List[str], List[str]]:
|
242
|
-
flattened = []
|
243
|
-
flat_names = []
|
244
|
-
none_names = []
|
245
|
-
for i, kv_entry in enumerate(kvc.caches):
|
246
|
-
flattened.append(kv_entry.k_cache)
|
247
|
-
flat_names.append(f"k_{i}")
|
248
|
-
flattened.append(kv_entry.v_cache)
|
249
|
-
flat_names.append(f"v_{i}")
|
250
|
-
return flattened, [flat_names, none_names]
|
251
|
-
|
252
|
-
|
253
|
-
def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
|
254
|
-
flattened, (flat_names, none_names) = _flatten_kvc(kvc)
|
255
|
-
return [
|
256
|
-
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
257
|
-
], flat_names
|
258
|
-
|
259
|
-
|
260
|
-
def _unflatten_kvc(
|
261
|
-
kv_ty: Type[KVCacheBase],
|
262
|
-
kv_entry_type: Type[KVCacheEntryBase],
|
263
|
-
values: List[torch.Tensor],
|
264
|
-
context: Tuple[List, List],
|
265
|
-
) -> KVCacheBase:
|
266
|
-
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
267
|
-
num_layers = len(values) // 2
|
268
|
-
flat_names = context[0]
|
269
|
-
kv_entries = []
|
270
|
-
for i in range(num_layers):
|
271
|
-
k_cache_idx = flat_names.index(f"k_{i}")
|
272
|
-
v_cache_idx = flat_names.index(f"v_{i}")
|
273
|
-
kv_entries.append(
|
274
|
-
kv_entry_type(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
|
275
|
-
)
|
276
|
-
obj = kv_ty(tuple(kv_entries))
|
277
|
-
return obj
|
278
|
-
|
279
|
-
|
280
|
-
pytree.register_pytree_node(
|
281
|
-
KVCacheTransposed,
|
282
|
-
_flatten_kvc,
|
283
|
-
functools.partial(
|
284
|
-
_unflatten_kvc, KVCacheTransposed, KVCacheEntryTransposed
|
285
|
-
),
|
286
|
-
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
287
|
-
serialized_type_name="",
|
288
|
-
)
|
289
|
-
|
290
|
-
pytree.register_pytree_node(
|
291
|
-
KVCacheBase,
|
292
|
-
_flatten_kvc,
|
293
|
-
functools.partial(_unflatten_kvc, KVCacheBase, KVCacheEntryBase),
|
294
|
-
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
295
|
-
serialized_type_name="",
|
296
|
-
)
|
297
24
|
|
298
25
|
|
299
26
|
def update(
|
300
|
-
cache:
|
27
|
+
cache: kv_utils.KVCacheEntry,
|
301
28
|
input_pos: torch.Tensor,
|
302
29
|
k_slice: torch.Tensor,
|
303
30
|
v_slice: torch.Tensor,
|
304
|
-
) ->
|
31
|
+
) -> kv_utils.KVCacheEntry:
|
305
32
|
"""Out of place update of Cache buffer.
|
306
33
|
|
307
34
|
Args:
|
308
|
-
cache (
|
35
|
+
cache (kv_utils.KVCacheEntry): The original cache buffer.
|
309
36
|
input_pos (torch.Tensor): The update slice positions.
|
310
37
|
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
311
38
|
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
312
39
|
|
313
40
|
Returns:
|
314
|
-
|
41
|
+
kv_utils.KVCacheEntry: The updated KVCacheBase entry based on the passed
|
315
42
|
inputs.
|
316
43
|
"""
|
317
|
-
|
318
|
-
|
44
|
+
assert (
|
45
|
+
cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
|
46
|
+
), "KV entry must have transposed layout."
|
47
|
+
return _update_kv_impl_transposed(cache, input_pos, k_slice, v_slice)
|
319
48
|
|
320
49
|
|
321
50
|
def _get_slice_indices(
|
@@ -338,12 +67,12 @@ def _get_slice_indices(
|
|
338
67
|
return slice_indices
|
339
68
|
|
340
69
|
|
341
|
-
def
|
342
|
-
cache:
|
70
|
+
def _update_kv_impl_transposed(
|
71
|
+
cache: kv_utils.KVCacheEntry,
|
343
72
|
input_pos: torch.Tensor,
|
344
73
|
k_slice: torch.Tensor,
|
345
74
|
v_slice: torch.Tensor,
|
346
|
-
) ->
|
75
|
+
) -> kv_utils.KVCacheEntry:
|
347
76
|
"""Update the cache buffer with High Level Function Boundary annotation."""
|
348
77
|
cache_dim = 4
|
349
78
|
k_ts_idx = 2
|
@@ -357,4 +86,4 @@ def _update_kv_impl(
|
|
357
86
|
v = dus_utils.dynamic_update_slice(
|
358
87
|
cache.v_cache, v_slice, [x for x in v_slice_indices]
|
359
88
|
)
|
360
|
-
return
|
89
|
+
return kv_utils.KVCacheEntry(k, v, cache.kv_layout)
|
@@ -19,7 +19,7 @@ import math
|
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
|
22
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
from ai_edge_torch.generative.layers.experimental import types
|
24
24
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
25
25
|
from multipledispatch import dispatch
|
@@ -28,7 +28,7 @@ import torch.nn.functional as F
|
|
28
28
|
|
29
29
|
|
30
30
|
def scaled_dot_product_attention(
|
31
|
-
kv: kv_utils.
|
31
|
+
kv: kv_utils.KVCacheEntry,
|
32
32
|
query: torch.Tensor,
|
33
33
|
key: torch.Tensor,
|
34
34
|
value: torch.Tensor,
|
@@ -37,10 +37,10 @@ def scaled_dot_product_attention(
|
|
37
37
|
scale: Optional[float] = None,
|
38
38
|
softcap: Optional[float] = None,
|
39
39
|
):
|
40
|
-
if hasattr(kv, "
|
40
|
+
if hasattr(kv, "kv_layout"):
|
41
41
|
return _sdpa(
|
42
|
-
kv.
|
43
|
-
kv.
|
42
|
+
kv.kv_layout[0](), # key layout
|
43
|
+
kv.kv_layout[1](), # value layout
|
44
44
|
query=query,
|
45
45
|
key=key,
|
46
46
|
value=value,
|
@@ -49,10 +49,7 @@ def scaled_dot_product_attention(
|
|
49
49
|
scale=scale,
|
50
50
|
softcap=softcap,
|
51
51
|
)
|
52
|
-
raise ValueError(
|
53
|
-
f"SDPA for K type {type(kv.caches[0].k_type)} and V type"
|
54
|
-
f" {type(kv.caches[0].v_type)} not supported."
|
55
|
-
)
|
52
|
+
raise ValueError("No kv_layout attribute found in kv.")
|
56
53
|
|
57
54
|
|
58
55
|
@dispatch(types.BNTH, types.BNHT)
|
@@ -62,6 +62,9 @@ class TensorDimensionMeta(type):
|
|
62
62
|
def __repr__(cls):
|
63
63
|
return f'{cls.__name__}'
|
64
64
|
|
65
|
+
def __iter__(cls):
|
66
|
+
return iter(getattr(cls, 'dimensions'))
|
67
|
+
|
65
68
|
|
66
69
|
def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
|
67
70
|
"""Creates a TensorDimensionMeta class with the specified dimensions.
|
@@ -16,24 +16,58 @@
|
|
16
16
|
"""Utility functions for externalized KV Cache."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import List, Tuple
|
19
|
+
from typing import Any, List, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.custom_ops.dynamic_update_slice import dynamic_update_slice
|
22
22
|
from ai_edge_torch.generative.layers import model_config
|
23
|
+
from ai_edge_torch.generative.layers.experimental import types
|
23
24
|
import torch
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
26
27
|
|
28
|
+
KVLayout = Tuple[types.TensorDimensionMeta, types.TensorDimensionMeta]
|
29
|
+
|
30
|
+
# Define common layouts for KV Cache.
|
31
|
+
KV_LAYOUT_DEFAULT = (types.BTNH, types.BTNH)
|
32
|
+
KV_LAYOUT_TRANSPOSED = (types.BNTH, types.BNHT)
|
33
|
+
|
34
|
+
|
27
35
|
@dataclasses.dataclass
|
28
36
|
class KVCacheEntry:
|
29
37
|
"""A single cache entry that includes K and V caches.
|
30
38
|
|
31
|
-
The
|
32
|
-
(batch_size=1, kv_cache_max, num_query_groups, head_dim).
|
39
|
+
The cache layout can be customized based on different use cases.
|
33
40
|
"""
|
34
41
|
|
35
42
|
k_cache: torch.Tensor
|
36
43
|
v_cache: torch.Tensor
|
44
|
+
kv_layout: KVLayout = KV_LAYOUT_DEFAULT
|
45
|
+
|
46
|
+
@classmethod
|
47
|
+
def construct_kv_shape_from_layout(
|
48
|
+
cls,
|
49
|
+
shape_spec: types.TensorDimensionMeta,
|
50
|
+
kv_cache_max: int,
|
51
|
+
config: model_config.AttentionConfig,
|
52
|
+
batch_size: int,
|
53
|
+
) -> List[int]:
|
54
|
+
"""Constructs the shape of the key or value cache entry based on
|
55
|
+
|
56
|
+
the specified layout.
|
57
|
+
"""
|
58
|
+
output_shape = []
|
59
|
+
for dim_spec in shape_spec:
|
60
|
+
if dim_spec is types.TensorDims.BATCH:
|
61
|
+
output_shape.append(batch_size)
|
62
|
+
elif dim_spec is types.TensorDims.SEQUENCE:
|
63
|
+
output_shape.append(kv_cache_max)
|
64
|
+
elif dim_spec is types.TensorDims.NUM_HEADS:
|
65
|
+
output_shape.append(config.num_query_groups)
|
66
|
+
elif dim_spec is types.TensorDims.HEAD_DIM:
|
67
|
+
output_shape.append(config.head_dim)
|
68
|
+
else:
|
69
|
+
raise ValueError(f"Unsupported dimension spec: {dim_spec}")
|
70
|
+
return output_shape
|
37
71
|
|
38
72
|
@classmethod
|
39
73
|
def from_model_config(
|
@@ -41,14 +75,20 @@ class KVCacheEntry:
|
|
41
75
|
kv_cache_max: int,
|
42
76
|
config: model_config.AttentionConfig,
|
43
77
|
dtype: torch.dtype = torch.float32,
|
44
|
-
device: torch.device = None,
|
78
|
+
device: torch.device | None = None,
|
45
79
|
batch_size: int = 1,
|
80
|
+
kv_layout: KVLayout = KV_LAYOUT_DEFAULT,
|
46
81
|
) -> "KVCacheEntry":
|
47
82
|
"""Build an instance of the class based on model config."""
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
83
|
+
k_shape = cls.construct_kv_shape_from_layout(
|
84
|
+
kv_layout[0], kv_cache_max, config, batch_size
|
85
|
+
)
|
86
|
+
v_shape = cls.construct_kv_shape_from_layout(
|
87
|
+
kv_layout[1], kv_cache_max, config, batch_size
|
88
|
+
)
|
89
|
+
k = torch.zeros(k_shape, dtype=dtype, device=device)
|
90
|
+
v = torch.zeros(v_shape, dtype=dtype, device=device)
|
91
|
+
obj = cls(k_cache=k, v_cache=v, kv_layout=kv_layout)
|
52
92
|
return obj
|
53
93
|
|
54
94
|
|
@@ -63,8 +103,9 @@ class KVCache:
|
|
63
103
|
cls,
|
64
104
|
config: model_config.ModelConfig,
|
65
105
|
dtype: torch.dtype = torch.float32,
|
66
|
-
device: torch.device = None,
|
106
|
+
device: torch.device | None = None,
|
67
107
|
batch_size: int = 1,
|
108
|
+
kv_layout: KVLayout = KV_LAYOUT_DEFAULT,
|
68
109
|
) -> "KVCache":
|
69
110
|
"""Build an instance of the class based on model config.
|
70
111
|
|
@@ -89,6 +130,7 @@ class KVCache:
|
|
89
130
|
dtype,
|
90
131
|
device,
|
91
132
|
batch_size,
|
133
|
+
kv_layout,
|
92
134
|
)
|
93
135
|
for idx in range(config.num_layers)
|
94
136
|
]
|
@@ -104,7 +146,7 @@ class KVCache:
|
|
104
146
|
def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
|
105
147
|
flattened = []
|
106
148
|
flat_names = []
|
107
|
-
none_names = []
|
149
|
+
none_names = [kvc.caches[0].kv_layout]
|
108
150
|
for i, kv_entry in enumerate(kvc.caches):
|
109
151
|
flattened.append(kv_entry.k_cache)
|
110
152
|
flat_names.append(f"k_{i}")
|
@@ -121,22 +163,48 @@ def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
|
|
121
163
|
|
122
164
|
|
123
165
|
def _unflatten_kvc(
|
124
|
-
values: List[torch.Tensor],
|
166
|
+
values: List[torch.Tensor],
|
167
|
+
context: Tuple[List, List],
|
125
168
|
) -> KVCache:
|
126
169
|
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
127
170
|
num_layers = len(values) // 2
|
128
171
|
flat_names = context[0]
|
172
|
+
kv_layout = context[1][0]
|
129
173
|
kv_entries = []
|
130
174
|
for i in range(num_layers):
|
131
175
|
k_cache_idx = flat_names.index(f"k_{i}")
|
132
176
|
v_cache_idx = flat_names.index(f"v_{i}")
|
133
177
|
kv_entries.append(
|
134
|
-
KVCacheEntry(
|
178
|
+
KVCacheEntry(
|
179
|
+
k_cache=values[k_cache_idx],
|
180
|
+
v_cache=values[v_cache_idx],
|
181
|
+
kv_layout=kv_layout,
|
182
|
+
)
|
135
183
|
)
|
136
184
|
obj = KVCache(tuple(kv_entries))
|
137
185
|
return obj
|
138
186
|
|
139
187
|
|
188
|
+
def _flatten_kv_entry(
|
189
|
+
kv_e: KVCacheEntry,
|
190
|
+
) -> Tuple[List[torch.Tensor], Any]:
|
191
|
+
return ([kv_e.k_cache, kv_e.v_cache], kv_e.kv_layout)
|
192
|
+
|
193
|
+
|
194
|
+
def _unflatten_kv_entry(
|
195
|
+
values: List[torch.Tensor],
|
196
|
+
context: Any,
|
197
|
+
) -> KVCacheEntry:
|
198
|
+
return KVCacheEntry(*values, kv_layout=context)
|
199
|
+
|
200
|
+
|
201
|
+
pytree.register_pytree_node(
|
202
|
+
KVCacheEntry,
|
203
|
+
_flatten_kv_entry,
|
204
|
+
_unflatten_kv_entry,
|
205
|
+
serialized_type_name="",
|
206
|
+
)
|
207
|
+
|
140
208
|
pytree.register_pytree_node(
|
141
209
|
KVCache,
|
142
210
|
_flatten_kvc,
|
@@ -145,7 +213,6 @@ pytree.register_pytree_node(
|
|
145
213
|
serialized_type_name="",
|
146
214
|
)
|
147
215
|
|
148
|
-
|
149
216
|
def update(
|
150
217
|
cache: KVCacheEntry,
|
151
218
|
input_pos: torch.Tensor,
|
@@ -204,5 +271,5 @@ def _update_kv_impl(
|
|
204
271
|
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
|
205
272
|
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
|
206
273
|
|
207
|
-
updated_cache = KVCacheEntry(k, v)
|
274
|
+
updated_cache = KVCacheEntry(k, v, cache.kv_layout)
|
208
275
|
return updated_cache
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""A suite of tests to validate KV Cache layer."""
|
17
17
|
|
18
18
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
|
-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
20
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
21
20
|
import torch
|
22
21
|
import torch.utils._pytree as pytree
|
@@ -117,7 +116,7 @@ class TestKVLayers(googletest.TestCase):
|
|
117
116
|
self.assertEqual(input_specs[0].arg.name, "kv_k_0")
|
118
117
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
119
118
|
|
120
|
-
def
|
119
|
+
def test_pytree_roundtrip_kv_cache(self):
|
121
120
|
NUM_LAYERS = 4
|
122
121
|
config = self._get_test_config(
|
123
122
|
num_layers=NUM_LAYERS,
|
@@ -125,15 +124,13 @@ class TestKVLayers(googletest.TestCase):
|
|
125
124
|
num_query_groups=1,
|
126
125
|
kv_cache_max_len=4,
|
127
126
|
)
|
128
|
-
kv =
|
129
|
-
config, batch_size=1
|
130
|
-
)
|
127
|
+
kv = kv_utils.KVCache.from_model_config(config, batch_size=1)
|
131
128
|
flat, treespec = pytree.tree_flatten(kv)
|
132
129
|
self.assertLen(flat, NUM_LAYERS * 2)
|
133
130
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
134
131
|
self.assertEqual(kv, kv_unflat)
|
135
132
|
|
136
|
-
def
|
133
|
+
def test_pytree_roundtrip_kv_cache_derived(self):
|
137
134
|
NUM_LAYERS = 4
|
138
135
|
config = self._get_test_config(
|
139
136
|
num_layers=NUM_LAYERS,
|
@@ -141,41 +138,37 @@ class TestKVLayers(googletest.TestCase):
|
|
141
138
|
num_query_groups=1,
|
142
139
|
kv_cache_max_len=4,
|
143
140
|
)
|
144
|
-
kv =
|
145
|
-
config, batch_size=1
|
141
|
+
kv = kv_utils.KVCache.from_model_config(
|
142
|
+
config, batch_size=1, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
|
146
143
|
)
|
147
144
|
flat, treespec = pytree.tree_flatten(kv)
|
148
145
|
self.assertLen(flat, NUM_LAYERS * 2)
|
149
146
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
150
147
|
self.assertEqual(kv, kv_unflat)
|
151
148
|
|
152
|
-
def
|
149
|
+
def test_pytree_roundtrip_kv_entry(self):
|
153
150
|
attn_config = cfg.AttentionConfig(
|
154
151
|
num_heads=1, head_dim=1, num_query_groups=1
|
155
152
|
)
|
156
|
-
kv =
|
157
|
-
32, attn_config
|
158
|
-
)
|
153
|
+
kv = kv_utils.KVCacheEntry.from_model_config(32, attn_config)
|
159
154
|
flat, treespec = pytree.tree_flatten(kv)
|
160
155
|
self.assertLen(flat, 2)
|
161
156
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
162
157
|
self.assertEqual(kv, kv_unflat)
|
163
|
-
self.assertIsInstance(kv_unflat,
|
158
|
+
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
|
164
159
|
|
165
|
-
def
|
160
|
+
def test_pytree_roundtrip_kv_entry_derived(self):
|
166
161
|
attn_config = cfg.AttentionConfig(
|
167
162
|
num_heads=1, head_dim=1, num_query_groups=1
|
168
163
|
)
|
169
|
-
kv =
|
170
|
-
32, attn_config
|
164
|
+
kv = kv_utils.KVCacheEntry.from_model_config(
|
165
|
+
32, attn_config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
|
171
166
|
)
|
172
167
|
flat, treespec = pytree.tree_flatten(kv)
|
173
168
|
self.assertLen(flat, 2)
|
174
169
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
175
170
|
self.assertEqual(kv, kv_unflat)
|
176
|
-
self.assertIsInstance(
|
177
|
-
kv_unflat, kv_utils_experimental.KVCacheEntryTransposed
|
178
|
-
)
|
171
|
+
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
|
179
172
|
|
180
173
|
|
181
174
|
if __name__ == "__main__":
|
@@ -20,6 +20,7 @@ import pathlib
|
|
20
20
|
from typing import Optional, Union
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch._convert import converter as converter_utils
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
24
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
24
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
26
|
from ai_edge_torch.generative.quantize import quant_recipes
|
@@ -218,9 +219,13 @@ def _export_helper(
|
|
218
219
|
[[0] for _ in range(export_config.decode_batch_size)], dtype=torch.int
|
219
220
|
)
|
220
221
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
221
|
-
prefill_kv =
|
222
|
-
|
223
|
-
|
222
|
+
prefill_kv = kv_utils.KVCache.from_model_config(
|
223
|
+
config, kv_layout=export_config.kvcache_layout
|
224
|
+
)
|
225
|
+
decode_kv = kv_utils.KVCache.from_model_config(
|
226
|
+
config,
|
227
|
+
batch_size=export_config.decode_batch_size,
|
228
|
+
kv_layout=export_config.kvcache_layout,
|
224
229
|
)
|
225
230
|
|
226
231
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -32,7 +32,9 @@ class ExportConfig:
|
|
32
32
|
# Attention masks given as inputs to the model.
|
33
33
|
prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
34
34
|
decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
35
|
-
# The KV Cache
|
35
|
+
# The KV Cache layout for K and V buffers in attention.
|
36
|
+
kvcache_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT
|
37
|
+
# TODO(b/409373223): The KV Cache class for K and V buffers in attention.
|
36
38
|
kvcache_cls: type = kv_utils.KVCache
|
37
39
|
# The batch size of the decode signature.
|
38
40
|
decode_batch_size: int = 1
|
@@ -209,7 +209,10 @@ class MlirLowered:
|
|
209
209
|
|
210
210
|
def get_text(self, enable_debug_info=False):
|
211
211
|
return str(
|
212
|
-
self.module.operation.get_asm(
|
212
|
+
self.module.operation.get_asm(
|
213
|
+
enable_debug_info=enable_debug_info,
|
214
|
+
large_elements_limit=16,
|
215
|
+
)
|
213
216
|
)
|
214
217
|
|
215
218
|
@property
|
@@ -326,8 +329,24 @@ def _convert_q_dq_per_channel_args_to_list(
|
|
326
329
|
|
327
330
|
def exported_program_to_mlir(
|
328
331
|
exported_program: torch.export.ExportedProgram,
|
332
|
+
*,
|
333
|
+
ir_context: ir.Context | None = None,
|
334
|
+
_pre_lower_pass: (
|
335
|
+
Callable[[torch.export.ExportedProgram], None] | None
|
336
|
+
) = None,
|
329
337
|
) -> MlirLowered:
|
330
|
-
"""Lower the exported program to MLIR.
|
338
|
+
"""Lower the exported program to MLIR.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
exported_program: The exported program to lower.
|
342
|
+
ir_context: The MLIR context to use. If not provided, a new context will be
|
343
|
+
created.
|
344
|
+
_pre_lower_pass: A function to run on exported program before lowering.
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
The lowered MLIR module, metadata, and weight tensors bundle from exported
|
348
|
+
program.
|
349
|
+
"""
|
331
350
|
exported_program = fx_infra.safe_run_decompositions(
|
332
351
|
exported_program,
|
333
352
|
fx_infra.decomp.pre_lower_decomp(),
|
@@ -340,10 +359,16 @@ def exported_program_to_mlir(
|
|
340
359
|
# Do not call run_decompositions after applying the passes.
|
341
360
|
_convert_q_dq_per_channel_args_to_list(exported_program)
|
342
361
|
|
343
|
-
|
362
|
+
if _pre_lower_pass:
|
363
|
+
_pre_lower_pass(exported_program)
|
364
|
+
|
365
|
+
if not ir_context:
|
366
|
+
ir_context = export_utils.create_ir_context()
|
367
|
+
|
368
|
+
with ir_context, ir.Location.unknown():
|
344
369
|
|
345
370
|
module = ir.Module.create()
|
346
|
-
lctx = LoweringContext(
|
371
|
+
lctx = LoweringContext(ir_context, module)
|
347
372
|
interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
|
348
373
|
ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
|
349
374
|
exported_program
|
@@ -382,7 +407,6 @@ def exported_program_to_mlir(
|
|
382
407
|
|
383
408
|
main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
384
409
|
temp_func.erase()
|
385
|
-
|
386
410
|
module.operation.verify()
|
387
411
|
|
388
412
|
input_signature = []
|
@@ -422,5 +446,5 @@ def exported_program_to_mlir(
|
|
422
446
|
for tensor_meta in _get_output_metas(exported_program)
|
423
447
|
]
|
424
448
|
return MlirLowered(
|
425
|
-
|
449
|
+
ir_context, module, state_dict, input_signature, output_signature
|
426
450
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.0.dev20250409
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -25,6 +25,9 @@ License-File: LICENSE
|
|
25
25
|
Requires-Dist: numpy
|
26
26
|
Requires-Dist: scipy
|
27
27
|
Requires-Dist: safetensors
|
28
|
+
Requires-Dist: multipledispatch
|
29
|
+
Requires-Dist: transformers
|
30
|
+
Requires-Dist: kagglehub
|
28
31
|
Requires-Dist: tabulate
|
29
32
|
Requires-Dist: torch>=2.4.0
|
30
33
|
Requires-Dist: tf-nightly>=2.19.0.dev20250101
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=DEYqmCDZNmwuMxnxrFvcTEaDp6Z_BVHJaZMYjVQ2ijU,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -65,12 +65,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
|
|
65
65
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
66
66
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
67
67
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
68
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
69
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
68
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
|
69
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=n6ZQfqNEHuOhY7Pu21bb8Eax8yn2Sx5osTKJKmhonXY,15659
|
70
70
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=5PEt0aWJ5wkUBvMoWFOJ-C48ZhG7uCVb8PCKQtZ8Fvw,6485
|
71
71
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
72
72
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
73
|
-
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=
|
73
|
+
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
|
74
74
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
75
75
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=A4uLUdqvU1NKo3seqZlWSS3fqYahnEKqNBQBJO6yXvE,1762
|
76
76
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
|
@@ -153,17 +153,17 @@ ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDq
|
|
153
153
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
154
154
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
155
155
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
156
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
156
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=9kkFpB9msgUDStFxEyQYYsavKPP4Dgqb_NFcd4hA4aU,8502
|
157
157
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
158
158
|
ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
|
159
159
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
160
160
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
161
161
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
162
162
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
163
|
-
ai_edge_torch/generative/layers/experimental/attention.py,sha256=
|
164
|
-
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=
|
165
|
-
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256
|
166
|
-
ai_edge_torch/generative/layers/experimental/types.py,sha256=
|
163
|
+
ai_edge_torch/generative/layers/experimental/attention.py,sha256=oW8cxv0pXcesnyGz6bXacRmlvHPfKNnJnls_Qb4L_aQ,8968
|
164
|
+
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=PlgL2bNNKasu3wFr3Iu9wbATWluWZt3_s4tzglJu2tM,2942
|
165
|
+
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=-ztTIgdec5gXkOVe6FXk3PMeS2HoL6-mBfDBdjQIcLQ,2808
|
166
|
+
ai_edge_torch/generative/layers/experimental/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
|
167
167
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
168
168
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
169
169
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -177,7 +177,7 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
|
|
177
177
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
178
178
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
179
179
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
|
180
|
-
ai_edge_torch/generative/test/test_kv_cache.py,sha256=
|
180
|
+
ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
|
181
181
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
182
182
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
183
183
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
@@ -185,8 +185,8 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3Gy
|
|
185
185
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
186
186
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
187
187
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
188
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
189
|
-
ai_edge_torch/generative/utilities/export_config.py,sha256
|
188
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=87Tzj-gLydx8_xnHxKlCbMmM1XHShstpKi8RH3xY7Xw,9757
|
189
|
+
ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
|
190
190
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
191
191
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
192
192
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
@@ -210,7 +210,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
|
|
210
210
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
211
211
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
212
212
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
213
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
213
|
+
ai_edge_torch/odml_torch/export.py,sha256=rxsyVagQgb-DDIVtwZwSTSVFINqwIZleOOfmPkBoPKg,14817
|
214
214
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
215
215
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
216
216
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -243,8 +243,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
243
243
|
ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
|
244
244
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
245
245
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
246
|
-
ai_edge_torch_nightly-0.
|
247
|
-
ai_edge_torch_nightly-0.
|
248
|
-
ai_edge_torch_nightly-0.
|
249
|
-
ai_edge_torch_nightly-0.
|
250
|
-
ai_edge_torch_nightly-0.
|
246
|
+
ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
247
|
+
ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/METADATA,sha256=kZwo6E79HLuM7_4E-Yw9erTzOnAAzio3Vy45hXNiC48,2051
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/RECORD,,
|
File without changes
|
File without changes
|