ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  8. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  9. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  10. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  11. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  12. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  16. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  17. ai_edge_torch/generative/layers/attention.py +77 -73
  18. ai_edge_torch/generative/layers/builder.py +5 -3
  19. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  20. ai_edge_torch/generative/layers/model_config.py +38 -19
  21. ai_edge_torch/generative/layers/normalization.py +158 -0
  22. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  23. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  24. ai_edge_torch/generative/test/test_loader.py +1 -1
  25. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  26. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  27. ai_edge_torch/generative/test/utils.py +54 -0
  28. ai_edge_torch/generative/utilities/loader.py +15 -15
  29. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  31. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  32. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  33. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  34. ai_edge_torch/version.py +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
  37. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  38. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  40. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  41. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  42. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  43. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  44. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  45. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  46. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  47. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,72 +12,184 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # `nn.Module` which implements a KV cache.
16
15
 
17
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
16
+ """Utility functions for externalized KV Cache."""
17
+
18
+ import dataclasses
19
+ from typing import List, Tuple
20
+
21
+ from ai_edge_torch import hlfb
22
+ from ai_edge_torch.generative.layers import model_config
18
23
  import torch
19
- from torch import nn
24
+ import torch.utils._pytree as pytree
20
25
 
26
+ BATCH_SIZE = 1
21
27
 
22
- class KVCache(nn.Module):
23
28
 
24
- def __init__(
25
- self, batch_size, kv_cache_max, n_heads, head_dim, enable_hlfb=False
26
- ):
27
- """Initializes the KVCache layer.
29
+ @dataclasses.dataclass
30
+ class KVCacheEntry:
31
+ """A single cache entry that includes K and V caches.
28
32
 
29
- Args:
30
- batch_size (int): batch size. Currently only batch size 1 is supported.
31
- kv_cache_max (int): the max length of KV cache.
32
- n_heads (int): number of kv heads.
33
- head_dim (int): the head dimension size.
34
- enable_hlfb (bool): whether hlfb is enabled or not.
35
- """
36
- super().__init__()
37
- cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
38
- self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
39
- self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
40
- self.enable_hlfb = enable_hlfb
41
- self.kv_cache_max = kv_cache_max
33
+ The chaches are built based on the provided config with the shape of
34
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
35
+ """
42
36
 
43
- def update_cache(self, input_pos, k_val, v_val):
44
- """Update an entry in the KV cache.
37
+ k_cache: torch.Tensor
38
+ v_cache: torch.Tensor
45
39
 
46
- Args:
47
- input_pos (torch.Tensor): the input position.
48
- k_val (torch.Tensor): the new `key` value.
49
- v_val (torch.Tensor): the new `value` value.
40
+ @classmethod
41
+ def from_model_config(
42
+ cls,
43
+ kv_cache_max: int,
44
+ config: model_config.AttentionConfig,
45
+ dtype: torch.dtype = torch.float32,
46
+ device: torch.device = None,
47
+ ) -> "KVCacheEntry":
48
+ """Build an instance of the class based on model config."""
49
+ shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
50
+ k = torch.zeros(shape, dtype=dtype, device=device)
51
+ v = torch.zeros(shape, dtype=dtype, device=device)
52
+ obj = cls(k_cache=k, v_cache=v)
53
+ return obj
50
54
 
51
- Returns:
52
- The updated key and value tensor.
53
- """
54
- if self.enable_hlfb:
55
- return self.update_cache_with_hlfb(input_pos, k_val, v_val)
56
55
 
57
- updated_k = self.k_cache.index_copy_(1, input_pos, k_val)
58
- updated_v = self.v_cache.index_copy_(1, input_pos, v_val)
59
- # Here we need a clone otherwise dynamo export will fail.
60
- return torch.clone(updated_k), torch.clone(updated_v)
56
+ @dataclasses.dataclass
57
+ class KVCache:
58
+ """A utility class for holding KV cache entries per layer."""
61
59
 
