ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
  3. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  6. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  7. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  9. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  10. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  11. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  12. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  13. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  15. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  16. ai_edge_torch/generative/examples/phi/phi3.py +279 -0
  17. ai_edge_torch/generative/examples/phi/verify.py +13 -13
  18. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  19. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  20. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  21. ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
  22. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
  23. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
  24. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  25. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  26. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
  27. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  29. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  30. ai_edge_torch/generative/layers/model_config.py +2 -0
  31. ai_edge_torch/generative/layers/normalization.py +2 -2
  32. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  33. ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
  34. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  35. ai_edge_torch/generative/utilities/verifier.py +130 -114
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
  39. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Llama 3.2-3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = llama.build_3b_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
56
+ # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
57
+ # available.
58
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
59
+
60
+ verifier.verify_reauthored_model(
61
+ original_model=transformers_verifier.TransformersModelWrapper(
62
+ original_model
63
+ ),
64
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
65
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
66
+ generate_prompts=_PROMPTS.value,
67
+ max_new_tokens=_MAX_NEW_TOKENS.value,
68
+ atol=1e-04,
69
+ )
70
+
71
+
72
+ if __name__ == "__main__":
73
+ app.run(main)
@@ -68,15 +68,10 @@ class OpenELM(nn.Module):
68
68
  self.rope_cache = attn_utils.build_rope_cache(
69
69
  size=config.kv_cache_max,
70
70
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
- base=10_000,
72
- condense_ratio=1,
73
- dtype=torch.float32,
74
- device=torch.device("cpu"),
71
+ base=attn_config.rotary_base,
75
72
  )
76
73
  self.mask_cache = attn_utils.build_causal_mask_cache(
77
74
  size=config.kv_cache_max,
78
- dtype=torch.float32,
79
- device=torch.device("cpu"),
80
75
  )
81
76
  self.config = config
82
77
 
@@ -154,6 +149,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
154
149
  num_heads=num_heads[idx],
155
150
  head_dim=128,
156
151
  num_query_groups=num_query_groups[idx],
152
+ rotary_base=10000,
157
153
  rotary_percentage=1.0,
158
154
  qkv_transpose_before_split=True,
159
155
  query_norm_config=norm_config,
@@ -15,28 +15,33 @@
15
15
 
