ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240929__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  12. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  13. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  14. ai_edge_torch/generative/examples/phi/phi3.py +17 -24
  15. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  16. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  17. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +81 -0
  19. ai_edge_torch/generative/examples/qwen/qwen.py +141 -0
  20. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  21. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  22. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  23. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  24. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
  25. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
  26. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  27. ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
  28. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
  29. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  30. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  31. ai_edge_torch/generative/layers/model_config.py +2 -0
  32. ai_edge_torch/generative/test/test_model_conversion_large.py +20 -0
  33. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  34. ai_edge_torch/generative/utilities/verifier.py +117 -97
  35. ai_edge_torch/version.py +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/METADATA +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/RECORD +40 -29
  38. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/LICENSE +0 -0
  39. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/WHEEL +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import pathlib
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
@@ -29,15 +30,18 @@ _PROMPTS = flags.DEFINE_multi_string(
29
30
  "What is the meaning of life?",
30
31
  "The input prompts to generate answers.",
31
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
32
38
 
33
39
 
34
40
  def main(_):
35
41
  checkpoint = "apple/OpenELM-3B"
36
42
  logging.info("Loading the original model from: %s", checkpoint)
37
- wrapper_model = verifier.ModelWrapper(
38
- model=transformers.AutoModelForCausalLM.from_pretrained(
39
- checkpoint, trust_remote_code=True
40
- ),
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
41
45
  )
42
46
 
43
47
  # Locate the cached dir.
@@ -53,10 +57,13 @@ def main(_):
53
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
54
58
 
55
59
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
67
  )
61
68
 
62
69
 
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
65
65
  self.rope_cache = attn_utils.build_rope_cache(
66
66
  size=config.kv_cache_max,
67
67
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
68
- base=10_000,
69
- condense_ratio=1,
70
- dtype=torch.float32,
71
- device=torch.device("cpu"),
68
+ base=attn_config.rotary_base,
72
69
  )
73
70
  self.mask_cache = attn_utils.build_causal_mask_cache(
74
71
  size=config.kv_cache_max,
75
- dtype=torch.float32,
76
- device=torch.device("cpu"),
77
72
  )
78
73
  self.config = config
79
74
 
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
129
124
  num_heads=32,
130
125
  head_dim=80,
131
126
  num_query_groups=32,
127
+ rotary_base=10000,
132
128
  rotary_percentage=0.4,
133
129
  qkv_use_bias=True,
134
130
  output_proj_use_bias=True,
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
97
97
  ]
98
98
 
99
99
 
100
- def build_rope_cache(
100
+ def _build_rope_cache(
101
101
  size: int,
102
102
  dim: int,
103
- base: int = 10000,
104
- condense_ratio: int = 1,
105
- dtype: torch.dtype = torch.float32,
106
- device: torch.device = None,
107
- theta_factors: torch.Tensor = None,
108
- scale: float = 1.0,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
109
  ) -> Tuple[torch.Tensor, torch.Tensor]:
110
110
  """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
111
 
@@ -116,26 +116,20 @@ def build_rope_cache(
116
116
  Args:
117
117
  size (int): The size of the built cache.
118
118
  dim (int): Each sequence's dimmension.
119
- base (int, optional): Rope base value. Defaults to 10000.
119
+ base (int, optional): Rope base value.
120
120
  condense_ratio (int, optional): The ratio by which sequence indicies are
121
- condensed. Defaults to 1.
122
- dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
- torch.float32.
124
- device (torch.device, optional): Output tensor's data type. Defaults to
125
- None in which case "cpu" is used.
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
126
124
  theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
- scale the theta values. Defaults to None.
128
- scale (float, optional): A float used to scale the rope values. Defaults
129
- to 1.0.
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
130
127
 
131
128
  Returns:
132
129
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
130
  """
134
- if device is None:
135
- device = torch.device('cpu')
136
131
  theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
- if theta_factors is not None:
138
- theta = theta / theta_factors
132
+ theta = theta / theta_factors
139
133
  seq_idx = torch.arange(size) / condense_ratio
140
134
  idx_theta = torch.outer(seq_idx, theta)
141
135
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
@@ -167,10 +161,10 @@ class Phi3_5Mini(nn.Module):
167
161
  config.final_norm_config,
168
162
  )
169
163
  attn_config = block_config.attn_config
