ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
  3. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  6. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  7. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  9. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  10. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  11. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  12. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  13. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  15. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  16. ai_edge_torch/generative/examples/phi/phi3.py +279 -0
  17. ai_edge_torch/generative/examples/phi/verify.py +13 -13
  18. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  19. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  20. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  21. ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
  22. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
  23. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
  24. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  25. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  26. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
  27. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  29. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  30. ai_edge_torch/generative/layers/model_config.py +2 -0
  31. ai_edge_torch/generative/layers/normalization.py +2 -2
  32. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  33. ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
  34. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  35. ai_edge_torch/generative/utilities/verifier.py +130 -114
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
  39. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -15,43 +15,53 @@
15
15
 
16
16
  """Verifies the reauthored SmolLM-135M model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "What is the meaning of life?",
29
32
  "The input prompts to generate answers.",
30
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
31
39
 
32
40
 
33
41
  def main(_):
34
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
38
- )
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
39
46
  # Locate the cached dir.
40
47
  cached_config_file = transformers.utils.cached_file(
41
48
  checkpoint, transformers.utils.CONFIG_NAME
42
49
  )
43
50
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
44
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
45
52
  reauthored_model = smollm.build_model(reauthored_checkpoint)
46
53
 
47
- verifier.log_msg("Loading the tokenizer from", checkpoint)
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
48
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
49
56
 
50
57
  verifier.verify_reauthored_model(
51
- original_model=wrapper_model,
52
- reauthored_model=reauthored_model,
53
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
54
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
55
65
  atol=1e-04,
56
66
  )
57
67
 
@@ -48,7 +48,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
48
48
 
49
49
 
50
50
  class CLIP(nn.Module):
51
- """CLIP text encoder
51
+ """CLIP text encoder.
52
52
 
53
53
  For details, see https://arxiv.org/abs/2103.00020
54
54
  """
@@ -86,6 +86,7 @@ class CLIP(nn.Module):
86
86
 
87
87
 
88
88
  def get_model_config() -> cfg.ModelConfig:
89
+ """Get configs for the CLIP of Stable Diffusion v1.5."""
89
90
  max_seq_len = 77
90
91
  vocab_size = 49408
91
92
  num_layers = 12
@@ -97,6 +98,58 @@ def get_model_config() -> cfg.ModelConfig:
97
98
  num_heads=num_heads,
98
99
  head_dim=embedding_dim // num_heads,
99
100
  num_query_groups=num_query_groups,