62
- def update_cache_with_hlfb(self, input_pos, k_val, v_val):
63
- """Update an entry in the KV cache and enable high-level function boundary.
60
+ caches: Tuple[KVCacheEntry, ...]
61
+
62
+ @classmethod
63
+ def from_model_config(
64
+ cls,
65
+ config: model_config.ModelConfig,
66
+ dtype: torch.dtype = torch.float32,
67
+ device: torch.device = None,
68
+ ) -> "KVCache":
69
+ """Build an instance of the class based on model config.
64
70
 
65
71
  Args:
66
- input_pos (torch.Tensor): the input position.
67
- k_val (torch.Tensor): the new `key` value.
68
- v_val (torch.Tensor): the new `value` value.
72
+ config (ModelConfig): Model config used for building the cache.
73
+ dtype (torch.dtype, optional): The data type of the cache tensor.
74
+ Defaults to torch.float32.
75
+ device (torch.device, optional): The device placement of the cache
76
+ tensors. Defaults to None.
69
77
 
70
78
  Returns:
71
- The updated key and value tensor.
79
+ KVCache: The created cache object.
72
80
  """
81
+ caches = [
82
+ KVCacheEntry.from_model_config(
83
+ config.kv_cache_max,
84
+ config.block_config(idx).attn_config,
85
+ dtype,
86
+ device,
87
+ )
88
+ for idx in range(config.num_layers)
89
+ ]
90
+ obj = cls(caches=tuple(caches))
91
+ return obj
73
92
 
74
- builder = StableHLOCompositeBuilder(
75
- name="odml.update_kv_cache", attr={"kv_cache_max": self.kv_cache_max}
76
- )
77
- k_cache, v_cache, input_pos, k_val, v_val = builder.mark_inputs(
78
- self.k_cache, self.v_cache, input_pos, k_val, v_val
93
+ def flatten(self) -> List[torch.Tensor]:
94
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
95
+ flattened, _ = _flatten_kvc(self)
96
+ return flattened
97
+
98
+
99
+ def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
100
+ flattened = []
101
+ flat_names = []
102
+ none_names = []
103
+ for i, kv_entry in enumerate(kvc.caches):
104
+ flattened.append(kv_entry.k_cache)
105
+ flat_names.append(f"k_{i}")
106
+ flattened.append(kv_entry.v_cache)
107
+ flat_names.append(f"v_{i}")
108
+ return flattened, [flat_names, none_names]
109
+
110
+
111
+ def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
112
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
113
+ return [
114
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
115
+ ], flat_names
116
+
117
+
118
+ def _unflatten_kvc(
119
+ values: List[torch.Tensor], context: Tuple[List, List]
120
+ ) -> KVCache:
121
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
122
+ num_layers = len(values) // 2
123
+ flat_names = context[0]
124
+ kv_entries = []
125
+ for i in range(num_layers):
126
+ k_cache_idx = flat_names.index(f"k_{i}")
127
+ v_cache_idx = flat_names.index(f"v_{i}")
128
+ kv_entries.append(
129
+ KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
79
130
  )
80
- updated_k = k_cache.index_copy_(1, input_pos, k_val)
81
- updated_v = v_cache.index_copy_(1, input_pos, v_val)
82
- updated_k, updated_v = builder.mark_outputs(updated_k, updated_v)
83
- return updated_k, updated_v
131
+ obj = KVCache(tuple(kv_entries))
132
+ return obj
133
+
134
+
135
+ pytree.register_pytree_node(
136
+ KVCache,
137
+ _flatten_kvc,
138
+ _unflatten_kvc,
139
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
140
+ serialized_type_name="",
141
+ )
142
+
143
+
144
+ def update(
145
+ cache: KVCacheEntry,
146
+ input_pos: torch.Tensor,
147
+ k_slice: torch.Tensor,
148
+ v_slice: torch.Tensor,
149
+ enable_hlfb: bool = True,
150
+ ) -> KVCacheEntry:
151
+ """Out of place update of Cache buffer.
152
+
153
+ Args:
154
+ cache (KVCacheEntry): The original cache buffer.
155
+ input_pos (torch.Tensor): The update slice positions.
156
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
157
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
158
+ enable_hlfb (bool, optional): Whether the op is annotated for export with
159
+ High Level Function Boundary. Defaults to True.
160
+
161
+ Returns:
162
+ KVCacheEntry: The updated KVCache entry based on the passed inputs.
163
+ """
164
+ update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
165
+ return update_func(cache, input_pos, k_slice, v_slice)
166
+
167
+
168
+ def _update_kv_base_impl(
169
+ cache: KVCacheEntry,
170
+ input_pos: torch.Tensor,
171
+ k_slice: torch.Tensor,
172
+ v_slice: torch.Tensor,
173
+ ) -> KVCacheEntry:
174
+ """Update the cache buffer without High Level Function Boundary annotation."""
175
+ k = cache.k_cache.index_copy(1, input_pos, k_slice)
176
+ v = cache.v_cache.index_copy(1, input_pos, v_slice)
177
+ updated_cache = KVCacheEntry(k, v)
178
+ return updated_cache
179
+
180
+
181
+ def _update_kv_hlfb_impl(
182
+ cache: KVCacheEntry,
183
+ input_pos: torch.Tensor,
184
+ k_slice: torch.Tensor,
185
+ v_slice: torch.Tensor,
186
+ ) -> KVCacheEntry:
187
+ """Update the cache buffer with High Level Function Boundary annotation."""
188
+ builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
189
+ k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
190
+ cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
191
+ )
192
+ k = k_cache.index_copy(1, input_pos, k_slice)
193
+ v = v_cache.index_copy(1, input_pos, v_slice)
194
+ k, v = builder.mark_outputs(k, v)
195
+ return KVCacheEntry(k, v)
@@ -16,7 +16,7 @@
16
16
  from dataclasses import dataclass
17
17
  from dataclasses import field
18
18
  import enum
19
- from typing import Optional, Sequence
19
+ from typing import Optional, Sequence, Union
20
20
 
21
21
 
22
22
  @enum.unique
@@ -85,8 +85,8 @@ class AttentionConfig:
85
85
  relative_attention_max_distance: int = 0
86
86
  # Softcap on the output logits.
87
87
  logit_softcap: Optional[float] = None
88
- # The types of attention used in the layers of the model.
89
- attn_types: Optional[Sequence[AttentionType]] = None
88
+ # The type of attention.
89
+ attn_type: Optional[AttentionType] = None
90
90
  # The size of the sliding window used for local attention.
91
91
  sliding_window_size: Optional[int] = None
92
92
 
@@ -104,6 +104,7 @@ class NormalizationConfig:
104
104
  """Normalizater parameters."""
105
105
 
106
106
  type: NormalizationType = NormalizationType.NONE
107
+ enable_hlfb: bool = False
107
108
  epsilon: float = 1e-5
108
109
  zero_centered: bool = False
109
110
  # Number of groups used in group normalization.
@@ -129,13 +130,8 @@ class FeedForwardConfig:
129
130
 
130
131
 
131
132
  @dataclass
132
- class ModelConfig:
133
- """Base configurations for building a transformer architecture."""
134
-
135
- vocab_size: int
136
- num_layers: int
137
- max_seq_len: int
138
- embedding_dim: int
133
+ class TransformerBlockConfig:
134
+ """TransformerBlock module's parameters."""
139
135
 
