ai-edge-torch-nightly 0.5.0.dev20250408__py3-none-any.whl → 0.5.0.dev20250409__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,7 +17,7 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma3 import gemma3
20
- from ai_edge_torch.generative.layers.experimental import kv_cache
20
+ from ai_edge_torch.generative.layers import kv_cache
21
21
  from ai_edge_torch.generative.utilities import converter
22
22
  from ai_edge_torch.generative.utilities import export_config
23
23
  import torch
@@ -58,7 +58,7 @@ def _create_export_config(
58
58
  )
59
59
  decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
60
60
  export_config.decode_mask = decode_mask
61
- export_config.kvcache_cls = kv_cache.KVCacheTransposed
61
+ export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
62
62
  return export_config
63
63
 
64
64
 
@@ -18,9 +18,9 @@
18
18
  from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import builder
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
22
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
23
  from ai_edge_torch.generative.layers.experimental import attention
23
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
26
  from ai_edge_torch.generative.utilities import export_config as export_cfg
@@ -81,8 +81,8 @@ class DecoderBlock(attention.TransformerBlock):
81
81
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
82
82
  mask: Optional[torch.Tensor] = None,
83
83
  input_pos: Optional[torch.Tensor] = None,
84
- kv_cache: kv_utils.KVCacheEntryBase = None,
85
- ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntryBase]]:
84
+ kv_cache: kv_utils.KVCacheEntry = None,
85
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
86
86
  """Forward function of the Gemma3Block.
87
87
 
88
88
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -241,13 +241,12 @@ class Decoder(nn.Module):
241
241
  self,
242
242
  tokens: torch.Tensor,
243
243
  input_pos: torch.Tensor,
244
- kv_cache: kv_utils.KVCacheBase,
244
+ kv_cache: kv_utils.KVCache,
245
245
  input_embeds: Optional[torch.Tensor] = None,
246
246
  mask: Optional[torch.Tensor] = None,
247
247
  image_indices: Optional[torch.Tensor] = None,
248
248
  export_config: Optional[export_cfg.ExportConfig] = None,
249
- ) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
250
-
249
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
251
250
  pixel_mask = None
252
251
  if input_embeds is None:
253
252
  # token embeddings of shape (b, t, n_embd)
@@ -287,10 +286,10 @@ class Decoder(nn.Module):
287
286
  rope: List[Tuple[torch.Tensor, torch.Tensor]],
288
287
  mask: torch.Tensor | List[torch.Tensor],
289
288
  input_pos: torch.Tensor,
290
- kv_cache: kv_utils.KVCacheBase,
289
+ kv_cache: kv_utils.KVCache,
291
290
  pixel_mask: Optional[torch.Tensor] = None,
292
291
  export_config: Optional[export_cfg.ExportConfig] = None,
293
- ) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
292
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
294
293
  """Forwards the model with input embeddings."""
295
294
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
296
295
  "The number of transformer blocks and the number of KV cache entries"
@@ -326,7 +325,7 @@ class Decoder(nn.Module):
326
325
  x, kv_entry = block(x, rope[i], mask_entry, input_pos, kv_entry)
327
326
  if kv_entry:
328
327
  updated_kv_entries.append(kv_entry)
329
- updated_kv_cache = kv_utils.KVCacheBase(tuple(updated_kv_entries))
328
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
330
329
  if export_config is not None:
331
330
  if (
332
331
  torch.numel(input_pos) > 1
@@ -20,8 +20,8 @@ import os
20
20
  from typing import List, Optional, Tuple
21
21
 
22
22
  from ai_edge_torch.generative.examples.gemma3 import gemma3
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
25
25
  from ai_edge_torch.generative.utilities.experimental import verifier
26
26
  from gemma import config as gemma_config
27
27
  from gemma import model as gemma_model
@@ -94,7 +94,9 @@ class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
94
94
 
95
95
  def _init_kv_cache(self):
96
96
  """Returns an initialized KV cache."""
97
- return kv_utils.KVCacheTransposed.from_model_config(self.model.model.config)
97
+ return kv_utils.KVCache.from_model_config(
98
+ self.model.model.config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
99
+ )
98
100
 
99
101
  def forward(
100
102
  self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
@@ -22,8 +22,9 @@ at any time.
22
22
  from typing import Optional, Tuple, Union
23
23
 
24
24
  from ai_edge_torch.generative.layers import builder
25
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
26
  from ai_edge_torch.generative.layers import lora as lora_utils
26
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
27
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
27
28
  from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
28
29
  import ai_edge_torch.generative.layers.model_config as cfg
29
30
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
@@ -69,9 +70,9 @@ class TransformerBlock(nn.Module):
69
70
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
70
71
  mask: Optional[torch.Tensor] = None,
71
72
  input_pos: Optional[torch.Tensor] = None,
72
- kv_cache: kv_utils.KVCacheEntryBase = None,
73
+ kv_cache: kv_utils.KVCacheEntry = None,
73
74
  lora: Optional[lora_utils.LoRAEntry] = None,
74
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
75
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
75
76
  """Forward function of the TransformerBlock.
