ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
  3. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  6. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  7. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  9. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  10. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  11. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  12. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  13. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  15. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  16. ai_edge_torch/generative/examples/phi/phi3.py +279 -0
  17. ai_edge_torch/generative/examples/phi/verify.py +13 -13
  18. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  19. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  20. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  21. ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
  22. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
  23. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
  24. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  25. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  26. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
  27. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  29. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  30. ai_edge_torch/generative/layers/model_config.py +2 -0
  31. ai_edge_torch/generative/layers/normalization.py +2 -2
  32. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  33. ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
  34. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  35. ai_edge_torch/generative/utilities/verifier.py +130 -114
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
  39. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Llama 3.2-3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = llama.build_3b_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
56
+ # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
57
+ # available.
58
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
59
+
60
+ verifier.verify_reauthored_model(
61
+ original_model=transformers_verifier.TransformersModelWrapper(
62
+ original_model
63
+ ),
64
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
65
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
66
+ generate_prompts=_PROMPTS.value,
67
+ max_new_tokens=_MAX_NEW_TOKENS.value,
68
+ atol=1e-04,
69
+ )
70
+
71
+
72
+ if __name__ == "__main__":
73
+ app.run(main)
@@ -68,15 +68,10 @@ class OpenELM(nn.Module):
68
68
  self.rope_cache = attn_utils.build_rope_cache(
69
69
  size=config.kv_cache_max,
70
70
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
- base=10_000,
72
- condense_ratio=1,
73
- dtype=torch.float32,
74
- device=torch.device("cpu"),
71
+ base=attn_config.rotary_base,
75
72
  )
76
73
  self.mask_cache = attn_utils.build_causal_mask_cache(
77
74
  size=config.kv_cache_max,
78
- dtype=torch.float32,
79
- device=torch.device("cpu"),
80
75
  )
81
76
  self.config = config
82
77
 
@@ -154,6 +149,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
154
149
  num_heads=num_heads[idx],
155
150
  head_dim=128,
156
151
  num_query_groups=num_query_groups[idx],
152
+ rotary_base=10000,
157
153
  rotary_percentage=1.0,
158
154
  qkv_transpose_before_split=True,
159
155
  query_norm_config=norm_config,
@@ -15,28 +15,33 @@
15
15
 
