ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240927__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (23) hide show
  1. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  2. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  3. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  4. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  5. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  6. ai_edge_torch/generative/examples/llama/llama.py +203 -0
  7. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  8. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  9. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  10. ai_edge_torch/generative/examples/phi/phi3.py +15 -21
  11. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  12. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  13. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  14. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  16. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  17. ai_edge_torch/generative/utilities/verifier.py +117 -97
  18. ai_edge_torch/version.py +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/RECORD +23 -16
  21. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ from absl import app
20
20
  from absl import flags
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.gemma import verify_util
23
- from ai_edge_torch.generative.utilities import verifier
24
23
  import kagglehub
25
24
 
26
25
 
@@ -15,7 +15,6 @@
15
15
 
16
16
  """Utility functions to verify the reauthored Gemma model."""
17
17
 
18
- import dataclasses
19
18
  import logging
20
19
  import os
21
20
  from typing import List, Tuple
@@ -27,26 +26,17 @@ from gemma import model as gemma_model
27
26
  import torch
28
27
 
29
28
 
30
- @dataclasses.dataclass
31
- class _Output:
32
- logits: torch.Tensor
33
-
34
-
35
29
  class GemmaWrapper(verifier.ModelWrapper):
36
30
  """Gemma model wrapper for verification.
37
31
 
38
32
  Verifier calls model.forward() with maxium sequence length (1024) expecting
39
- the output has 'logits' field while Gemma gets the input tokens with the
40
- actual length and returns logits in a tuple.
33
+ the output is logits while Gemma gets the input tokens with the actual length
34
+ and returns logits in a tuple.
41
35
 
42
36
  Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
43
37
  inside model.generate().
44
38
  """
45
39
 
46
- def __init__(self, model: torch.nn.Module, max_new_tokens: int):
47
- super().__init__(model)
48
- self.max_new_tokens = max_new_tokens
49
-
50
40
  def _get_actual_input_len(self, tokens: torch.Tensor) -> int:
51
41
  for i in range(tokens.shape[1]):
52
42
  if tokens[0, i] == 0:
@@ -63,7 +53,7 @@ class GemmaWrapper(verifier.ModelWrapper):
63
53
  (cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
64
54
  ]
65
55
 
66
- def forward(self, tokens: torch.Tensor) -> _Output:
56
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
67
57
  """Forwards the model after reducing input tokens to the actual length."""
68
58
  actual_input_len = self._get_actual_input_len(tokens)
69
59
  input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
@@ -79,28 +69,26 @@ class GemmaWrapper(verifier.ModelWrapper):
79
69
  top_ps=torch.tensor([1.0], dtype=torch.float),
80
70
  top_ks=torch.tensor([1], dtype=torch.long),
81
71
  )
82
- return _Output(logits.float())
72
+ return logits
83
73
 
84
- def generate(self, tokens: torch.Tensor) -> torch.Tensor:
74
+ def generate(
75
+ self, tokens: torch.Tensor, max_new_tokens: int
76
+ ) -> torch.IntTensor:
85
77
  """Generates the response after decoding the tokens into a string."""
86
78
  prompts = self.model.tokenizer.decode(tokens[0].tolist())
87
79
  response = self.model.generate(
88
- prompts, device="cpu", output_len=self.max_new_tokens, top_k=1
80
+ prompts, device="cpu", output_len=max_new_tokens, top_k=1
89
81
  )
90
82
  return torch.tensor([self.model.tokenizer.encode(prompts + response)])
91
83
 
92
84
 
93
- class TokenizerWrapper(torch.nn.Module):
85
+ class GemmaTokenizerWrapper(verifier.TokenizerWrapper):
94
86
  """Tokenizer wrapper for verification.
95
87
 
96
88
  Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
97
89
  tokenizer expects tokens in a list.
98
90
  """
99
91
 
100
- def __init__(self, tokenizer: torch.nn.Module):
101
- super().__init__()
102
- self.tokenizer = tokenizer
103
-
104
92
  def encode(self, text: str, **_) -> torch.Tensor:
105
93
  """Adds one more dimension to the output of the tokenizer."""
106
94
  return torch.tensor([self.tokenizer.encode(text)])
