ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,255 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Attention modules for the T5 encoder-decoder model family.
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+
23
+ from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention # NOQA
24
+ from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention_with_hlfb # NOQA
25
+ import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.kv_cache import KVCache
27
+ import ai_edge_torch.generative.layers.model_config as cfg
28
+
29
+
30
+ class EncoderDecoderBlock(nn.Module):
31
+
32
+ def __init__(
33
+ self, config: cfg.ModelConfig, has_relative_attention_bias: bool = False
34
+ ) -> None:
35
+ """Initialize an instance of the EncoderDecoderBlock.
36
+
37
+ Args:
38
+ config (cfg.ModelConfig): the configuration object
39
+ for this transformer block.
40
+ has_relative_attention_bias (bool): whether the
41
+ self attention block has relative bias.
42
+ """
43
+
44
+ super().__init__()
45
+ self.atten_func = T5Attention(
46
+ config.embedding_dim,
47
+ config.attn_config,
48
+ config.pre_attention_norm_config,
49
+ config.kv_cache_max,
50
+ config.enable_hlfb,
51
+ has_relative_attention_bias=has_relative_attention_bias,
52
+ )
53
+ # For a decoder, we add a cross attention.
54
+ if config.is_decoder:
55
+ self.cross_atten_func = T5Attention(
56
+ config.embedding_dim,
57
+ config.attn_config,
58
+ config.pre_attention_norm_config,
59
+ config.kv_cache_max,
60
+ config.enable_hlfb,
61
+ # Cross Attention does not have relative attention bias.
62
+ has_relative_attention_bias=False,
63
+ )
64
+ else:
65
+ self.cross_atten_func = None
66
+
67
+ self.pre_ff_norm = builder.build_norm(
68
+ config.embedding_dim, config.pre_ff_norm_config
69
+ )
70
+ self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
71
+ self.config = config
72
+
73
+ def forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ input_pos: Optional[torch.Tensor] = None,
77
+ mask: Optional[torch.Tensor] = None,
78
+ relative_position: Optional[torch.Tensor] = None,
79
+ position_bias: Optional[torch.Tensor] = None,
80
+ encoder_hidden_states: Optional[torch.Tensor] = None,
81
+ encoder_attention_mask: Optional[torch.Tensor] = None,
82
+ encoder_decoder_position_bias: Optional[torch.Tensor] = None,
83
+ ) -> torch.Tensor:
84
+ """Forward function of the EncoderDecoderBlock.
85
+
86
+ Args:
87
+ x (torch.Tensor): the input tensor.
88
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
89
+ mask (torch.Tensor): the optional mask tensor.
90
+ input_pos (torch.Tensor): the optional input position tensor.
91
+
92
+ Returns:
93
+ output activation from this transformer block.
94
+ """
95
+
96
+ hidden_states, position_bias = self.atten_func(
97
+ x,
98
+ input_pos=input_pos,
99
+ mask=mask,
100
+ relative_position=relative_position,
101
+ position_bias=position_bias,
102
+ )
103
+
104
+ attn_out = hidden_states + x
105
+
106
+ if self.cross_atten_func:
107
+ hidden_states, encoder_decoder_position_bias = self.cross_atten_func(
108
+ attn_out,
109
+ input_pos=input_pos,
110
+ key_value_states=encoder_hidden_states,
111
+ mask=encoder_attention_mask,
112
+ relative_position=relative_position,
113
+ position_bias=encoder_decoder_position_bias,
114
+ )
115
+ attn_out = hidden_states + attn_out
116
+
117
+ forwarded = self.pre_ff_norm(attn_out)
118
+ forwarded = self.ff(forwarded)
119
+ hidden_states = attn_out + forwarded
120
+
121
+ # encoder_deocder_position_bias is from CrossAttention
122
+ return hidden_states, position_bias, encoder_decoder_position_bias
123
+
124
+
125
+ class T5Attention(nn.Module):
126
+
127
+ def __init__(
128
+ self,
129
+ dim: int,
130
+ config: cfg.AttentionConfig,
131
+ norm_config: cfg.NormalizationConfig,
132
+ kv_cache_max: int,
133
+ enable_hlfb: bool,
134
+ has_relative_attention_bias=False,
135
+ ) -> None:
136
+ """Initialize an instance of T5Attention.
137
+
138
+ Args:
139
+ dim (int): causal attention's input/output dimmension.
140
+ config (cfg.AttentionConfig): attention specific configurations.
141
+ kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
142
+ enable_hlfb (bool): whether hlfb is enabled or not.
143
+ has_relative_attention_bias (bool): whether we compute relative bias.
144
+ """
145
+ super().__init__()
146
+ self.pre_atten_norm = builder.build_norm(dim, norm_config)
147
+
148
+ self.has_relative_attention_bias = has_relative_attention_bias
149
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
150
+ self.d_model = dim
151
+ self.head_dim = dim // config.num_heads
152
+ self.n_heads = config.num_heads
153
+ self.inner_dim = self.n_heads * self.head_dim
154
+
155
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
156
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
157
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
158
+ # output projection
159
+ self.proj = nn.Linear(
160
+ self.inner_dim, self.d_model, bias=config.output_proj_use_bias
161
+ )
162
+
163
+ if self.has_relative_attention_bias:
164
+ self.relative_attention_bias = nn.Embedding(
165
+ self.relative_attention_num_buckets, self.n_heads
166
+ )
167
+
168
+ self.config = config
169
+ self.kv_cache = None
170
+ # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
171
+ # Now only supports a max batch_size of 1.
172
+ if config.enable_kv_cache:
173
+ self.kv_cache = KVCache(
174
+ 1,
175
+ kv_cache_max,
176
+ config.num_query_groups,
177
+ self.head_dim,
178
+ enable_hlfb,
179
+ )
180
+
181
+ if enable_hlfb:
182
+ self.sdpa_func = scaled_dot_product_attention_with_hlfb
183
+ else:
184
+ self.sdpa_func = scaled_dot_product_attention
185
+
186
+ def forward(
187
+ self,
188
+ x: torch.Tensor,
189
+ input_pos: Optional[torch.Tensor] = None,
190
+ key_value_states: Optional[torch.Tensor] = None,
191
+ mask: Optional[torch.Tensor] = None,
192
+ relative_position: Optional[torch.Tensor] = None,
193
+ position_bias: Optional[torch.Tensor] = None,
194
+ ) -> torch.Tensor:
195
+ """Forward function of the T5Attention layer.
196
+
197
+ Args:
198
+ x (torch.Tensor): the input tensor.
199
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
200
+ mask (torch.Tensor): the optional mask tensor.
201
+ input_pos (torch.Tensor): the optional input position tensor.
202
+
203
+ Returns:
204
+ output activation from this self attention layer.
205
+ """
206
+
207
+ x = self.pre_atten_norm(x)
208
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
209
+ query_states = self.q(x)
210
+ query_states = query_states.reshape(B, T, -1, self.head_dim) # (B, T, nh_q, hs)
211
+
212
+ if key_value_states is not None:
213
+ (
214
+ kvB,
215
+ kvT,
216
+ kvC,
217
+ ) = (
218
+ key_value_states.size()
219
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
220
+ key_states = self.k(key_value_states)
221
+ value_states = self.v(key_value_states)
222
+ key_states = key_states.reshape(kvB, kvT, -1, self.head_dim)
223
+ value_states = value_states.reshape(kvB, kvT, -1, self.head_dim)
224
+ else:
225
+ key_states = self.k(x)
226
+ value_states = self.v(x)
227
+ key_states = key_states.reshape(B, T, -1, self.head_dim)
228
+ value_states = value_states.reshape(B, T, -1, self.head_dim)
229
+
230
+ if key_value_states is None and self.kv_cache is not None:
231
+ key_states, value_states = self.kv_cache.update_cache(
232
+ input_pos, key_states, value_states
233
+ )
234
+
235
+ if position_bias is None:
236
+ # handle the encoder case first
237
+ if self.has_relative_attention_bias:
238
+ position_bias = self.relative_attention_bias(
239
+ relative_position
240
+ ) # shape (query_length, key_length, num_heads)
241
+ position_bias = position_bias.permute([0, 1, 4, 2, 3]).squeeze(
242
+ 0
243
+ ) # shape (1, num_heads, query_length, key_length)
244
+ else:
245
+ # position_bias = torch.zeros(B, self.n_heads, T, self.head_dim, dtype=torch.float32)
246
+ position_bias = torch.zeros_like(mask, dtype=torch.float32)
247
+
248
+ mask = mask + position_bias
249
+ y = self.sdpa_func(
250
+ query_states, key_states, value_states, self.head_dim, mask=mask, scale=1.0
251
+ )
252
+ y = y.reshape(B, T, C) # re-assemble all head outputs side by side
253
+ # output projection
254
+ y = self.proj(y)
255
+ return y, position_bias
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,119 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has a single-layer transformer block.
16
+ from typing import Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ import ai_edge_torch
23
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.builder as builder
26
+ import ai_edge_torch.generative.layers.model_config as cfg
27
+
28
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
29
+ KV_CACHE_MAX_LEN = 100
30
+
31
+
32
+ class ToySingleLayerModel(torch.nn.Module):
33
+
34
+ def __init__(self, config: cfg.ModelConfig) -> None:
35
+ super().__init__()
36
+ self.lm_head = nn.Linear(
37
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
38
+ )
39
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
40
+ self.transformer_block = TransformerBlock(config)
41
+ self.final_norm = builder.build_norm(
42
+ config.embedding_dim,
43
+ config.final_norm_config,
44
+ )
45
+ self.rope_cache = attn_utils.build_rope_cache(
46
+ size=config.max_seq_len,
47
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
48
+ base=10_000,
49
+ condense_ratio=1,
50
+ dtype=torch.float32,
51
+ device=torch.device('cpu'),
52
+ )
53
+ self.mask_cache = attn_utils.build_causal_mask_cache(
54
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
55
+ )
56
+ self.config = config
57
+
58
+ @torch.inference_mode
59
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
60
+ x = self.tok_embedding(idx)
61
+ cos, sin = self.rope_cache
62
+
63
+ cos = cos.index_select(0, input_pos)
64
+ sin = sin.index_select(0, input_pos)
65
+ mask = self.mask_cache.index_select(2, input_pos)
66
+ mask = mask[:, :, :, : self.config.max_seq_len]
67
+
68
+ x = self.transformer_block(x, (cos, sin), mask, input_pos)
69
+ x = self.final_norm(x)
70
+ return self.lm_head(x)
71
+
72
+
73
+ def define_and_run() -> None:
74
+ attn_config = cfg.AttentionConfig(
75
+ num_heads=32, num_query_groups=4, rotary_percentage=1.0, enable_kv_cache=False
76
+ )
77
+ ff_config = cfg.FeedForwardConfig(
78
+ type=cfg.FeedForwardType.GATED,
79
+ activation=cfg.ActivationType.SILU,
80
+ intermediate_size=256,
81
+ )
82
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
83
+ config = cfg.ModelConfig(
84
+ vocab_size=400,
85
+ num_layers=1,
86
+ max_seq_len=KV_CACHE_MAX_LEN,
87
+ embedding_dim=128,
88
+ attn_config=attn_config,
89
+ ff_config=ff_config,
90
+ pre_attention_norm_config=norm_config,
91
+ pre_ff_norm_config=norm_config,
92
+ final_norm_config=norm_config,
93
+ )
94
+
95
+ model = ToySingleLayerModel(config)
96
+ idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
97
+ input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
98
+ print('running an inference')
99
+ print(
100
+ model.forward(
101
+ idx,
102
+ input_pos,
103
+ )
104
+ )
105
+
106
+ # Convert model to tflite.
107
+ print('converting model to tflite')
108
+ edge_model = ai_edge_torch.convert(
109
+ model,
110
+ (
111
+ idx,
112
+ input_pos,
113
+ ),
114
+ )
115
+ edge_model.export('/tmp/toy_model.tflite')
116
+
117
+
118
+ if __name__ == '__main__':
119
+ define_and_run()
@@ -0,0 +1,143 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has basic transformer block (w/ KV-Cache).
16
+ from typing import List, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch_xla
22
+
23
+ import ai_edge_torch
24
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.builder as builder
27
+ import ai_edge_torch.generative.layers.model_config as cfg
28
+
29
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
30
+
31
+
32
+ class ToyModelWithKV(torch.nn.Module):
33
+
34
+ def __init__(self, config: cfg.ModelConfig) -> None:
35
+ super().__init__()
36
+ self.lm_head = nn.Linear(
37
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
38
+ )
39
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
40
+ self.transformer_blocks = nn.ModuleList(
41
+ TransformerBlock(config) for _ in range(config.num_layers)
42
+ )
43
+ self.final_norm = builder.build_norm(
44
+ config.embedding_dim,
45
+ config.final_norm_config,
46
+ )
47
+ self.rope_cache = attn_utils.build_rope_cache(
48
+ size=config.max_seq_len,
49
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
50
+ base=10_000,
51
+ condense_ratio=1,
52
+ dtype=torch.float32,
53
+ device=torch.device('cpu'),
54
+ )
55
+ self.mask_cache = attn_utils.build_causal_mask_cache(
56
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
57
+ )
58
+ self.config = config
59
+
60
+ @torch.inference_mode
61
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
62
+ x = self.tok_embedding(idx)
63
+ cos, sin = self.rope_cache
64
+ cos = cos.index_select(0, input_pos)
65
+ sin = sin.index_select(0, input_pos)
66
+ mask = self.mask_cache.index_select(2, input_pos)
67
+ mask = mask[:, :, :, : self.config.max_seq_len]
68
+ for i, block in enumerate(self.transformer_blocks):
69
+ x = block(x, (cos, sin), mask, input_pos)
70
+ x = self.final_norm(x)
71
+ return self.lm_head(x)
72
+
73
+
74
+ def _export_stablehlo_mlir(model, args):
75
+ ep = torch.export.export(model, args)
76
+ stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
77
+ return stablehlo_gm.get_stablehlo_text()
78
+
79
+
80
+ def get_model_config() -> cfg.ModelConfig:
81
+ attn_config = cfg.AttentionConfig(
82
+ num_heads=32, num_query_groups=4, rotary_percentage=1.0
83
+ )
84
+ ff_config = cfg.FeedForwardConfig(
85
+ type=cfg.FeedForwardType.GATED,
86
+ activation=cfg.ActivationType.SILU,
87
+ intermediate_size=256,
88
+ )
89
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
90
+ config = cfg.ModelConfig(
91
+ vocab_size=150,
92
+ num_layers=2,
93
+ max_seq_len=500,
94
+ embedding_dim=128,
95
+ attn_config=attn_config,
96
+ ff_config=ff_config,
97
+ pre_attention_norm_config=norm_config,
98
+ pre_ff_norm_config=norm_config,
99
+ final_norm_config=norm_config,
100
+ enable_hlfb=True,
101
+ )
102
+ return config
103
+
104
+
105
+ def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
106
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
107
+ input_pos = torch.arange(0, 100)
108
+ return idx, input_pos
109
+
110
+
111
+ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
112
+ idx = torch.tensor([[1]], dtype=torch.long)
113
+ input_pos = torch.tensor([10], dtype=torch.int64)
114
+ return idx, input_pos
115
+
116
+
117
+ def define_and_run() -> None:
118
+ dump_mlir = False
119
+
120
+ config = get_model_config()
121
+ model = ToyModelWithKV(config)
122
+ print('running an inference')
123
+ idx, input_pos = get_sample_prefill_inputs()
124
+ decode_idx, decode_input_pos = get_sample_decode_inputs()
125
+ print(model.forward(idx, input_pos))
126
+
127
+ if dump_mlir:
128
+ mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
129
+ with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
130
+ f.write(mlir_text)
131
+
132
+ # Convert model to tflite with 2 signatures (prefill + decode).
133
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
134
+ edge_model = (
135
+ ai_edge_torch.signature('prefill', model, (idx, input_pos))
136
+ .signature('decode', model, (decode_idx, decode_input_pos))
137
+ .convert()
138
+ )
139
+ edge_model.export('/tmp/toy_kv_cache.tflite')
140
+
141
+
142
+ if __name__ == '__main__':
143
+ define_and_run()
@@ -0,0 +1,66 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.quantize import quant_recipes
24
+
25
+
26
+ def convert_tiny_llama_to_tflite(
27
+ checkpoint_path: str,
28
+ prefill_seq_len: int = 512,
29
+ kv_cache_max_len: int = 1024,
30
+ quantize: bool = True,
31
+ ):
32
+ """An example method for converting TinyLlama model to multi-signature
33
+ tflite model.
34
+
35
+ Args:
36
+ checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
38
+ Defaults to 512.
39
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
40
+ including both prefill and decode. Defaults to 1024.
41
+ quantize (bool, optional): Whether the model should be quanized.
42
+ Defaults to True.
43
+ """
44
+ pytorch_model = tiny_llama.build_model(
45
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
46
+ )
47
+ # Tensors used to trace the model graph during conversion.
48
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
49
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
50
+ decode_token = torch.tensor([[0]], dtype=torch.long)
51
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
52
+
53
+ quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
54
+ edge_model = (
55
+ ai_edge_torch.signature(
56
+ 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
57
+ )
58
+ .signature('decode', pytorch_model, (decode_token, decode_input_pos))
59
+ .convert(quant_config=quant_config)
60
+ )
61
+ edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
62
+
63
+
64
+ if __name__ == '__main__':
65
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
66
+ convert_tiny_llama_to_tflite(checkpoint_path)