ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240929__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (40) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  12. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  13. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  14. ai_edge_torch/generative/examples/phi/phi3.py +17 -24
  15. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  16. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  17. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +81 -0
  19. ai_edge_torch/generative/examples/qwen/qwen.py +141 -0
  20. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  21. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  22. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  23. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  24. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
  25. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
  26. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  27. ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
  28. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
  29. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  30. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  31. ai_edge_torch/generative/layers/model_config.py +2 -0
  32. ai_edge_torch/generative/test/test_model_conversion_large.py +20 -0
  33. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  34. ai_edge_torch/generative/utilities/verifier.py +117 -97
  35. ai_edge_torch/version.py +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/METADATA +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/RECORD +40 -29
  38. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/LICENSE +0 -0
  39. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/WHEEL +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import pathlib
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
@@ -29,15 +30,18 @@ _PROMPTS = flags.DEFINE_multi_string(
29
30
  "What is the meaning of life?",
30
31
  "The input prompts to generate answers.",
31
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
32
38
 
33
39
 
34
40
  def main(_):
35
41
  checkpoint = "apple/OpenELM-3B"
36
42
  logging.info("Loading the original model from: %s", checkpoint)
37
- wrapper_model = verifier.ModelWrapper(
38
- model=transformers.AutoModelForCausalLM.from_pretrained(
39
- checkpoint, trust_remote_code=True
40
- ),
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
41
45
  )
42
46
 
43
47
  # Locate the cached dir.
@@ -53,10 +57,13 @@ def main(_):
53
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
54
58
 
55
59
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
67
  )
61
68
 
62
69
 
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
65
65
  self.rope_cache = attn_utils.build_rope_cache(
66
66
  size=config.kv_cache_max,
67
67
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
68
- base=10_000,
69
- condense_ratio=1,
70
- dtype=torch.float32,
71
- device=torch.device("cpu"),
68
+ base=attn_config.rotary_base,
72
69
  )
73
70
  self.mask_cache = attn_utils.build_causal_mask_cache(
74
71
  size=config.kv_cache_max,
75
- dtype=torch.float32,
76
- device=torch.device("cpu"),
77
72
  )
78
73
  self.config = config
79
74
 
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
129
124
  num_heads=32,
130
125
  head_dim=80,
131
126
  num_query_groups=32,
127
+ rotary_base=10000,
132
128
  rotary_percentage=0.4,
133
129
  qkv_use_bias=True,
134
130
  output_proj_use_bias=True,
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
97
97
  ]
98
98
 
99
99
 
100
- def build_rope_cache(
100
+ def _build_rope_cache(
101
101
  size: int,
102
102
  dim: int,
103
- base: int = 10000,
104
- condense_ratio: int = 1,
105
- dtype: torch.dtype = torch.float32,
106
- device: torch.device = None,
107
- theta_factors: torch.Tensor = None,
108
- scale: float = 1.0,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
109
  ) -> Tuple[torch.Tensor, torch.Tensor]:
110
110
  """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
111
 
@@ -116,26 +116,20 @@ def build_rope_cache(
116
116
  Args:
117
117
  size (int): The size of the built cache.
118
118
  dim (int): Each sequence's dimmension.
119
- base (int, optional): Rope base value. Defaults to 10000.
119
+ base (int, optional): Rope base value.
120
120
  condense_ratio (int, optional): The ratio by which sequence indicies are
121
- condensed. Defaults to 1.
122
- dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
- torch.float32.
124
- device (torch.device, optional): Output tensor's data type. Defaults to
125
- None in which case "cpu" is used.
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
126
124
  theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
- scale the theta values. Defaults to None.
128
- scale (float, optional): A float used to scale the rope values. Defaults
129
- to 1.0.
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
130
127
 
131
128
  Returns:
132
129
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
130
  """
134
- if device is None:
135
- device = torch.device('cpu')
136
131
  theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
- if theta_factors is not None:
138
- theta = theta / theta_factors
132
+ theta = theta / theta_factors
139
133
  seq_idx = torch.arange(size) / condense_ratio
140
134
  idx_theta = torch.outer(seq_idx, theta)
141
135
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
@@ -167,10 +161,10 @@ class Phi3_5Mini(nn.Module):
167
161
  config.final_norm_config,
168
162
  )
169
163
  attn_config = block_config.attn_config
