ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +30 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +330 -0
- ai_edge_torch/convert/converter.py +171 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
- ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +273 -0
- ai_edge_torch/convert/test/test_convert_composites.py +171 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/debug/__init__.py +16 -0
- ai_edge_torch/debug/culprit.py +423 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +288 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +103 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +135 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_model_conversion.py +201 -0
- ai_edge_torch/generative/test/test_quantize.py +109 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +290 -0
- ai_edge_torch/generative/utilities/t5_loader.py +467 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +134 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Attention modules for the T5 encoder-decoder model family.
|
|
16
|
+
|
|
17
|
+
from typing import Optional, Tuple
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch import nn
|
|
21
|
+
import torch.nn.functional as F
|
|
22
|
+
|
|
23
|
+
from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention # NOQA
|
|
24
|
+
from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
25
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
26
|
+
from ai_edge_torch.generative.layers.kv_cache import KVCache
|
|
27
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class EncoderDecoderBlock(nn.Module):
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self, config: cfg.ModelConfig, has_relative_attention_bias: bool = False
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initialize an instance of the EncoderDecoderBlock.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
config (cfg.ModelConfig): the configuration object
|
|
39
|
+
for this transformer block.
|
|
40
|
+
has_relative_attention_bias (bool): whether the
|
|
41
|
+
self attention block has relative bias.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.atten_func = T5Attention(
|
|
46
|
+
config.embedding_dim,
|
|
47
|
+
config.attn_config,
|
|
48
|
+
config.pre_attention_norm_config,
|
|
49
|
+
config.kv_cache_max,
|
|
50
|
+
config.enable_hlfb,
|
|
51
|
+
has_relative_attention_bias=has_relative_attention_bias,
|
|
52
|
+
)
|
|
53
|
+
# For a decoder, we add a cross attention.
|
|
54
|
+
if config.is_decoder:
|
|
55
|
+
self.cross_atten_func = T5Attention(
|
|
56
|
+
config.embedding_dim,
|
|
57
|
+
config.attn_config,
|
|
58
|
+
config.pre_attention_norm_config,
|
|
59
|
+
config.kv_cache_max,
|
|
60
|
+
config.enable_hlfb,
|
|
61
|
+
# Cross Attention does not have relative attention bias.
|
|
62
|
+
has_relative_attention_bias=False,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
self.cross_atten_func = None
|
|
66
|
+
|
|
67
|
+
self.pre_ff_norm = builder.build_norm(
|
|
68
|
+
config.embedding_dim, config.pre_ff_norm_config
|
|
69
|
+
)
|
|
70
|
+
self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
|
|
71
|
+
self.config = config
|
|
72
|
+
|
|
73
|
+
def forward(
|
|
74
|
+
self,
|
|
75
|
+
x: torch.Tensor,
|
|
76
|
+
input_pos: Optional[torch.Tensor] = None,
|
|
77
|
+
mask: Optional[torch.Tensor] = None,
|
|
78
|
+
relative_position: Optional[torch.Tensor] = None,
|
|
79
|
+
position_bias: Optional[torch.Tensor] = None,
|
|
80
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
81
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
82
|
+
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
|
83
|
+
) -> torch.Tensor:
|
|
84
|
+
"""Forward function of the EncoderDecoderBlock.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
x (torch.Tensor): the input tensor.
|
|
88
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
|
89
|
+
mask (torch.Tensor): the optional mask tensor.
|
|
90
|
+
input_pos (torch.Tensor): the optional input position tensor.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
output activation from this transformer block.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
hidden_states, position_bias = self.atten_func(
|
|
97
|
+
x,
|
|
98
|
+
input_pos=input_pos,
|
|
99
|
+
mask=mask,
|
|
100
|
+
relative_position=relative_position,
|
|
101
|
+
position_bias=position_bias,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
attn_out = hidden_states + x
|
|
105
|
+
|
|
106
|
+
if self.cross_atten_func:
|
|
107
|
+
hidden_states, encoder_decoder_position_bias = self.cross_atten_func(
|
|
108
|
+
attn_out,
|
|
109
|
+
input_pos=input_pos,
|
|
110
|
+
key_value_states=encoder_hidden_states,
|
|
111
|
+
mask=encoder_attention_mask,
|
|
112
|
+
relative_position=relative_position,
|
|
113
|
+
position_bias=encoder_decoder_position_bias,
|
|
114
|
+
)
|
|
115
|
+
attn_out = hidden_states + attn_out
|
|
116
|
+
|
|
117
|
+
forwarded = self.pre_ff_norm(attn_out)
|
|
118
|
+
forwarded = self.ff(forwarded)
|
|
119
|
+
hidden_states = attn_out + forwarded
|
|
120
|
+
|
|
121
|
+
# encoder_deocder_position_bias is from CrossAttention
|
|
122
|
+
return hidden_states, position_bias, encoder_decoder_position_bias
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class T5Attention(nn.Module):
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
dim: int,
|
|
130
|
+
config: cfg.AttentionConfig,
|
|
131
|
+
norm_config: cfg.NormalizationConfig,
|
|
132
|
+
kv_cache_max: int,
|
|
133
|
+
enable_hlfb: bool,
|
|
134
|
+
has_relative_attention_bias=False,
|
|
135
|
+
) -> None:
|
|
136
|
+
"""Initialize an instance of T5Attention.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
dim (int): causal attention's input/output dimmension.
|
|
140
|
+
config (cfg.AttentionConfig): attention specific configurations.
|
|
141
|
+
kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
|
|
142
|
+
enable_hlfb (bool): whether hlfb is enabled or not.
|
|
143
|
+
has_relative_attention_bias (bool): whether we compute relative bias.
|
|
144
|
+
"""
|
|
145
|
+
super().__init__()
|
|
146
|
+
self.pre_atten_norm = builder.build_norm(dim, norm_config)
|
|
147
|
+
|
|
148
|
+
self.has_relative_attention_bias = has_relative_attention_bias
|
|
149
|
+
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
|
150
|
+
self.d_model = dim
|
|
151
|
+
self.head_dim = dim // config.num_heads
|
|
152
|
+
self.n_heads = config.num_heads
|
|
153
|
+
self.inner_dim = self.n_heads * self.head_dim
|
|
154
|
+
|
|
155
|
+
self.q = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
|
|
156
|
+
self.k = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
|
|
157
|
+
self.v = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
|
|
158
|
+
# output projection
|
|
159
|
+
self.proj = nn.Linear(
|
|
160
|
+
self.inner_dim, self.d_model, bias=config.output_proj_use_bias
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if self.has_relative_attention_bias:
|
|
164
|
+
self.relative_attention_bias = nn.Embedding(
|
|
165
|
+
self.relative_attention_num_buckets, self.n_heads
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.config = config
|
|
169
|
+
self.kv_cache = None
|
|
170
|
+
# Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
|
|
171
|
+
# Now only supports a max batch_size of 1.
|
|
172
|
+
if config.enable_kv_cache:
|
|
173
|
+
self.kv_cache = KVCache(
|
|
174
|
+
1,
|
|
175
|
+
kv_cache_max,
|
|
176
|
+
config.num_query_groups,
|
|
177
|
+
self.head_dim,
|
|
178
|
+
enable_hlfb,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if enable_hlfb:
|
|
182
|
+
self.sdpa_func = scaled_dot_product_attention_with_hlfb
|
|
183
|
+
else:
|
|
184
|
+
self.sdpa_func = scaled_dot_product_attention
|
|
185
|
+
|
|
186
|
+
def forward(
|
|
187
|
+
self,
|
|
188
|
+
x: torch.Tensor,
|
|
189
|
+
input_pos: Optional[torch.Tensor] = None,
|
|
190
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
191
|
+
mask: Optional[torch.Tensor] = None,
|
|
192
|
+
relative_position: Optional[torch.Tensor] = None,
|
|
193
|
+
position_bias: Optional[torch.Tensor] = None,
|
|
194
|
+
) -> torch.Tensor:
|
|
195
|
+
"""Forward function of the T5Attention layer.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
x (torch.Tensor): the input tensor.
|
|
199
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
|
200
|
+
mask (torch.Tensor): the optional mask tensor.
|
|
201
|
+
input_pos (torch.Tensor): the optional input position tensor.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
output activation from this self attention layer.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
x = self.pre_atten_norm(x)
|
|
208
|
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
209
|
+
query_states = self.q(x)
|
|
210
|
+
query_states = query_states.reshape(B, T, -1, self.head_dim) # (B, T, nh_q, hs)
|
|
211
|
+
|
|
212
|
+
if key_value_states is not None:
|
|
213
|
+
(
|
|
214
|
+
kvB,
|
|
215
|
+
kvT,
|
|
216
|
+
kvC,
|
|
217
|
+
) = (
|
|
218
|
+
key_value_states.size()
|
|
219
|
+
) # batch size, sequence length, embedding dimensionality (n_embd)
|
|
220
|
+
key_states = self.k(key_value_states)
|
|
221
|
+
value_states = self.v(key_value_states)
|
|
222
|
+
key_states = key_states.reshape(kvB, kvT, -1, self.head_dim)
|
|
223
|
+
value_states = value_states.reshape(kvB, kvT, -1, self.head_dim)
|
|
224
|
+
else:
|
|
225
|
+
key_states = self.k(x)
|
|
226
|
+
value_states = self.v(x)
|
|
227
|
+
key_states = key_states.reshape(B, T, -1, self.head_dim)
|
|
228
|
+
value_states = value_states.reshape(B, T, -1, self.head_dim)
|
|
229
|
+
|
|
230
|
+
if key_value_states is None and self.kv_cache is not None:
|
|
231
|
+
key_states, value_states = self.kv_cache.update_cache(
|
|
232
|
+
input_pos, key_states, value_states
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if position_bias is None:
|
|
236
|
+
# handle the encoder case first
|
|
237
|
+
if self.has_relative_attention_bias:
|
|
238
|
+
position_bias = self.relative_attention_bias(
|
|
239
|
+
relative_position
|
|
240
|
+
) # shape (query_length, key_length, num_heads)
|
|
241
|
+
position_bias = position_bias.permute([0, 1, 4, 2, 3]).squeeze(
|
|
242
|
+
0
|
|
243
|
+
) # shape (1, num_heads, query_length, key_length)
|
|
244
|
+
else:
|
|
245
|
+
# position_bias = torch.zeros(B, self.n_heads, T, self.head_dim, dtype=torch.float32)
|
|
246
|
+
position_bias = torch.zeros_like(mask, dtype=torch.float32)
|
|
247
|
+
|
|
248
|
+
mask = mask + position_bias
|
|
249
|
+
y = self.sdpa_func(
|
|
250
|
+
query_states, key_states, value_states, self.head_dim, mask=mask, scale=1.0
|
|
251
|
+
)
|
|
252
|
+
y = y.reshape(B, T, C) # re-assemble all head outputs side by side
|
|
253
|
+
# output projection
|
|
254
|
+
y = self.proj(y)
|
|
255
|
+
return y, position_bias
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# A toy example which has a single-layer transformer block.
|
|
16
|
+
from typing import Tuple
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
|
|
22
|
+
import ai_edge_torch
|
|
23
|
+
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
24
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
25
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
26
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
27
|
+
|
|
28
|
+
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
29
|
+
KV_CACHE_MAX_LEN = 100
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ToySingleLayerModel(torch.nn.Module):
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: cfg.ModelConfig) -> None:
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.lm_head = nn.Linear(
|
|
37
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
38
|
+
)
|
|
39
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
40
|
+
self.transformer_block = TransformerBlock(config)
|
|
41
|
+
self.final_norm = builder.build_norm(
|
|
42
|
+
config.embedding_dim,
|
|
43
|
+
config.final_norm_config,
|
|
44
|
+
)
|
|
45
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
46
|
+
size=config.max_seq_len,
|
|
47
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
48
|
+
base=10_000,
|
|
49
|
+
condense_ratio=1,
|
|
50
|
+
dtype=torch.float32,
|
|
51
|
+
device=torch.device('cpu'),
|
|
52
|
+
)
|
|
53
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
54
|
+
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
|
55
|
+
)
|
|
56
|
+
self.config = config
|
|
57
|
+
|
|
58
|
+
@torch.inference_mode
|
|
59
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
x = self.tok_embedding(idx)
|
|
61
|
+
cos, sin = self.rope_cache
|
|
62
|
+
|
|
63
|
+
cos = cos.index_select(0, input_pos)
|
|
64
|
+
sin = sin.index_select(0, input_pos)
|
|
65
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
66
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
|
67
|
+
|
|
68
|
+
x = self.transformer_block(x, (cos, sin), mask, input_pos)
|
|
69
|
+
x = self.final_norm(x)
|
|
70
|
+
return self.lm_head(x)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def define_and_run() -> None:
|
|
74
|
+
attn_config = cfg.AttentionConfig(
|
|
75
|
+
num_heads=32, num_query_groups=4, rotary_percentage=1.0, enable_kv_cache=False
|
|
76
|
+
)
|
|
77
|
+
ff_config = cfg.FeedForwardConfig(
|
|
78
|
+
type=cfg.FeedForwardType.GATED,
|
|
79
|
+
activation=cfg.ActivationType.SILU,
|
|
80
|
+
intermediate_size=256,
|
|
81
|
+
)
|
|
82
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
83
|
+
config = cfg.ModelConfig(
|
|
84
|
+
vocab_size=400,
|
|
85
|
+
num_layers=1,
|
|
86
|
+
max_seq_len=KV_CACHE_MAX_LEN,
|
|
87
|
+
embedding_dim=128,
|
|
88
|
+
attn_config=attn_config,
|
|
89
|
+
ff_config=ff_config,
|
|
90
|
+
pre_attention_norm_config=norm_config,
|
|
91
|
+
pre_ff_norm_config=norm_config,
|
|
92
|
+
final_norm_config=norm_config,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
model = ToySingleLayerModel(config)
|
|
96
|
+
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
|
|
97
|
+
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
|
|
98
|
+
print('running an inference')
|
|
99
|
+
print(
|
|
100
|
+
model.forward(
|
|
101
|
+
idx,
|
|
102
|
+
input_pos,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Convert model to tflite.
|
|
107
|
+
print('converting model to tflite')
|
|
108
|
+
edge_model = ai_edge_torch.convert(
|
|
109
|
+
model,
|
|
110
|
+
(
|
|
111
|
+
idx,
|
|
112
|
+
input_pos,
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
edge_model.export('/tmp/toy_model.tflite')
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
if __name__ == '__main__':
|
|
119
|
+
define_and_run()
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# A toy example which has basic transformer block (w/ KV-Cache).
|
|
16
|
+
from typing import List, Tuple
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
import torch_xla
|
|
22
|
+
|
|
23
|
+
import ai_edge_torch
|
|
24
|
+
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
25
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
26
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
27
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
28
|
+
|
|
29
|
+
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ToyModelWithKV(torch.nn.Module):
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: cfg.ModelConfig) -> None:
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.lm_head = nn.Linear(
|
|
37
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
38
|
+
)
|
|
39
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
40
|
+
self.transformer_blocks = nn.ModuleList(
|
|
41
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
42
|
+
)
|
|
43
|
+
self.final_norm = builder.build_norm(
|
|
44
|
+
config.embedding_dim,
|
|
45
|
+
config.final_norm_config,
|
|
46
|
+
)
|
|
47
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
48
|
+
size=config.max_seq_len,
|
|
49
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
50
|
+
base=10_000,
|
|
51
|
+
condense_ratio=1,
|
|
52
|
+
dtype=torch.float32,
|
|
53
|
+
device=torch.device('cpu'),
|
|
54
|
+
)
|
|
55
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
56
|
+
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
|
57
|
+
)
|
|
58
|
+
self.config = config
|
|
59
|
+
|
|
60
|
+
@torch.inference_mode
|
|
61
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
|
62
|
+
x = self.tok_embedding(idx)
|
|
63
|
+
cos, sin = self.rope_cache
|
|
64
|
+
cos = cos.index_select(0, input_pos)
|
|
65
|
+
sin = sin.index_select(0, input_pos)
|
|
66
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
67
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
|
68
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
69
|
+
x = block(x, (cos, sin), mask, input_pos)
|
|
70
|
+
x = self.final_norm(x)
|
|
71
|
+
return self.lm_head(x)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _export_stablehlo_mlir(model, args):
|
|
75
|
+
ep = torch.export.export(model, args)
|
|
76
|
+
stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
|
|
77
|
+
return stablehlo_gm.get_stablehlo_text()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_model_config() -> cfg.ModelConfig:
|
|
81
|
+
attn_config = cfg.AttentionConfig(
|
|
82
|
+
num_heads=32, num_query_groups=4, rotary_percentage=1.0
|
|
83
|
+
)
|
|
84
|
+
ff_config = cfg.FeedForwardConfig(
|
|
85
|
+
type=cfg.FeedForwardType.GATED,
|
|
86
|
+
activation=cfg.ActivationType.SILU,
|
|
87
|
+
intermediate_size=256,
|
|
88
|
+
)
|
|
89
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
90
|
+
config = cfg.ModelConfig(
|
|
91
|
+
vocab_size=150,
|
|
92
|
+
num_layers=2,
|
|
93
|
+
max_seq_len=500,
|
|
94
|
+
embedding_dim=128,
|
|
95
|
+
attn_config=attn_config,
|
|
96
|
+
ff_config=ff_config,
|
|
97
|
+
pre_attention_norm_config=norm_config,
|
|
98
|
+
pre_ff_norm_config=norm_config,
|
|
99
|
+
final_norm_config=norm_config,
|
|
100
|
+
enable_hlfb=True,
|
|
101
|
+
)
|
|
102
|
+
return config
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
106
|
+
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
|
107
|
+
input_pos = torch.arange(0, 100)
|
|
108
|
+
return idx, input_pos
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
112
|
+
idx = torch.tensor([[1]], dtype=torch.long)
|
|
113
|
+
input_pos = torch.tensor([10], dtype=torch.int64)
|
|
114
|
+
return idx, input_pos
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def define_and_run() -> None:
|
|
118
|
+
dump_mlir = False
|
|
119
|
+
|
|
120
|
+
config = get_model_config()
|
|
121
|
+
model = ToyModelWithKV(config)
|
|
122
|
+
print('running an inference')
|
|
123
|
+
idx, input_pos = get_sample_prefill_inputs()
|
|
124
|
+
decode_idx, decode_input_pos = get_sample_decode_inputs()
|
|
125
|
+
print(model.forward(idx, input_pos))
|
|
126
|
+
|
|
127
|
+
if dump_mlir:
|
|
128
|
+
mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
|
|
129
|
+
with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
|
|
130
|
+
f.write(mlir_text)
|
|
131
|
+
|
|
132
|
+
# Convert model to tflite with 2 signatures (prefill + decode).
|
|
133
|
+
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
|
134
|
+
edge_model = (
|
|
135
|
+
ai_edge_torch.signature('prefill', model, (idx, input_pos))
|
|
136
|
+
.signature('decode', model, (decode_idx, decode_input_pos))
|
|
137
|
+
.convert()
|
|
138
|
+
)
|
|
139
|
+
edge_model.export('/tmp/toy_kv_cache.tflite')
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
if __name__ == '__main__':
|
|
143
|
+
define_and_run()
|
|
File without changes
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
import ai_edge_torch
|
|
22
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
23
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def convert_tiny_llama_to_tflite(
|
|
27
|
+
checkpoint_path: str,
|
|
28
|
+
prefill_seq_len: int = 512,
|
|
29
|
+
kv_cache_max_len: int = 1024,
|
|
30
|
+
quantize: bool = True,
|
|
31
|
+
):
|
|
32
|
+
"""An example method for converting TinyLlama model to multi-signature
|
|
33
|
+
tflite model.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
|
|
37
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
38
|
+
Defaults to 512.
|
|
39
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
40
|
+
including both prefill and decode. Defaults to 1024.
|
|
41
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
42
|
+
Defaults to True.
|
|
43
|
+
"""
|
|
44
|
+
pytorch_model = tiny_llama.build_model(
|
|
45
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
46
|
+
)
|
|
47
|
+
# Tensors used to trace the model graph during conversion.
|
|
48
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
49
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
50
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
51
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
52
|
+
|
|
53
|
+
quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
|
|
54
|
+
edge_model = (
|
|
55
|
+
ai_edge_torch.signature(
|
|
56
|
+
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
|
|
57
|
+
)
|
|
58
|
+
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
|
59
|
+
.convert(quant_config=quant_config)
|
|
60
|
+
)
|
|
61
|
+
edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
if __name__ == '__main__':
|
|
65
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
|
|
66
|
+
convert_tiny_llama_to_tflite(checkpoint_path)
|