ai-edge-torch-nightly 0.4.0.dev20250305__py3-none-any.whl → 0.4.0.dev20250307__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,8 +19,8 @@ This is an experimental implementation and is subject to change at any time.
19
19
  """
20
20
 
21
21
  import dataclasses
22
- from typing import List, Tuple
23
-
22
+ import functools
23
+ from typing import Any, List, Tuple, Type
24
24
  from ai_edge_torch.generative.layers import model_config
25
25
  from ai_edge_torch.generative.layers.experimental import types
26
26
  from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
@@ -46,7 +46,7 @@ class KVCacheEntryBase:
46
46
  v_shape: Tuple[int, ...],
47
47
  dtype: torch.dtype = torch.float32,
48
48
  device: torch.device = None,
49
- ) -> "KVCacheEntryBase":
49
+ ):
50
50
  """Build an instance of the class based on model config."""
51
51
  k = torch.zeros(k_shape, dtype=dtype, device=device)
52
52
  v = torch.zeros(v_shape, dtype=dtype, device=device)
@@ -61,7 +61,7 @@ class KVCacheEntryBase:
61
61
  dtype: torch.dtype = torch.float32,
62
62
  device: torch.device = None,
63
63
  batch_size: int = 1,
64
- ) -> "KVCacheEntryBase":
64
+ ):
65
65
  """Build an instance of the class based on model config."""
66
66
  shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
67
67
  return cls._from_model_config(shape, shape, dtype, device)
@@ -87,7 +87,7 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
87
87
  dtype: torch.dtype = torch.float32,
88
88
  device: torch.device = None,
89
89
  batch_size: int = 1,
90
- ) -> "KVCacheEntryBase":
90
+ ):
91
91
  """Build an instance of the class based on model config."""
92
92
  k_shape = (
93
93
  batch_size,
@@ -104,6 +104,35 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
104
104
  return cls._from_model_config(k_shape, v_shape, dtype, device)
105
105
 
106
106
 
107
+ def _flatten_kv_entry(
108
+ kv_e: KVCacheEntryBase,
109
+ ) -> Tuple[List[torch.Tensor], Any]:
110
+ return ([kv_e.k_cache, kv_e.v_cache], None)
111
+
112
+
113
+ def _unflatten_kv_entry(
114
+ kv_entry_ty: Type[KVCacheEntryBase],
115
+ values: List[torch.Tensor],
116
+ unused_context: Any,
117
+ ) -> KVCacheEntryBase:
118
+ return kv_entry_ty(*values)
119
+
120
+
121
+ pytree.register_pytree_node(
122
+ KVCacheEntryTransposed,
123
+ _flatten_kv_entry,
124
+ functools.partial(_unflatten_kv_entry, KVCacheEntryTransposed),
125
+ serialized_type_name="",
126
+ )
127
+
128
+ pytree.register_pytree_node(
129
+ KVCacheEntryBase,
130
+ _flatten_kv_entry,
131
+ functools.partial(_unflatten_kv_entry, KVCacheEntryBase),
132
+ serialized_type_name="",
133
+ )
134
+
135
+
107
136
  @dataclasses.dataclass
108
137
  class KVCacheBase:
109
138
  """A utility class for holding KV cache entries per layer."""
@@ -118,7 +147,7 @@ class KVCacheBase:
118
147
  dtype: torch.dtype = torch.float32,
119
148
  device: torch.device = None,
120
149
  batch_size: int = 1,
121
- ) -> "KVCacheBase":
150
+ ):
122
151
  caches = [
123
152
  kv_entry_cls.from_model_config(
124
153
  config.kv_cache_max,
@@ -139,7 +168,7 @@ class KVCacheBase:
139
168
  dtype: torch.dtype = torch.float32,
140
169
  device: torch.device = None,
141
170
  batch_size: int = 1,
142
- ) -> "KVCacheBase":
171
+ ):
143
172
  """Build an instance of the class based on model config.
144
173
 
145
174
  Args:
@@ -179,7 +208,7 @@ class KVCacheBTNH(KVCacheBase):
179
208
  dtype: torch.dtype = torch.float32,
180
209
  device: torch.device = None,
181
210
  batch_size: int = 1,