170
- self.rope_cache = build_rope_cache(
164
+ self.rope_cache = _build_rope_cache(
171
165
  size=config.kv_cache_max,
172
166
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
- base=10_000,
167
+ base=attn_config.rotary_base,
174
168
  condense_ratio=1,
175
169
  dtype=torch.float32,
176
170
  device=torch.device("cpu"),
@@ -181,8 +175,6 @@ class Phi3_5Mini(nn.Module):
181
175
  )
182
176
  self.mask_cache = attn_utils.build_causal_mask_cache(
183
177
  size=config.kv_cache_max,
184
- dtype=torch.float32,
185
- device=torch.device("cpu"),
186
178
  )
187
179
  self.config = config
188
180
 
@@ -238,6 +230,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
238
230
  num_heads=32,
239
231
  head_dim=96,
240
232
  num_query_groups=32,
233
+ rotary_base=10000,
241
234
  rotary_percentage=1.0,
242
235
  qkv_transpose_before_split=True,
243
236
  )
@@ -19,6 +19,7 @@ import logging
19
19
  from absl import app
20
20
  from absl import flags
21
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
22
23
  from ai_edge_torch.generative.utilities import verifier
23
24
  import kagglehub
24
25
  import transformers
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
40
  def main(_):
40
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
41
42
  logging.info("Loading the original model from: %s", checkpoint)
42
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
43
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
44
- wrapper_model = verifier.ModelWrapper(
45
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
46
- hf_generation_config=generation_config,
47
- )
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
48
44
 
49
45
  logging.info("Building the reauthored model from: %s", checkpoint)
50
46
  reauthored_model = phi2.build_model(checkpoint)
@@ -53,10 +49,13 @@ def main(_):
53
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
54
50
 
55
51
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
59
  atol=1e-03,
61
60
  )
62
61
 
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
40
41
  def main(_):
41
42
  checkpoint = "microsoft/Phi-3.5-mini-instruct"
42
43
  logging.info("Loading the original model from: %s", checkpoint)
