ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
  9. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  10. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
  12. ai_edge_torch/generative/layers/attention.py +60 -63
  13. ai_edge_torch/generative/layers/builder.py +4 -2
  14. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  15. ai_edge_torch/generative/layers/model_config.py +1 -0
  16. ai_edge_torch/generative/layers/normalization.py +158 -0
  17. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  18. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  19. ai_edge_torch/generative/test/test_loader.py +1 -1
  20. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  21. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  22. ai_edge_torch/generative/test/utils.py +54 -0
  23. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  24. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  25. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  26. ai_edge_torch/version.py +1 -1
  27. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
  29. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  30. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  31. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  32. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  33. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  34. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  35. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  36. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  37. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  38. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  39. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
  42. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Common building blocks for Attention layer.
16
15
 
17
- from typing import Optional, Tuple
16
+ """Common building blocks for Attention layer."""
18
17
 
19
- import ai_edge_torch.generative.layers.builder as builder
20
- from ai_edge_torch.generative.layers.kv_cache import KVCache
18
+ from typing import Optional, Tuple, Union
19
+
20
+ from ai_edge_torch.generative.layers import builder
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
21
23
  import ai_edge_torch.generative.layers.model_config as cfg
22
24
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
23
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
24
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
27
 
@@ -62,7 +62,6 @@ class TransformerBlock(nn.Module):
62
62
  config (cfg.ModelConfig): the configuration object for this transformer
63
63
  block.
64
64
  """
65
-
66
65
  super().__init__()
67
66
  self.pre_atten_norm = builder.build_norm(
68
67
  config.embedding_dim, config.pre_attention_norm_config
@@ -71,7 +70,6 @@ class TransformerBlock(nn.Module):
71
70
  config.batch_size,
72
71
  config.embedding_dim,
73
72
  config.attn_config,
74
- config.kv_cache_max,
75
73
  config.enable_hlfb,
76
74
  )
77
75
  self.post_atten_norm = builder.build_norm(
@@ -86,7 +84,8 @@ class TransformerBlock(nn.Module):
86
84
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
87
85
  mask: Optional[torch.Tensor] = None,
88
86
  input_pos: Optional[torch.Tensor] = None,
89
- ) -> torch.Tensor:
87
+ kv_cache: kv_utils.KVCacheEntry = None,
88
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
90
89
  """Forward function of the TransformerBlock.
91
90
 
92
91
  Args:
@@ -94,24 +93,34 @@ class TransformerBlock(nn.Module):
94
93
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
95
94
  mask (torch.Tensor): the optional mask tensor.
96
95
  input_pos (torch.Tensor): the optional input position tensor.
96
+ kv_cache (KVCacheEntry): the optional kv cache entry.
97
97
 
98
98
  Returns:
99
- output activation from this transformer block.
99
+ output activation from this transformer block, and updated kv cache (if
100
+ passed in).
100
101
  """
101
-
102
+ kv = None
102
103
  if self.config.parallel_residual:
103
104
  x_norm = self.pre_atten_norm(x)
104
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
105
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
106
+ if kv_cache is None:
107
+ attn_out = atten_func_out
108
+ else:
109
+ attn_out, kv = atten_func_out
105
110
  ff_out = self.ff(x_norm)
106
111
  output = x + attn_out + ff_out
107
112
  else:
108
113
  x_norm = self.pre_atten_norm(x)
109
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
114
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
115
+ if kv_cache is None:
116
+ attn_out = atten_func_out
117
+ else:
118
+ attn_out, kv = atten_func_out
110
119
  x = x + attn_out
111
120
  x_norm = self.post_atten_norm(x)
112
121
  output = x + self.ff(x_norm)
113
122
 
114
- return output
123
+ return output if kv is None else (output, kv)
115
124
 
116
125
 
117
126
  class CausalSelfAttention(nn.Module):
@@ -121,7 +130,6 @@ class CausalSelfAttention(nn.Module):
121
130
  batch_size: int,
122
131
  dim: int,
123
132
  config: cfg.AttentionConfig,
124
- kv_cache_max: int,
125
133
  enable_hlfb: bool,
126
134
  ) -> None:
127
135
  """Initialize an instance of CausalSelfAttention.
@@ -130,8 +138,6 @@ class CausalSelfAttention(nn.Module):
130
138
  batch_size (int): batch size of the input tensor.