182
- ) -> "KVCacheBTNH":
211
+ ):
183
212
  return cls._from_model_config(
184
213
  KVCacheEntryBTNH,
185
214
  config=config,
@@ -199,7 +228,7 @@ class KVCacheTransposed(KVCacheBase):
199
228
  dtype: torch.dtype = torch.float32,
200
229
  device: torch.device = None,
201
230
  batch_size: int = 1,
202
- ) -> "KVCacheBTNH":
231
+ ):
203
232
  return cls._from_model_config(
204
233
  KVCacheEntryTransposed,
205
234
  config=config,
@@ -229,7 +258,10 @@ def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
229
258
 
230
259
 
231
260
  def _unflatten_kvc(
232
- values: List[torch.Tensor], context: Tuple[List, List]
261
+ kv_ty: Type[KVCacheBase],
262
+ kv_entry_type: Type[KVCacheEntryBase],
263
+ values: List[torch.Tensor],
264
+ context: Tuple[List, List],
233
265
  ) -> KVCacheBase:
234
266
  assert len(values) % 2 == 0, "Found odd number of K and V entries."
235
267
  num_layers = len(values) // 2
@@ -239,18 +271,18 @@ def _unflatten_kvc(
239
271
  k_cache_idx = flat_names.index(f"k_{i}")
240
272
  v_cache_idx = flat_names.index(f"v_{i}")
241
273
  kv_entries.append(
242
- KVCacheEntryBase(
243
- k_cache=values[k_cache_idx], v_cache=values[v_cache_idx]
244
- )
274
+ kv_entry_type(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
245
275
  )
246
- obj = KVCacheBase(tuple(kv_entries))
276
+ obj = kv_ty(tuple(kv_entries))
247
277
  return obj
248
278
 
249
279
 
250
280
  pytree.register_pytree_node(
251
281
  KVCacheTransposed,
252
282
  _flatten_kvc,
253
- _unflatten_kvc,
283
+ functools.partial(
284
+ _unflatten_kvc, KVCacheTransposed, KVCacheEntryTransposed
285
+ ),
254
286
  flatten_with_keys_fn=_flatten_kvc_with_keys,
255
287
  serialized_type_name="",
256
288
  )
@@ -258,7 +290,7 @@ pytree.register_pytree_node(
258
290
  pytree.register_pytree_node(
259
291
  KVCacheBase,
260
292
  _flatten_kvc,
261
- _unflatten_kvc,
293
+ functools.partial(_unflatten_kvc, KVCacheBase, KVCacheEntryBase),
262
294
  flatten_with_keys_fn=_flatten_kvc_with_keys,
263
295
  serialized_type_name="",
264
296
  )
@@ -16,8 +16,10 @@
16
16
  """A suite of tests to validate KV Cache layer."""
17
17
 
18
18
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
19
20
  import ai_edge_torch.generative.layers.model_config as cfg
20
21
  import torch
22
+ import torch.utils._pytree as pytree
21
23
 
22
24
  from absl.testing import absltest as googletest
23
25
 
@@ -115,6 +117,66 @@ class TestKVLayers(googletest.TestCase):
115
117
  self.assertEqual(input_specs[0].arg.name, "kv_k_0")
116
118
  self.assertEqual(input_specs[1].arg.name, "kv_v_0")
117
119
 
120
+ def test_pytree_roundtrip_experimental_kv_cache_base(self):
121
+ NUM_LAYERS = 4
122
+ config = self._get_test_config(
123
+ num_layers=NUM_LAYERS,
124
+ head_dim=2,
125
+ num_query_groups=1,
126
+ kv_cache_max_len=4,
127
+ )
128
+ kv = kv_utils_experimental.KVCacheBase.from_model_config(
129
+ config, batch_size=1
130
+ )
131
+ flat, treespec = pytree.tree_flatten(kv)
132
+ self.assertLen(flat, NUM_LAYERS * 2)
133
+ kv_unflat = pytree.tree_unflatten(flat, treespec)
134
+ self.assertEqual(kv, kv_unflat)
135
+
136
+ def test_pytree_roundtrip_experimental_kv_cache_derived(self):
137
+ NUM_LAYERS = 4
138
+ config = self._get_test_config(
139
+ num_layers=NUM_LAYERS,
140
+ head_dim=2,
141
+ num_query_groups=1,
142
+ kv_cache_max_len=4,
143
+ )
144
+ kv = kv_utils_experimental.KVCacheTransposed.from_model_config(
145
+ config, batch_size=1
146
+ )
147
+ flat, treespec = pytree.tree_flatten(kv)
148
+ self.assertLen(flat, NUM_LAYERS * 2)
149
+ kv_unflat = pytree.tree_unflatten(flat, treespec)
150
+ self.assertEqual(kv, kv_unflat)
151
+
152
+ def test_pytree_roundtrip_experimental_kv_entry_base(self):
153
+ attn_config = cfg.AttentionConfig(
154
+ num_heads=1, head_dim=1, num_query_groups=1
155
+ )
156
+ kv = kv_utils_experimental.KVCacheEntryBase.from_model_config(
157
+ 32, attn_config
158
+ )
159
+ flat, treespec = pytree.tree_flatten(kv)
160
+ self.assertLen(flat, 2)
161
+ kv_unflat = pytree.tree_unflatten(flat, treespec)
162
+ self.assertEqual(kv, kv_unflat)
163
+ self.assertIsInstance(kv_unflat, kv_utils_experimental.KVCacheEntryBase)
164
+
165
+ def test_pytree_roundtrip_experimental_kv_entry_derived(self):
166
+ attn_config = cfg.AttentionConfig(
167
+ num_heads=1, head_dim=1, num_query_groups=1
168
+ )
169
+ kv = kv_utils_experimental.KVCacheEntryTransposed.from_model_config(
170
+ 32, attn_config
171
+ )
172
+ flat, treespec = pytree.tree_flatten(kv)
173
+ self.assertLen(flat, 2)
174
+ kv_unflat = pytree.tree_unflatten(flat, treespec)
175
+ self.assertEqual(kv, kv_unflat)
176
+ self.assertIsInstance(
177
+ kv_unflat, kv_utils_experimental.KVCacheEntryTransposed
178
+ )
179
+
118
180
 
119
181
  if __name__ == "__main__":
120
182
  googletest.main()
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250305"
16
+ __version__ = "0.4.0.dev20250307"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250305
3
+ Version: 0.4.0.dev20250307
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=KSuxuG4iAZMdD9Pi3Eg36fWk8j5YbKwBegeGyh08BIg,706
5
+ ai_edge_torch/version.py,sha256=B2NHP6eHVBZO3ZheJ-I5AunC4KYgMmwSGOtNsYjpXqw,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -151,7 +151,7 @@ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIr
151
151
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
152
152
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
153
153
  ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
154
- ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=VN4gn4ylaVOwaTR5EXKv0YTVgpQ850bmjGLCgCCI1ps,9267
154
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=0H-Rqtm6ArMxchHSv3eeX8W3AryoF73EFEpGNfjciK8,9996
155
155
  ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
156
156
  ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
157
157
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -167,7 +167,7 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
167
167
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
168
168
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
169
169
  ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
170
- ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
170
+ ai_edge_torch/generative/test/test_kv_cache.py,sha256=MBPS-0bDXB0tQSKHa1XwDQeVIfabRbc8JQA99h9fzlQ,5961
171
171
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
172
172
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
173
173
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
@@ -233,8 +233,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
233
233
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
234
234
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
235
235
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
236
- ai_edge_torch_nightly-0.4.0.dev20250305.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
237
- ai_edge_torch_nightly-0.4.0.dev20250305.dist-info/METADATA,sha256=PFPy_Qd9oHyUNQFnd04iTXM5UKD0R_azKA3QLYAWr8o,1966
238
- ai_edge_torch_nightly-0.4.0.dev20250305.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
239
- ai_edge_torch_nightly-0.4.0.dev20250305.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
240
- ai_edge_torch_nightly-0.4.0.dev20250305.dist-info/RECORD,,
236
+ ai_edge_torch_nightly-0.4.0.dev20250307.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
237
+ ai_edge_torch_nightly-0.4.0.dev20250307.dist-info/METADATA,sha256=p765RF7LPV48LOEn4_itVn84dtUjz-C41k11Id26TyI,1966
238
+ ai_edge_torch_nightly-0.4.0.dev20250307.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
239
+ ai_edge_torch_nightly-0.4.0.dev20250307.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
240
+ ai_edge_torch_nightly-0.4.0.dev20250307.dist-info/RECORD,,