ai-edge-torch-nightly 0.4.0.dev20250305__py3-none-any.whl → 0.4.0.dev20250306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/layers/experimental/kv_cache.py +48 -16
- ai_edge_torch/generative/test/test_kv_cache.py +62 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250305.dist-info → ai_edge_torch_nightly-0.4.0.dev20250306.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250305.dist-info → ai_edge_torch_nightly-0.4.0.dev20250306.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.4.0.dev20250305.dist-info → ai_edge_torch_nightly-0.4.0.dev20250306.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250305.dist-info → ai_edge_torch_nightly-0.4.0.dev20250306.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250305.dist-info → ai_edge_torch_nightly-0.4.0.dev20250306.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,8 @@ This is an experimental implementation and is subject to change at any time.
|
|
19
19
|
"""
|
20
20
|
|
21
21
|
import dataclasses
|
22
|
-
|
23
|
-
|
22
|
+
import functools
|
23
|
+
from typing import Any, List, Tuple, Type
|
24
24
|
from ai_edge_torch.generative.layers import model_config
|
25
25
|
from ai_edge_torch.generative.layers.experimental import types
|
26
26
|
from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
|
@@ -46,7 +46,7 @@ class KVCacheEntryBase:
|
|
46
46
|
v_shape: Tuple[int, ...],
|
47
47
|
dtype: torch.dtype = torch.float32,
|
48
48
|
device: torch.device = None,
|
49
|
-
)
|
49
|
+
):
|
50
50
|
"""Build an instance of the class based on model config."""
|
51
51
|
k = torch.zeros(k_shape, dtype=dtype, device=device)
|
52
52
|
v = torch.zeros(v_shape, dtype=dtype, device=device)
|
@@ -61,7 +61,7 @@ class KVCacheEntryBase:
|
|
61
61
|
dtype: torch.dtype = torch.float32,
|
62
62
|
device: torch.device = None,
|
63
63
|
batch_size: int = 1,
|
64
|
-
)
|
64
|
+
):
|
65
65
|
"""Build an instance of the class based on model config."""
|
66
66
|
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
|
67
67
|
return cls._from_model_config(shape, shape, dtype, device)
|
@@ -87,7 +87,7 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
|
|
87
87
|
dtype: torch.dtype = torch.float32,
|
88
88
|
device: torch.device = None,
|
89
89
|
batch_size: int = 1,
|
90
|
-
)
|
90
|
+
):
|
91
91
|
"""Build an instance of the class based on model config."""
|
92
92
|
k_shape = (
|
93
93
|
batch_size,
|
@@ -104,6 +104,35 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
|
|
104
104
|
return cls._from_model_config(k_shape, v_shape, dtype, device)
|
105
105
|
|
106
106
|
|
107
|
+
def _flatten_kv_entry(
|
108
|
+
kv_e: KVCacheEntryBase,
|
109
|
+
) -> Tuple[List[torch.Tensor], Any]:
|
110
|
+
return ([kv_e.k_cache, kv_e.v_cache], None)
|
111
|
+
|
112
|
+
|
113
|
+
def _unflatten_kv_entry(
|
114
|
+
kv_entry_ty: Type[KVCacheEntryBase],
|
115
|
+
values: List[torch.Tensor],
|
116
|
+
unused_context: Any,
|
117
|
+
) -> KVCacheEntryBase:
|
118
|
+
return kv_entry_ty(*values)
|
119
|
+
|
120
|
+
|
121
|
+
pytree.register_pytree_node(
|
122
|
+
KVCacheEntryTransposed,
|
123
|
+
_flatten_kv_entry,
|
124
|
+
functools.partial(_unflatten_kv_entry, KVCacheEntryTransposed),
|
125
|
+
serialized_type_name="",
|
126
|
+
)
|
127
|
+
|
128
|
+
pytree.register_pytree_node(
|
129
|
+
KVCacheEntryBase,
|
130
|
+
_flatten_kv_entry,
|
131
|
+
functools.partial(_unflatten_kv_entry, KVCacheEntryBase),
|
132
|
+
serialized_type_name="",
|
133
|
+
)
|
134
|
+
|
135
|
+
|
107
136
|
@dataclasses.dataclass
|
108
137
|
class KVCacheBase:
|
109
138
|
"""A utility class for holding KV cache entries per layer."""
|
@@ -118,7 +147,7 @@ class KVCacheBase:
|
|
118
147
|
dtype: torch.dtype = torch.float32,
|
119
148
|
device: torch.device = None,
|
120
149
|
batch_size: int = 1,
|
121
|
-
)
|
150
|
+
):
|
122
151
|
caches = [
|
123
152
|
kv_entry_cls.from_model_config(
|
124
153
|
config.kv_cache_max,
|
@@ -139,7 +168,7 @@ class KVCacheBase:
|
|
139
168
|
dtype: torch.dtype = torch.float32,
|
140
169
|
device: torch.device = None,
|
141
170
|
batch_size: int = 1,
|
142
|
-
)
|
171
|
+
):
|
143
172
|
"""Build an instance of the class based on model config.
|
144
173
|
|
145
174
|
Args:
|
@@ -179,7 +208,7 @@ class KVCacheBTNH(KVCacheBase):
|
|
179
208
|
dtype: torch.dtype = torch.float32,
|
180
209
|
device: torch.device = None,
|
181
210
|
batch_size: int = 1,
|
182
|
-
)
|
211
|
+
):
|
183
212
|
return cls._from_model_config(
|
184
213
|
KVCacheEntryBTNH,
|
185
214
|
config=config,
|
@@ -199,7 +228,7 @@ class KVCacheTransposed(KVCacheBase):
|
|
199
228
|
dtype: torch.dtype = torch.float32,
|
200
229
|
device: torch.device = None,
|
201
230
|
batch_size: int = 1,
|
202
|
-
)
|
231
|
+
):
|
203
232
|
return cls._from_model_config(
|
204
233
|
KVCacheEntryTransposed,
|
205
234
|
config=config,
|
@@ -229,7 +258,10 @@ def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
|
|
229
258
|
|
230
259
|
|
231
260
|
def _unflatten_kvc(
|
232
|
-
|
261
|
+
kv_ty: Type[KVCacheBase],
|
262
|
+
kv_entry_type: Type[KVCacheEntryBase],
|
263
|
+
values: List[torch.Tensor],
|
264
|
+
context: Tuple[List, List],
|
233
265
|
) -> KVCacheBase:
|
234
266
|
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
235
267
|
num_layers = len(values) // 2
|
@@ -239,18 +271,18 @@ def _unflatten_kvc(
|
|
239
271
|
k_cache_idx = flat_names.index(f"k_{i}")
|
240
272
|
v_cache_idx = flat_names.index(f"v_{i}")
|
241
273
|
kv_entries.append(
|
242
|
-
|
243
|
-
k_cache=values[k_cache_idx], v_cache=values[v_cache_idx]
|
244
|
-
)
|
274
|
+
kv_entry_type(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
|
245
275
|
)
|
246
|
-
obj =
|
276
|
+
obj = kv_ty(tuple(kv_entries))
|
247
277
|
return obj
|
248
278
|
|
249
279
|
|
250
280
|
pytree.register_pytree_node(
|
251
281
|
KVCacheTransposed,
|
252
282
|
_flatten_kvc,
|
253
|
-
|
283
|
+
functools.partial(
|
284
|
+
_unflatten_kvc, KVCacheTransposed, KVCacheEntryTransposed
|
285
|
+
),
|
254
286
|
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
255
287
|
serialized_type_name="",
|
256
288
|
)
|
@@ -258,7 +290,7 @@ pytree.register_pytree_node(
|
|
258
290
|
pytree.register_pytree_node(
|
259
291
|
KVCacheBase,
|
260
292
|
_flatten_kvc,
|
261
|
-
_unflatten_kvc,
|
293
|
+
functools.partial(_unflatten_kvc, KVCacheBase, KVCacheEntryBase),
|
262
294
|
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
263
295
|
serialized_type_name="",
|
264
296
|
)
|
@@ -16,8 +16,10 @@
|
|
16
16
|
"""A suite of tests to validate KV Cache layer."""
|
17
17
|
|
18
18
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
|
+
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
19
20
|
import ai_edge_torch.generative.layers.model_config as cfg
|
20
21
|
import torch
|
22
|
+
import torch.utils._pytree as pytree
|
21
23
|
|
22
24
|
from absl.testing import absltest as googletest
|
23
25
|
|
@@ -115,6 +117,66 @@ class TestKVLayers(googletest.TestCase):
|
|
115
117
|
self.assertEqual(input_specs[0].arg.name, "kv_k_0")
|
116
118
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
117
119
|
|
120
|
+
def test_pytree_roundtrip_experimental_kv_cache_base(self):
|
121
|
+
NUM_LAYERS = 4
|
122
|
+
config = self._get_test_config(
|
123
|
+
num_layers=NUM_LAYERS,
|
124
|
+
head_dim=2,
|
125
|
+
num_query_groups=1,
|
126
|
+
kv_cache_max_len=4,
|
127
|
+
)
|
128
|
+
kv = kv_utils_experimental.KVCacheBase.from_model_config(
|
129
|
+
config, batch_size=1
|
130
|
+
)
|
131
|
+
flat, treespec = pytree.tree_flatten(kv)
|
132
|
+
self.assertLen(flat, NUM_LAYERS * 2)
|
133
|
+
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
134
|
+
self.assertEqual(kv, kv_unflat)
|
135
|
+
|
136
|
+
def test_pytree_roundtrip_experimental_kv_cache_derived(self):
|
137
|
+
NUM_LAYERS = 4
|
138
|
+
config = self._get_test_config(
|
139
|
+
num_layers=NUM_LAYERS,
|
140
|
+
head_dim=2,
|
141
|
+
num_query_groups=1,
|
142
|
+
kv_cache_max_len=4,
|
143
|
+
)
|
144
|
+
kv = kv_utils_experimental.KVCacheTransposed.from_model_config(
|
145
|
+
config, batch_size=1
|
146
|
+
)
|
147
|
+
flat, treespec = pytree.tree_flatten(kv)
|
148
|
+
self.assertLen(flat, NUM_LAYERS * 2)
|
149
|
+
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
150
|
+
self.assertEqual(kv, kv_unflat)
|
151
|
+
|
152
|
+
def test_pytree_roundtrip_experimental_kv_entry_base(self):
|
153
|
+
attn_config = cfg.AttentionConfig(
|
154
|
+
num_heads=1, head_dim=1, num_query_groups=1
|
155
|
+
)
|
156
|
+
kv = kv_utils_experimental.KVCacheEntryBase.from_model_config(
|
157
|
+
32, attn_config
|
158
|
+
)
|
159
|
+
flat, treespec = pytree.tree_flatten(kv)
|
160
|
+
self.assertLen(flat, 2)
|
161
|
+
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
162
|
+
self.assertEqual(kv, kv_unflat)
|
163
|
+
self.assertIsInstance(kv_unflat, kv_utils_experimental.KVCacheEntryBase)
|
164
|
+
|
165
|
+
def test_pytree_roundtrip_experimental_kv_entry_derived(self):
|
166
|
+
attn_config = cfg.AttentionConfig(
|
167
|
+
num_heads=1, head_dim=1, num_query_groups=1
|
168
|
+
)
|
169
|
+
kv = kv_utils_experimental.KVCacheEntryTransposed.from_model_config(
|
170
|
+
32, attn_config
|
171
|
+
)
|
172
|
+
flat, treespec = pytree.tree_flatten(kv)
|
173
|
+
self.assertLen(flat, 2)
|
174
|
+
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
175
|
+
self.assertEqual(kv, kv_unflat)
|
176
|
+
self.assertIsInstance(
|
177
|
+
kv_unflat, kv_utils_experimental.KVCacheEntryTransposed
|
178
|
+
)
|
179
|
+
|
118
180
|
|
119
181
|
if __name__ == "__main__":
|
120
182
|
googletest.main()
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250306
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=3TrrWqlr-XarP5R47N_A6I7W4epX_4Iuuv5YVSKn-rQ,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -151,7 +151,7 @@ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIr
|
|
151
151
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
152
152
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
153
153
|
ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
|
154
|
-
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=
|
154
|
+
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=0H-Rqtm6ArMxchHSv3eeX8W3AryoF73EFEpGNfjciK8,9996
|
155
155
|
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
|
156
156
|
ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
|
157
157
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -167,7 +167,7 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
|
|
167
167
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
168
168
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
169
169
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
|
170
|
-
ai_edge_torch/generative/test/test_kv_cache.py,sha256=
|
170
|
+
ai_edge_torch/generative/test/test_kv_cache.py,sha256=MBPS-0bDXB0tQSKHa1XwDQeVIfabRbc8JQA99h9fzlQ,5961
|
171
171
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
172
172
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
173
173
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
@@ -233,8 +233,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
233
233
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
234
234
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
235
235
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
236
|
-
ai_edge_torch_nightly-0.4.0.
|
237
|
-
ai_edge_torch_nightly-0.4.0.
|
238
|
-
ai_edge_torch_nightly-0.4.0.
|
239
|
-
ai_edge_torch_nightly-0.4.0.
|
240
|
-
ai_edge_torch_nightly-0.4.0.
|
236
|
+
ai_edge_torch_nightly-0.4.0.dev20250306.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
237
|
+
ai_edge_torch_nightly-0.4.0.dev20250306.dist-info/METADATA,sha256=AGXvfH7AuPCCnrW0vgLAcW2e2jUQLqx1berQqniMFc0,1966
|
238
|
+
ai_edge_torch_nightly-0.4.0.dev20250306.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
239
|
+
ai_edge_torch_nightly-0.4.0.dev20250306.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
240
|
+
ai_edge_torch_nightly-0.4.0.dev20250306.dist-info/RECORD,,
|
File without changes
|
File without changes
|