43
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
44
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
45
- wrapper_model = verifier.ModelWrapper(
46
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
47
- hf_generation_config=generation_config,
48
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
49
45
 
50
46
  # Locate the cached dir.
51
47
  cached_config_file = transformers.utils.cached_file(
@@ -59,10 +55,13 @@ def main(_):
59
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
60
56
 
61
57
  verifier.verify_reauthored_model(
62
- original_model=wrapper_model,
63
- reauthored_model=reauthored_model,
64
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
65
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
66
65
  )
67
66
 
68
67
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,81 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Qwen 2.5 models to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen import qwen
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _MODEL_SIZE = flags.DEFINE_enum(
27
+ 'model_size',
28
+ '3b',
29
+ ['0.5b', '1.5b', '3b'],
30
+ 'The size of the model to convert.',
31
+ )
32
+ _CHECKPOINT_PATH = flags.DEFINE_string(
33
+ 'checkpoint_path',
34
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
35
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
36
+ )
37
+ _TFLITE_PATH = flags.DEFINE_string(
38
+ 'tflite_path',
39
+ '/tmp/',
40
+ 'The tflite file path to export.',
41
+ )
42
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
+ 'prefill_seq_len',
44
+ 1024,
45
+ 'The maximum size of prefill input tensor.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+
58
+ _BUILDER = {
59
+ '0.5b': qwen.build_0_5b_model,
60
+ '1.5b': qwen.build_1_5b_model,
61
+ '3b': qwen.build_3b_model,
62
+ }
63
+
64
+
65
+ def main(_):
66
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
+ )
69
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
+ model_size = _MODEL_SIZE.value.replace('.', '_')
71
+ output_filename = f'qwen_{model_size}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
72
+ converter.convert_to_tflite(
73
+ pytorch_model,
74
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
75
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
76
+ quantize=_QUANTIZE.value,
77
+ )
78
+
79
+
80
+ if __name__ == '__main__':
81
+ app.run(main)
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Qwen 2.5 models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
+ import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.utilities.loader as loading_utils
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
+ # Qwen re-uses the embedding as the head projection layer.
27
+ TENSOR_NAMES.lm_head = None
28
+
29
+
30
+ class Qwen(tiny_llama.TinyLlama):
31
+ """A Qwen model built from the Edge Generative API layers.
32
+
33
+ Qwen 2.5 shares the same architecture as TinyLlama.
34
+ """
35
+
36
+ def __init__(self, config: cfg.ModelConfig):
37
+ super().__init__(config)
38
+ # Qwen re-uses the embedding as the head projection layer.
39
+ self.lm_head.weight.data = self.tok_embedding.weight.data
40
+
41
+
42
+ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
43
+ """Returns the model config for a Qwen 2.5 3B model.
44
+
45
+ Args:
46
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
47
+ is 1024.
48
+
49
+ Returns:
50
+ The model config for a SmolLM model.
51
+ """
52
+ attn_config = cfg.AttentionConfig(
53
+ num_heads=16,
54
+ head_dim=128,
55
+ num_query_groups=2,
56
+ rotary_base=1000000,
57
+ rotary_percentage=1.0,
58
+ qkv_use_bias=True,
59
+ )
60
+ ff_config = cfg.FeedForwardConfig(
61
+ type=cfg.FeedForwardType.GATED,
62
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
63
+ intermediate_size=11008,
64
+ )
65
+ norm_config = cfg.NormalizationConfig(
66
+ type=cfg.NormalizationType.RMS_NORM,
67
+ epsilon=1e-06,
68
+ )
69
+ block_config = cfg.TransformerBlockConfig(
70
+ attn_config=attn_config,
71
+ ff_config=ff_config,
72
+ pre_attention_norm_config=norm_config,
73
+ post_attention_norm_config=norm_config,
74
+ )
75
+ config = cfg.ModelConfig(
76
+ vocab_size=151936,
77
+ num_layers=36,
78
+ max_seq_len=32768,
79
+ embedding_dim=2048,
80
+ kv_cache_max_len=kv_cache_max_len,
81
+ block_configs=block_config,
82
+ final_norm_config=norm_config,
83
+ enable_hlfb=True,
84
+ )
85
+ return config
86
+
87
+
88
+ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
89
+ """Returns the model config for a Qwen 2.5 1B model."""
90
+ config = get_3b_model_config(kv_cache_max_len)
91
+ # Qwen has only one block config.
92
+ block_config = config.block_config(0)
93
+ block_config.attn_config.num_heads = 12
94
+ block_config.ff_config.intermediate_size = 8960
95
+ config.num_layers = 28
96
+ config.embedding_dim = 1536
97
+ return config
98
+
99
+
100
+ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
101
+ """Returns the model config for a Qwen 2.5 0.5B model."""
102
+ config = get_3b_model_config(kv_cache_max_len)
103
+ # Qwen has only one block config.
104
+ block_config = config.block_config(0)
105
+ block_config.attn_config.num_heads = 14
106
+ block_config.attn_config.head_dim = 64
107
+ block_config.ff_config.intermediate_size = 4864
108
+ config.num_layers = 24
109
+ config.embedding_dim = 896
110
+ return config
111
+
112
+
113
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
114
+ config = get_3b_model_config(**kwargs)
115
+ config.vocab_size = 128
116
+ config.num_layers = 2
117
+ # Qwen has only one block config.
118
+ config.block_config(0).ff_config.intermediate_size = 64
119
+ return config
120
+
121
+
122
+ def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
+ model = Qwen(config)
124
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
+ # Since embedding and lm-head use the same weight, we need to set strict
126
+ # to False.
127
+ loader.load(model, strict=False)
128
+ model.eval()
129
+ return model
130
+
131
+
132
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
+ return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
134
+
135
+
136
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
+ return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
138
+
139
+
140
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
+ return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
@@ -0,0 +1,88 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Qwen 2.5 0.5B, 1.5B, and 3B models."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen import qwen
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _MODEL_SIZE = flags.DEFINE_enum(
30
+ "model_size",
31
+ "3b",
32
+ ["0.5b", "1.5b", "3b"],
33
+ "The size of the model to verify.",
34
+ )
35
+ _PROMPTS = flags.DEFINE_multi_string(
36
+ "prompts",
37
+ "What is the meaning of life?",
38
+ "The input prompts to generate answers.",
39
+ )
40
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
41
+ "max_new_tokens",
42
+ 30,
43
+ "The maximum size of the generated tokens.",
44
+ )
45
+
46
+ _CHECKPOINT = {
47
+ "0.5b": "Qwen/Qwen2.5-0.5B-Instruct",
48
+ "1.5b": "Qwen/Qwen2.5-1.5B-Instruct",
49
+ "3b": "Qwen/Qwen2.5-3B-Instruct",
50
+ }
51
+
52
+ _BUILDER = {
53
+ "0.5b": qwen.build_0_5b_model,
54
+ "1.5b": qwen.build_1_5b_model,
55
+ "3b": qwen.build_3b_model,
56
+ }
57
+
58
+
59
+ def main(_):
60
+ checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
61
+ logging.info("Loading the original model from: %s", checkpoint)
62
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
63
+
64
+ # Locate the cached dir.
65
+ cached_config_file = transformers.utils.cached_file(
66
+ checkpoint, transformers.utils.CONFIG_NAME
67
+ )
68
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
+ reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
71
+
72
+ logging.info("Loading the tokenizer from: %s", checkpoint)
73
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
74
+
75
+ verifier.verify_reauthored_model(
76
+ original_model=transformers_verifier.TransformersModelWrapper(
77
+ original_model
78
+ ),
79
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
81
+ generate_prompts=_PROMPTS.value,
82
+ max_new_tokens=_MAX_NEW_TOKENS.value,
83
+ atol=1e-04,
84
+ )
85
+
86
+
87
+ if __name__ == "__main__":
88
+ app.run(main)
@@ -54,6 +54,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
54
54
  num_heads=9,