140
136
  attn_config: AttentionConfig
141
137
  ff_config: FeedForwardConfig
@@ -147,15 +143,33 @@ class ModelConfig:
147
143
  post_attention_norm_config: NormalizationConfig = field(
148
144
  default_factory=NormalizationConfig
149
145
  )
146
+ # If set to True, only attn_config.pre_attention_norm is applied to the input
147
+ # and the decode's output is computed as `output = input + attn_out + ff_out`
148
+ # where attention and feed forward are called with pre_attention_norm's
149
+ # output.
150
+ parallel_residual: bool = False
151
+ # The Attention computation will include relative positional bias.
152
+ relative_attention: bool = False
153
+
154
+
155
+ @dataclass
156
+ class ModelConfig:
157
+ """Base configurations for building a transformer architecture."""
158
+
159
+ vocab_size: int
160
+ num_layers: int
161
+ max_seq_len: int
162
+ embedding_dim: int
163
+
164
+ # TransformerBlockConfig for each layer block. If a single
165
+ # TransformerBlockConfig is provided, it will be used for all layers.
166
+ block_configs: Union[TransformerBlockConfig, Sequence[TransformerBlockConfig]]
167
+
150
168
  # The normalization applied before LM head.
151
169
  final_norm_config: NormalizationConfig = field(
152
170
  default_factory=NormalizationConfig
153
171
  )