16
16
  """Verifies the reauthored OpenELM-3B model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
-
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
27
+
26
28
  _PROMPTS = flags.DEFINE_multi_string(
27
29
  "prompts",
28
30
  "What is the meaning of life?",
29
31
  "The input prompts to generate answers.",
30
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
31
38
 
32
39
 
33
40
  def main(_):
34
41
  checkpoint = "apple/OpenELM-3B"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(
38
- checkpoint, trust_remote_code=True
39
- ),
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
40
45
  )
41
46
 
42
47
  # Locate the cached dir.
@@ -44,18 +49,21 @@ def main(_):
44
49
  checkpoint, transformers.utils.CONFIG_NAME
45
50
  )
46
51
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
47
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
52
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
48
53
  reauthored_model = openelm.build_model(reauthored_checkpoint)
49
54
 
50
55
  tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
51
- verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
56
+ logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
52
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
53
58
 
54
59
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
67
  )
60
68
 
61
69
 
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Phi-3.5 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = phi3.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
65
65
  self.rope_cache = attn_utils.build_rope_cache(
66
66
  size=config.kv_cache_max,
67
67
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
68
- base=10_000,
69
- condense_ratio=1,
70
- dtype=torch.float32,
71
- device=torch.device("cpu"),
68
+ base=attn_config.rotary_base,
72
69
  )
73
70
  self.mask_cache = attn_utils.build_causal_mask_cache(
74
71
  size=config.kv_cache_max,
75
- dtype=torch.float32,
76
- device=torch.device("cpu"),
77
72
  )
78
73
  self.config = config
79
74
 
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
129
124
  num_heads=32,
130
125
  head_dim=80,
131
126
  num_query_groups=32,
127
+ rotary_base=10000,
132
128
  rotary_percentage=0.4,
133
129
  qkv_use_bias=True,
134
130
  output_proj_use_bias=True,
@@ -0,0 +1,279 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
+
18
+ import math
19
+ from typing import Tuple
20
+
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.model_config as cfg
26
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.gate_up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
34
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
35
+ pre_attn_norm="model.layers.{}.input_layernorm",
36
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
37
+ embedding="model.embed_tokens",
38
+ final_norm="model.norm",
39
+ lm_head="lm_head",
40
+ )
41
+
42
+ # max_position_embeddings / original_max_position_embeddings in Phi-3.5 config.
43
+ ROPE_SCALE_FACTOR = 32
44
+
45
+ # ROPE short factor in Phi-3.5 config. According to LOPE paper and its code in
46
+ # https://github.com/microsoft/LongRoPE, these values had been searched with
47
+ # min=1.0, step-0.01 to optimize the errors of sample dataset.
48
+ ROPE_SHORT_FACTOR = [
49
+ 1.0,
50
+ 1.0199999809265137,
51
+ 1.0299999713897705,
52
+ 1.0299999713897705,
53
+ 1.0499999523162842,
54
+ 1.0499999523162842,
55
+ 1.0499999523162842,
56
+ 1.0499999523162842,
57
+ 1.0499999523162842,
58
+ 1.0699999332427979,
59
+ 1.0999999046325684,
60
+ 1.1099998950958252,
61
+ 1.1599998474121094,
62
+ 1.1599998474121094,
63
+ 1.1699998378753662,
64
+ 1.2899998426437378,
65
+ 1.339999794960022,
66
+ 1.679999828338623,
67
+ 1.7899998426437378,
68
+ 1.8199998140335083,
69
+ 1.8499997854232788,
70
+ 1.8799997568130493,
71
+ 1.9099997282028198,
72
+ 1.9399996995925903,
73
+ 1.9899996519088745,
74
+ 2.0199997425079346,
75
+ 2.0199997425079346,
76
+ 2.0199997425079346,
77
+ 2.0199997425079346,
78
+ 2.0199997425079346,
79
+ 2.0199997425079346,
80
+ 2.0299997329711914,
81
+ 2.0299997329711914,
82
+ 2.0299997329711914,
83
+ 2.0299997329711914,
84
+ 2.0299997329711914,
85
+ 2.0299997329711914,
86
+ 2.0299997329711914,
87
+ 2.0299997329711914,
88
+ 2.0299997329711914,
89
+ 2.0799996852874756,
90
+ 2.0899996757507324,
91
+ 2.189999580383301,
92
+ 2.2199995517730713,
93
+ 2.5899994373321533,
94
+ 2.729999542236328,
95
+ 2.749999523162842,
96
+ 2.8399994373321533,
97
+ ]
98
+
99
+
100
+ def _build_rope_cache(
101
+ size: int,
102
+ dim: int,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
110
+ """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
+
112
+ It's a modified version of attn_utils.build_rope_cache with additional
113
+ arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
114
+ Cos values with scaling factors for quick lookup during the inference.
115
+
116
+ Args:
117
+ size (int): The size of the built cache.
118
+ dim (int): Each sequence's dimmension.
119
+ base (int, optional): Rope base value.
120
+ condense_ratio (int, optional): The ratio by which sequence indicies are
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
124
+ theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
127
+
128
+ Returns:
129
+ Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
130
+ """
131
+ theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
132
+ theta = theta / theta_factors
133
+ seq_idx = torch.arange(size) / condense_ratio
134
+ idx_theta = torch.outer(seq_idx, theta)
135
+ cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
136
+ sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
137
+ return cos, sin
138
+
139
+
140
+ class Phi3_5Mini(nn.Module):
141
+ """A Phi-3.5 model built from the Edge Generative API layers."""
142
+
143
+ def __init__(self, config: cfg.ModelConfig):
144
+ super().__init__()
145
+
146
+ # Construct model layers.
147
+ self.lm_head = nn.Linear(
148
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
149
+ )
150
+ self.tok_embedding = nn.Embedding(
151
+ config.vocab_size, config.embedding_dim, padding_idx=0
152
+ )
153
+ # Phi-3.5 has only one block config.
154
+ block_config = config.block_config(0)
155
+ self.transformer_blocks = nn.ModuleList(
156
+ attention.TransformerBlock(block_config, config)
157
+ for _ in range(config.num_layers)
158
+ )
159
+ self.final_norm = builder.build_norm(
160
+ config.embedding_dim,
161
+ config.final_norm_config,
162
+ )
163
+ attn_config = block_config.attn_config
164
+ self.rope_cache = _build_rope_cache(
165
+ size=config.kv_cache_max,
166
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
167
+ base=attn_config.rotary_base,
168
+ condense_ratio=1,
169
+ dtype=torch.float32,
170
+ device=torch.device("cpu"),
171
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
172
+ scale=math.sqrt(
173
+ 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
174
+ ),
175
+ )
176
+ self.mask_cache = attn_utils.build_causal_mask_cache(
177
+ size=config.kv_cache_max,
178
+ )
179
+ self.config = config
180
+
181
+ @torch.inference_mode
182
+ def forward(
183
+ self,
184
+ tokens: torch.Tensor,
185
+ input_pos: torch.Tensor,
186
+ kv_cache: kv_utils.KVCache,
187
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
188
+ _, seq_len = tokens.size()
189
+ assert self.config.max_seq_len >= seq_len, (
190
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
191
+ f" {self.config.max_seq_len}"
192
+ )
193
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
194
+ "The number of transformer blocks and the number of KV cache entries"
195
+ " must be the same."
196
+ )
197
+
198
+ cos, sin = self.rope_cache
199
+ cos = cos.index_select(0, input_pos)
200
+ sin = sin.index_select(0, input_pos)
201
+ mask = self.mask_cache.index_select(2, input_pos)
202
+ mask = mask[:, :, :, : self.config.kv_cache_max]
203
+
204
+ x = self.tok_embedding(tokens)
205
+
206
+ updated_kv_entires = []
207
+ for i, block in enumerate(self.transformer_blocks):
208
+ kv_entry = kv_cache.caches[i] if kv_cache else None
209
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
210
+ if kv_entry:
211
+ updated_kv_entires.append(kv_entry)
212
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
213
+
214
+ x = self.final_norm(x)
215
+ logits = self.lm_head(x) # (b, t, vocab_size)
216
+ return {"logits": logits, "kv_cache": updated_kv_cache}
217
+
218
+
219
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
220
+ """Returns the model config for a Phi-3.5 model.
221
+
222
+ Args:
223
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
224
+ is 1024.
225
+
226
+ Returns:
227
+ The model config for a Phi-2 model.
228
+ """
229
+ attn_config = cfg.AttentionConfig(
230
+ num_heads=32,
231
+ head_dim=96,
232
+ num_query_groups=32,
233
+ rotary_base=10000,
234
+ rotary_percentage=1.0,
235
+ qkv_transpose_before_split=True,
236
+ )
237
+ ff_config = cfg.FeedForwardConfig(
238
+ type=cfg.FeedForwardType.SEQUENTIAL,
239
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
240
+ intermediate_size=8192,
241
+ )
242
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
243
+ block_config = cfg.TransformerBlockConfig(
244
+ attn_config=attn_config,
245
+ ff_config=ff_config,
246
+ pre_attention_norm_config=norm_config,
247
+ post_attention_norm_config=norm_config,
248
+ )
249
+ config = cfg.ModelConfig(
250
+ vocab_size=32064,
251
+ num_layers=32,
252
+ max_seq_len=4096,
253
+ kv_cache_max_len=kv_cache_max_len,
254
+ embedding_dim=3072,
255
+ block_configs=block_config,
256
+ final_norm_config=norm_config,
257
+ enable_hlfb=True,
258
+ )
259
+ return config
260
+
261
+
262
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
263
+ config = get_model_config(kv_cache_max_len)
264
+ config.vocab_size = 128
265
+ config.num_layers = 2
266
+ config.max_seq_len = 2 * kv_cache_max_len
267
+ # Phi-3.5 has only one block config.
268
+ config.block_config(0).ff_config.intermediate_size = 128
269
+ return config
270
+
271
+
272
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
273
+ """Instantiates the model instance and load checkpoint if provided."""
274
+ config = get_model_config(**kwargs)
275
+ model = Phi3_5Mini(config)
276
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
277
+ loader.load(model)
278
+ model.eval()
279
+ return model
@@ -14,20 +14,22 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """Verifies the reauthored Phi-2 model."""
17
+ import logging
17
18
 
18
19
  from absl import app
19
20
  from absl import flags
20
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
21
23
  from ai_edge_torch.generative.utilities import verifier
22
24
  import kagglehub
23
25
  import transformers
24
26
 
27
+
25
28
  _PROMPTS = flags.DEFINE_multi_string(
26
29
  "prompts",
27
30
  "Instruct: Write an email about the weather Output:",
28
31
  "The input prompts to generate answers.",
29
32
  )
30
-
31
33
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
34
  "max_new_tokens",
33
35
  30,
@@ -37,25 +39,23 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
37
39
 
38
40
  def main(_):
39
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
40
- verifier.log_msg("Loading the original model from", checkpoint)
41
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
42
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
43
- wrapper_model = verifier.ModelWrapper(
44
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
45
- hf_generation_config=generation_config,
46
- )
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
47
44
 
48
- verifier.log_msg("Building the reauthored model from", checkpoint)
45
+ logging.info("Building the reauthored model from: %s", checkpoint)
49
46
  reauthored_model = phi2.build_model(checkpoint)
50
47
 
51
- verifier.log_msg("Loading the tokenizer from", checkpoint)
48
+ logging.info("Loading the tokenizer from: %s", checkpoint)
52
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
53
50
 
54
51
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
59
  atol=1e-03,
60
60
  )
61
61
 
@@ -0,0 +1,69 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-3.5 model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "Instruct: Write an email about the weather Output:",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "microsoft/Phi-3.5-mini-instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = phi3.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
+
57
+ verifier.verify_reauthored_model(
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
+ generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
65
+ )
66
+
67
+
68
+ if __name__ == "__main__":
69
+ app.run(main)
@@ -54,6 +54,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
54
54
  num_heads=9,
55
55
  head_dim=64,
56
56
  num_query_groups=3,
57
+ rotary_base=10000,
57
58
  rotary_percentage=1.0,
58
59
  )
59
60
  ff_config = cfg.FeedForwardConfig(