170
- self.rope_cache = build_rope_cache(
164
+ self.rope_cache = _build_rope_cache(
171
165
  size=config.kv_cache_max,
172
166
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
- base=10_000,
167
+ base=attn_config.rotary_base,
174
168
  condense_ratio=1,
175
169
  dtype=torch.float32,
176
170
  device=torch.device("cpu"),
@@ -181,8 +175,6 @@ class Phi3_5Mini(nn.Module):
181
175
  )
182
176
  self.mask_cache = attn_utils.build_causal_mask_cache(
183
177
  size=config.kv_cache_max,
184
- dtype=torch.float32,
185
- device=torch.device("cpu"),
186
178
  )
187
179
  self.config = config
188
180
 
@@ -238,6 +230,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
238
230
  num_heads=32,
239
231
  head_dim=96,
240
232
  num_query_groups=32,
233
+ rotary_base=10000,
241
234
  rotary_percentage=1.0,
242
235
  qkv_transpose_before_split=True,
243
236
  )
@@ -19,6 +19,7 @@ import logging
19
19
  from absl import app
20
20
  from absl import flags
21
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
22
23
  from ai_edge_torch.generative.utilities import verifier
23
24
  import kagglehub
24
25
  import transformers
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
40
  def main(_):
40
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
41
42
  logging.info("Loading the original model from: %s", checkpoint)
42
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
43
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
44
- wrapper_model = verifier.ModelWrapper(
45
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
46
- hf_generation_config=generation_config,
47
- )
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
48
44
 
49
45
  logging.info("Building the reauthored model from: %s", checkpoint)
50
46
  reauthored_model = phi2.build_model(checkpoint)
@@ -53,10 +49,13 @@ def main(_):
53
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
54
50
 
55
51
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
59
  atol=1e-03,
61
60
  )
62
61
 
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
40
41
  def main(_):
41
42
  checkpoint = "microsoft/Phi-3.5-mini-instruct"
42
43
  logging.info("Loading the original model from: %s", checkpoint)