131
139
  dim (int): causal attention's input/output dimmension.
132
140
  config (cfg.AttentionConfig): attention specific configurations.
133
- kv_cache_max (int): determines the size of the KV Cache buffer, if
134
- enabled.
135
141
  enable_hlfb (bool): whether hlfb is enabled or not.
136
142
  """
137
143
  super().__init__()
@@ -147,21 +153,13 @@ class CausalSelfAttention(nn.Module):
147
153
  self.output_projection = nn.Linear(
148
154
  output_shape, dim, bias=config.output_proj_use_bias
149
155
  )
150
-
151
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
152
- if config.enable_kv_cache:
153
- self.kv_cache = KVCache(
154
- batch_size,
155
- kv_cache_max,
156
- config.num_query_groups,
157
- config.head_dim,
158
- enable_hlfb,
159
- )
160
-
161
- if enable_hlfb:
162
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
163
- else:
164
- self.sdpa_func = scaled_dot_product_attention
156
+ self.config = config
157
+ self.enable_hlfb = enable_hlfb
158
+ self.sdpa_func = (
159
+ sdpa.scaled_dot_product_attention_with_hlfb
160
+ if enable_hlfb
161
+ else sdpa.scaled_dot_product_attention
162
+ )
165
163
 
166
164
  def forward(
167
165
  self,
@@ -169,7 +167,8 @@ class CausalSelfAttention(nn.Module):
169
167
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
170
168
  mask: Optional[torch.Tensor] = None,
171
169
  input_pos: Optional[torch.Tensor] = None,
172
- ) -> torch.Tensor:
170
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
171
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
173
172
  """Forward function of the CausalSelfAttention layer, which can support
174
173
 
175
174
  MQA, GQA and MHA.
@@ -179,9 +178,11 @@ class CausalSelfAttention(nn.Module):
179
178
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
180
179
  mask (torch.Tensor): the optional mask tensor.
181
180
  input_pos (torch.Tensor): the optional input position tensor.
181
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
182
182
 
183
183
  Returns:
184
- output activation from this self attention layer.
184
+ output activation from this self attention layer, and the updated
185
+ KV Cach Entry (if passed in).
185
186
  """
186
187
  # Batch size, sequence length, embedding dimensionality.
187
188
  B, T, E = x.size()
@@ -224,9 +225,11 @@ class CausalSelfAttention(nn.Module):
224
225
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
225
226
  q, k = _embed_rope(q, k, n_elem, rope)
226
227
 
227
- if self.kv_cache is not None:
228
- # TODO(haoliang): Handle when execeeding max sequence length.
229
- k, v = self.kv_cache.update_cache(input_pos, k, v)
228
+ if kv_cache is not None:
229
+ kv_cache = kv_utils.update(
230
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
231
+ )
232
+ k, v = kv_cache.k_cache, kv_cache.v_cache
230
233
 
231
234
  y = self.sdpa_func(
232
235
  q,
@@ -240,7 +243,7 @@ class CausalSelfAttention(nn.Module):
240
243
 
241
244
  # Compute the output projection.
242
245
  y = self.output_projection(y)
243
- return y
246
+ return y if kv_cache is None else (y, kv_cache)
244
247
 
245
248
 
246
249
  class SelfAttention(CausalSelfAttention):
@@ -251,16 +254,19 @@ class SelfAttention(CausalSelfAttention):
251
254
  x: torch.Tensor,
252
255
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253
256
  input_pos: Optional[torch.Tensor] = None,
254
- ) -> torch.Tensor:
257
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
258
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
255
259
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
256
260
 
257
261
  Args:
258
262
  x (torch.Tensor): the input tensor.
259
263
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
260
264
  input_pos (torch.Tensor): the optional input position tensor.
265
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
261
266
 
262
267
  Returns:
263
- output activation from this self attention layer.
268
+ output activation from this self attention layer, and the updated
269
+ KV Cach Entry (if passed in).
264
270
  """
265
271
  B, T, _ = x.size()
