ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
  3. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  6. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  7. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  9. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  10. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  11. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  12. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  13. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  15. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  16. ai_edge_torch/generative/examples/phi/phi3.py +279 -0
  17. ai_edge_torch/generative/examples/phi/verify.py +13 -13
  18. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  19. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  20. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  21. ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
  22. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
  23. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
  24. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  25. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  26. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
  27. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  29. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  30. ai_edge_torch/generative/layers/model_config.py +2 -0
  31. ai_edge_torch/generative/layers/normalization.py +2 -2
  32. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  33. ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
  34. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  35. ai_edge_torch/generative/utilities/verifier.py +130 -114
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
  39. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -15,43 +15,53 @@
15
15
 
16
16
  """Verifies the reauthored SmolLM-135M model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "What is the meaning of life?",
29
32
  "The input prompts to generate answers.",
30
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
31
39
 
32
40
 
33
41
  def main(_):
34
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
38
- )
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
39
46
  # Locate the cached dir.
40
47
  cached_config_file = transformers.utils.cached_file(
41
48
  checkpoint, transformers.utils.CONFIG_NAME
42
49
  )
43
50
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
44
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
45
52
  reauthored_model = smollm.build_model(reauthored_checkpoint)
46
53
 
47
- verifier.log_msg("Loading the tokenizer from", checkpoint)
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
48
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
49
56
 
50
57
  verifier.verify_reauthored_model(
51
- original_model=wrapper_model,
52
- reauthored_model=reauthored_model,
53
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
54
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
55
65
  atol=1e-04,
56
66
  )
57
67
 
@@ -48,7 +48,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
48
48
 
49
49
 
50
50
  class CLIP(nn.Module):
51
- """CLIP text encoder
51
+ """CLIP text encoder.
52
52
 
53
53
  For details, see https://arxiv.org/abs/2103.00020
54
54
  """
@@ -86,6 +86,7 @@ class CLIP(nn.Module):
86
86
 
87
87
 
88
88
  def get_model_config() -> cfg.ModelConfig:
89
+ """Get configs for the CLIP of Stable Diffusion v1.5."""
89
90
  max_seq_len = 77
90
91
  vocab_size = 49408
91
92
  num_layers = 12
@@ -97,6 +98,58 @@ def get_model_config() -> cfg.ModelConfig:
97
98
  num_heads=num_heads,
98
99
  head_dim=embedding_dim // num_heads,
99
100
  num_query_groups=num_query_groups,