55
55
  head_dim=64,
56
56
  num_query_groups=3,
57
+ rotary_base=10000,
57
58
  rotary_percentage=1.0,
58
59
  )
59
60
  ff_config = cfg.FeedForwardConfig(
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "What is the meaning of life?",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
40
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
41
46
  # Locate the cached dir.
42
47
  cached_config_file = transformers.utils.cached_file(
43
48
  checkpoint, transformers.utils.CONFIG_NAME
@@ -50,10 +55,13 @@ def main(_):
50
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
56
 
52
57
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
65
  atol=1e-04,
58
66
  )
59
67
 
@@ -98,6 +98,7 @@ def get_model_config() -> cfg.ModelConfig:
98
98
  num_heads=num_heads,
99
99
  head_dim=embedding_dim // num_heads,
100
100
  num_query_groups=num_query_groups,
101
+ rotary_base=0,
101
102
  rotary_percentage=0.0,
102
103
  qkv_use_bias=True,
103
104
  qkv_transpose_before_split=True,
@@ -148,6 +149,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
148
149
  num_heads=num_heads,
149
150
  head_dim=embedding_dim // num_heads,
150
151
  num_query_groups=num_query_groups,
152
+ rotary_base=0,
151
153
  rotary_percentage=0.0,
152
154
  qkv_use_bias=True,
153
155
  qkv_transpose_before_split=True,
@@ -295,6 +295,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
295
295
  enable_kv_cache=False,
296
296
  qkv_transpose_before_split=True,
297
297
  qkv_fused_interleaved=False,
298
+ rotary_base=0,
298
299
  rotary_percentage=0.0,
299
300
  ),
300
301
  enable_hlfb=False,
@@ -351,6 +352,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
351
352
  enable_kv_cache=False,
352
353
  qkv_transpose_before_split=True,
353
354
  qkv_fused_interleaved=False,
355
+ rotary_base=0,
354
356
  rotary_percentage=0.0,
355
357
  ),
356
358
  enable_hlfb=False,
@@ -199,6 +199,7 @@ def build_attention_config(
199
199
  num_heads,
200
200
  dim,
201
201
  num_query_groups,
202
+ rotary_base=0,
202
203
  rotary_percentage=0.0,
203
204
  qkv_transpose_before_split=True,
204
205
  qkv_use_bias=False,
@@ -211,6 +212,7 @@ def build_attention_config(
211
212
  num_heads=num_heads,
212
213
  head_dim=dim // num_heads,
213
214
  num_query_groups=num_query_groups,
215
+ rotary_base=rotary_base,
214
216
  rotary_percentage=rotary_percentage,
215
217
  qkv_transpose_before_split=qkv_transpose_before_split,
216
218
  qkv_use_bias=qkv_use_bias,
@@ -335,8 +335,6 @@ class T5Decoder(nn.Module):
335
335
 
336
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
337
337
  size=config.kv_cache_max,
338
- dtype=torch.float32,
339
- device=torch.device("cpu"),
340
338
  )
341
339
 
342
340
  @torch.inference_mode