266
272
  return super().forward(
@@ -279,9 +285,8 @@ class CrossAttention(nn.Module):
279
285
  query_dim: int,
280
286
  cross_dim: int,
281
287
  config: cfg.AttentionConfig,
282
- kv_cache_max: int,
283
288
  enable_hlfb: bool,
284
- ) -> None:
289
+ ):
285
290
  """Initialize an instance of CrossAttention.
286
291
 
287
292
  Args:
@@ -289,8 +294,6 @@ class CrossAttention(nn.Module):
289
294
  query_dim (int): query tensor's dimension.
290
295
  cross_dim (int): cross attention's dimensions, for key and value tensors.
291
296
  config (cfg.AttentionConfig): attention specific configurations.
292
- kv_cache_max (int): determines the size of the KV Cache buffer, if
293
- enabled.
294
297
  enable_hlfb (bool): whether hlfb is enabled or not.
295
298
  """
296
299
  super().__init__()
@@ -309,21 +312,11 @@ class CrossAttention(nn.Module):
309
312
  query_dim, query_dim, bias=config.output_proj_use_bias
310
313
  )
311
314
 
312
- self.kv_cache = None
313
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
314
- if config.enable_kv_cache:
315
- self.kv_cache = KVCache(
316
- batch_size,
317
- kv_cache_max,
318
- config.num_query_groups,
319
- self.config.head_dim,
320
- enable_hlfb,
321
- )
322
-
323
- if enable_hlfb:
324
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
325
- else:
326
- self.sdpa_func = scaled_dot_product_attention
315
+ self.sdpa_func = (
316
+ sdpa.scaled_dot_product_attention_with_hlfb
317
+ if enable_hlfb
318
+ else sdpa.scaled_dot_product_attention
319
+ )
327
320
 
328
321
  def forward(
329
322
  self,
@@ -332,6 +325,7 @@ class CrossAttention(nn.Module):
332
325
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
333
326
  mask: Optional[torch.Tensor] = None,
334
327
  input_pos: Optional[torch.Tensor] = None,
328
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
335
329
  ):
336
330
  """Forward function of the CrossAttention layer.
337
331
 
@@ -342,6 +336,7 @@ class CrossAttention(nn.Module):
342
336
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
343
337
  [B, n_heads, target_seq_len, source_seq_len].
344
338
  input_pos (torch.Tensor): the optional input position tensor.
339
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
345
340
 
346
341
  Returns:
347
342
  output activation from this cross attention layer.
@@ -363,9 +358,11 @@ class CrossAttention(nn.Module):
363
358
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
364
359
  q, k = _embed_rope(q, k, n_elem, rope)
365
360
 
366
- if self.kv_cache is not None:
367
- # TODO(haoliang): Handle when execeeding max sequence length.
368
- k, v = self.kv_cache.update_cache(input_pos, k, v)
361
+ if kv_cache is not None:
362
+ kv_cache = kv_utils.update(
363
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
364
+ )
365
+ k, v = kv_cache.k_cache, kv_cache.v_cache
369
366
  if mask is None:
370
367
  mask = torch.zeros(
371
368
  (batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
@@ -375,4 +372,4 @@ class CrossAttention(nn.Module):
375
372
 
376
373
  # Compute the output projection.
377
374
  y = self.output_projection(y)
378
- return y
375
+ return y if kv_cache is None else (y, kv_cache)
@@ -59,9 +59,11 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
59
59
  zero_centered_gamma=config.zero_centered,
60
60
  )
61
61
  elif config.type == cfg.NormalizationType.LAYER_NORM:
62
- return nn.LayerNorm(dim, eps=config.epsilon)
62
+ return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
63
63
  elif config.type == cfg.NormalizationType.GROUP_NORM:
64
- return nn.GroupNorm(config.group_num, dim, config.epsilon)
64
+ return normalization.GroupNorm(
65
+ config.group_num, dim, config.epsilon, config.enable_hlfb
66
+ )
65
67
  else:
66
68
  raise ValueError("Unsupported norm type.")
67
69
 
@@ -12,72 +12,181 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # `nn.Module` which implements a KV cache.
16
15
 
17
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
16
+ """Utility functions for externalized KV Cache."""
17
+
18
+ import dataclasses
19
+ from typing import List, Tuple
20
+
21
+ from ai_edge_torch import hlfb
22
+ from ai_edge_torch.generative.layers import model_config
18
23
  import torch
19
- from torch import nn
24
+ import torch.utils._pytree as pytree
20
25
 
21
26
 
22
- class KVCache(nn.Module):
27
+ @dataclasses.dataclass
28
+ class KVCacheEntry:
29
+ """A single cache entry that includes K and V caches.
23
30
 