16
16
  """Verifies the reauthored OpenELM-3B model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
-
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
27
+
26
28
  _PROMPTS = flags.DEFINE_multi_string(
27
29
  "prompts",
28
30
  "What is the meaning of life?",
29
31
  "The input prompts to generate answers.",
30
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
31
38
 
32
39
 
33
40
  def main(_):
34
41
  checkpoint = "apple/OpenELM-3B"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(
38
- checkpoint, trust_remote_code=True
39
- ),
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
40
45
  )
41
46
 
42
47
  # Locate the cached dir.
@@ -44,18 +49,21 @@ def main(_):
44
49
  checkpoint, transformers.utils.CONFIG_NAME
45
50
  )
46
51
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
47
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
52
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
48
53
  reauthored_model = openelm.build_model(reauthored_checkpoint)
49
54
 
50
55
  tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
51
- verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
56
+ logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
52
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
53
58
 
54
59
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
67
  )
60
68
 
61
69
 
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Phi-3.5 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = phi3.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
65
65
  self.rope_cache = attn_utils.build_rope_cache(
66
66
  size=config.kv_cache_max,
67
67
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
68
- base=10_000,
69
- condense_ratio=1,
70
- dtype=torch.float32,
71
- device=torch.device("cpu"),
68
+ base=attn_config.rotary_base,
72
69
  )
73
70
  self.mask_cache = attn_utils.build_causal_mask_cache(
74
71
  size=config.kv_cache_max,
75
- dtype=torch.float32,
76
- device=torch.device("cpu"),
77
72
  )
78
73
  self.config = config
79
74
 
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
129
124
  num_heads=32,
130
125
  head_dim=80,
131
126
  num_query_groups=32,
127
+ rotary_base=10000,
132
128
  rotary_percentage=0.4,
133
129
  qkv_use_bias=True,
134
130
  output_proj_use_bias=True,
@@ -0,0 +1,279 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
+
18
+ import math
19
+ from typing import Tuple
20
+
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.model_config as cfg
26
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.gate_up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
34
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
35
+ pre_attn_norm="model.layers.{}.input_layernorm",
36
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
37
+ embedding="model.embed_tokens",
38
+ final_norm="model.norm",
39
+ lm_head="lm_head",
40
+ )
41
+
42
+ # max_position_embeddings / original_max_position_embeddings in Phi-3.5 config.
43
+ ROPE_SCALE_FACTOR = 32
44
+
45
+ # ROPE short factor in Phi-3.5 config. According to LOPE paper and its code in
46
+ # https://github.com/microsoft/LongRoPE, these values had been searched with
47
+ # min=1.0, step-0.01 to optimize the errors of sample dataset.
48
+ ROPE_SHORT_FACTOR = [
49
+ 1.0,
50
+ 1.0199999809265137,
51
+ 1.0299999713897705,
52
+ 1.0299999713897705,
53
+ 1.0499999523162842,
54
+ 1.0499999523162842,
55
+ 1.0499999523162842,
56
+ 1.0499999523162842,
57
+ 1.0499999523162842,
58
+ 1.0699999332427979,
59
+ 1.0999999046325684,
60
+ 1.1099998950958252,
61
+ 1.1599998474121094,
62
+ 1.1599998474121094,
63
+ 1.1699998378753662,
64
+ 1.2899998426437378,
65
+ 1.339999794960022,
66
+ 1.679999828338623,
67
+ 1.7899998426437378,
68
+ 1.8199998140335083,
69
+ 1.8499997854232788,
70
+ 1.8799997568130493,
71
+ 1.9099997282028198,
72
+ 1.9399996995925903,
73
+ 1.9899996519088745,
74
+ 2.0199997425079346,
75
+ 2.0199997425079346,
76
+ 2.0199997425079346,
77
+ 2.0199997425079346,
78
+ 2.0199997425079346,
79
+ 2.0199997425079346,
80
+ 2.0299997329711914,
81
+ 2.0299997329711914,
82
+ 2.0299997329711914,
83
+ 2.0299997329711914,
84
+ 2.0299997329711914,
85
+ 2.0299997329711914,
86
+ 2.0299997329711914,
87
+ 2.0299997329711914,
88
+ 2.0299997329711914,
89
+ 2.0799996852874756,
90
+ 2.0899996757507324,
91
+ 2.189999580383301,
92
+ 2.2199995517730713,
93
+ 2.5899994373321533,
94
+ 2.729999542236328,
95
+ 2.749999523162842,
96
+ 2.8399994373321533,
97
+ ]
98
+
99
+
100
+ def _build_rope_cache(
101
+ size: int,
102
+ dim: int,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
110
+ """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
+
112
+ It's a modified version of attn_utils.build_rope_cache with additional
113
+ arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
114
+ Cos values with scaling factors for quick lookup during the inference.
115
+
116
+ Args:
117
+ size (int): The size of the built cache.
118
+ dim (int): Each sequence's dimmension.
119
+ base (int, optional): Rope base value.
120
+ condense_ratio (int, optional): The ratio by which sequence indicies are
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
124
+ theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
127
+
128
+ Returns:
129
+ Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
130
+ """
131
+ theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
132
+ theta = theta / theta_factors
133
+ seq_idx = torch.arange(size) / condense_ratio
134
+ idx_theta = torch.outer(seq_idx, theta)
135
+ cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
136
+ sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
137
+ return cos, sin
138
+
139
+
140
+ class Phi3_5Mini(nn.Module):
141
+ """A Phi-3.5 model built from the Edge Generative API layers."""
142
+
143
+ def __init__(self, config: cfg.ModelConfig):
144
+ super().__init__()
145
+
146
+ # Construct model layers.
147
+ self.lm_head = nn.Linear(
148
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
149
+ )
150
+ self.tok_embedding = nn.Embedding(
151
+ config.vocab_size, config.embedding_dim, padding_idx=0
152
+ )
153
+ # Phi-3.5 has only one block config.
154
+ block_config = config.block_config(0)
155
+ self.transformer_blocks = nn.ModuleList(
156
+ attention.TransformerBlock(block_config, config)
157
+ for _ in range(config.num_layers)
158
+ )
159
+ self.final_norm = builder.build_norm(
160
+ config.embedding_dim,
161
+ config.final_norm_config,
162
+ )
163
+ attn_config = block_config.attn_config
164
+ self.rope_cache = _build_rope_cache(
165
+ size=config.kv_cache_max,
166
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
167
+ base=attn_config.rotary_base,
168
+ condense_ratio=1,
169
+ dtype=torch.float32,
170
+ device=torch.device("cpu"),
171
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
172
+ scale=math.sqrt(
173
+ 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
174
+ ),
175
+ )
176
+ self.mask_cache = attn_utils.build_causal_mask_cache(
177
+ size=config.kv_cache_max,
178
+ )
179
+ self.config = config
180
+
181
+ @torch.inference_mode
182
+ def forward(
183
+ self,
184
+ tokens: torch.Tensor,
185
+ input_pos: torch.Tensor,
186
+ kv_cache: kv_utils.KVCache,
187
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
188
+ _, seq_len = tokens.size()
189
+ assert self.config.max_seq_len >= seq_len, (
190
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
191
+ f" {self.config.max_seq_len}"
192
+ )
193
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
194
+ "The number of transformer blocks and the number of KV cache entries"
195
+ " must be the same."
196
+ )
197
+
198
+ cos, sin = self.rope_cache
199
+ cos = cos.index_select(0, input_pos)
200
+ sin = sin.index_select(0, input_pos)
201
+ mask = self.mask_cache.index_select(2, input_pos)
202
+ mask = mask[:, :, :, : self.config.kv_cache_max]
203
+
204
+ x = self.tok_embedding(tokens)
205
+
206
+ updated_kv_entires = []
207
+ for i, block in enumerate(self.transformer_blocks):
208
+ kv_entry = kv_cache.caches[i] if kv_cache else None
209
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
210
+ if kv_entry:
211
+ updated_kv_entires.append(kv_entry)
212
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
213
+
214
+ x = self.final_norm(x)
215
+ logits = self.lm_head(x) # (b, t, vocab_size)
216
+ return {"logits": logits, "kv_cache": updated_kv_cache}
217
+
218
+
219
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
220
+ """Returns the model config for a Phi-3.5 model.
221
+
222
+ Args:
223
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
224
+ is 1024.
225
+
226
+ Returns:
227
+ The model config for a Phi-2 model.
228
+ """
229
+ attn_config = cfg.AttentionConfig(
230
+ num_heads=32,
231
+ head_dim=96,
232
+ num_query_groups=32,
233
+ rotary_base=10000,
234
+ rotary_percentage=1.0,
235
+ qkv_transpose_before_split=True,
236
+ )
237
+ ff_config = cfg.FeedForwardConfig(
238
+ type=cfg.FeedForwardType.SEQUENTIAL,
239
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
240
+ intermediate_size=8192,
241
+ )
242
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
243
+ block_config = cfg.TransformerBlockConfig(
244
+ attn_config=attn_config,
245
+ ff_config=ff_config,
246
+ pre_attention_norm_config=norm_config,
247
+ post_attention_norm_config=norm_config,
248
+ )
249
+ config = cfg.ModelConfig(
250
+ vocab_size=32064,
251
+ num_layers=32,
252
+ max_seq_len=4096,
253
+ kv_cache_max_len=kv_cache_max_len,
254
+ embedding_dim=3072,
255
+ block_configs=block_config,
256
+ final_norm_config=norm_config,
257
+ enable_hlfb=True,
258
+ )
259
+ return config
260
+
261
+
262
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
263
+ config = get_model_config(kv_cache_max_len)
264
+ config.vocab_size = 128
265
+ config.num_layers = 2
266
+ config.max_seq_len = 2 * kv_cache_max_len
267
+ # Phi-3.5 has only one block config.
268
+ config.block_config(0).ff_config.intermediate_size = 128
269
+ return config
270
+
271
+
272
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
273
+ """Instantiates the model instance and load checkpoint if provided."""
274
+ config = get_model_config(**kwargs)
275
+ model = Phi3_5Mini(config)
276
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
277
+ loader.load(model)
278
+ model.eval()
279
+ return model
@@ -14,20 +14,22 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """Verifies the reauthored Phi-2 model."""
17
+ import logging
17
18
 
18
19
  from absl import app
19
20
  from absl import flags
20
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
21
23
  from ai_edge_torch.generative.utilities import verifier
22
24
  import kagglehub
23
25
  import transformers
24
26
 
27
+
25
28
  _PROMPTS = flags.DEFINE_multi_string(
26
29
  "prompts",
27
30
  "Instruct: Write an email about the weather Output:",
28
31
  "The input prompts to generate answers.",
29
32
  )
30
-
31
33
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
34
  "max_new_tokens",
33
35
  30,
@@ -37,25 +39,23 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
37
39
 
38
40
  def main(_):
39
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
40
- verifier.log_msg("Loading the original model from", checkpoint)
41
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
42
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
43
- wrapper_model = verifier.ModelWrapper(
44
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
45
- hf_generation_config=generation_config,
46
- )
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
47
44
 
48
- verifier.log_msg("Building the reauthored model from", checkpoint)
45
+ logging.info("Building the reauthored model from: %s", checkpoint)
49
46
  reauthored_model = phi2.build_model(checkpoint)
50
47
 
51
- verifier.log_msg("Loading the tokenizer from", checkpoint)
48
+ logging.info("Loading the tokenizer from: %s", checkpoint)
52
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
53
50
 
54
51
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
59
  atol=1e-03,
60
60
  )
61
61
 
@@ -0,0 +1,69 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-3.5 model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "Instruct: Write an email about the weather Output:",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "microsoft/Phi-3.5-mini-instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = phi3.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
+
57
+ verifier.verify_reauthored_model(
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
+ generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
65
+ )
66
+
67
+
68
+ if __name__ == "__main__":
69
+ app.run(main)
@@ -54,6 +54,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
54
54
  num_heads=9,
55
55
  head_dim=64,
56
56
  num_query_groups=3,
57
+ rotary_base=10000,
57
58
  rotary_percentage=1.0,
58
59
  )
59
60
  ff_config = cfg.FeedForwardConfig(