101
+ rotary_base=0,
102
+ rotary_percentage=0.0,
103
+ qkv_use_bias=True,
104
+ qkv_transpose_before_split=True,
105
+ qkv_fused_interleaved=False,
106
+ output_proj_use_bias=True,
107
+ enable_kv_cache=False,
108
+ )
109
+
110
+ ff_config = cfg.FeedForwardConfig(
111
+ type=cfg.FeedForwardType.SEQUENTIAL,
112
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
113
+ intermediate_size=embedding_dim * 4,
114
+ use_bias=True,
115
+ )
116
+
117
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
118
+
119
+ block_config = cfg.TransformerBlockConfig(
120
+ attn_config=attn_config,
121
+ ff_config=ff_config,
122
+ pre_attention_norm_config=norm_config,
123
+ post_attention_norm_config=norm_config,
124
+ )
125
+
126
+ config = cfg.ModelConfig(
127
+ vocab_size=vocab_size,
128
+ num_layers=num_layers,
129
+ max_seq_len=max_seq_len,
130
+ embedding_dim=embedding_dim,
131
+ block_configs=block_config,
132
+ final_norm_config=norm_config,
133
+ enable_hlfb=True,
134
+ )
135
+
136
+ return config
137
+
138
+
139
+ def get_fake_model_config() -> cfg.ModelConfig:
140
+ """Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
141
+ max_seq_len = 6
142
+ vocab_size = 100
143
+ num_layers = 2
144
+ num_heads = 12
145
+ num_query_groups = 12
146
+ embedding_dim = 24
147
+
148
+ attn_config = cfg.AttentionConfig(
149
+ num_heads=num_heads,
150
+ head_dim=embedding_dim // num_heads,
151
+ num_query_groups=num_query_groups,
152
+ rotary_base=0,
100
153
  rotary_percentage=0.0,
101
154
  qkv_use_bias=True,
102
155
  qkv_transpose_before_split=True,
@@ -295,6 +295,64 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
295
295
  enable_kv_cache=False,
296
296
  qkv_transpose_before_split=True,
297
297
  qkv_fused_interleaved=False,
298
+ rotary_base=0,
299
+ rotary_percentage=0.0,
300
+ ),
301
+ enable_hlfb=False,
302
+ )
303
+
304
+ mid_block_config = unet_cfg.MidBlock2DConfig(
305
+ in_channels=block_out_channels[-1],
306
+ normalization_config=norm_config,
307
+ activation_config=layers_cfg.ActivationConfig(
308
+ layers_cfg.ActivationType.SILU
309
+ ),
310
+ num_layers=1,
311
+ attention_block_config=att_config,
312
+ )
313
+
314
+ config = unet_cfg.AutoEncoderConfig(
315
+ in_channels=in_channels,
316
+ latent_channels=latent_channels,
317
+ out_channels=out_channels,
318
+ activation_config=layers_cfg.ActivationConfig(
319
+ layers_cfg.ActivationType.SILU
320
+ ),
321
+ block_out_channels=block_out_channels,
322
+ scaling_factor=scaling_factor,
323
+ layers_per_block=layers_per_block,
324
+ normalization_config=norm_config,
325
+ mid_block_config=mid_block_config,
326
+ )
327
+ return config
328
+
329
+
330
+ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
331
+ """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
332
+ in_channels = 3
333
+ latent_channels = 4
334
+ out_channels = 3
335
+ block_out_channels = [2, 4]
336
+ scaling_factor = 0.18215
337
+ layers_per_block = 2
338
+
339
+ norm_config = layers_cfg.NormalizationConfig(
340
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
341
+ )
342
+
343
+ att_config = unet_cfg.AttentionBlock2DConfig(
344
+ dim=block_out_channels[-1],
345
+ normalization_config=norm_config,
346
+ attention_config=layers_cfg.AttentionConfig(
347
+ num_heads=1,
348
+ head_dim=block_out_channels[-1],
349
+ num_query_groups=1,
350
+ qkv_use_bias=True,
351
+ output_proj_use_bias=True,
352
+ enable_kv_cache=False,
353
+ qkv_transpose_before_split=True,
354
+ qkv_fused_interleaved=False,
355
+ rotary_base=0,
298
356
  rotary_percentage=0.0,
299
357
  ),
300
358
  enable_hlfb=False,
@@ -199,6 +199,7 @@ def build_attention_config(
199
199
  num_heads,
200
200
  dim,
201
201
  num_query_groups,
202
+ rotary_base=0,
202
203
  rotary_percentage=0.0,
203
204
  qkv_transpose_before_split=True,
204
205
  qkv_use_bias=False,
@@ -211,6 +212,7 @@ def build_attention_config(
211
212
  num_heads=num_heads,
212
213
  head_dim=dim // num_heads,
213
214
  num_query_groups=num_query_groups,
215
+ rotary_base=rotary_base,
214
216
  rotary_percentage=rotary_percentage,
215
217
  qkv_transpose_before_split=qkv_transpose_before_split,
216
218
  qkv_use_bias=qkv_use_bias,
@@ -603,7 +605,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
603
605
  # Transformer configs.
604
606
  transformer_num_attention_heads = 8
605
607
  transformer_batch_size = batch_size
606
- transformer_cross_attention_dim = 768 # Embedding fomr CLIP model
608
+ transformer_cross_attention_dim = 768 # Embedding from CLIP model
607
609
  transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
608
610
  layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
609
611
  )
@@ -645,3 +647,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
645
647
  final_norm_config=final_norm_config,
646
648
  final_activation_type=final_activation_type,
647
649
  )
650
+
651
+
652
+ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
653
+ """Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
654
+
655
+ Args:
656
+ batch_size (int): the batch size of input.
657
+
658
+ Retruns:
659
+ The configuration of diffusion model of Stable Diffusion v1.5.
660
+ """
661
+ in_channels = 4
662
+ out_channels = 4
663
+ block_out_channels = [2, 4, 8, 8]
664
+ layers_per_block = 1
665
+ downsample_padding = 1
666
+
667
+ # Residual configs.
668
+ residual_norm_config = layers_cfg.NormalizationConfig(
669
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
670
+ )
671
+ residual_activation_type = layers_cfg.ActivationType.SILU
672
+
673
+ # Transformer configs.
674
+ transformer_num_attention_heads = 1
675
+ transformer_batch_size = batch_size
676
+ transformer_cross_attention_dim = 4 # Embedding from CLIP model
677
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
678
+ layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
679
+ )
680
+ transformer_norm_config = layers_cfg.NormalizationConfig(
681
+ layers_cfg.NormalizationType.LAYER_NORM
682
+ )
683
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
684
+
685
+ # Time embedding configs.
686
+ time_embedding_dim = 2
687
+ time_embedding_blocks_dim = 4
688
+
689
+ # Mid block configs.
690
+ mid_block_layers = 1
691
+
692
+ # Finaly layer configs.
693
+ final_norm_config = layers_cfg.NormalizationConfig(
694
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
695
+ )
696
+ final_activation_type = layers_cfg.ActivationType.SILU
697
+
698
+ return unet_cfg.DiffusionModelConfig(
699
+ in_channels=in_channels,
700
+ out_channels=out_channels,
701
+ block_out_channels=block_out_channels,
702
+ layers_per_block=layers_per_block,
703
+ downsample_padding=downsample_padding,
704
+ residual_norm_config=residual_norm_config,
705
+ residual_activation_type=residual_activation_type,
706
+ transformer_batch_size=transformer_batch_size,
707
+ transformer_num_attention_heads=transformer_num_attention_heads,
708
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
709
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
710
+ transformer_norm_config=transformer_norm_config,
711
+ transformer_ff_activation_type=transformer_ff_activation_type,
712
+ mid_block_layers=mid_block_layers,
713
+ time_embedding_dim=time_embedding_dim,
714
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
715
+ final_norm_config=final_norm_config,
716
+ final_activation_type=final_activation_type,
717
+ )
@@ -335,8 +335,6 @@ class T5Decoder(nn.Module):
335
335
 
