ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Common building blocks for Attention layer.
16
15
 
17
- from typing import Optional, Tuple
16
+ """Common building blocks for Attention layer."""
18
17
 
19
- import ai_edge_torch.generative.layers.builder as builder
20
- from ai_edge_torch.generative.layers.kv_cache import KVCache
18
+ from typing import Optional, Tuple, Union
19
+
20
+ from ai_edge_torch.generative.layers import builder
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
21
23
  import ai_edge_torch.generative.layers.model_config as cfg
22
24
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
23
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
24
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
27
 
@@ -55,29 +55,35 @@ def _embed_rope(
55
55
 
56
56
  class TransformerBlock(nn.Module):
57
57
 
58
- def __init__(self, config: cfg.ModelConfig) -> None:
58
+ def __init__(
59
+ self,
60
+ config: cfg.TransformerBlockConfig,
61
+ model_config: cfg.ModelConfig,
62
+ ) -> None:
59
63
  """Initialize an instance of the TransformerBlock.
60
64
 
61
65
  Args:
62
- config (cfg.ModelConfig): the configuration object for this transformer
63
- block.
66
+ config (cfg.TransformerBlockConfig): the configuration object for this
67
+ transformer block.
68
+ model_config (cfg.ModelConfig): the configuration object for the model
69
+ this transformer block belongs to.
64
70
  """
65
-
66
71
  super().__init__()
67
72
  self.pre_atten_norm = builder.build_norm(
68
- config.embedding_dim, config.pre_attention_norm_config
73
+ model_config.embedding_dim,
74
+ config.pre_attention_norm_config,
69
75
  )
70
76
  self.atten_func = CausalSelfAttention(
71
- config.batch_size,
72
- config.embedding_dim,
77
+ model_config.batch_size,
78
+ model_config.embedding_dim,
73
79
  config.attn_config,
74
- config.kv_cache_max,
75
- config.enable_hlfb,
80
+ model_config.enable_hlfb,
76
81
  )
77
82
  self.post_atten_norm = builder.build_norm(
78
- config.embedding_dim, config.post_attention_norm_config
83
+ model_config.embedding_dim,
84
+ config.post_attention_norm_config,
79
85
  )
80
- self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
86
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
81
87
  self.config = config
82
88
 
83
89
  def forward(
@@ -86,7 +92,8 @@ class TransformerBlock(nn.Module):
86
92
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
87
93
  mask: Optional[torch.Tensor] = None,
88
94
  input_pos: Optional[torch.Tensor] = None,
89
- ) -> torch.Tensor:
95
+ kv_cache: kv_utils.KVCacheEntry = None,
96
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
90
97
  """Forward function of the TransformerBlock.
91
98
 
92
99
  Args:
@@ -94,24 +101,34 @@ class TransformerBlock(nn.Module):
94
101
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
95
102
  mask (torch.Tensor): the optional mask tensor.
96
103
  input_pos (torch.Tensor): the optional input position tensor.
104
+ kv_cache (KVCacheEntry): the optional kv cache entry.
97
105
 
98
106
  Returns:
99
- output activation from this transformer block.
107
+ output activation from this transformer block, and updated kv cache (if
108
+ passed in).
100
109
  """
101
-
110
+ kv = None
102
111
  if self.config.parallel_residual:
103
112
  x_norm = self.pre_atten_norm(x)
104
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
113
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
114
+ if kv_cache is None:
115
+ attn_out = atten_func_out
116
+ else:
117
+ attn_out, kv = atten_func_out
105
118
  ff_out = self.ff(x_norm)
106
119
  output = x + attn_out + ff_out
107
120
  else:
108
121
  x_norm = self.pre_atten_norm(x)
109
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
122
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
123
+ if kv_cache is None:
124
+ attn_out = atten_func_out
125
+ else:
126
+ attn_out, kv = atten_func_out
110
127
  x = x + attn_out
111
128
  x_norm = self.post_atten_norm(x)
112
129
  output = x + self.ff(x_norm)
113
130
 
114
- return output
131
+ return output if kv is None else (output, kv)
115
132
 
116
133
 
117
134
  class CausalSelfAttention(nn.Module):
@@ -121,7 +138,6 @@ class CausalSelfAttention(nn.Module):
121
138
  batch_size: int,
122
139
  dim: int,
123
140
  config: cfg.AttentionConfig,
124
- kv_cache_max: int,
125
141
  enable_hlfb: bool,
126
142
  ) -> None:
127
143
  """Initialize an instance of CausalSelfAttention.
@@ -130,12 +146,9 @@ class CausalSelfAttention(nn.Module):
130
146
  batch_size (int): batch size of the input tensor.
131
147
  dim (int): causal attention's input/output dimmension.
132
148
  config (cfg.AttentionConfig): attention specific configurations.
133
- kv_cache_max (int): determines the size of the KV Cache buffer, if
134
- enabled.
135
149
  enable_hlfb (bool): whether hlfb is enabled or not.
136
150
  """
137
151
  super().__init__()
138
- self.config = config
139
152
  self.kv_cache = None
140
153
  self.batch_size = batch_size
141
154
  qkv_shape = (
@@ -147,21 +160,13 @@ class CausalSelfAttention(nn.Module):
147
160
  self.output_projection = nn.Linear(
148
161
  output_shape, dim, bias=config.output_proj_use_bias
149
162
  )
150
-
151
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
152
- if config.enable_kv_cache:
153
- self.kv_cache = KVCache(
154
- batch_size,
155
- kv_cache_max,
156
- config.num_query_groups,
157
- config.head_dim,
158
- enable_hlfb,
159
- )
160
-
161
- if enable_hlfb:
162
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
163
- else:
164
- self.sdpa_func = scaled_dot_product_attention
163
+ self.config = config
164
+ self.enable_hlfb = enable_hlfb
165
+ self.sdpa_func = (
166
+ sdpa.scaled_dot_product_attention_with_hlfb
167
+ if enable_hlfb
168
+ else sdpa.scaled_dot_product_attention
169
+ )
165
170
 
166
171
  def forward(
167
172
  self,
@@ -169,7 +174,8 @@ class CausalSelfAttention(nn.Module):
169
174
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
170
175
  mask: Optional[torch.Tensor] = None,
171
176
  input_pos: Optional[torch.Tensor] = None,
172
- ) -> torch.Tensor:
177
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
178
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
173
179
  """Forward function of the CausalSelfAttention layer, which can support
174
180
 
175
181
  MQA, GQA and MHA.
@@ -179,9 +185,11 @@ class CausalSelfAttention(nn.Module):
179
185
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
180
186
  mask (torch.Tensor): the optional mask tensor.
181
187
  input_pos (torch.Tensor): the optional input position tensor.
188
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
182
189
 
183
190
  Returns:
184
- output activation from this self attention layer.
191
+ output activation from this self attention layer, and the updated
192
+ KV Cach Entry (if passed in).
185
193
  """
186
194
  # Batch size, sequence length, embedding dimensionality.
187
195
  B, T, E = x.size()
@@ -224,9 +232,11 @@ class CausalSelfAttention(nn.Module):
224
232
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
225
233
  q, k = _embed_rope(q, k, n_elem, rope)
226
234
 
227
- if self.kv_cache is not None:
228
- # TODO(haoliang): Handle when execeeding max sequence length.
229
- k, v = self.kv_cache.update_cache(input_pos, k, v)
235
+ if kv_cache is not None:
236
+ kv_cache = kv_utils.update(
237
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
238
+ )
239
+ k, v = kv_cache.k_cache, kv_cache.v_cache
230
240
 
231
241
  y = self.sdpa_func(
232
242
  q,
@@ -240,7 +250,7 @@ class CausalSelfAttention(nn.Module):
240
250
 
241
251
  # Compute the output projection.
242
252
  y = self.output_projection(y)
243
- return y
253
+ return y if kv_cache is None else (y, kv_cache)
244
254
 
245
255
 
246
256
  class SelfAttention(CausalSelfAttention):
@@ -251,16 +261,19 @@ class SelfAttention(CausalSelfAttention):
251
261
  x: torch.Tensor,
252
262
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253
263
  input_pos: Optional[torch.Tensor] = None,
254
- ) -> torch.Tensor:
264
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
265
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
255
266
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
256
267
 
257
268
  Args:
258
269
  x (torch.Tensor): the input tensor.
259
270
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
260
271
  input_pos (torch.Tensor): the optional input position tensor.
272
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
261
273
 
262
274
  Returns:
263
- output activation from this self attention layer.
275
+ output activation from this self attention layer, and the updated
276
+ KV Cach Entry (if passed in).
264
277
  """
265
278
  B, T, _ = x.size()
266
279
  return super().forward(
@@ -279,9 +292,8 @@ class CrossAttention(nn.Module):
279
292
  query_dim: int,
280
293
  cross_dim: int,
281
294
  config: cfg.AttentionConfig,
282
- kv_cache_max: int,
283
295
  enable_hlfb: bool,
284
- ) -> None:
296
+ ):
285
297
  """Initialize an instance of CrossAttention.
286
298
 
287
299
  Args:
@@ -289,8 +301,6 @@ class CrossAttention(nn.Module):
289
301
  query_dim (int): query tensor's dimension.
290
302
  cross_dim (int): cross attention's dimensions, for key and value tensors.
291
303
  config (cfg.AttentionConfig): attention specific configurations.
292
- kv_cache_max (int): determines the size of the KV Cache buffer, if
293
- enabled.
294
304
  enable_hlfb (bool): whether hlfb is enabled or not.
295
305
  """
296
306
  super().__init__()
@@ -309,21 +319,11 @@ class CrossAttention(nn.Module):
309
319
  query_dim, query_dim, bias=config.output_proj_use_bias
310
320
  )
311
321
 
312
- self.kv_cache = None
313
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
314
- if config.enable_kv_cache:
315
- self.kv_cache = KVCache(
316
- batch_size,
317
- kv_cache_max,
318
- config.num_query_groups,
319
- self.config.head_dim,
320
- enable_hlfb,
321
- )
322
-
323
- if enable_hlfb:
324
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
325
- else:
326
- self.sdpa_func = scaled_dot_product_attention
322
+ self.sdpa_func = (
323
+ sdpa.scaled_dot_product_attention_with_hlfb
324
+ if enable_hlfb
325
+ else sdpa.scaled_dot_product_attention
326
+ )
327
327
 
328
328
  def forward(
329
329
  self,
@@ -332,6 +332,7 @@ class CrossAttention(nn.Module):
332
332
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
333
333
  mask: Optional[torch.Tensor] = None,
334
334
  input_pos: Optional[torch.Tensor] = None,
335
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
335
336
  ):
336
337
  """Forward function of the CrossAttention layer.
337
338
 
@@ -342,6 +343,7 @@ class CrossAttention(nn.Module):
342
343
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
343
344
  [B, n_heads, target_seq_len, source_seq_len].
344
345
  input_pos (torch.Tensor): the optional input position tensor.
346
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
345
347
 
346
348
  Returns:
347
349
  output activation from this cross attention layer.
@@ -363,9 +365,11 @@ class CrossAttention(nn.Module):
363
365
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
364
366
  q, k = _embed_rope(q, k, n_elem, rope)
365
367
 
366
- if self.kv_cache is not None:
367
- # TODO(haoliang): Handle when execeeding max sequence length.
368
- k, v = self.kv_cache.update_cache(input_pos, k, v)
368
+ if kv_cache is not None:
369
+ kv_cache = kv_utils.update(
370
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
371
+ )
372
+ k, v = kv_cache.k_cache, kv_cache.v_cache
369
373
  if mask is None:
370
374
  mask = torch.zeros(
371
375
  (batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
@@ -375,4 +379,4 @@ class CrossAttention(nn.Module):
375
379
 
376
380
  # Compute the output projection.
377
381
  y = self.output_projection(y)
378
- return y
382
+ return y if kv_cache is None else (y, kv_cache)
@@ -59,9 +59,11 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
59
59
  zero_centered_gamma=config.zero_centered,
60
60
  )
61
61
  elif config.type == cfg.NormalizationType.LAYER_NORM:
62
- return nn.LayerNorm(dim, eps=config.epsilon)
62
+ return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
63
63
  elif config.type == cfg.NormalizationType.GROUP_NORM:
64
- return nn.GroupNorm(config.group_num, dim, config.epsilon)
64
+ return normalization.GroupNorm(
65
+ config.group_num, dim, config.epsilon, config.enable_hlfb
66
+ )
65
67
  else:
66
68
  raise ValueError("Unsupported norm type.")
67
69
 
@@ -71,7 +73,7 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
71
73
 
72
74
  Args:
73
75
  dim (int): dimension of the input tensor.
74
- config (`ModelConfig` object): the model configuration.
76
+ config (`FeedForwardConfig` object): the model configuration.
75
77
 
76
78
  Returns:
77
79
  The constructed `nn.Module` feedforward layer.
@@ -12,72 +12,184 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # `nn.Module` which implements a KV cache.
16
15
 
17
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
16
+ """Utility functions for externalized KV Cache."""
17
+
18
+ import dataclasses
19
+ from typing import List, Tuple
20
+
21
+ from ai_edge_torch import hlfb
22
+ from ai_edge_torch.generative.layers import model_config
18
23
  import torch
19
- from torch import nn
24
+ import torch.utils._pytree as pytree
20
25
 
26
+ BATCH_SIZE = 1
21
27
 
22
- class KVCache(nn.Module):
23
28
 
24
- def __init__(
25
- self, batch_size, kv_cache_max, n_heads, head_dim, enable_hlfb=False
26
- ):
27
- """Initializes the KVCache layer.
29
+ @dataclasses.dataclass
30
+ class KVCacheEntry:
31
+ """A single cache entry that includes K and V caches.
28
32
 
29
- Args:
30
- batch_size (int): batch size. Currently only batch size 1 is supported.
31
- kv_cache_max (int): the max length of KV cache.
32
- n_heads (int): number of kv heads.
33
- head_dim (int): the head dimension size.
34
- enable_hlfb (bool): whether hlfb is enabled or not.
35
- """
36
- super().__init__()
37
- cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
38
- self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
39
- self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
40
- self.enable_hlfb = enable_hlfb
41
- self.kv_cache_max = kv_cache_max
33
+ The chaches are built based on the provided config with the shape of
34
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
35
+ """
42
36
 
43
- def update_cache(self, input_pos, k_val, v_val):
44
- """Update an entry in the KV cache.
37
+ k_cache: torch.Tensor
38
+ v_cache: torch.Tensor
45
39
 
46
- Args:
47
- input_pos (torch.Tensor): the input position.
48
- k_val (torch.Tensor): the new `key` value.
49
- v_val (torch.Tensor): the new `value` value.
40
+ @classmethod
41
+ def from_model_config(
42
+ cls,
43
+ kv_cache_max: int,
44
+ config: model_config.AttentionConfig,
45
+ dtype: torch.dtype = torch.float32,
46
+ device: torch.device = None,
47
+ ) -> "KVCacheEntry":
48
+ """Build an instance of the class based on model config."""
49
+ shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
50
+ k = torch.zeros(shape, dtype=dtype, device=device)
51
+ v = torch.zeros(shape, dtype=dtype, device=device)
52
+ obj = cls(k_cache=k, v_cache=v)
53
+ return obj
50
54
 
51
- Returns:
52
- The updated key and value tensor.
53
- """
54
- if self.enable_hlfb:
55
- return self.update_cache_with_hlfb(input_pos, k_val, v_val)
56
55
 
57
- updated_k = self.k_cache.index_copy_(1, input_pos, k_val)
58
- updated_v = self.v_cache.index_copy_(1, input_pos, v_val)
59
- # Here we need a clone otherwise dynamo export will fail.
60
- return torch.clone(updated_k), torch.clone(updated_v)
56
+ @dataclasses.dataclass
57
+ class KVCache:
58
+ """A utility class for holding KV cache entries per layer."""
61
59
 
62
- def update_cache_with_hlfb(self, input_pos, k_val, v_val):
63
- """Update an entry in the KV cache and enable high-level function boundary.
60
+ caches: Tuple[KVCacheEntry, ...]
61
+
62
+ @classmethod
63
+ def from_model_config(
64
+ cls,
65
+ config: model_config.ModelConfig,
66
+ dtype: torch.dtype = torch.float32,
67
+ device: torch.device = None,
68
+ ) -> "KVCache":
69
+ """Build an instance of the class based on model config.
64
70
 
65
71
  Args:
66
- input_pos (torch.Tensor): the input position.
67
- k_val (torch.Tensor): the new `key` value.
68
- v_val (torch.Tensor): the new `value` value.
72
+ config (ModelConfig): Model config used for building the cache.
73
+ dtype (torch.dtype, optional): The data type of the cache tensor.
74
+ Defaults to torch.float32.
75
+ device (torch.device, optional): The device placement of the cache
76
+ tensors. Defaults to None.
69
77
 
70
78
  Returns:
71
- The updated key and value tensor.
79
+ KVCache: The created cache object.
72
80
  """
81
+ caches = [
82
+ KVCacheEntry.from_model_config(
83
+ config.kv_cache_max,
84
+ config.block_config(idx).attn_config,
85
+ dtype,
86
+ device,
87
+ )
88
+ for idx in range(config.num_layers)
89
+ ]
90
+ obj = cls(caches=tuple(caches))
91
+ return obj
73
92
 
74
- builder = StableHLOCompositeBuilder(
75
- name="odml.update_kv_cache", attr={"kv_cache_max": self.kv_cache_max}
76
- )
77
- k_cache, v_cache, input_pos, k_val, v_val = builder.mark_inputs(
78
- self.k_cache, self.v_cache, input_pos, k_val, v_val
93
+ def flatten(self) -> List[torch.Tensor]:
94
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
95
+ flattened, _ = _flatten_kvc(self)
96
+ return flattened
97
+
98
+
99
+ def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
100
+ flattened = []
101
+ flat_names = []
102
+ none_names = []
103
+ for i, kv_entry in enumerate(kvc.caches):
104
+ flattened.append(kv_entry.k_cache)
105
+ flat_names.append(f"k_{i}")
106
+ flattened.append(kv_entry.v_cache)
107
+ flat_names.append(f"v_{i}")
108
+ return flattened, [flat_names, none_names]
109
+
110
+
111
+ def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
112
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
113
+ return [
114
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
115
+ ], flat_names
116
+
117
+
118
+ def _unflatten_kvc(
119
+ values: List[torch.Tensor], context: Tuple[List, List]
120
+ ) -> KVCache:
121
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
122
+ num_layers = len(values) // 2
123
+ flat_names = context[0]
124
+ kv_entries = []
125
+ for i in range(num_layers):
126
+ k_cache_idx = flat_names.index(f"k_{i}")
127
+ v_cache_idx = flat_names.index(f"v_{i}")
128
+ kv_entries.append(
129
+ KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
79
130
  )
80
- updated_k = k_cache.index_copy_(1, input_pos, k_val)
81
- updated_v = v_cache.index_copy_(1, input_pos, v_val)
82
- updated_k, updated_v = builder.mark_outputs(updated_k, updated_v)
83
- return updated_k, updated_v
131
+ obj = KVCache(tuple(kv_entries))
132
+ return obj
133
+
134
+
135
+ pytree.register_pytree_node(
136
+ KVCache,
137
+ _flatten_kvc,
138
+ _unflatten_kvc,
139
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
140
+ serialized_type_name="",
141
+ )
142
+
143
+
144
+ def update(
145
+ cache: KVCacheEntry,
146
+ input_pos: torch.Tensor,
147
+ k_slice: torch.Tensor,
148
+ v_slice: torch.Tensor,
149
+ enable_hlfb: bool = True,
150
+ ) -> KVCacheEntry:
151
+ """Out of place update of Cache buffer.
152
+
153
+ Args:
154
+ cache (KVCacheEntry): The original cache buffer.
155
+ input_pos (torch.Tensor): The update slice positions.
156
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
157
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
158
+ enable_hlfb (bool, optional): Whether the op is annotated for export with
159
+ High Level Function Boundary. Defaults to True.
160
+
161
+ Returns:
162
+ KVCacheEntry: The updated KVCache entry based on the passed inputs.
163
+ """
164
+ update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
165
+ return update_func(cache, input_pos, k_slice, v_slice)
166
+
167
+
168
+ def _update_kv_base_impl(
169
+ cache: KVCacheEntry,
170
+ input_pos: torch.Tensor,
171
+ k_slice: torch.Tensor,
172
+ v_slice: torch.Tensor,
173
+ ) -> KVCacheEntry:
174
+ """Update the cache buffer without High Level Function Boundary annotation."""
175
+ k = cache.k_cache.index_copy(1, input_pos, k_slice)
176
+ v = cache.v_cache.index_copy(1, input_pos, v_slice)
177
+ updated_cache = KVCacheEntry(k, v)
178
+ return updated_cache
179
+
180
+
181
+ def _update_kv_hlfb_impl(
182
+ cache: KVCacheEntry,
183
+ input_pos: torch.Tensor,
184
+ k_slice: torch.Tensor,
185
+ v_slice: torch.Tensor,
186
+ ) -> KVCacheEntry:
187
+ """Update the cache buffer with High Level Function Boundary annotation."""
188
+ builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
189
+ k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
190
+ cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
191
+ )
192
+ k = k_cache.index_copy(1, input_pos, k_slice)
193
+ v = v_cache.index_copy(1, input_pos, v_slice)
194
+ k, v = builder.mark_outputs(k, v)
195
+ return KVCacheEntry(k, v)
@@ -16,7 +16,7 @@
16
16
  from dataclasses import dataclass
17
17
  from dataclasses import field
18
18
  import enum
19
- from typing import Optional, Sequence
19
+ from typing import Optional, Sequence, Union
20
20
 
21
21
 
22
22
  @enum.unique
@@ -85,8 +85,8 @@ class AttentionConfig:
85
85
  relative_attention_max_distance: int = 0
86
86
  # Softcap on the output logits.
87
87
  logit_softcap: Optional[float] = None
88
- # The types of attention used in the layers of the model.
89
- attn_types: Optional[Sequence[AttentionType]] = None
88
+ # The type of attention.
89
+ attn_type: Optional[AttentionType] = None
90
90
  # The size of the sliding window used for local attention.
91
91
  sliding_window_size: Optional[int] = None
92
92
 
@@ -104,6 +104,7 @@ class NormalizationConfig:
104
104
  """Normalizater parameters."""
105
105
 
106
106
  type: NormalizationType = NormalizationType.NONE
107
+ enable_hlfb: bool = False
107
108
  epsilon: float = 1e-5
108
109
  zero_centered: bool = False
109
110
  # Number of groups used in group normalization.
@@ -129,13 +130,8 @@ class FeedForwardConfig:
129
130
 
130
131
 
131
132
  @dataclass
132
- class ModelConfig:
133
- """Base configurations for building a transformer architecture."""
134
-
135
- vocab_size: int
136
- num_layers: int
137
- max_seq_len: int
138
- embedding_dim: int
133
+ class TransformerBlockConfig:
134
+ """TransformerBlock module's parameters."""
139
135
 
140
136
  attn_config: AttentionConfig
141
137
  ff_config: FeedForwardConfig
@@ -147,15 +143,33 @@ class ModelConfig:
147
143
  post_attention_norm_config: NormalizationConfig = field(
148
144
  default_factory=NormalizationConfig
149
145
  )
146
+ # If set to True, only attn_config.pre_attention_norm is applied to the input
147
+ # and the decode's output is computed as `output = input + attn_out + ff_out`
148
+ # where attention and feed forward are called with pre_attention_norm's
149
+ # output.
150
+ parallel_residual: bool = False
151
+ # The Attention computation will include relative positional bias.
152
+ relative_attention: bool = False
153
+
154
+
155
+ @dataclass
156
+ class ModelConfig:
157
+ """Base configurations for building a transformer architecture."""
158
+
159
+ vocab_size: int
160
+ num_layers: int
161
+ max_seq_len: int
162
+ embedding_dim: int
163
+
164
+ # TransformerBlockConfig for each layer block. If a single
165
+ # TransformerBlockConfig is provided, it will be used for all layers.
166
+ block_configs: Union[TransformerBlockConfig, Sequence[TransformerBlockConfig]]
167
+
150
168
  # The normalization applied before LM head.
151
169
  final_norm_config: NormalizationConfig = field(
152
170
  default_factory=NormalizationConfig
153
171
  )
154
172
 
155
- # If set to True, only pre_attention_norm is applied to the input and the
156
- # decode's output is computed as `output = input + attn_out + ff_out` where
157
- # attention and feed forward are called with pre_attention_norm's output.
158
- parallel_residual: bool = False
159
173
  # Use bias term within LLM's HEAD.
160
174
  lm_head_use_bias: bool = False
161
175
  # Whether to turn on high-level function boundary.
@@ -164,9 +178,6 @@ class ModelConfig:
164
178
  # The maximum sequence length of the KV cache. Should not exceed max_seq_len.
165
179
  kv_cache_max_len: int = 0
166
180
 
167
- # The Attention computation will include relative positional bias.
168
- relative_attention: bool = False
169
-
170
181
  # Default batch size of the exported model. Default value is 1.
171
182
  batch_size: int = 1
172
183
 
@@ -177,5 +188,13 @@ class ModelConfig:
177
188
  def kv_cache_max(self) -> int:
178
189
  if self.kv_cache_max_len > 0:
179
190
  return self.kv_cache_max_len
180
- else:
181
- return self.max_seq_len
191
+ return self.max_seq_len
192
+
193
+ def block_config(self, idx: int) -> TransformerBlockConfig:
194
+ if isinstance(self.block_configs, TransformerBlockConfig):
195
+ return self.block_configs
196
+ if idx < 0 or idx >= len(self.block_configs):
197
+ raise ValueError(
198
+ f"Index {idx} is out of range for layer configs: {self.block_configs}"
199
+ )
200
+ return self.block_configs[idx]