154
172
 
155
- # If set to True, only pre_attention_norm is applied to the input and the
156
- # decode's output is computed as `output = input + attn_out + ff_out` where
157
- # attention and feed forward are called with pre_attention_norm's output.
158
- parallel_residual: bool = False
159
173
  # Use bias term within LLM's HEAD.
160
174
  lm_head_use_bias: bool = False
161
175
  # Whether to turn on high-level function boundary.
@@ -164,9 +178,6 @@ class ModelConfig:
164
178
  # The maximum sequence length of the KV cache. Should not exceed max_seq_len.
165
179
  kv_cache_max_len: int = 0
166
180
 
167
- # The Attention computation will include relative positional bias.
168
- relative_attention: bool = False
169
-
170
181
  # Default batch size of the exported model. Default value is 1.
171
182
  batch_size: int = 1
172
183
 
@@ -177,5 +188,13 @@ class ModelConfig:
177
188
  def kv_cache_max(self) -> int:
178
189
  if self.kv_cache_max_len > 0:
179
190
  return self.kv_cache_max_len
180
- else:
181
- return self.max_seq_len
191
+ return self.max_seq_len
192
+
193
+ def block_config(self, idx: int) -> TransformerBlockConfig:
194
+ if isinstance(self.block_configs, TransformerBlockConfig):
195
+ return self.block_configs
196
+ if idx < 0 or idx >= len(self.block_configs):
197
+ raise ValueError(
198
+ f"Index {idx} is out of range for layer configs: {self.block_configs}"
199
+ )
200
+ return self.block_configs[idx]
@@ -14,7 +14,10 @@
14
14
  # ==============================================================================
15
15
  # Common normalization layers.
16
16
 
17
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
17
18
  import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
18
21
 
19
22
 
20
23
  # Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
58
61
  return output * (1 + self.weight)
59
62
  else:
60
63
  return output * self.weight
