ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
- ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
- ai_edge_torch/generative/examples/t5/t5.py +35 -22
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
- ai_edge_torch/generative/layers/attention.py +77 -73
- ai_edge_torch/generative/layers/builder.py +5 -3
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +38 -19
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +15 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,72 +12,184 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# `nn.Module` which implements a KV cache.
|
16
15
|
|
17
|
-
|
16
|
+
"""Utility functions for externalized KV Cache."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
from typing import List, Tuple
|
20
|
+
|
21
|
+
from ai_edge_torch import hlfb
|
22
|
+
from ai_edge_torch.generative.layers import model_config
|
18
23
|
import torch
|
19
|
-
|
24
|
+
import torch.utils._pytree as pytree
|
20
25
|
|
26
|
+
BATCH_SIZE = 1
|
21
27
|
|
22
|
-
class KVCache(nn.Module):
|
23
28
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
"""Initializes the KVCache layer.
|
29
|
+
@dataclasses.dataclass
|
30
|
+
class KVCacheEntry:
|
31
|
+
"""A single cache entry that includes K and V caches.
|
28
32
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
n_heads (int): number of kv heads.
|
33
|
-
head_dim (int): the head dimension size.
|
34
|
-
enable_hlfb (bool): whether hlfb is enabled or not.
|
35
|
-
"""
|
36
|
-
super().__init__()
|
37
|
-
cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
|
38
|
-
self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
|
39
|
-
self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
|
40
|
-
self.enable_hlfb = enable_hlfb
|
41
|
-
self.kv_cache_max = kv_cache_max
|
33
|
+
The chaches are built based on the provided config with the shape of
|
34
|
+
(batch_size=1, kv_cache_max, num_query_groups, head_dim).
|
35
|
+
"""
|
42
36
|
|
43
|
-
|
44
|
-
|
37
|
+
k_cache: torch.Tensor
|
38
|
+
v_cache: torch.Tensor
|
45
39
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
40
|
+
@classmethod
|
41
|
+
def from_model_config(
|
42
|
+
cls,
|
43
|
+
kv_cache_max: int,
|
44
|
+
config: model_config.AttentionConfig,
|
45
|
+
dtype: torch.dtype = torch.float32,
|
46
|
+
device: torch.device = None,
|
47
|
+
) -> "KVCacheEntry":
|
48
|
+
"""Build an instance of the class based on model config."""
|
49
|
+
shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
|
50
|
+
k = torch.zeros(shape, dtype=dtype, device=device)
|
51
|
+
v = torch.zeros(shape, dtype=dtype, device=device)
|
52
|
+
obj = cls(k_cache=k, v_cache=v)
|
53
|
+
return obj
|
50
54
|
|
51
|
-
Returns:
|
52
|
-
The updated key and value tensor.
|
53
|
-
"""
|
54
|
-
if self.enable_hlfb:
|
55
|
-
return self.update_cache_with_hlfb(input_pos, k_val, v_val)
|
56
55
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
return torch.clone(updated_k), torch.clone(updated_v)
|
56
|
+
@dataclasses.dataclass
|
57
|
+
class KVCache:
|
58
|
+
"""A utility class for holding KV cache entries per layer."""
|
61
59
|
|
62
|
-
|
63
|
-
|
60
|
+
caches: Tuple[KVCacheEntry, ...]
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def from_model_config(
|
64
|
+
cls,
|
65
|
+
config: model_config.ModelConfig,
|
66
|
+
dtype: torch.dtype = torch.float32,
|
67
|
+
device: torch.device = None,
|
68
|
+
) -> "KVCache":
|
69
|
+
"""Build an instance of the class based on model config.
|
64
70
|
|
65
71
|
Args:
|
66
|
-
|
67
|
-
|
68
|
-
|
72
|
+
config (ModelConfig): Model config used for building the cache.
|
73
|
+
dtype (torch.dtype, optional): The data type of the cache tensor.
|
74
|
+
Defaults to torch.float32.
|
75
|
+
device (torch.device, optional): The device placement of the cache
|
76
|
+
tensors. Defaults to None.
|
69
77
|
|
70
78
|
Returns:
|
71
|
-
|
79
|
+
KVCache: The created cache object.
|
72
80
|
"""
|
81
|
+
caches = [
|
82
|
+
KVCacheEntry.from_model_config(
|
83
|
+
config.kv_cache_max,
|
84
|
+
config.block_config(idx).attn_config,
|
85
|
+
dtype,
|
86
|
+
device,
|
87
|
+
)
|
88
|
+
for idx in range(config.num_layers)
|
89
|
+
]
|
90
|
+
obj = cls(caches=tuple(caches))
|
91
|
+
return obj
|
73
92
|
|
74
|
-
|
75
|
-
|
76
|
-
)
|
77
|
-
|
78
|
-
|
93
|
+
def flatten(self) -> List[torch.Tensor]:
|
94
|
+
"""Flatten the cache entries into a list of tensors with order k_i, v_i."""
|
95
|
+
flattened, _ = _flatten_kvc(self)
|
96
|
+
return flattened
|
97
|
+
|
98
|
+
|
99
|
+
def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
|
100
|
+
flattened = []
|
101
|
+
flat_names = []
|
102
|
+
none_names = []
|
103
|
+
for i, kv_entry in enumerate(kvc.caches):
|
104
|
+
flattened.append(kv_entry.k_cache)
|
105
|
+
flat_names.append(f"k_{i}")
|
106
|
+
flattened.append(kv_entry.v_cache)
|
107
|
+
flat_names.append(f"v_{i}")
|
108
|
+
return flattened, [flat_names, none_names]
|
109
|
+
|
110
|
+
|
111
|
+
def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
|
112
|
+
flattened, (flat_names, none_names) = _flatten_kvc(kvc)
|
113
|
+
return [
|
114
|
+
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
115
|
+
], flat_names
|
116
|
+
|
117
|
+
|
118
|
+
def _unflatten_kvc(
|
119
|
+
values: List[torch.Tensor], context: Tuple[List, List]
|
120
|
+
) -> KVCache:
|
121
|
+
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
122
|
+
num_layers = len(values) // 2
|
123
|
+
flat_names = context[0]
|
124
|
+
kv_entries = []
|
125
|
+
for i in range(num_layers):
|
126
|
+
k_cache_idx = flat_names.index(f"k_{i}")
|
127
|
+
v_cache_idx = flat_names.index(f"v_{i}")
|
128
|
+
kv_entries.append(
|
129
|
+
KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
|
79
130
|
)
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
131
|
+
obj = KVCache(tuple(kv_entries))
|
132
|
+
return obj
|
133
|
+
|
134
|
+
|
135
|
+
pytree.register_pytree_node(
|
136
|
+
KVCache,
|
137
|
+
_flatten_kvc,
|
138
|
+
_unflatten_kvc,
|
139
|
+
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
140
|
+
serialized_type_name="",
|
141
|
+
)
|
142
|
+
|
143
|
+
|
144
|
+
def update(
|
145
|
+
cache: KVCacheEntry,
|
146
|
+
input_pos: torch.Tensor,
|
147
|
+
k_slice: torch.Tensor,
|
148
|
+
v_slice: torch.Tensor,
|
149
|
+
enable_hlfb: bool = True,
|
150
|
+
) -> KVCacheEntry:
|
151
|
+
"""Out of place update of Cache buffer.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
cache (KVCacheEntry): The original cache buffer.
|
155
|
+
input_pos (torch.Tensor): The update slice positions.
|
156
|
+
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
157
|
+
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
158
|
+
enable_hlfb (bool, optional): Whether the op is annotated for export with
|
159
|
+
High Level Function Boundary. Defaults to True.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
163
|
+
"""
|
164
|
+
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
|
165
|
+
return update_func(cache, input_pos, k_slice, v_slice)
|
166
|
+
|
167
|
+
|
168
|
+
def _update_kv_base_impl(
|
169
|
+
cache: KVCacheEntry,
|
170
|
+
input_pos: torch.Tensor,
|
171
|
+
k_slice: torch.Tensor,
|
172
|
+
v_slice: torch.Tensor,
|
173
|
+
) -> KVCacheEntry:
|
174
|
+
"""Update the cache buffer without High Level Function Boundary annotation."""
|
175
|
+
k = cache.k_cache.index_copy(1, input_pos, k_slice)
|
176
|
+
v = cache.v_cache.index_copy(1, input_pos, v_slice)
|
177
|
+
updated_cache = KVCacheEntry(k, v)
|
178
|
+
return updated_cache
|
179
|
+
|
180
|
+
|
181
|
+
def _update_kv_hlfb_impl(
|
182
|
+
cache: KVCacheEntry,
|
183
|
+
input_pos: torch.Tensor,
|
184
|
+
k_slice: torch.Tensor,
|
185
|
+
v_slice: torch.Tensor,
|
186
|
+
) -> KVCacheEntry:
|
187
|
+
"""Update the cache buffer with High Level Function Boundary annotation."""
|
188
|
+
builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
|
189
|
+
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
|
190
|
+
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
|
191
|
+
)
|
192
|
+
k = k_cache.index_copy(1, input_pos, k_slice)
|
193
|
+
v = v_cache.index_copy(1, input_pos, v_slice)
|
194
|
+
k, v = builder.mark_outputs(k, v)
|
195
|
+
return KVCacheEntry(k, v)
|
@@ -16,7 +16,7 @@
|
|
16
16
|
from dataclasses import dataclass
|
17
17
|
from dataclasses import field
|
18
18
|
import enum
|
19
|
-
from typing import Optional, Sequence
|
19
|
+
from typing import Optional, Sequence, Union
|
20
20
|
|
21
21
|
|
22
22
|
@enum.unique
|
@@ -85,8 +85,8 @@ class AttentionConfig:
|
|
85
85
|
relative_attention_max_distance: int = 0
|
86
86
|
# Softcap on the output logits.
|
87
87
|
logit_softcap: Optional[float] = None
|
88
|
-
# The
|
89
|
-
|
88
|
+
# The type of attention.
|
89
|
+
attn_type: Optional[AttentionType] = None
|
90
90
|
# The size of the sliding window used for local attention.
|
91
91
|
sliding_window_size: Optional[int] = None
|
92
92
|
|
@@ -104,6 +104,7 @@ class NormalizationConfig:
|
|
104
104
|
"""Normalizater parameters."""
|
105
105
|
|
106
106
|
type: NormalizationType = NormalizationType.NONE
|
107
|
+
enable_hlfb: bool = False
|
107
108
|
epsilon: float = 1e-5
|
108
109
|
zero_centered: bool = False
|
109
110
|
# Number of groups used in group normalization.
|
@@ -129,13 +130,8 @@ class FeedForwardConfig:
|
|
129
130
|
|
130
131
|
|
131
132
|
@dataclass
|
132
|
-
class
|
133
|
-
"""
|
134
|
-
|
135
|
-
vocab_size: int
|
136
|
-
num_layers: int
|
137
|
-
max_seq_len: int
|
138
|
-
embedding_dim: int
|
133
|
+
class TransformerBlockConfig:
|
134
|
+
"""TransformerBlock module's parameters."""
|
139
135
|
|
140
136
|
attn_config: AttentionConfig
|
141
137
|
ff_config: FeedForwardConfig
|
@@ -147,15 +143,33 @@ class ModelConfig:
|
|
147
143
|
post_attention_norm_config: NormalizationConfig = field(
|
148
144
|
default_factory=NormalizationConfig
|
149
145
|
)
|
146
|
+
# If set to True, only attn_config.pre_attention_norm is applied to the input
|
147
|
+
# and the decode's output is computed as `output = input + attn_out + ff_out`
|
148
|
+
# where attention and feed forward are called with pre_attention_norm's
|
149
|
+
# output.
|
150
|
+
parallel_residual: bool = False
|
151
|
+
# The Attention computation will include relative positional bias.
|
152
|
+
relative_attention: bool = False
|
153
|
+
|
154
|
+
|
155
|
+
@dataclass
|
156
|
+
class ModelConfig:
|
157
|
+
"""Base configurations for building a transformer architecture."""
|
158
|
+
|
159
|
+
vocab_size: int
|
160
|
+
num_layers: int
|
161
|
+
max_seq_len: int
|
162
|
+
embedding_dim: int
|
163
|
+
|
164
|
+
# TransformerBlockConfig for each layer block. If a single
|
165
|
+
# TransformerBlockConfig is provided, it will be used for all layers.
|
166
|
+
block_configs: Union[TransformerBlockConfig, Sequence[TransformerBlockConfig]]
|
167
|
+
|
150
168
|
# The normalization applied before LM head.
|
151
169
|
final_norm_config: NormalizationConfig = field(
|
152
170
|
default_factory=NormalizationConfig
|
153
171
|
)
|
154
172
|
|
155
|
-
# If set to True, only pre_attention_norm is applied to the input and the
|
156
|
-
# decode's output is computed as `output = input + attn_out + ff_out` where
|
157
|
-
# attention and feed forward are called with pre_attention_norm's output.
|
158
|
-
parallel_residual: bool = False
|
159
173
|
# Use bias term within LLM's HEAD.
|
160
174
|
lm_head_use_bias: bool = False
|
161
175
|
# Whether to turn on high-level function boundary.
|
@@ -164,9 +178,6 @@ class ModelConfig:
|
|
164
178
|
# The maximum sequence length of the KV cache. Should not exceed max_seq_len.
|
165
179
|
kv_cache_max_len: int = 0
|
166
180
|
|
167
|
-
# The Attention computation will include relative positional bias.
|
168
|
-
relative_attention: bool = False
|
169
|
-
|
170
181
|
# Default batch size of the exported model. Default value is 1.
|
171
182
|
batch_size: int = 1
|
172
183
|
|
@@ -177,5 +188,13 @@ class ModelConfig:
|
|
177
188
|
def kv_cache_max(self) -> int:
|
178
189
|
if self.kv_cache_max_len > 0:
|
179
190
|
return self.kv_cache_max_len
|
180
|
-
|
181
|
-
|
191
|
+
return self.max_seq_len
|
192
|
+
|
193
|
+
def block_config(self, idx: int) -> TransformerBlockConfig:
|
194
|
+
if isinstance(self.block_configs, TransformerBlockConfig):
|
195
|
+
return self.block_configs
|
196
|
+
if idx < 0 or idx >= len(self.block_configs):
|
197
|
+
raise ValueError(
|
198
|
+
f"Index {idx} is out of range for layer configs: {self.block_configs}"
|
199
|
+
)
|
200
|
+
return self.block_configs[idx]
|
@@ -14,7 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
# Common normalization layers.
|
16
16
|
|
17
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
17
18
|
import torch
|
19
|
+
from torch import nn
|
20
|
+
import torch.nn.functional as F
|
18
21
|
|
19
22
|
|
20
23
|
# Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
|
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
|
|
58
61
|
return output * (1 + self.weight)
|
59
62
|
else:
|
60
63
|
return output * self.weight
|
64
|
+
|
65
|
+
|
66
|
+
class GroupNorm(torch.nn.Module):
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
group_num: int,
|
71
|
+
dim: int,
|
72
|
+
eps: float = 1e-5,
|
73
|
+
enable_hlfb: bool = False,
|
74
|
+
):
|
75
|
+
"""Initialize the GroupNorm layer.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
group_num (int): Number of groups to separate the channels into.
|
79
|
+
dim (int): Dimension of the input tensor.
|
80
|
+
eps (float): A small float value to ensure numerical stability (default:
|
81
|
+
1e-6).
|
82
|
+
enable_hlfb (bool): Whether to convert this normalization into a single
|
83
|
+
op.
|
84
|
+
"""
|
85
|
+
super().__init__()
|
86
|
+
self.enable_hlfb = enable_hlfb
|
87
|
+
self.group_num = group_num
|
88
|
+
self.eps = eps
|
89
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
90
|
+
self.bias = torch.nn.Parameter(torch.ones(dim))
|
91
|
+
|
92
|
+
def forward(self, x):
|
93
|
+
"""Running the forward pass of GroupNorm layer.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
x (torch.Tensor): input tensor.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
torch.Tensor: output tensor after applying GroupNorm.
|
100
|
+
"""
|
101
|
+
if self.enable_hlfb:
|
102
|
+
return group_norm_with_hlfb(
|
103
|
+
x,
|
104
|
+
self.weight,
|
105
|
+
self.bias,
|
106
|
+
self.group_num,
|
107
|
+
self.eps,
|
108
|
+
)
|
109
|
+
else:
|
110
|
+
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
|
111
|
+
|
112
|
+
|
113
|
+
class LayerNorm(torch.nn.Module):
|
114
|
+
|
115
|
+
def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
|
116
|
+
"""Initialize the LayerNorm layer.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
dim (int): dimension of the input tensor.
|
120
|
+
eps (float): A small float value to ensure numerical stability (default:
|
121
|
+
1e-6).
|
122
|
+
enable_hlfb (bool): Whether to convert this normalization into a single
|
123
|
+
op.
|
124
|
+
"""
|
125
|
+
super().__init__()
|
126
|
+
self.enable_hlfb = enable_hlfb
|
127
|
+
self.eps = eps
|
128
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
129
|
+
self.bias = torch.nn.Parameter(torch.ones(dim))
|
130
|
+
|
131
|
+
def forward(self, x):
|
132
|
+
"""Running the forward pass of LayerNorm layer.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
x (torch.Tensor): input tensor.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
torch.Tensor: output tensor after applying LayerNorm.
|
139
|
+
"""
|
140
|
+
if self.enable_hlfb:
|
141
|
+
return layer_norm_with_hlfb(
|
142
|
+
x,
|
143
|
+
self.weight,
|
144
|
+
self.bias,
|
145
|
+
self.eps,
|
146
|
+
)
|
147
|
+
else:
|
148
|
+
return F.layer_norm(
|
149
|
+
x,
|
150
|
+
x.shape,
|
151
|
+
self.weight.broadcast_to(x.shape),
|
152
|
+
self.bias.broadcast_to(x.shape),
|
153
|
+
self.eps,
|
154
|
+
)
|
155
|
+
|
156
|
+
|
157
|
+
def group_norm_with_hlfb(
|
158
|
+
x: torch.Tensor,
|
159
|
+
w: torch.Tensor,
|
160
|
+
b: torch.Tensor,
|
161
|
+
num_groups: int,
|
162
|
+
eps: float,
|
163
|
+
):
|
164
|
+
"""Group Normalization with high-level function boundary enabled.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
|
168
|
+
w (torch.Tensor): The weight tensor for the normalization.
|
169
|
+
b (torch.Tensor): The bias tensor for the normalization.
|
170
|
+
num_groups (int): Number of groups to separate the channels into.
|
171
|
+
eps (float): A small float value to ensure numerical stability.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
The output tensor of Group Normalization.
|
175
|
+
"""
|
176
|
+
x = torch.permute(x, (0, 2, 3, 1))
|
177
|
+
|
178
|
+
builder = StableHLOCompositeBuilder(
|
179
|
+
name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
|
180
|
+
)
|
181
|
+
x, w, b = builder.mark_inputs(x, w, b)
|
182
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
183
|
+
y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
|
184
|
+
y = torch.permute(y, (0, 2, 3, 1))
|
185
|
+
y = builder.mark_outputs(y)
|
186
|
+
|
187
|
+
y = torch.permute(y, (0, 3, 1, 2))
|
188
|
+
return y
|
189
|
+
|
190
|
+
|
191
|
+
def layer_norm_with_hlfb(
|
192
|
+
x: torch.Tensor,
|
193
|
+
w: torch.Tensor,
|
194
|
+
b: torch.Tensor,
|
195
|
+
eps: float,
|
196
|
+
):
|
197
|
+
"""Layer Normalization with high-level function boundary enabled.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
x (torch.Tensor): Input tensor for Layer Normalization.
|
201
|
+
w (torch.Tensor): The weight tensor for the normalization.
|
202
|
+
b (torch.Tensor): The bias tensor for the normalization.
|
203
|
+
eps (float): A small float value to ensure numerical stability.
|
204
|
+
|
205
|
+
Returns:
|
206
|
+
The output tensor of Layer Normalization.
|
207
|
+
"""
|
208
|
+
builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
|
209
|
+
x, w, b = builder.mark_inputs(x, w, b)
|
210
|
+
y = F.layer_norm(
|
211
|
+
x,
|
212
|
+
x.shape,
|
213
|
+
weight=w.broadcast_to(x.shape),
|
214
|
+
bias=b.broadcast_to(x.shape),
|
215
|
+
eps=eps,
|
216
|
+
)
|
217
|
+
y = builder.mark_outputs(y)
|
218
|
+
return y
|
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
|
|
122
122
|
config.attention_batch_size,
|
123
123
|
config.dim,
|
124
124
|
config.attention_config,
|
125
|
-
0,
|
126
125
|
enable_hlfb=config.enable_hlfb,
|
127
126
|
)
|
128
127
|
|
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
|
|
180
179
|
config.query_dim,
|
181
180
|
config.cross_dim,
|
182
181
|
config.attention_config,
|
183
|
-
0,
|
184
182
|
enable_hlfb=config.enable_hlfb,
|
185
183
|
)
|
186
184
|
|
@@ -12,19 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# A suite of tests to validate experimental external KV Cache layers and models.
|
16
15
|
|
17
|
-
|
18
|
-
|
19
|
-
from ai_edge_torch.generative.
|
20
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
16
|
+
"""A suite of tests to validate KV Cache layer."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
20
|
import torch
|
23
21
|
|
24
22
|
from absl.testing import absltest as googletest
|
25
23
|
|
26
24
|
|
27
|
-
class
|
25
|
+
class TestKVLayers(googletest.TestCase):
|
28
26
|
|
29
27
|
def _get_test_config(
|
30
28
|
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
@@ -32,14 +30,16 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
32
30
|
attn_config = cfg.AttentionConfig(
|
33
31
|
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
34
32
|
)
|
33
|
+
block_config = cfg.TransformerBlockConfig(
|
34
|
+
attn_config=attn_config, ff_config=None
|
35
|
+
)
|
35
36
|
config = cfg.ModelConfig(
|
36
37
|
kv_cache_max_len=kv_cache_max_len,
|
37
38
|
embedding_dim=head_dim,
|
38
|
-
|
39
|
+
block_configs=block_config,
|
39
40
|
num_layers=num_layers,
|
40
41
|
max_seq_len=None,
|
41
42
|
vocab_size=None,
|
42
|
-
ff_config=None,
|
43
43
|
)
|
44
44
|
return config
|
45
45
|
|
@@ -54,7 +54,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
54
54
|
num_query_groups=NUM_QG,
|
55
55
|
kv_cache_max_len=KV_LEN,
|
56
56
|
)
|
57
|
-
kv = kv_utils.
|
57
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
58
58
|
entry = kv.caches[0]
|
59
59
|
# single-slice update
|
60
60
|
input_pos = torch.tensor([1])
|
@@ -88,14 +88,14 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
88
88
|
def test_serialization(self):
|
89
89
|
class TestModel(torch.nn.Module):
|
90
90
|
|
91
|
-
def forward(self, kv: kv_utils.
|
91
|
+
def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
|
92
92
|
updated_kv_entries = [
|
93
93
|
kv_utils.KVCacheEntry(
|
94
94
|
torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
|
95
95
|
)
|
96
96
|
for entry in kv.caches
|
97
97
|
]
|
98
|
-
return kv_utils.
|
98
|
+
return kv_utils.KVCache(updated_kv_entries)
|
99
99
|
|
100
100
|
N = 1
|
101
101
|
HEAD_DIM = 2
|
@@ -107,7 +107,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
107
107
|
num_query_groups=NUM_QG,
|
108
108
|
kv_cache_max_len=KV_LEN,
|
109
109
|
)
|
110
|
-
kv = kv_utils.
|
110
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
111
111
|
model = TestModel()
|
112
112
|
exported_program = torch.export.export(model, (kv,))
|
113
113
|
input_specs = exported_program.graph_signature.input_specs
|
@@ -116,17 +116,5 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
116
116
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
117
117
|
|
118
118
|
|
119
|
-
class TestExternalKVModels(googletest.TestCase):
|
120
|
-
|
121
|
-
def test_can_build_gemma(self):
|
122
|
-
gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
|
123
|
-
|
124
|
-
def test_can_build_phi2(self):
|
125
|
-
phi2.define_and_run(checkpoint_path=None, test_model=True)
|
126
|
-
|
127
|
-
def test_can_build_tinyllama(self):
|
128
|
-
tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
|
129
|
-
|
130
|
-
|
131
119
|
if __name__ == "__main__":
|
132
120
|
googletest.main()
|
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
|
|
71
71
|
safetensors.torch.save_file(test_weights, file_path)
|
72
72
|
cfg = tiny_llama.get_model_config()
|
73
73
|
cfg.num_layers = 1
|
74
|
-
model = tiny_llama.
|
74
|
+
model = tiny_llama.TinyLlama(cfg)
|
75
75
|
|
76
76
|
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
77
77
|
# if returns successfully, it means all the tensors were initiallized.
|