101
+ rotary_base=0,
102
+ rotary_percentage=0.0,
103
+ qkv_use_bias=True,
104
+ qkv_transpose_before_split=True,
105
+ qkv_fused_interleaved=False,
106
+ output_proj_use_bias=True,
107
+ enable_kv_cache=False,
108
+ )
109
+
110
+ ff_config = cfg.FeedForwardConfig(
111
+ type=cfg.FeedForwardType.SEQUENTIAL,
112
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
113
+ intermediate_size=embedding_dim * 4,
114
+ use_bias=True,
115
+ )
116
+
117
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
118
+
119
+ block_config = cfg.TransformerBlockConfig(
120
+ attn_config=attn_config,
121
+ ff_config=ff_config,
122
+ pre_attention_norm_config=norm_config,
123
+ post_attention_norm_config=norm_config,
124
+ )
125
+
126
+ config = cfg.ModelConfig(
127
+ vocab_size=vocab_size,
128
+ num_layers=num_layers,
129
+ max_seq_len=max_seq_len,
130
+ embedding_dim=embedding_dim,
131
+ block_configs=block_config,
132
+ final_norm_config=norm_config,
133
+ enable_hlfb=True,
134
+ )
135
+
136
+ return config
137
+
138
+
139
+ def get_fake_model_config() -> cfg.ModelConfig:
140
+ """Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
141
+ max_seq_len = 6
142
+ vocab_size = 100
143
+ num_layers = 2
144
+ num_heads = 12
145
+ num_query_groups = 12
146
+ embedding_dim = 24
147
+
148
+ attn_config = cfg.AttentionConfig(
149
+ num_heads=num_heads,
150
+ head_dim=embedding_dim // num_heads,
151
+ num_query_groups=num_query_groups,
152
+ rotary_base=0,
100
153
  rotary_percentage=0.0,
101
154
  qkv_use_bias=True,
102
155
  qkv_transpose_before_split=True,
@@ -295,6 +295,64 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
295
295
  enable_kv_cache=False,
296
296
  qkv_transpose_before_split=True,
297
297
  qkv_fused_interleaved=False,
298
+ rotary_base=0,
299
+ rotary_percentage=0.0,
300
+ ),
301
+ enable_hlfb=False,
302
+ )
303
+
304
+ mid_block_config = unet_cfg.MidBlock2DConfig(
305
+ in_channels=block_out_channels[-1],
306
+ normalization_config=norm_config,
307
+ activation_config=layers_cfg.ActivationConfig(
308
+ layers_cfg.ActivationType.SILU
309
+ ),
310
+ num_layers=1,
311
+ attention_block_config=att_config,
312
+ )
313
+
314
+ config = unet_cfg.AutoEncoderConfig(
315
+ in_channels=in_channels,
316
+ latent_channels=latent_channels,
317
+ out_channels=out_channels,
318
+ activation_config=layers_cfg.ActivationConfig(
319
+ layers_cfg.ActivationType.SILU
320
+ ),
321
+ block_out_channels=block_out_channels,
322
+ scaling_factor=scaling_factor,
323
+ layers_per_block=layers_per_block,
324
+ normalization_config=norm_config,
325
+ mid_block_config=mid_block_config,
326
+ )
327
+ return config
328
+
329
+
330
+ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
331
+ """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
332
+ in_channels = 3
333
+ latent_channels = 4
334
+ out_channels = 3
335
+ block_out_channels = [2, 4]
336
+ scaling_factor = 0.18215
337
+ layers_per_block = 2
338
+
339
+ norm_config = layers_cfg.NormalizationConfig(
340
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
341
+ )
342
+
343
+ att_config = unet_cfg.AttentionBlock2DConfig(
344
+ dim=block_out_channels[-1],
345
+ normalization_config=norm_config,
346
+ attention_config=layers_cfg.AttentionConfig(
347
+ num_heads=1,
348
+ head_dim=block_out_channels[-1],
349
+ num_query_groups=1,
350
+ qkv_use_bias=True,
351
+ output_proj_use_bias=True,
352
+ enable_kv_cache=False,
353
+ qkv_transpose_before_split=True,
354
+ qkv_fused_interleaved=False,
355
+ rotary_base=0,
298
356
  rotary_percentage=0.0,
299
357
  ),
300
358
  enable_hlfb=False,
@@ -199,6 +199,7 @@ def build_attention_config(
199
199
  num_heads,
200
200
  dim,
201
201
  num_query_groups,
202
+ rotary_base=0,
202
203
  rotary_percentage=0.0,
203
204
  qkv_transpose_before_split=True,
204
205
  qkv_use_bias=False,
@@ -211,6 +212,7 @@ def build_attention_config(
211
212
  num_heads=num_heads,
212
213
  head_dim=dim // num_heads,
213
214
  num_query_groups=num_query_groups,
215
+ rotary_base=rotary_base,
214
216
  rotary_percentage=rotary_percentage,
215
217
  qkv_transpose_before_split=qkv_transpose_before_split,
216
218
  qkv_use_bias=qkv_use_bias,
@@ -603,7 +605,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
603
605
  # Transformer configs.
604
606
  transformer_num_attention_heads = 8
605
607
  transformer_batch_size = batch_size
606
- transformer_cross_attention_dim = 768 # Embedding fomr CLIP model
608
+ transformer_cross_attention_dim = 768 # Embedding from CLIP model
607
609
  transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
608
610
  layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
609
611
  )
@@ -645,3 +647,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
645
647
  final_norm_config=final_norm_config,
646
648
  final_activation_type=final_activation_type,
647
649
  )
650
+
651
+
652
+ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
653
+ """Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
654
+
655
+ Args:
656
+ batch_size (int): the batch size of input.
657
+
658
+ Retruns:
659
+ The configuration of diffusion model of Stable Diffusion v1.5.
660
+ """
661
+ in_channels = 4
662
+ out_channels = 4
663
+ block_out_channels = [2, 4, 8, 8]
664
+ layers_per_block = 1
665
+ downsample_padding = 1
666
+
667
+ # Residual configs.
668
+ residual_norm_config = layers_cfg.NormalizationConfig(
669
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
670
+ )
671
+ residual_activation_type = layers_cfg.ActivationType.SILU
672
+
673
+ # Transformer configs.
674
+ transformer_num_attention_heads = 1
675
+ transformer_batch_size = batch_size
676
+ transformer_cross_attention_dim = 4 # Embedding from CLIP model
677
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
678
+ layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
679
+ )
680
+ transformer_norm_config = layers_cfg.NormalizationConfig(
681
+ layers_cfg.NormalizationType.LAYER_NORM
682
+ )
683
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
684
+
685
+ # Time embedding configs.
686
+ time_embedding_dim = 2
687
+ time_embedding_blocks_dim = 4
688
+
689
+ # Mid block configs.
690
+ mid_block_layers = 1
691
+
692
+ # Finaly layer configs.
693
+ final_norm_config = layers_cfg.NormalizationConfig(
694
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
695
+ )
696
+ final_activation_type = layers_cfg.ActivationType.SILU
697
+
698
+ return unet_cfg.DiffusionModelConfig(
699
+ in_channels=in_channels,
700
+ out_channels=out_channels,
701
+ block_out_channels=block_out_channels,
702
+ layers_per_block=layers_per_block,
703
+ downsample_padding=downsample_padding,
704
+ residual_norm_config=residual_norm_config,
705
+ residual_activation_type=residual_activation_type,
706
+ transformer_batch_size=transformer_batch_size,
707
+ transformer_num_attention_heads=transformer_num_attention_heads,
708
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
709
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
710
+ transformer_norm_config=transformer_norm_config,
711
+ transformer_ff_activation_type=transformer_ff_activation_type,
712
+ mid_block_layers=mid_block_layers,
713
+ time_embedding_dim=time_embedding_dim,
714
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
715
+ final_norm_config=final_norm_config,
716
+ final_activation_type=final_activation_type,
717
+ )
@@ -335,8 +335,6 @@ class T5Decoder(nn.Module):
335
335
 