24
- def __init__(
25
- self, batch_size, kv_cache_max, n_heads, head_dim, enable_hlfb=False
26
- ):
27
- """Initializes the KVCache layer.
31
+ The chaches are built based on the provided config with the shape of
32
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
33
+ """
28
34
 
29
- Args:
30
- batch_size (int): batch size. Currently only batch size 1 is supported.
31
- kv_cache_max (int): the max length of KV cache.
32
- n_heads (int): number of kv heads.
33
- head_dim (int): the head dimension size.
34
- enable_hlfb (bool): whether hlfb is enabled or not.
35
- """
36
- super().__init__()
37
- cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
38
- self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
39
- self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
40
- self.enable_hlfb = enable_hlfb
41
- self.kv_cache_max = kv_cache_max
35
+ k_cache: torch.Tensor
36
+ v_cache: torch.Tensor
42
37
 
43
- def update_cache(self, input_pos, k_val, v_val):
44
- """Update an entry in the KV cache.
38
+ @classmethod
39
+ def from_model_config(
40
+ cls,
41
+ config: model_config.ModelConfig,
42
+ dtype: torch.dtype = torch.float32,
43
+ device: torch.device = None,
44
+ ) -> "KVCacheEntry":
45
+ """Build an instance of the class based on model config."""
46
+ shape = (
47
+ 1, # Batch dimmension.
48
+ config.kv_cache_max,
49
+ config.attn_config.num_query_groups,
50
+ config.attn_config.head_dim,
51
+ )
52
+ k = torch.zeros(shape, dtype=dtype, device=device)
53
+ v = torch.zeros(shape, dtype=dtype, device=device)
54
+ obj = cls(k_cache=k, v_cache=v)
55
+ return obj
45
56
 
46
- Args:
47
- input_pos (torch.Tensor): the input position.
48
- k_val (torch.Tensor): the new `key` value.
49
- v_val (torch.Tensor): the new `value` value.
50
57
 
51
- Returns:
52
- The updated key and value tensor.
53
- """
54
- if self.enable_hlfb:
55
- return self.update_cache_with_hlfb(input_pos, k_val, v_val)
58
+ @dataclasses.dataclass
59
+ class KVCache:
60
+ """A utility class for holding KV cache entries per layer."""
56
61
 
57
- updated_k = self.k_cache.index_copy_(1, input_pos, k_val)
58
- updated_v = self.v_cache.index_copy_(1, input_pos, v_val)
59
- # Here we need a clone otherwise dynamo export will fail.
60
- return torch.clone(updated_k), torch.clone(updated_v)
62
+ caches: Tuple[KVCacheEntry, ...]
61
63
 
62
- def update_cache_with_hlfb(self, input_pos, k_val, v_val):
63
- """Update an entry in the KV cache and enable high-level function boundary.
64
+ @classmethod
65
+ def from_model_config(
66
+ cls,
67
+ config: model_config.ModelConfig,
68
+ dtype: torch.dtype = torch.float32,
69
+ device: torch.device = None,
70
+ ) -> "KVCache":
71
+ """Build an instance of the class based on model config.
64
72
 
65
73
  Args:
66
- input_pos (torch.Tensor): the input position.
67
- k_val (torch.Tensor): the new `key` value.
68
- v_val (torch.Tensor): the new `value` value.
74
+ config (ModelConfig): Model config used for building the cache.
75
+ dtype (torch.dtype, optional): The data type of the cache tensor.
76
+ Defaults to torch.float32.
77
+ device (torch.device, optional): The device placement of the cache
78
+ tensors. Defaults to None.
69
79
 
70
80
  Returns:
71
- The updated key and value tensor.
81
+ KVCache: The created cache object.
72
82
  """
83
+ caches = [
84
+ KVCacheEntry.from_model_config(config, dtype, device)
85
+ for _ in range(config.num_layers)
86
+ ]
87
+ obj = cls(caches=tuple(caches))
88
+ return obj
73
89
 