64
+
65
+
66
+ class GroupNorm(torch.nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ group_num: int,
71
+ dim: int,
72
+ eps: float = 1e-5,
73
+ enable_hlfb: bool = False,
74
+ ):
75
+ """Initialize the GroupNorm layer.
76
+
77
+ Args:
78
+ group_num (int): Number of groups to separate the channels into.
79
+ dim (int): Dimension of the input tensor.
80
+ eps (float): A small float value to ensure numerical stability (default:
81
+ 1e-6).
82
+ enable_hlfb (bool): Whether to convert this normalization into a single
83
+ op.
84
+ """
85
+ super().__init__()
86
+ self.enable_hlfb = enable_hlfb
87
+ self.group_num = group_num
88
+ self.eps = eps
89
+ self.weight = torch.nn.Parameter(torch.ones(dim))
90
+ self.bias = torch.nn.Parameter(torch.ones(dim))
91
+
92
+ def forward(self, x):
93
+ """Running the forward pass of GroupNorm layer.
94
+
95
+ Args:
96
+ x (torch.Tensor): input tensor.
97
+
98
+ Returns:
99
+ torch.Tensor: output tensor after applying GroupNorm.
100
+ """
101
+ if self.enable_hlfb:
102
+ return group_norm_with_hlfb(
103
+ x,
104
+ self.weight,
105
+ self.bias,
106
+ self.group_num,
107
+ self.eps,
108
+ )
109
+ else:
110
+ return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
111
+
112
+
113
+ class LayerNorm(torch.nn.Module):
114
+
115
+ def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
116
+ """Initialize the LayerNorm layer.
117
+
118
+ Args:
119
+ dim (int): dimension of the input tensor.
120
+ eps (float): A small float value to ensure numerical stability (default:
121
+ 1e-6).
122
+ enable_hlfb (bool): Whether to convert this normalization into a single
123
+ op.
124
+ """
125
+ super().__init__()
126
+ self.enable_hlfb = enable_hlfb
127
+ self.eps = eps
128
+ self.weight = torch.nn.Parameter(torch.ones(dim))
129
+ self.bias = torch.nn.Parameter(torch.ones(dim))
130
+
131
+ def forward(self, x):
132
+ """Running the forward pass of LayerNorm layer.
133
+
134
+ Args:
135
+ x (torch.Tensor): input tensor.
136
+
137
+ Returns:
138
+ torch.Tensor: output tensor after applying LayerNorm.
139
+ """
140
+ if self.enable_hlfb:
141
+ return layer_norm_with_hlfb(
142
+ x,
143
+ self.weight,
144
+ self.bias,
145
+ self.eps,
146
+ )
147
+ else:
148
+ return F.layer_norm(
149
+ x,
150
+ x.shape,
151
+ self.weight.broadcast_to(x.shape),
152
+ self.bias.broadcast_to(x.shape),
153
+ self.eps,
154
+ )
155
+
156
+
157
+ def group_norm_with_hlfb(
158
+ x: torch.Tensor,
159
+ w: torch.Tensor,
160
+ b: torch.Tensor,
161
+ num_groups: int,
162
+ eps: float,
163
+ ):
164
+ """Group Normalization with high-level function boundary enabled.
165
+
166
+ Args:
167
+ x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
168
+ w (torch.Tensor): The weight tensor for the normalization.
169
+ b (torch.Tensor): The bias tensor for the normalization.
170
+ num_groups (int): Number of groups to separate the channels into.
171
+ eps (float): A small float value to ensure numerical stability.
172
+
173
+ Returns:
174
+ The output tensor of Group Normalization.
175
+ """
176
+ x = torch.permute(x, (0, 2, 3, 1))
177
+
178
+ builder = StableHLOCompositeBuilder(
179
+ name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
180
+ )
181
+ x, w, b = builder.mark_inputs(x, w, b)
182
+ x = torch.permute(x, (0, 3, 1, 2))
183
+ y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
184
+ y = torch.permute(y, (0, 2, 3, 1))
185
+ y = builder.mark_outputs(y)
186
+
187
+ y = torch.permute(y, (0, 3, 1, 2))
188
+ return y
189
+
190
+
191
+ def layer_norm_with_hlfb(
192
+ x: torch.Tensor,
193
+ w: torch.Tensor,
194
+ b: torch.Tensor,
195
+ eps: float,
196
+ ):
197
+ """Layer Normalization with high-level function boundary enabled.
198
+
199
+ Args:
200
+ x (torch.Tensor): Input tensor for Layer Normalization.
201
+ w (torch.Tensor): The weight tensor for the normalization.
202
+ b (torch.Tensor): The bias tensor for the normalization.
203
+ eps (float): A small float value to ensure numerical stability.
204
+
205
+ Returns:
206
+ The output tensor of Layer Normalization.
207
+ """
208
+ builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
+ x, w, b = builder.mark_inputs(x, w, b)
210
+ y = F.layer_norm(
211
+ x,
212
+ x.shape,
213
+ weight=w.broadcast_to(x.shape),
214
+ bias=b.broadcast_to(x.shape),
215
+ eps=eps,
216
+ )
217
+ y = builder.mark_outputs(y)
218
+ return y
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
122
122
  config.attention_batch_size,
123
123
  config.dim,
124
124
  config.attention_config,
125
- 0,
126
125
  enable_hlfb=config.enable_hlfb,
127
126
  )
128
127
 
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
180
179
  config.query_dim,
181
180
  config.cross_dim,
182
181
  config.attention_config,
183
- 0,
184
182
  enable_hlfb=config.enable_hlfb,
185
183
  )
186
184
 
@@ -12,19 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A suite of tests to validate experimental external KV Cache layers and models.
16
15
 