43
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
44
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
45
- wrapper_model = verifier.ModelWrapper(
46
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
47
- hf_generation_config=generation_config,
48
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
49
45
 
50
46
  # Locate the cached dir.
51
47
  cached_config_file = transformers.utils.cached_file(
@@ -59,10 +55,13 @@ def main(_):
59
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
60
56
 
61
57
  verifier.verify_reauthored_model(
62
- original_model=wrapper_model,
63
- reauthored_model=reauthored_model,
64
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
65
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
66
65
  )
67
66
 
68
67
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,81 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Qwen 2.5 models to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen import qwen
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _MODEL_SIZE = flags.DEFINE_enum(
27
+ 'model_size',
28
+ '3b',
29
+ ['0.5b', '1.5b', '3b'],
30
+ 'The size of the model to convert.',
31
+ )
32
+ _CHECKPOINT_PATH = flags.DEFINE_string(
33
+ 'checkpoint_path',
34
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
35
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
36
+ )
37
+ _TFLITE_PATH = flags.DEFINE_string(
38
+ 'tflite_path',
39
+ '/tmp/',
40
+ 'The tflite file path to export.',
41
+ )
42
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
+ 'prefill_seq_len',
44
+ 1024,
45
+ 'The maximum size of prefill input tensor.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+
58
+ _BUILDER = {
59
+ '0.5b': qwen.build_0_5b_model,
60
+ '1.5b': qwen.build_1_5b_model,
61
+ '3b': qwen.build_3b_model,
62
+ }
63
+
64
+
65
+ def main(_):
66
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
+ )
69
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
+ model_size = _MODEL_SIZE.value.replace('.', '_')
71
+ output_filename = f'qwen_{model_size}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
72
+ converter.convert_to_tflite(
73
+ pytorch_model,
74
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
75
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
76
+ quantize=_QUANTIZE.value,
77
+ )
78
+
79
+
80
+ if __name__ == '__main__':
81
+ app.run(main)
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Qwen 2.5 models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
+ import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.utilities.loader as loading_utils
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
+ # Qwen re-uses the embedding as the head projection layer.
27
+ TENSOR_NAMES.lm_head = None
28
+
29
+
30
+ class Qwen(tiny_llama.TinyLlama):
31
+ """A Qwen model built from the Edge Generative API layers.
32
+
33
+ Qwen 2.5 shares the same architecture as TinyLlama.
34
+ """
35
+
36
+ def __init__(self, config: cfg.ModelConfig):
37
+ super().__init__(config)
38
+ # Qwen re-uses the embedding as the head projection layer.
39
+ self.lm_head.weight.data = self.tok_embedding.weight.data
40
+
41
+
42
+ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
43
+ """Returns the model config for a Qwen 2.5 3B model.
44
+
45
+ Args:
46
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
47
+ is 1024.
48
+
49
+ Returns:
50
+ The model config for a SmolLM model.
51
+ """
52
+ attn_config = cfg.AttentionConfig(
53
+ num_heads=16,
54
+ head_dim=128,
55
+ num_query_groups=2,
56
+ rotary_base=1000000,
57
+ rotary_percentage=1.0,
58
+ qkv_use_bias=True,
59
+ )
60
+ ff_config = cfg.FeedForwardConfig(
61
+ type=cfg.FeedForwardType.GATED,
62
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
63
+ intermediate_size=11008,
64
+ )
65
+ norm_config = cfg.NormalizationConfig(
66
+ type=cfg.NormalizationType.RMS_NORM,
67
+ epsilon=1e-06,
68
+ )
69
+ block_config = cfg.TransformerBlockConfig(
70
+ attn_config=attn_config,
71
+ ff_config=ff_config,
72
+ pre_attention_norm_config=norm_config,
73
+ post_attention_norm_config=norm_config,
74
+ )
75
+ config = cfg.ModelConfig(
76
+ vocab_size=151936,
77
+ num_layers=36,
78
+ max_seq_len=32768,
79
+ embedding_dim=2048,
80
+ kv_cache_max_len=kv_cache_max_len,
81
+ block_configs=block_config,
82
+ final_norm_config=norm_config,
83
+ enable_hlfb=True,
84
+ )
85
+ return config
86
+
87
+
88
+ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
89
+ """Returns the model config for a Qwen 2.5 1B model."""
90
+ config = get_3b_model_config(kv_cache_max_len)
91
+ # Qwen has only one block config.
92
+ block_config = config.block_config(0)
93
+ block_config.attn_config.num_heads = 12
94
+ block_config.ff_config.intermediate_size = 8960
95
+ config.num_layers = 28
96
+ config.embedding_dim = 1536
97
+ return config
98
+
99
+
100
+ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
101
+ """Returns the model config for a Qwen 2.5 0.5B model."""
102
+ config = get_3b_model_config(kv_cache_max_len)
103
+ # Qwen has only one block config.
104
+ block_config = config.block_config(0)
105
+ block_config.attn_config.num_heads = 14
106
+ block_config.attn_config.head_dim = 64
107
+ block_config.ff_config.intermediate_size = 4864
108
+ config.num_layers = 24
109
+ config.embedding_dim = 896
110
+ return config
111
+
112
+
113
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
114
+ config = get_3b_model_config(**kwargs)
115
+ config.vocab_size = 128
116
+ config.num_layers = 2
117
+ # Qwen has only one block config.
118
+ config.block_config(0).ff_config.intermediate_size = 64
119
+ return config
120
+
121
+
122
+ def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
+ model = Qwen(config)
124
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
+ # Since embedding and lm-head use the same weight, we need to set strict
126
+ # to False.
127
+ loader.load(model, strict=False)
128
+ model.eval()
129
+ return model
130
+
131
+
132
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
+ return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
134
+
135
+
136
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
+ return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
138
+
139
+
140
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
+ return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
@@ -0,0 +1,88 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Qwen 2.5 0.5B, 1.5B, and 3B models."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen import qwen
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _MODEL_SIZE = flags.DEFINE_enum(
30
+ "model_size",
31
+ "3b",
32
+ ["0.5b", "1.5b", "3b"],
33
+ "The size of the model to verify.",
34
+ )
35
+ _PROMPTS = flags.DEFINE_multi_string(
36
+ "prompts",
37
+ "What is the meaning of life?",
38
+ "The input prompts to generate answers.",
39
+ )
40
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
41
+ "max_new_tokens",
42
+ 30,
43
+ "The maximum size of the generated tokens.",
44
+ )
45
+
46
+ _CHECKPOINT = {
47
+ "0.5b": "Qwen/Qwen2.5-0.5B-Instruct",
48
+ "1.5b": "Qwen/Qwen2.5-1.5B-Instruct",
49
+ "3b": "Qwen/Qwen2.5-3B-Instruct",
50
+ }
51
+
52
+ _BUILDER = {
53
+ "0.5b": qwen.build_0_5b_model,
54
+ "1.5b": qwen.build_1_5b_model,
55
+ "3b": qwen.build_3b_model,
56
+ }
57
+
58
+
59
+ def main(_):
60
+ checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
61
+ logging.info("Loading the original model from: %s", checkpoint)
62
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
63
+
64
+ # Locate the cached dir.
65
+ cached_config_file = transformers.utils.cached_file(
66
+ checkpoint, transformers.utils.CONFIG_NAME
67
+ )
68
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
+ reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
71
+
72
+ logging.info("Loading the tokenizer from: %s", checkpoint)
73
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
74
+
75
+ verifier.verify_reauthored_model(
76
+ original_model=transformers_verifier.TransformersModelWrapper(
77
+ original_model
78
+ ),
79
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
81
+ generate_prompts=_PROMPTS.value,
82
+ max_new_tokens=_MAX_NEW_TOKENS.value,
83
+ atol=1e-04,
84
+ )
85
+
86
+
87
+ if __name__ == "__main__":
88
+ app.run(main)
@@ -54,6 +54,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
54
54
  num_heads=9,