336
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
337
337
  size=config.kv_cache_max,
338
- dtype=torch.float32,
339
- device=torch.device("cpu"),
340
338
  )
341
339
 
342
340
  @torch.inference_mode
@@ -0,0 +1,105 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has a single-layer transformer block.
16
+ from absl import app
17
+ import ai_edge_torch
18
+ from ai_edge_torch import lowertools
19
+ from ai_edge_torch.generative.examples.test_models import toy_model
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import torch
23
+
24
+ KV_CACHE_MAX_LEN = 100
25
+
26
+
27
+ def convert_toy_model(_) -> None:
28
+ """Converts a toy model to tflite."""
29
+ model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
30
+ idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
31
+ input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
32
+ print('running an inference')
33
+ print(
34
+ model.forward(
35
+ idx,
36
+ input_pos,
37
+ )
38
+ )
39
+
40
+ # Convert model to tflite.
41
+ print('converting model to tflite')
42
+ edge_model = ai_edge_torch.convert(
43
+ model,
44
+ (
45
+ idx,
46
+ input_pos,
47
+ ),
48
+ )
49
+ edge_model.export('/tmp/toy_model.tflite')
50
+
51
+
52
+ def _export_stablehlo_mlir(model, args):
53
+ ep = torch.export.export(model, args)
54
+ return lowertools.exported_program_to_mlir_text(ep)
55
+
56
+
57
+ def convert_toy_model_with_kv_cache(_) -> None:
58
+ """Converts a toy model with kv cache to tflite."""
59
+ dump_mlir = False
60
+
61
+ config = toy_model_with_kv_cache.get_model_config()
62
+ model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
63
+ model.eval()
64
+ print('running an inference')
65
+ kv = kv_utils.KVCache.from_model_config(config)
66
+
67
+ tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
68
+ decode_token, decode_input_pos = (
69
+ toy_model_with_kv_cache.get_sample_decode_inputs()
70
+ )
71
+ print(model.forward(tokens, input_pos, kv))
72
+
73
+ if dump_mlir:
74
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
75
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
76
+ f.write(mlir_text)
77
+
78
+ # Convert model to tflite with 2 signatures (prefill + decode).
79
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
80
+ edge_model = (
81
+ ai_edge_torch.signature(
82
+ 'prefill',
83
+ model,
84
+ sample_kwargs={
85
+ 'tokens': tokens,
86
+ 'input_pos': input_pos,
87
+ 'kv_cache': kv,
88
+ },
89
+ )
90
+ .signature(
91
+ 'decode',
92
+ model,
93
+ sample_kwargs={
94
+ 'tokens': decode_token,
95
+ 'input_pos': decode_input_pos,
96
+ 'kv_cache': kv,
97
+ },
98
+ )
99
+ .convert()
100
+ )
101
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
102
+
103
+
104
+ if __name__ == '__main__':
105
+ app.run(convert_toy_model)
@@ -15,13 +15,12 @@
15
15
  # A toy example which has a single-layer transformer block.