336
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
337
337
  size=config.kv_cache_max,
338
- dtype=torch.float32,
339
- device=torch.device("cpu"),
340
338
  )
341
339
 
342
340
  @torch.inference_mode
@@ -0,0 +1,105 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has a single-layer transformer block.
16
+ from absl import app
17
+ import ai_edge_torch
18
+ from ai_edge_torch import lowertools
19
+ from ai_edge_torch.generative.examples.test_models import toy_model
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import torch
23
+
24
+ KV_CACHE_MAX_LEN = 100
25
+
26
+
27
+ def convert_toy_model(_) -> None:
28
+ """Converts a toy model to tflite."""
29
+ model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
30
+ idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
31
+ input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
32
+ print('running an inference')
33
+ print(
34
+ model.forward(
35
+ idx,
36
+ input_pos,
37
+ )
38
+ )
39
+
40
+ # Convert model to tflite.
41
+ print('converting model to tflite')
42
+ edge_model = ai_edge_torch.convert(
43
+ model,
44
+ (
45
+ idx,
46
+ input_pos,
47
+ ),
48
+ )
49
+ edge_model.export('/tmp/toy_model.tflite')
50
+
51
+
52
+ def _export_stablehlo_mlir(model, args):
53
+ ep = torch.export.export(model, args)
54
+ return lowertools.exported_program_to_mlir_text(ep)
55
+
56
+
57
+ def convert_toy_model_with_kv_cache(_) -> None:
58
+ """Converts a toy model with kv cache to tflite."""
59
+ dump_mlir = False
60
+
61
+ config = toy_model_with_kv_cache.get_model_config()
62
+ model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
63
+ model.eval()
64
+ print('running an inference')
65
+ kv = kv_utils.KVCache.from_model_config(config)
66
+
67
+ tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
68
+ decode_token, decode_input_pos = (
69
+ toy_model_with_kv_cache.get_sample_decode_inputs()
70
+ )
71
+ print(model.forward(tokens, input_pos, kv))
72
+
73
+ if dump_mlir:
74
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
75
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
76
+ f.write(mlir_text)
77
+
78
+ # Convert model to tflite with 2 signatures (prefill + decode).
79
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
80
+ edge_model = (
81
+ ai_edge_torch.signature(
82
+ 'prefill',
83
+ model,
84
+ sample_kwargs={
85
+ 'tokens': tokens,
86
+ 'input_pos': input_pos,
87
+ 'kv_cache': kv,
88
+ },
89
+ )
90
+ .signature(
91
+ 'decode',
92
+ model,
93
+ sample_kwargs={
94
+ 'tokens': decode_token,
95
+ 'input_pos': decode_input_pos,
96
+ 'kv_cache': kv,
97
+ },
98
+ )
99
+ .convert()
100
+ )
101
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
102
+
103
+
104
+ if __name__ == '__main__':
105
+ app.run(convert_toy_model)
@@ -15,13 +15,12 @@
15
15
  # A toy example which has a single-layer transformer block.