74
- builder = StableHLOCompositeBuilder(
75
- name="odml.update_kv_cache", attr={"kv_cache_max": self.kv_cache_max}
76
- )
77
- k_cache, v_cache, input_pos, k_val, v_val = builder.mark_inputs(
78
- self.k_cache, self.v_cache, input_pos, k_val, v_val
90
+ def flatten(self) -> List[torch.Tensor]:
91
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
92
+ flattened, _ = _flatten_kvc(self)
93
+ return flattened
94
+
95
+
96
+ def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
97
+ flattened = []
98
+ flat_names = []
99
+ none_names = []
100
+ for i, kv_entry in enumerate(kvc.caches):
101
+ flattened.append(kv_entry.k_cache)
102
+ flat_names.append(f"k_{i}")
103
+ flattened.append(kv_entry.v_cache)
104
+ flat_names.append(f"v_{i}")
105
+ return flattened, [flat_names, none_names]
106
+
107
+
108
+ def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
109
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
110
+ return [
111
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
112
+ ], flat_names
113
+
114
+
115
+ def _unflatten_kvc(
116
+ values: List[torch.Tensor], context: Tuple[List, List]
117
+ ) -> KVCache:
118
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
119
+ num_layers = len(values) // 2
120
+ flat_names = context[0]
121
+ kv_entries = []
122
+ for i in range(num_layers):
123
+ k_cache_idx = flat_names.index(f"k_{i}")
124
+ v_cache_idx = flat_names.index(f"v_{i}")
125
+ kv_entries.append(
126
+ KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
79
127
  )
80
- updated_k = k_cache.index_copy_(1, input_pos, k_val)
81
- updated_v = v_cache.index_copy_(1, input_pos, v_val)
82
- updated_k, updated_v = builder.mark_outputs(updated_k, updated_v)
83
- return updated_k, updated_v
128
+ obj = KVCache(tuple(kv_entries))
129
+ return obj
130
+
131
+
132
+ pytree.register_pytree_node(
133
+ KVCache,
134
+ _flatten_kvc,
135
+ _unflatten_kvc,
136
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
137
+ serialized_type_name="",
138
+ )
139
+
140
+
141
+ def update(
142
+ cache: KVCacheEntry,
143
+ input_pos: torch.Tensor,
144
+ k_slice: torch.Tensor,
145
+ v_slice: torch.Tensor,
146
+ enable_hlfb: bool = True,
147
+ ) -> KVCacheEntry:
148
+ """Out of place update of Cache buffer.
149
+
150
+ Args:
151
+ cache (KVCacheEntry): The original cache buffer.
152
+ input_pos (torch.Tensor): The update slice positions.
153
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
154
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
155
+ enable_hlfb (bool, optional): Whether the op is annotated for export with
156
+ High Level Function Boundary. Defaults to True.
157
+
158
+ Returns:
159
+ KVCacheEntry: The updated KVCache entry based on the passed inputs.
160
+ """
161
+ update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
162
+ return update_func(cache, input_pos, k_slice, v_slice)
163
+
164
+
165
+ def _update_kv_base_impl(
166
+ cache: KVCacheEntry,
167
+ input_pos: torch.Tensor,
168
+ k_slice: torch.Tensor,
169
+ v_slice: torch.Tensor,
170
+ ) -> KVCacheEntry:
171
+ """Update the cache buffer without High Level Function Boundary annotation."""
172
+ k = cache.k_cache.index_copy(1, input_pos, k_slice)
173
+ v = cache.v_cache.index_copy(1, input_pos, v_slice)
174
+ updated_cache = KVCacheEntry(k, v)
175
+ return updated_cache
176
+
177
+
178
+ def _update_kv_hlfb_impl(
179
+ cache: KVCacheEntry,
180
+ input_pos: torch.Tensor,
181
+ k_slice: torch.Tensor,
182
+ v_slice: torch.Tensor,
183
+ ) -> KVCacheEntry:
184
+ """Update the cache buffer with High Level Function Boundary annotation."""
185
+ builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
186
+ k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
187
+ cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
188
+ )
189
+ k = k_cache.index_copy(1, input_pos, k_slice)
190
+ v = v_cache.index_copy(1, input_pos, v_slice)
191
+ k, v = builder.mark_outputs(k, v)
192
+ return KVCacheEntry(k, v)
@@ -104,6 +104,7 @@ class NormalizationConfig:
104
104
  """Normalizater parameters."""
105
105
 
106
106
  type: NormalizationType = NormalizationType.NONE
107
+ enable_hlfb: bool = False
107
108
  epsilon: float = 1e-5
108
109
  zero_centered: bool = False
109
110
  # Number of groups used in group normalization.