16
16
  from typing import Tuple
17
17
 
18
- import ai_edge_torch
18
+ from ai_edge_torch.generative.layers import builder
19
19
  from ai_edge_torch.generative.layers.attention import TransformerBlock
20
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
- import ai_edge_torch.generative.layers.builder as builder
22
21
  import ai_edge_torch.generative.layers.model_config as cfg
23
22
  import torch
24
- import torch.nn as nn
23
+ from torch import nn
25
24
 
26
25
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
27
26
  KV_CACHE_MAX_LEN = 100
@@ -45,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
45
44
  self.rope_cache = attn_utils.build_rope_cache(
46
45
  size=config.max_seq_len,
47
46
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
48
- base=10_000,
49
- condense_ratio=1,
50
- dtype=torch.float32,
51
- device=torch.device('cpu'),
47
+ base=attn_config.rotary_base,
52
48
  )
53
49
  self.mask_cache = attn_utils.build_causal_mask_cache(
54
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
50
+ size=config.max_seq_len,
55
51
  )
56
52
  self.config = config
57
53
 
@@ -94,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
94
90
  self.rope_cache = attn_utils.build_rope_cache(
95
91
  size=config.max_seq_len,
96
92
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
97
- base=10_000,
98
- condense_ratio=1,
99
- dtype=torch.float32,
100
- device=torch.device('cpu'),
93
+ base=attn_config.rotary_base,
101
94
  )
102
95
  self.mask_cache = attn_utils.build_causal_mask_cache(
103
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
96
+ size=config.max_seq_len,
104
97
  )
105
98
  self.config = config
106
99
 
@@ -125,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
125
118
  num_heads=32,
126
119
  head_dim=4,
127
120
  num_query_groups=4,
121
+ rotary_base=10000,
128
122
  rotary_percentage=1.0,
129
123
  enable_kv_cache=False,
130
124
  )
@@ -149,31 +143,3 @@ def get_model_config() -> cfg.ModelConfig:
149
143
  final_norm_config=norm_config,
150
144
  )
151
145
  return config
152
-
153
-
154
- def define_and_run() -> None:
155
- model = ToySingleLayerModel(get_model_config())
156
- idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
157
- input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
158
- print('running an inference')
159
- print(
160
- model.forward(
161
- idx,
162
- input_pos,
163
- )
164
- )
165
-
166
- # Convert model to tflite.
167
- print('converting model to tflite')
168
- edge_model = ai_edge_torch.convert(
169
- model,
170
- (
171
- idx,
172
- input_pos,
173
- ),
174
- )
175
- edge_model.export('/tmp/toy_model.tflite')
176
-
177
-
178
- if __name__ == '__main__':
179
- define_and_run()
@@ -17,15 +17,14 @@
17
17
 
18
18
  from typing import Tuple
19
19
 
20
- import ai_edge_torch
21
- from ai_edge_torch import lowertools
20
+ from absl import app
22
21
  from ai_edge_torch.generative.layers import attention
23
22
  from ai_edge_torch.generative.layers import builder
24
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
25
  import ai_edge_torch.generative.layers.model_config as cfg
27
26
  import torch
28
- import torch.nn as nn
27
+ from torch import nn
29
28
 
30
29
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
30
 
@@ -52,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
52
51
  self.rope_cache = attn_utils.build_rope_cache(
53
52
  size=config.max_seq_len,
54
53
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
55
- base=10_000,
56
- condense_ratio=1,
57
- dtype=torch.float32,
58
- device=torch.device('cpu'),
54
+ base=attn_config.rotary_base,
59
55
  )