@@ -133,10 +121,11 @@ def verify_reauthored_gemma_model(
133
121
  original_model.load_weights(os.path.join(checkpoint, weight_filename))
134
122
 
135
123
  verifier.verify_reauthored_model(
136
- original_model=GemmaWrapper(original_model, max_new_tokens),
137
- reauthored_model=reauthored_model,
138
- tokenizer=TokenizerWrapper(original_model.tokenizer),
124
+ original_model=GemmaWrapper(original_model),
125
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
126
+ tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
139
127
  generate_prompts=generate_prompts,
128
+ max_new_tokens=max_new_tokens,
140
129
  forward_input_ids=forward_input_ids,
141
130
  rtol=rtol,
142
131
  atol=atol,
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Llama 3.2 3B model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = llama.build_3b_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'llama_3b_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Llama 3.2 1B model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = llama.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'llama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,203 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Llama 3.2 models."""
17
+
18
+ import copy
19
+ import math
20
+ from typing import Tuple
21
+
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ import ai_edge_torch.generative.layers.model_config as cfg
24
+ import ai_edge_torch.generative.utilities.loader as loading_utils
25
+ import torch
26
+ from torch import nn
27
+
28
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
29
+ # SmolLM re-uses the embedding as the head projection layer.
30
+ TENSOR_NAMES.lm_head = None
31
+
32
+
33
+ def _build_llama3_rope_cache(
34
+ size: int,
35
+ dim: int,
36
+ base: int,
37
+ condense_ratio: int,
38
+ dtype: torch.dtype,
39
+ device: torch.device,
40
+ factor: float,
41
+ low_freq_factor: float,
42
+ high_freq_factor: float,
43
+ max_seq_len: int,
44
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
45
+ """Precomputes Rotary Positional Embeddings for Llama 3.2 model.
46
+
47
+ It's a modified version of attn_utils.build_rope_cache with additional
48
+ arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
49
+ and Cos values with scaling factors for quick lookup during the inference.
50
+
51
+ Reference:
52
+ https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
53
+
54
+ Args:
55
+ size (int): The size of the built cache.
56
+ dim (int): Each sequence's dimmension.
57
+ base (int, optional): Rope base value.
58
+ condense_ratio (int, optional): The ratio by which sequence indicies are
59
+ condensed.
60
+ dtype (torch.dtype, optional): Output tensor's data type.
61
+ device (torch.device, optional): Output tensor's data type.
62
+ factor (float): Factor to scale theta down for tokens in long range in the
63
+ sequence.
64
+ low_freq_factor (float): Factor to determine if tokens are in long range
65
+ in the sequence.
66
+ high_freq_factor (float): Factor to determine if tokens are in short range
67
+ in the sequence.
68
+ max_seq_len (int): The original token sequence length before extending
69
+ ROPE to support longer sequence.
70
+
71
+ Returns:
72
+ Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
73
+ """
74
+ theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
75
+ low_freq_wavelen = max_seq_len / low_freq_factor
76
+ high_freq_wavelen = max_seq_len / high_freq_factor
77
+ wavelen = 2 * math.pi / theta
78
+ # wavelen < high_freq_wavelen: do nothing
79
+ # wavelen > low_freq_wavelen: divide by factor
80
+ theta = torch.where(wavelen > low_freq_wavelen, theta / factor, theta)
81
+ # otherwise: interpolate between the two, using a smooth factor
82
+ smooth_factor = (max_seq_len / wavelen - low_freq_factor) / (
83
+ high_freq_factor - low_freq_factor
84
+ )
85
+ smoothed_theta = (1 - smooth_factor) * theta / factor + smooth_factor * theta
86
+ is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
87
+ theta = torch.where(is_medium, smoothed_theta, theta)
88
+
89
+ seq_idx = torch.arange(size) / condense_ratio
90
+ idx_theta = torch.outer(seq_idx, theta)
91
+ cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
92
+ sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
93
+ return cos, sin
94
+
95
+
96
+ class Llama(tiny_llama.TinyLlama):
97
+ """A Llama model built from the Edge Generative API layers.
98
+
99
+ Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
100
+ """
101
+
102
+ def __init__(self, config: cfg.ModelConfig):
103
+ super().__init__(config)
104
+ # Llama 3.2 re-uses the embedding as the head projection layer.
105
+ self.lm_head.weight.data = self.tok_embedding.weight.data
106
+ # Llama has only one block config.
107
+ attn_config = self.config.block_config(0).attn_config
108
+ self.rope_cache = _build_llama3_rope_cache(
109
+ size=self.config.kv_cache_max,
110
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
111
+ base=500_000,
112
+ condense_ratio=1,
113
+ dtype=torch.float32,
114
+ device=torch.device("cpu"),
115
+ factor=32.0,
116
+ low_freq_factor=1.0,
117
+ high_freq_factor=4.0,
118
+ max_seq_len=self.config.max_seq_len,
119
+ )
120
+
121
+
122
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
123
+ """Returns the model config for a Llama 3.2-1B model.
124
+
125
+ Args:
126
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
127
+ is 1024.
128
+
129
+ Returns:
130
+ The model config for a SmolLM model.
131
+ """
132
+ attn_config = cfg.AttentionConfig(
133
+ num_heads=32,
134
+ head_dim=64,
135
+ num_query_groups=8,
136
+ rotary_percentage=1.0,
137
+ )
138
+ ff_config = cfg.FeedForwardConfig(
139
+ type=cfg.FeedForwardType.GATED,
140
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
141
+ intermediate_size=8192,
142
+ )
143
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
144
+ block_config = cfg.TransformerBlockConfig(
145
+ attn_config=attn_config,
146
+ ff_config=ff_config,
147
+ pre_attention_norm_config=norm_config,
148
+ post_attention_norm_config=norm_config,
149
+ )
150
+ config = cfg.ModelConfig(
151
+ vocab_size=128256,
152
+ num_layers=16,
153
+ max_seq_len=8192,
154
+ embedding_dim=2048,
155
+ kv_cache_max_len=kv_cache_max_len,
156
+ block_configs=block_config,
157
+ final_norm_config=norm_config,
158
+ enable_hlfb=True,
159
+ )
160
+ return config
161
+
162
+
163
+ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
164
+ """Returns the model config for a Llama 3.2-3B model."""
165
+ config = get_model_config(kv_cache_max_len)
166
+ # Llama 3.2 has only one block config.
167
+ attn_config = config.block_config(0).attn_config
168
+ attn_config.num_heads = 24
169
+ attn_config.head_dim = 128
170
+ config.num_layers = 28
171
+ config.embedding_dim = 3072
172
+ return config
173
+
174
+
175
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
176
+ config = get_model_config(**kwargs)
177
+ config.vocab_size = 128
178
+ config.num_layers = 2
179
+ # SmolLM has only one block config.
180
+ config.block_config(0).ff_config.intermediate_size = 64
181
+ return config
182
+
183
+
184
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
185
+ config = get_model_config(**kwargs)
186
+ model = Llama(config)
187
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
188
+ # Since embedding and lm-head use the same weight, we need to set strict
189
+ # to False.
190
+ loader.load(model, strict=False)
191
+ model.eval()
192
+ return model
193
+
194
+
195
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
196
+ config = get_3b_model_config(**kwargs)
197
+ model = Llama(config)
198
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
199
+ # Since embedding and lm-head use the same weight, we need to set strict
200
+ # to False.
201
+ loader.load(model, strict=False)
202
+ model.eval()
203
+ return model
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Llama 3.2-1B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = llama.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
56
+ # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
57
+ # available.
58
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
59
+
60
+ verifier.verify_reauthored_model(
61
+ original_model=transformers_verifier.TransformersModelWrapper(
62
+ original_model
63
+ ),
64
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
65
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
66
+ generate_prompts=_PROMPTS.value,
67
+ max_new_tokens=_MAX_NEW_TOKENS.value,
68
+ atol=1e-04,
69
+ )
70
+
71
+
72
+ if __name__ == "__main__":
73
+ app.run(main)
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Llama 3.2-3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = llama.build_3b_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
56
+ # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
57
+ # available.
58
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
59
+
60
+ verifier.verify_reauthored_model(
61
+ original_model=transformers_verifier.TransformersModelWrapper(
62
+ original_model
63
+ ),
64
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
65
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
66
+ generate_prompts=_PROMPTS.value,
67
+ max_new_tokens=_MAX_NEW_TOKENS.value,
68
+ atol=1e-04,
69
+ )
70
+
71
+
72
+ if __name__ == "__main__":
73
+ app.run(main)
@@ -20,6 +20,7 @@ import pathlib
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
@@ -29,15 +30,18 @@ _PROMPTS = flags.DEFINE_multi_string(
29
30
  "What is the meaning of life?",
30
31
  "The input prompts to generate answers.",
31
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
32
38
 
33
39
 
34
40
  def main(_):
35
41
  checkpoint = "apple/OpenELM-3B"
36
42
  logging.info("Loading the original model from: %s", checkpoint)
37
- wrapper_model = verifier.ModelWrapper(
38
- model=transformers.AutoModelForCausalLM.from_pretrained(
39
- checkpoint, trust_remote_code=True
40
- ),
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
41
45
  )
42
46
 
43
47
  # Locate the cached dir.
@@ -53,10 +57,13 @@ def main(_):
53
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
54
58
 
55
59
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
67
  )
61
68
 
62
69