76
77
 
77
78
  Args:
@@ -79,7 +80,7 @@ class TransformerBlock(nn.Module):
79
80
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
80
81
  mask (torch.Tensor): the optional mask tensor.
81
82
  input_pos (torch.Tensor): the optional input position tensor.
82
- kv_cache (KVCacheEntryBase): the optional kv cache entry.
83
+ kv_cache (KVCacheEntry): the optional kv cache entry.
83
84
  lora (LoRAEntry): the optional lora entry.
84
85
 
85
86
  Returns:
@@ -154,9 +155,9 @@ class CausalSelfAttention(nn.Module):
154
155
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155
156
  mask: Optional[torch.Tensor] = None,
156
157
  input_pos: Optional[torch.Tensor] = None,
157
- kv_cache: Optional[kv_utils.KVCacheEntryBase] = None,
158
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
158
159
  lora: Optional[lora_utils.LoRAEntry] = None,
159
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
160
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
160
161
  """Forward function of the CausalSelfAttention layer, which can support
161
162
 
162
163
  MQA, GQA and MHA.
@@ -166,8 +167,7 @@ class CausalSelfAttention(nn.Module):
166
167
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
167
168
  mask (torch.Tensor): the optional mask tensor.
168
169
  input_pos (torch.Tensor): the optional input position tensor.
169
- kv_cache (KVCacheEntryBase): the KV cache entry corresponding to this
170
- module.
170
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
171
171
  lora (LoRAEntry): the optional lora entry.
172
172
 
173
173
  Returns:
@@ -237,7 +237,7 @@ class CausalSelfAttention(nn.Module):
237
237
  ) # 1, bk, h, s
238
238
 
239
239
  if kv_cache is not None:
240
- kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
240
+ kv_cache = kv_utils_experimental.update(kv_cache, input_pos, k, v)
241
241
  k, v = kv_cache.k_cache, kv_cache.v_cache
242
242
 
243
243
  sdpa_out = self.sdpa_func(
@@ -18,304 +18,33 @@
18
18
  This is an experimental implementation and is subject to change at any time.
19
19
  """
20
20
 
21
- import dataclasses
22
- import functools
23
- from typing import Any, List, Tuple, Type
24
- from ai_edge_torch.generative.layers import model_config
25
- from ai_edge_torch.generative.layers.experimental import types
26
21
  from ai_edge_torch.generative.custom_ops import dynamic_update_slice as dus_utils
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
27
23
  import torch