60
56
  self.mask_cache = attn_utils.build_causal_mask_cache(
61
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
57
+ size=config.max_seq_len,
62
58
  )
63
59
  self.config = config
64
60
 
@@ -87,16 +83,12 @@ class ToyModelWithKVCache(torch.nn.Module):
87
83
  return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
88
84
 
89
85
 
90
- def _export_stablehlo_mlir(model, args):
91
- ep = torch.export.export(model, args)
92
- return lowertools.exported_program_to_mlir_text(ep)
93
-
94
-
95
86
  def get_model_config() -> cfg.ModelConfig:
96
87
  attn_config = cfg.AttentionConfig(
97
88
  num_heads=32,
98
89
  head_dim=4,
99
90
  num_query_groups=4,
91
+ rotary_base=10000,
100
92
  rotary_percentage=1.0,
101
93
  )
102
94
  ff_config = cfg.FeedForwardConfig(
@@ -133,51 +125,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
133
125
  tokens = torch.tensor([[1]], dtype=torch.int)
134
126
  input_pos = torch.tensor([10])
135
127
  return tokens, input_pos
136
-
137
-
138
- def define_and_run() -> None:
139
- dump_mlir = False
140
-
141
- config = get_model_config()
142
- model = ToyModelWithExternalKV(config)
143
- model.eval()
144
- print('running an inference')
145
- kv = kv_utils.KVCache.from_model_config(config)
146
-
147
- tokens, input_pos = get_sample_prefill_inputs()
148
- decode_token, decode_input_pos = get_sample_decode_inputs()
149
- print(model.forward(tokens, input_pos, kv))
150
-
151
- if dump_mlir:
152
- mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153
- with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
154
- f.write(mlir_text)
155
-
156
- # Convert model to tflite with 2 signatures (prefill + decode).
157
- print('converting toy model to tflite with 2 signatures (prefill + decode)')
158
- edge_model = (
159
- ai_edge_torch.signature(
160
- 'prefill',
161
- model,
162
- sample_kwargs={
163
- 'tokens': tokens,
164
- 'input_pos': input_pos,
165
- 'kv_cache': kv,
166
- },
167
- )
168
- .signature(
169
- 'decode',
170
- model,
171
- sample_kwargs={
172
- 'tokens': decode_token,
173
- 'input_pos': decode_input_pos,
174
- 'kv_cache': kv,
175
- },
176
- )
177
- .convert()
178
- )
179
- edge_model.export('/tmp/toy_external_kv_cache.tflite')
180
-
181
-
182
- if __name__ == '__main__':
183
- define_and_run()
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
67
67
  self.rope_cache = attn_utils.build_rope_cache(
68
68
  size=config.kv_cache_max,
69
69
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=10_000,
71
- condense_ratio=1,
72
- dtype=torch.float32,
73
- device=torch.device("cpu"),
70
+ base=attn_config.rotary_base,
74
71
  )
75
72
  self.mask_cache = attn_utils.build_causal_mask_cache(
76
73
  size=config.kv_cache_max,
77
- dtype=torch.float32,
78
- device=torch.device("cpu"),
79
74
  )
80
75
  self.config = config
81
76
 
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
132
127
  num_heads=32,
133
128
  head_dim=64,
134
129
  num_query_groups=4,
130
+ rotary_base=10000,
135
131
  rotary_percentage=1.0,
136
132
  )
137
133
  ff_config = cfg.FeedForwardConfig(
@@ -15,45 +15,55 @@
15
15
 
16
16
  """Verifies the reauthored TinyLlama-1.1B model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "Show me the program to add 2 and 3.",
29
32
  "The input prompts to generate answers.",
30
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
31
39
 
32
40
 
33
41
  def main(_):
34
42
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(
38
- checkpoint, trust_remote_code=True
39
- ),
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint, trust_remote_code=True
40
46
  )
47
+
41
48
  # Locate the cached dir.
42
49
  cached_config_file = transformers.utils.cached_file(
43
50
  checkpoint, transformers.utils.CONFIG_NAME
44
51
  )
45
52
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
46
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
53
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
47
54
  reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
48
55
 
49
- verifier.log_msg("Loading the tokenizer from", checkpoint)
56
+ logging.info("Loading the tokenizer from: %s", checkpoint)
50
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
58
 
52
59
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
67
  atol=1e-04,
58
68
  )
59
69