16
16
  from typing import Tuple
17
17
 
18
- import ai_edge_torch
18
+ from ai_edge_torch.generative.layers import builder
19
19
  from ai_edge_torch.generative.layers.attention import TransformerBlock
20
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
- import ai_edge_torch.generative.layers.builder as builder
22
21
  import ai_edge_torch.generative.layers.model_config as cfg
23
22
  import torch
24
- import torch.nn as nn
23
+ from torch import nn
25
24
 
26
25
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
27
26
  KV_CACHE_MAX_LEN = 100
@@ -45,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
45
44
  self.rope_cache = attn_utils.build_rope_cache(
46
45
  size=config.max_seq_len,
47
46
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
48
- base=10_000,
49
- condense_ratio=1,
50
- dtype=torch.float32,
51
- device=torch.device('cpu'),
47
+ base=attn_config.rotary_base,
52
48
  )
53
49
  self.mask_cache = attn_utils.build_causal_mask_cache(
54
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
50
+ size=config.max_seq_len,
55
51
  )
56
52
  self.config = config
57
53
 
@@ -94,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
94
90
  self.rope_cache = attn_utils.build_rope_cache(
95
91
  size=config.max_seq_len,
96
92
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
97
- base=10_000,
98
- condense_ratio=1,
99
- dtype=torch.float32,
100
- device=torch.device('cpu'),
93
+ base=attn_config.rotary_base,
101
94
  )
102
95
  self.mask_cache = attn_utils.build_causal_mask_cache(
103
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
96
+ size=config.max_seq_len,
104
97
  )
105
98
  self.config = config
106
99
 
@@ -125,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
125
118
  num_heads=32,
126
119
  head_dim=4,
127
120
  num_query_groups=4,
121
+ rotary_base=10000,
128
122
  rotary_percentage=1.0,
129
123
  enable_kv_cache=False,
130
124
  )
@@ -149,31 +143,3 @@ def get_model_config() -> cfg.ModelConfig:
149
143
  final_norm_config=norm_config,
150
144
  )
151
145
  return config
152
-
153
-
154
- def define_and_run() -> None:
155
- model = ToySingleLayerModel(get_model_config())
156
- idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
157
- input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
158
- print('running an inference')
159
- print(
160
- model.forward(
161
- idx,
162
- input_pos,
163
- )
164
- )
165
-
166
- # Convert model to tflite.
167
- print('converting model to tflite')
168
- edge_model = ai_edge_torch.convert(
169
- model,
170
- (
171
- idx,
172
- input_pos,
173
- ),
174
- )
175
- edge_model.export('/tmp/toy_model.tflite')
176
-
177
-
178
- if __name__ == '__main__':
179
- define_and_run()
@@ -17,15 +17,14 @@
17
17
 
18
18
  from typing import Tuple
19
19
 
20
- import ai_edge_torch
21
- from ai_edge_torch import lowertools
20
+ from absl import app
22
21
  from ai_edge_torch.generative.layers import attention
23
22
  from ai_edge_torch.generative.layers import builder
24
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
25
  import ai_edge_torch.generative.layers.model_config as cfg
27
26
  import torch
28
- import torch.nn as nn
27
+ from torch import nn
29
28
 
30
29
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
30
 
@@ -52,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
52
51
  self.rope_cache = attn_utils.build_rope_cache(
53
52
  size=config.max_seq_len,
54
53
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
55
- base=10_000,
56
- condense_ratio=1,
57
- dtype=torch.float32,
58
- device=torch.device('cpu'),
54
+ base=attn_config.rotary_base,
59
55
  )