17
- from ai_edge_torch.generative.examples.experimental.gemma import gemma
18
- from ai_edge_torch.generative.examples.experimental.phi import phi2
19
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
20
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
16
+ """A suite of tests to validate KV Cache layer."""
17
+
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
19
  import ai_edge_torch.generative.layers.model_config as cfg
22
20
  import torch
23
21
 
24
22
  from absl.testing import absltest as googletest
25
23
 
26
24
 
27
- class TestExternalKVLayers(googletest.TestCase):
25
+ class TestKVLayers(googletest.TestCase):
28
26
 
29
27
  def _get_test_config(
30
28
  self, num_layers, head_dim, num_query_groups, kv_cache_max_len
@@ -32,14 +30,16 @@ class TestExternalKVLayers(googletest.TestCase):
32
30
  attn_config = cfg.AttentionConfig(
33
31
  num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
34
32
  )
33
+ block_config = cfg.TransformerBlockConfig(
34
+ attn_config=attn_config, ff_config=None
35
+ )
35
36
  config = cfg.ModelConfig(
36
37
  kv_cache_max_len=kv_cache_max_len,
37
38
  embedding_dim=head_dim,
38
- attn_config=attn_config,
39
+ block_configs=block_config,
39
40
  num_layers=num_layers,
40
41
  max_seq_len=None,
41
42
  vocab_size=None,
42
- ff_config=None,
43
43
  )
44
44
  return config
45
45
 
@@ -54,7 +54,7 @@ class TestExternalKVLayers(googletest.TestCase):
54
54
  num_query_groups=NUM_QG,
55
55
  kv_cache_max_len=KV_LEN,
56
56
  )
57
- kv = kv_utils.EKVCache.from_model_config(config)
57
+ kv = kv_utils.KVCache.from_model_config(config)
58
58
  entry = kv.caches[0]
59
59
  # single-slice update
60
60
  input_pos = torch.tensor([1])
@@ -88,14 +88,14 @@ class TestExternalKVLayers(googletest.TestCase):
88
88
  def test_serialization(self):
89
89
  class TestModel(torch.nn.Module):
90
90
 
91
- def forward(self, kv: kv_utils.EKVCache) -> kv_utils.EKVCache:
91
+ def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
92
92
  updated_kv_entries = [
93
93
  kv_utils.KVCacheEntry(
94
94
  torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
95
95
  )
96
96
  for entry in kv.caches
97
97
  ]
98
- return kv_utils.EKVCache(updated_kv_entries)
98
+ return kv_utils.KVCache(updated_kv_entries)
99
99
 
100
100
  N = 1
101
101
  HEAD_DIM = 2
@@ -107,7 +107,7 @@ class TestExternalKVLayers(googletest.TestCase):
107
107
  num_query_groups=NUM_QG,
108
108
  kv_cache_max_len=KV_LEN,
109
109
  )
110
- kv = kv_utils.EKVCache.from_model_config(config)
110
+ kv = kv_utils.KVCache.from_model_config(config)
111
111
  model = TestModel()
112
112
  exported_program = torch.export.export(model, (kv,))
113
113
  input_specs = exported_program.graph_signature.input_specs
@@ -116,17 +116,5 @@ class TestExternalKVLayers(googletest.TestCase):
116
116
  self.assertEqual(input_specs[1].arg.name, "kv_v_0")
117
117
 
118
118
 
119
- class TestExternalKVModels(googletest.TestCase):
120
-
121
- def test_can_build_gemma(self):
122
- gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
123
-
124
- def test_can_build_phi2(self):
125
- phi2.define_and_run(checkpoint_path=None, test_model=True)
126
-
127
- def test_can_build_tinyllama(self):
128
- tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
129
-
130
-
131
119
  if __name__ == "__main__":
132
120
  googletest.main()
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
71
71
  safetensors.torch.save_file(test_weights, file_path)
72
72
  cfg = tiny_llama.get_model_config()
73
73
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLLamma(cfg)
74
+ model = tiny_llama.TinyLlama(cfg)
75
75
 
76
76
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
77
  # if returns successfully, it means all the tensors were initiallized.