28
- import torch.utils._pytree as pytree
29
-
30
-
31
- @dataclasses.dataclass
32
- class KVCacheEntryBase:
33
- """A single cache entry that includes K and V caches.
34
-
35
- The chaches are built based on the provided config with the shape of
36
- (batch_size, kv_cache_max, num_query_groups, head_dim).
37
- """
38
-
39
- k_cache: torch.Tensor
40
- v_cache: torch.Tensor
41
-
42
- @classmethod
43
- def _from_model_config(
44
- cls,
45
- k_shape: Tuple[int, ...],
46
- v_shape: Tuple[int, ...],
47
- dtype: torch.dtype = torch.float32,
48
- device: torch.device = None,
49
- ):
50
- """Build an instance of the class based on model config."""
51
- k = torch.zeros(k_shape, dtype=dtype, device=device)
52
- v = torch.zeros(v_shape, dtype=dtype, device=device)
53
- obj = cls(k_cache=k, v_cache=v)
54
- return obj
55
-
56
- @classmethod
57
- def from_model_config(
58
- cls,
59
- kv_cache_max: int,
60
- config: model_config.AttentionConfig,
61
- dtype: torch.dtype = torch.float32,
62
- device: torch.device = None,
63
- batch_size: int = 1,
64
- ):
65
- """Build an instance of the class based on model config."""
66
- shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
67
- return cls._from_model_config(shape, shape, dtype, device)
68
-
69
-
70
- @dataclasses.dataclass
71
- class KVCacheEntryBTNH(KVCacheEntryBase):
72
- k_type = types.BTNH()
73
- v_type = types.BTNH()
74
-
75
-
76
- @dataclasses.dataclass
77
- class KVCacheEntryTransposed(KVCacheEntryBase):
78
-
79
- k_type = types.BNTH()
80
- v_type = types.BNHT()
81
-
82
- @classmethod
83
- def from_model_config(
84
- cls,
85
- kv_cache_max: int,
86
- config: model_config.AttentionConfig,
87
- dtype: torch.dtype = torch.float32,
88
- device: torch.device = None,
89
- batch_size: int = 1,
90
- ):
91
- """Build an instance of the class based on model config."""
92
- k_shape = (
93
- batch_size,
94
- config.num_query_groups,
95
- kv_cache_max,
96
- config.head_dim,
97
- ) # b, k, s, h
98
- v_shape = (
99
- batch_size,
100
- config.num_query_groups,
101
- config.head_dim,
102
- kv_cache_max,
103
- ) # b, k, h, s
104
- return cls._from_model_config(k_shape, v_shape, dtype, device)
105
-
106
-
107
- def _flatten_kv_entry(
108
- kv_e: KVCacheEntryBase,
109
- ) -> Tuple[List[torch.Tensor], Any]:
110
- return ([kv_e.k_cache, kv_e.v_cache], None)
111
-
112
-
113
- def _unflatten_kv_entry(
114
- kv_entry_ty: Type[KVCacheEntryBase],
115
- values: List[torch.Tensor],
116
- unused_context: Any,
117
- ) -> KVCacheEntryBase:
118
- return kv_entry_ty(*values)
119
-
120
-
121
- pytree.register_pytree_node(
122
- KVCacheEntryTransposed,
123
- _flatten_kv_entry,
124
- functools.partial(_unflatten_kv_entry, KVCacheEntryTransposed),
125
- serialized_type_name="",
126
- )
127
-
128
- pytree.register_pytree_node(
129
- KVCacheEntryBase,
130
- _flatten_kv_entry,
131
- functools.partial(_unflatten_kv_entry, KVCacheEntryBase),
132
- serialized_type_name="",
133
- )
134
-
135
-
136
- @dataclasses.dataclass
137
- class KVCacheBase:
138
- """A utility class for holding KV cache entries per layer."""
139
-
140
- caches: Tuple[KVCacheEntryBase, ...]
141
-
142
- @classmethod
143
- def _from_model_config(
144
- cls,
145
- kv_entry_cls,
146
- config: model_config.ModelConfig,
147
- dtype: torch.dtype = torch.float32,
148
- device: torch.device = None,
149
- batch_size: int = 1,
150
- ):
151
- caches = [
152
- kv_entry_cls.from_model_config(
153
- config.kv_cache_max,
154
- config.block_config(idx).attn_config,
155
- dtype,
156
- device,
157
- batch_size,
158
- )
159
- for idx in range(config.num_layers)
160
- ]
161
- obj = cls(caches=tuple(caches))
162
- return obj
163
-
164
- @classmethod
165
- def from_model_config(
166
- cls,
167
- config: model_config.ModelConfig,
168
- dtype: torch.dtype = torch.float32,
169
- device: torch.device = None,
170
- batch_size: int = 1,
171
- ):
172
- """Build an instance of the class based on model config.
173
-
174
- Args:
175
- config (ModelConfig): Model config used for building the cache.
176
- dtype (torch.dtype, optional): The data type of the cache tensor.
177
- Defaults to torch.float32.
178
- device (torch.device, optional): The device placement of the cache
179
- tensors. Defaults to None.
180
- batch_size (int, optional): The batch size of the cache tensors.
181
- Defaults to 1.
182
-
183
- Returns:
184
- KVCacheBase: The created cache object.
185
- """
186
- assert batch_size == 1, "Batch size must be 1 for KV Cache."
187
- return cls._from_model_config(
188
- KVCacheEntryBase,
189
- config=config,
190
- dtype=dtype,
191
- device=device,
192
- batch_size=batch_size,
193
- )
194
-
195
- def flatten(self) -> List[torch.Tensor]:
196
- """Flatten the cache entries into a list of tensors with order k_i, v_i."""
197
- flattened, _ = _flatten_kvc(self)
198
- return flattened
199
-
200
-
201
- @dataclasses.dataclass
202
- class KVCacheBTNH(KVCacheBase):
203
-
204
- @classmethod
205
- def from_model_config(
206
- cls,
207
- config: model_config.ModelConfig,
208
- dtype: torch.dtype = torch.float32,
209
- device: torch.device = None,
210
- batch_size: int = 1,
211
- ):
212
- return cls._from_model_config(
213
- KVCacheEntryBTNH,
214
- config=config,
215
- dtype=dtype,
216
- device=device,
217
- batch_size=batch_size,
218
- )
219
-
220
-
221
- @dataclasses.dataclass
222
- class KVCacheTransposed(KVCacheBase):
223
-
224
- @classmethod
225
- def from_model_config(
226
- cls,
227
- config: model_config.ModelConfig,
228
- dtype: torch.dtype = torch.float32,
229
- device: torch.device = None,
230
- batch_size: int = 1,
231
- ):
232
- return cls._from_model_config(
233
- KVCacheEntryTransposed,
234
- config=config,
235
- dtype=dtype,
236
- device=device,
237
- batch_size=batch_size,
238
- )
239
-
240
-
241
- def _flatten_kvc(kvc: KVCacheBase) -> Tuple[List[str], List[str]]:
242
- flattened = []
243
- flat_names = []
244
- none_names = []
245
- for i, kv_entry in enumerate(kvc.caches):
246
- flattened.append(kv_entry.k_cache)
247
- flat_names.append(f"k_{i}")
248
- flattened.append(kv_entry.v_cache)
249
- flat_names.append(f"v_{i}")
250
- return flattened, [flat_names, none_names]
251
-
252
-
253
- def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
254
- flattened, (flat_names, none_names) = _flatten_kvc(kvc)
255
- return [
256
- (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
257
- ], flat_names
258
-
259
-
260
- def _unflatten_kvc(
261
- kv_ty: Type[KVCacheBase],
262
- kv_entry_type: Type[KVCacheEntryBase],
263
- values: List[torch.Tensor],
264
- context: Tuple[List, List],
265
- ) -> KVCacheBase:
266
- assert len(values) % 2 == 0, "Found odd number of K and V entries."
267
- num_layers = len(values) // 2
268
- flat_names = context[0]
269
- kv_entries = []
270
- for i in range(num_layers):
271
- k_cache_idx = flat_names.index(f"k_{i}")
272
- v_cache_idx = flat_names.index(f"v_{i}")
273
- kv_entries.append(
274
- kv_entry_type(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
275
- )
276
- obj = kv_ty(tuple(kv_entries))
277
- return obj
278
-
279
-
280
- pytree.register_pytree_node(
281
- KVCacheTransposed,
282
- _flatten_kvc,
283
- functools.partial(
284
- _unflatten_kvc, KVCacheTransposed, KVCacheEntryTransposed
285
- ),
286
- flatten_with_keys_fn=_flatten_kvc_with_keys,
287
- serialized_type_name="",
288
- )
289
-
290
- pytree.register_pytree_node(
291
- KVCacheBase,
292
- _flatten_kvc,
293
- functools.partial(_unflatten_kvc, KVCacheBase, KVCacheEntryBase),
294
- flatten_with_keys_fn=_flatten_kvc_with_keys,
295
- serialized_type_name="",
296
- )
297
24
 
298
25
 
299
26
  def update(
300
- cache: KVCacheEntryBase,
27
+ cache: kv_utils.KVCacheEntry,
301
28
  input_pos: torch.Tensor,
302
29
  k_slice: torch.Tensor,
303
30
  v_slice: torch.Tensor,
304
- ) -> KVCacheEntryBase:
31
+ ) -> kv_utils.KVCacheEntry:
305
32
  """Out of place update of Cache buffer.
306
33
 
307
34
  Args:
308
- cache (KVCacheEntryBase): The original cache buffer.
35
+ cache (kv_utils.KVCacheEntry): The original cache buffer.
309
36
  input_pos (torch.Tensor): The update slice positions.
310
37
  k_slice (torch.Tensor): The K slice to be updated in the new cache.
311
38
  v_slice (torch.Tensor): The V slice to be updated in the new cache.
312
39
 
313
40
  Returns:
314
- KVCacheEntryBase: The updated KVCacheBase entry based on the passed
41
+ kv_utils.KVCacheEntry: The updated KVCacheBase entry based on the passed
315
42
  inputs.
316
43
  """
317
- update_kv_cache = _update_kv_impl
318
- return update_kv_cache(cache, input_pos, k_slice, v_slice)
44
+ assert (
45
+ cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
46
+ ), "KV entry must have transposed layout."
47
+ return _update_kv_impl_transposed(cache, input_pos, k_slice, v_slice)
319
48
 
320
49
 
321
50
  def _get_slice_indices(
@@ -338,12 +67,12 @@ def _get_slice_indices(
338
67
  return slice_indices
339
68
 
340
69
 
341
- def _update_kv_impl(
342
- cache: KVCacheEntryTransposed,
70
+ def _update_kv_impl_transposed(
71
+ cache: kv_utils.KVCacheEntry,
343
72
  input_pos: torch.Tensor,
344
73
  k_slice: torch.Tensor,
345
74
  v_slice: torch.Tensor,
346
- ) -> KVCacheEntryTransposed:
75
+ ) -> kv_utils.KVCacheEntry:
347
76
  """Update the cache buffer with High Level Function Boundary annotation."""
348
77
  cache_dim = 4
349
78
  k_ts_idx = 2
@@ -357,4 +86,4 @@ def _update_kv_impl(
357
86
  v = dus_utils.dynamic_update_slice(
358
87
  cache.v_cache, v_slice, [x for x in v_slice_indices]
359
88
  )
360
- return KVCacheEntryTransposed(k, v)
89
+ return kv_utils.KVCacheEntry(k, v, cache.kv_layout)
@@ -19,7 +19,7 @@ import math
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
22
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  from ai_edge_torch.generative.layers.experimental import types
24
24
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
25
  from multipledispatch import dispatch
@@ -28,7 +28,7 @@ import torch.nn.functional as F
28
28
 
29
29
 
30
30
  def scaled_dot_product_attention(
31
- kv: kv_utils.KVCacheBase,
31
+ kv: kv_utils.KVCacheEntry,
32
32
  query: torch.Tensor,
33
33
  key: torch.Tensor,
34
34
  value: torch.Tensor,
@@ -37,10 +37,10 @@ def scaled_dot_product_attention(
37
37
  scale: Optional[float] = None,
38
38
  softcap: Optional[float] = None,
39
39
  ):
40
- if hasattr(kv, "k_type") and hasattr(kv, "v_type"):
40
+ if hasattr(kv, "kv_layout"):
41
41
  return _sdpa(
42
- kv.k_type,
43
- kv.v_type,
42
+ kv.kv_layout[0](), # key layout
43
+ kv.kv_layout[1](), # value layout
44
44
  query=query,
45
45
  key=key,
46
46
  value=value,
@@ -49,10 +49,7 @@ def scaled_dot_product_attention(
49
49
  scale=scale,
50
50
  softcap=softcap,
51
51
  )
52
- raise ValueError(
53
- f"SDPA for K type {type(kv.caches[0].k_type)} and V type"
54
- f" {type(kv.caches[0].v_type)} not supported."
55
- )
52
+ raise ValueError("No kv_layout attribute found in kv.")
56
53
 
57
54
 
58
55
  @dispatch(types.BNTH, types.BNHT)
@@ -62,6 +62,9 @@ class TensorDimensionMeta(type):
62
62
  def __repr__(cls):
63
63
  return f'{cls.__name__}'
64
64
 
65
+ def __iter__(cls):
66
+ return iter(getattr(cls, 'dimensions'))
67
+
65
68
 
66
69
  def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
67
70
  """Creates a TensorDimensionMeta class with the specified dimensions.
@@ -16,24 +16,58 @@
16
16
  """Utility functions for externalized KV Cache."""
17
17
 
18
18
  import dataclasses
19
- from typing import List, Tuple
19
+ from typing import Any, List, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.custom_ops.dynamic_update_slice import dynamic_update_slice
22
22
  from ai_edge_torch.generative.layers import model_config
23
+ from ai_edge_torch.generative.layers.experimental import types
23
24
  import torch
24
25
  import torch.utils._pytree as pytree
25
26
 
26
27
 
28
+ KVLayout = Tuple[types.TensorDimensionMeta, types.TensorDimensionMeta]
29
+
30
+ # Define common layouts for KV Cache.
31
+ KV_LAYOUT_DEFAULT = (types.BTNH, types.BTNH)
32
+ KV_LAYOUT_TRANSPOSED = (types.BNTH, types.BNHT)
33
+
34
+
27
35
  @dataclasses.dataclass
28
36
  class KVCacheEntry:
29
37
  """A single cache entry that includes K and V caches.
30
38
 
31
- The chaches are built based on the provided config with the shape of
32
- (batch_size=1, kv_cache_max, num_query_groups, head_dim).
39
+ The cache layout can be customized based on different use cases.
33
40
  """
34
41
 
35
42
  k_cache: torch.Tensor
36
43
  v_cache: torch.Tensor
44
+ kv_layout: KVLayout = KV_LAYOUT_DEFAULT
45
+
46
+ @classmethod
47
+ def construct_kv_shape_from_layout(
48
+ cls,
49
+ shape_spec: types.TensorDimensionMeta,
50
+ kv_cache_max: int,
51
+ config: model_config.AttentionConfig,
52
+ batch_size: int,
53
+ ) -> List[int]:
54
+ """Constructs the shape of the key or value cache entry based on
55
+
56
+ the specified layout.
57
+ """
58
+ output_shape = []
59
+ for dim_spec in shape_spec:
60
+ if dim_spec is types.TensorDims.BATCH:
61
+ output_shape.append(batch_size)
62
+ elif dim_spec is types.TensorDims.SEQUENCE:
63
+ output_shape.append(kv_cache_max)
64
+ elif dim_spec is types.TensorDims.NUM_HEADS:
65
+ output_shape.append(config.num_query_groups)
66
+ elif dim_spec is types.TensorDims.HEAD_DIM:
67
+ output_shape.append(config.head_dim)
68
+ else:
69
+ raise ValueError(f"Unsupported dimension spec: {dim_spec}")
70
+ return output_shape
37
71
 
38
72
  @classmethod
39
73
  def from_model_config(
@@ -41,14 +75,20 @@ class KVCacheEntry:
41
75
  kv_cache_max: int,
42
76
  config: model_config.AttentionConfig,
43
77
  dtype: torch.dtype = torch.float32,
44
- device: torch.device = None,
78
+ device: torch.device | None = None,
45
79
  batch_size: int = 1,
80
+ kv_layout: KVLayout = KV_LAYOUT_DEFAULT,
46
81
  ) -> "KVCacheEntry":
47
82
  """Build an instance of the class based on model config."""
48
- shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
49
- k = torch.zeros(shape, dtype=dtype, device=device)
50
- v = torch.zeros(shape, dtype=dtype, device=device)
51
- obj = cls(k_cache=k, v_cache=v)
83
+ k_shape = cls.construct_kv_shape_from_layout(
84
+ kv_layout[0], kv_cache_max, config, batch_size
85
+ )
86
+ v_shape = cls.construct_kv_shape_from_layout(
87
+ kv_layout[1], kv_cache_max, config, batch_size
88
+ )
89
+ k = torch.zeros(k_shape, dtype=dtype, device=device)
90
+ v = torch.zeros(v_shape, dtype=dtype, device=device)
91
+ obj = cls(k_cache=k, v_cache=v, kv_layout=kv_layout)
52
92
  return obj
53
93
 
54
94
 
@@ -63,8 +103,9 @@ class KVCache:
63
103
  cls,
64
104
  config: model_config.ModelConfig,
65
105
  dtype: torch.dtype = torch.float32,
66
- device: torch.device = None,
106
+ device: torch.device | None = None,
67
107
  batch_size: int = 1,
108
+ kv_layout: KVLayout = KV_LAYOUT_DEFAULT,
68
109
  ) -> "KVCache":
69
110
  """Build an instance of the class based on model config.
70
111
 
@@ -89,6 +130,7 @@ class KVCache:
89
130
  dtype,
90
131
  device,
91
132
  batch_size,
133
+ kv_layout,
92
134
  )
93
135
  for idx in range(config.num_layers)
94
136
  ]
@@ -104,7 +146,7 @@ class KVCache:
104
146
  def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
105
147
  flattened = []
106
148
  flat_names = []
107
- none_names = []
149
+ none_names = [kvc.caches[0].kv_layout]
108
150
  for i, kv_entry in enumerate(kvc.caches):
109
151
  flattened.append(kv_entry.k_cache)
110
152
  flat_names.append(f"k_{i}")
@@ -121,22 +163,48 @@ def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
121
163
 
122
164
 
123
165
  def _unflatten_kvc(
124
- values: List[torch.Tensor], context: Tuple[List, List]
166
+ values: List[torch.Tensor],
167
+ context: Tuple[List, List],
125
168
  ) -> KVCache:
126
169
  assert len(values) % 2 == 0, "Found odd number of K and V entries."
127
170
  num_layers = len(values) // 2
128
171
  flat_names = context[0]
172
+ kv_layout = context[1][0]
129
173
  kv_entries = []
130
174
  for i in range(num_layers):
131
175
  k_cache_idx = flat_names.index(f"k_{i}")
132
176
  v_cache_idx = flat_names.index(f"v_{i}")
133
177
  kv_entries.append(
134
- KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
178
+ KVCacheEntry(
179
+ k_cache=values[k_cache_idx],
180
+ v_cache=values[v_cache_idx],
181
+ kv_layout=kv_layout,
182
+ )
135
183
  )
136
184
  obj = KVCache(tuple(kv_entries))
137
185
  return obj
138
186
 
139
187
 
188
+ def _flatten_kv_entry(
189
+ kv_e: KVCacheEntry,
190
+ ) -> Tuple[List[torch.Tensor], Any]:
191
+ return ([kv_e.k_cache, kv_e.v_cache], kv_e.kv_layout)
192
+
193
+
194
+ def _unflatten_kv_entry(
195
+ values: List[torch.Tensor],
196
+ context: Any,
197
+ ) -> KVCacheEntry:
198
+ return KVCacheEntry(*values, kv_layout=context)
199
+
200
+
201
+ pytree.register_pytree_node(
202
+ KVCacheEntry,
203
+ _flatten_kv_entry,
204
+ _unflatten_kv_entry,
205
+ serialized_type_name="",
206
+ )
207
+
140
208
  pytree.register_pytree_node(
141
209
  KVCache,
142
210
  _flatten_kvc,
@@ -145,7 +213,6 @@ pytree.register_pytree_node(
145
213
  serialized_type_name="",
146
214
  )
147
215
 
148
-
149
216
  def update(
150
217
  cache: KVCacheEntry,
151
218
  input_pos: torch.Tensor,
@@ -204,5 +271,5 @@ def _update_kv_impl(
204
271
  k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
205
272
  v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
206
273
 
207
- updated_cache = KVCacheEntry(k, v)
274
+ updated_cache = KVCacheEntry(k, v, cache.kv_layout)
208
275
  return updated_cache
@@ -16,7 +16,6 @@
16
16
  """A suite of tests to validate KV Cache layer."""
17
17
 
18
18
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
20
19
  import ai_edge_torch.generative.layers.model_config as cfg
21
20
  import torch
22
21
  import torch.utils._pytree as pytree
@@ -117,7 +116,7 @@ class TestKVLayers(googletest.TestCase):
117
116
  self.assertEqual(input_specs[0].arg.name, "kv_k_0")
118
117
  self.assertEqual(input_specs[1].arg.name, "kv_v_0")
119
118
 
120
- def test_pytree_roundtrip_experimental_kv_cache_base(self):
119
+ def test_pytree_roundtrip_kv_cache(self):
121
120
  NUM_LAYERS = 4
122
121
  config = self._get_test_config(
123
122
  num_layers=NUM_LAYERS,
@@ -125,15 +124,13 @@ class TestKVLayers(googletest.TestCase):
125
124
  num_query_groups=1,
126
125
  kv_cache_max_len=4,
127
126
  )
128
- kv = kv_utils_experimental.KVCacheBase.from_model_config(
129
- config, batch_size=1
130
- )
127
+ kv = kv_utils.KVCache.from_model_config(config, batch_size=1)
131
128
  flat, treespec = pytree.tree_flatten(kv)
132
129
  self.assertLen(flat, NUM_LAYERS * 2)
133
130
  kv_unflat = pytree.tree_unflatten(flat, treespec)
134
131
  self.assertEqual(kv, kv_unflat)
135
132
 
136
- def test_pytree_roundtrip_experimental_kv_cache_derived(self):
133
+ def test_pytree_roundtrip_kv_cache_derived(self):
137
134
  NUM_LAYERS = 4
138
135
  config = self._get_test_config(
139
136
  num_layers=NUM_LAYERS,
@@ -141,41 +138,37 @@ class TestKVLayers(googletest.TestCase):
141
138
  num_query_groups=1,
142
139
  kv_cache_max_len=4,
143
140
  )
144
- kv = kv_utils_experimental.KVCacheTransposed.from_model_config(
145
- config, batch_size=1
141
+ kv = kv_utils.KVCache.from_model_config(
142
+ config, batch_size=1, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
146
143
  )
147
144
  flat, treespec = pytree.tree_flatten(kv)
148
145
  self.assertLen(flat, NUM_LAYERS * 2)
149
146
  kv_unflat = pytree.tree_unflatten(flat, treespec)
150
147
  self.assertEqual(kv, kv_unflat)
151
148
 
152
- def test_pytree_roundtrip_experimental_kv_entry_base(self):
149
+ def test_pytree_roundtrip_kv_entry(self):
153
150
  attn_config = cfg.AttentionConfig(
154
151
  num_heads=1, head_dim=1, num_query_groups=1
155
152
  )
156
- kv = kv_utils_experimental.KVCacheEntryBase.from_model_config(
157
- 32, attn_config
158
- )
153
+ kv = kv_utils.KVCacheEntry.from_model_config(32, attn_config)
159
154
  flat, treespec = pytree.tree_flatten(kv)
160
155
  self.assertLen(flat, 2)
161
156
  kv_unflat = pytree.tree_unflatten(flat, treespec)
162
157
  self.assertEqual(kv, kv_unflat)
163
- self.assertIsInstance(kv_unflat, kv_utils_experimental.KVCacheEntryBase)
158
+ self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
164
159
 
165
- def test_pytree_roundtrip_experimental_kv_entry_derived(self):
160
+ def test_pytree_roundtrip_kv_entry_derived(self):
166
161
  attn_config = cfg.AttentionConfig(
167
162
  num_heads=1, head_dim=1, num_query_groups=1
168
163
  )
169
- kv = kv_utils_experimental.KVCacheEntryTransposed.from_model_config(
170
- 32, attn_config
164
+ kv = kv_utils.KVCacheEntry.from_model_config(
165
+ 32, attn_config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
171
166
  )
172
167
  flat, treespec = pytree.tree_flatten(kv)
173
168
  self.assertLen(flat, 2)
174
169
  kv_unflat = pytree.tree_unflatten(flat, treespec)
175
170
  self.assertEqual(kv, kv_unflat)
176
- self.assertIsInstance(
177
- kv_unflat, kv_utils_experimental.KVCacheEntryTransposed
178
- )
171
+ self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
179
172
 
180
173
 
181
174
  if __name__ == "__main__":
@@ -20,6 +20,7 @@ import pathlib
20
20
  from typing import Optional, Union
21
21
  from absl import flags
22
22
  from ai_edge_torch._convert import converter as converter_utils
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
24
  from ai_edge_torch.generative.layers import lora as lora_utils
24
25
  import ai_edge_torch.generative.layers.model_config as cfg
25
26
  from ai_edge_torch.generative.quantize import quant_recipes
@@ -218,9 +219,13 @@ def _export_helper(
218
219
  [[0] for _ in range(export_config.decode_batch_size)], dtype=torch.int
219
220
  )
220
221
  decode_input_pos = torch.tensor([0], dtype=torch.int)
221
- prefill_kv = export_config.kvcache_cls.from_model_config(config)
222
- decode_kv = export_config.kvcache_cls.from_model_config(
223
- config, batch_size=export_config.decode_batch_size
222
+ prefill_kv = kv_utils.KVCache.from_model_config(
223
+ config, kv_layout=export_config.kvcache_layout
224
+ )
225
+ decode_kv = kv_utils.KVCache.from_model_config(
226
+ config,
227
+ batch_size=export_config.decode_batch_size,
228
+ kv_layout=export_config.kvcache_layout,
224
229
  )
225
230
 
226
231
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -32,7 +32,9 @@ class ExportConfig:
32
32
  # Attention masks given as inputs to the model.
33
33
  prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
34
34
  decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
35
- # The KV Cache class for K and V buffers in attention.
35
+ # The KV Cache layout for K and V buffers in attention.
36
+ kvcache_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT
37
+ # TODO(b/409373223): The KV Cache class for K and V buffers in attention.
36
38
  kvcache_cls: type = kv_utils.KVCache
37
39
  # The batch size of the decode signature.
38
40
  decode_batch_size: int = 1
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250408"
16
+ __version__ = "0.5.0.dev20250409"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250408
3
+ Version: 0.5.0.dev20250409
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -25,7 +25,9 @@ License-File: LICENSE
25
25
  Requires-Dist: numpy
26
26
  Requires-Dist: scipy
27
27
  Requires-Dist: safetensors
28
- Requires-Dist: multipledispatchtransformerskagglehub
28
+ Requires-Dist: multipledispatch
29
+ Requires-Dist: transformers
30
+ Requires-Dist: kagglehub
29
31
  Requires-Dist: tabulate
30
32
  Requires-Dist: torch>=2.4.0
31
33
  Requires-Dist: tf-nightly>=2.19.0.dev20250101
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=Lf4c2aVfixNX2KTgdqQTLOGBdi0vVxNOkJuNt4SvQ8c,706
5
+ ai_edge_torch/version.py,sha256=DEYqmCDZNmwuMxnxrFvcTEaDp6Z_BVHJaZMYjVQ2ijU,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -65,12 +65,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
65
65
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
66
66
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
67
67
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
68
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=6Dkxi7Vs8xBaqMif00ATQSr_hTPhYXMdDqHwzOsAzq8,2952
69
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=__kpzArZ0mLfX7IzpHPmYFuhKTP9uI_9Lrzk_EfFDlE,15701
68
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
69
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=n6ZQfqNEHuOhY7Pu21bb8Eax8yn2Sx5osTKJKmhonXY,15659
70
70
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=5PEt0aWJ5wkUBvMoWFOJ-C48ZhG7uCVb8PCKQtZ8Fvw,6485
71
71
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
72
72
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
73
- ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=u30qiZu3HJCTt5noWqtf9PgGLKQ87ke4Zpa4cpG6-As,8883
73
+ ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
74
74
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
75
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=A4uLUdqvU1NKo3seqZlWSS3fqYahnEKqNBQBJO6yXvE,1762
76
76
  ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
@@ -153,17 +153,17 @@ ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDq
153
153
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
154
154
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
155
155
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
156
- ai_edge_torch/generative/layers/kv_cache.py,sha256=zjdovWqgEKtx7cvbA0apOwXaNft5AXxNTbJhBT4CXyg,6541
156
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=9kkFpB9msgUDStFxEyQYYsavKPP4Dgqb_NFcd4hA4aU,8502
157
157
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
158
158
  ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
159
159
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
160
160
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
161
161
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
162
162
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
163
- ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
164
- ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=uXUxiQjPndXYZVGKgm9FxzHgQDal8GdY7cUZDpc_Sno,9997
165
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFW0iGcZjTuej6VFIkwdSY28fIQi_KTAVdT8gWNmq7o,2880
166
- ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
163
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=oW8cxv0pXcesnyGz6bXacRmlvHPfKNnJnls_Qb4L_aQ,8968
164
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=PlgL2bNNKasu3wFr3Iu9wbATWluWZt3_s4tzglJu2tM,2942
165
+ ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=-ztTIgdec5gXkOVe6FXk3PMeS2HoL6-mBfDBdjQIcLQ,2808
166
+ ai_edge_torch/generative/layers/experimental/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
167
167
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
168
168
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
169
169
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -177,7 +177,7 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
177
177
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
178
178
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
179
179
  ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
180
- ai_edge_torch/generative/test/test_kv_cache.py,sha256=MBPS-0bDXB0tQSKHa1XwDQeVIfabRbc8JQA99h9fzlQ,5961
180
+ ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
181
181
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
182
182
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
183
183
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
@@ -185,8 +185,8 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3Gy
185
185
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
186
186
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
187
187
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
188
- ai_edge_torch/generative/utilities/converter.py,sha256=ycXDcd3ZE-EdjksDjHi4ru3JpfhtrfOompg_990qvWI,9607
189
- ai_edge_torch/generative/utilities/export_config.py,sha256=-UuukWqUUj8RM8lTtMCa_PD6SqCZv97i4BMiJA2zBPg,1491
188
+ ai_edge_torch/generative/utilities/converter.py,sha256=87Tzj-gLydx8_xnHxKlCbMmM1XHShstpKi8RH3xY7Xw,9757
189
+ ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
190
190
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
191
191
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
192
192
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
@@ -243,8 +243,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
243
243
  ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
244
244
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
245
245
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
246
- ai_edge_torch_nightly-0.5.0.dev20250408.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
247
- ai_edge_torch_nightly-0.5.0.dev20250408.dist-info/METADATA,sha256=-Bw-LUn9l-B66aMZiFiUiYBifr1B6Fr86LU8KXtBieo,2019
248
- ai_edge_torch_nightly-0.5.0.dev20250408.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
249
- ai_edge_torch_nightly-0.5.0.dev20250408.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
250
- ai_edge_torch_nightly-0.5.0.dev20250408.dist-info/RECORD,,
246
+ ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
247
+ ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/METADATA,sha256=kZwo6E79HLuM7_4E-Yw9erTzOnAAzio3Vy45hXNiC48,2051
248
+ ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
249
+ ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
250
+ ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/RECORD,,