60
56
  self.mask_cache = attn_utils.build_causal_mask_cache(
61
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
57
+ size=config.max_seq_len,
62
58
  )
63
59
  self.config = config
64
60
 
@@ -87,16 +83,12 @@ class ToyModelWithKVCache(torch.nn.Module):
87
83
  return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
88
84
 
89
85
 
90
- def _export_stablehlo_mlir(model, args):
91
- ep = torch.export.export(model, args)
92
- return lowertools.exported_program_to_mlir_text(ep)
93
-
94
-
95
86
  def get_model_config() -> cfg.ModelConfig:
96
87
  attn_config = cfg.AttentionConfig(
97
88
  num_heads=32,
98
89
  head_dim=4,
99
90
  num_query_groups=4,
91
+ rotary_base=10000,
100
92
  rotary_percentage=1.0,
101
93
  )
102
94
  ff_config = cfg.FeedForwardConfig(
@@ -133,51 +125,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
133
125
  tokens = torch.tensor([[1]], dtype=torch.int)
134
126
  input_pos = torch.tensor([10])
135
127
  return tokens, input_pos
136
-
137
-
138
- def define_and_run() -> None:
139
- dump_mlir = False
140
-
141
- config = get_model_config()
142
- model = ToyModelWithExternalKV(config)
143
- model.eval()
144
- print('running an inference')
145
- kv = kv_utils.KVCache.from_model_config(config)
146
-
147
- tokens, input_pos = get_sample_prefill_inputs()
148
- decode_token, decode_input_pos = get_sample_decode_inputs()
149
- print(model.forward(tokens, input_pos, kv))
150
-
151
- if dump_mlir:
152
- mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153
- with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
154
- f.write(mlir_text)
155
-
156
- # Convert model to tflite with 2 signatures (prefill + decode).
157
- print('converting toy model to tflite with 2 signatures (prefill + decode)')
158
- edge_model = (
159
- ai_edge_torch.signature(
160
- 'prefill',
161
- model,
162
- sample_kwargs={
163
- 'tokens': tokens,
164
- 'input_pos': input_pos,
165
- 'kv_cache': kv,
166
- },
167
- )
168
- .signature(
169
- 'decode',
170
- model,
171
- sample_kwargs={
172
- 'tokens': decode_token,
173
- 'input_pos': decode_input_pos,
174
- 'kv_cache': kv,
175
- },
176
- )
177
- .convert()
178
- )
179
- edge_model.export('/tmp/toy_external_kv_cache.tflite')
180
-
181
-
182
- if __name__ == '__main__':
183
- define_and_run()
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
67
67
  self.rope_cache = attn_utils.build_rope_cache(
68
68
  size=config.kv_cache_max,
69
69
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=10_000,
71
- condense_ratio=1,
72
- dtype=torch.float32,
73
- device=torch.device("cpu"),
70
+ base=attn_config.rotary_base,
74
71
  )
75
72
  self.mask_cache = attn_utils.build_causal_mask_cache(
76
73
  size=config.kv_cache_max,
77
- dtype=torch.float32,
78
- device=torch.device("cpu"),
79
74
  )
80
75
  self.config = config
81
76
 
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
132
127
  num_heads=32,
133
128
  head_dim=64,
134
129
  num_query_groups=4,
130
+ rotary_base=10000,
135
131
  rotary_percentage=1.0,
136
132
  )
137
133
  ff_config = cfg.FeedForwardConfig(
@@ -15,45 +15,55 @@
15
15
 
16
16
  """Verifies the reauthored TinyLlama-1.1B model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "Show me the program to add 2 and 3.",
29
32
  "The input prompts to generate answers.",
30
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
31
39
 
32
40
 
33
41
  def main(_):
34
42
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(
38
- checkpoint, trust_remote_code=True
39
- ),
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint, trust_remote_code=True
40
46
  )
47
+
41
48
  # Locate the cached dir.
42
49
  cached_config_file = transformers.utils.cached_file(
43
50
  checkpoint, transformers.utils.CONFIG_NAME
44
51
  )
45
52
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
46
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
53
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
47
54
  reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
48
55
 
49
- verifier.log_msg("Loading the tokenizer from", checkpoint)
56
+ logging.info("Loading the tokenizer from: %s", checkpoint)
50
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
58
 
52
59
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
67
  atol=1e-04,
58
68
  )
59
69