55
55
  head_dim=64,
56
56
  num_query_groups=3,
57
+ rotary_base=10000,
57
58
  rotary_percentage=1.0,
58
59
  )
59
60
  ff_config = cfg.FeedForwardConfig(
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "What is the meaning of life?",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
40
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
41
46
  # Locate the cached dir.
42
47
  cached_config_file = transformers.utils.cached_file(
43
48
  checkpoint, transformers.utils.CONFIG_NAME
@@ -50,10 +55,13 @@ def main(_):
50
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
56
 
52
57
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
65
  atol=1e-04,
58
66
  )
59
67
 
@@ -98,6 +98,7 @@ def get_model_config() -> cfg.ModelConfig:
98
98
  num_heads=num_heads,
99
99
  head_dim=embedding_dim // num_heads,
100
100
  num_query_groups=num_query_groups,
101
+ rotary_base=0,
101
102
  rotary_percentage=0.0,
102
103
  qkv_use_bias=True,
103
104
  qkv_transpose_before_split=True,
@@ -148,6 +149,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
148
149
  num_heads=num_heads,
149
150
  head_dim=embedding_dim // num_heads,
150
151
  num_query_groups=num_query_groups,
152
+ rotary_base=0,
151
153
  rotary_percentage=0.0,
152
154
  qkv_use_bias=True,
153
155
  qkv_transpose_before_split=True,
@@ -295,6 +295,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
295
295
  enable_kv_cache=False,
296
296
  qkv_transpose_before_split=True,
297
297
  qkv_fused_interleaved=False,
298
+ rotary_base=0,
298
299
  rotary_percentage=0.0,
299
300
  ),
300
301
  enable_hlfb=False,
@@ -351,6 +352,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
351
352
  enable_kv_cache=False,
352
353
  qkv_transpose_before_split=True,
353
354
  qkv_fused_interleaved=False,
355
+ rotary_base=0,
354
356
  rotary_percentage=0.0,
355
357
  ),
356
358
  enable_hlfb=False,
@@ -199,6 +199,7 @@ def build_attention_config(
199
199
  num_heads,
200
200
  dim,
201
201
  num_query_groups,
202
+ rotary_base=0,
202
203
  rotary_percentage=0.0,
203
204
  qkv_transpose_before_split=True,
204
205
  qkv_use_bias=False,
@@ -211,6 +212,7 @@ def build_attention_config(
211
212
  num_heads=num_heads,
212
213
  head_dim=dim // num_heads,
213
214
  num_query_groups=num_query_groups,
215
+ rotary_base=rotary_base,
214
216
  rotary_percentage=rotary_percentage,
215
217
  qkv_transpose_before_split=qkv_transpose_before_split,
216
218
  qkv_use_bias=qkv_use_bias,
@@ -335,8 +335,6 @@ class T5Decoder(nn.Module):
335
335
 
336
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
337
337
  size=config.kv_cache_max,
338
- dtype=torch.float32,
339
- device=torch.device("cpu"),
340
338
  )
341
339
 
342
340
  @torch.inference_mode