ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (34) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
  3. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
  4. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  5. ai_edge_torch/generative/examples/openelm/verify.py +61 -0
  6. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
  7. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  8. ai_edge_torch/generative/examples/phi/verify.py +53 -0
  9. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
  10. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  11. ai_edge_torch/generative/examples/smollm/verify.py +59 -0
  12. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
  14. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  15. ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
  16. ai_edge_torch/generative/layers/attention.py +8 -4
  17. ai_edge_torch/generative/layers/builder.py +3 -1
  18. ai_edge_torch/generative/layers/model_config.py +3 -0
  19. ai_edge_torch/generative/layers/normalization.py +31 -20
  20. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +11 -4
  22. ai_edge_torch/generative/layers/unet/model_config.py +3 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
  24. ai_edge_torch/generative/utilities/converter.py +82 -0
  25. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +10 -0
  26. ai_edge_torch/generative/utilities/verifier.py +200 -0
  27. ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
  28. ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
  29. ai_edge_torch/version.py +1 -1
  30. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
  31. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +34 -28
  32. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
  33. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.gemma import gemma2
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/gemma2_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_gemma2_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Gemma2 2B model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = gemma2.build_2b_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
- convert_gemma2_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.gemma import gemma
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/gemma_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_gemma_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Gemma 2B model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = gemma.build_2b_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
- convert_gemma_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/openelm_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_openelm_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts OpenELM model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = openelm.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
86
- convert_openelm_to_tflite(path)
66
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
210
206
  loader.load(model, strict=False)
211
207
  model.eval()
212
208
  return model
213
-
214
-
215
- def define_and_run(checkpoint_path: str) -> None:
216
- """Instantiates and runs an OpenELM model."""
217
-
218
- current_dir = pathlib.Path(__file__).parent.resolve()
219
- openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
220
- kv_cache_max_len = 1024
221
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
222
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
223
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
224
- tokens[0, :4] = idx
225
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
226
- kv = kv_utils.KVCache.from_model_config(model.config)
227
- output = model.forward(tokens, input_pos, kv)
228
- assert torch.allclose(
229
- openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
230
- )
231
-
232
-
233
- if __name__ == "__main__":
234
- input_checkpoint_path = os.path.join(
235
- pathlib.Path.home(), "Downloads/llm_data/openelm"
236
- )
237
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored OpenELM-3B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "apple/OpenELM-3B"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
+ checkpoint, trust_remote_code=True
38
+ )
39
+
40
+ # Locate the cached dir.
41
+ cached_config_file = transformers.utils.cached_file(
42
+ checkpoint, transformers.utils.CONFIG_NAME
43
+ )
44
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
45
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
46
+ reauthored_model = openelm.build_model(reauthored_checkpoint)
47
+
48
+ tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
49
+ verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
50
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
51
+
52
+ verifier.verify_reauthored_model(
53
+ original_model=original_model,
54
+ reauthored_model=reauthored_model,
55
+ tokenizer=tokenizer,
56
+ prompts=_PROMPTS.value,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.phi import phi2
23
- from ai_edge_torch.generative.layers import kv_cache
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/phi2_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_phi2_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Phi-2 model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = phi2.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
- convert_phi2_to_tflite(path)
66
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
139
  intermediate_size=10240,
144
140
  use_bias=True,
145
141
  )
146
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
142
+ norm_config = cfg.NormalizationConfig(
143
+ type=cfg.NormalizationType.LAYER_NORM,
144
+ use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
145
+ )
147
146
  block_config = cfg.TransformerBlockConfig(
148
147
  attn_config=attn_config,
149
148
  ff_config=ff_config,
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
182
181
  loader.load(model)
183
182
  model.eval()
184
183
  return model
185
-
186
-
187
- def define_and_run(checkpoint_path: str) -> None:
188
- """Instantiates and runs a Phi-2 model."""
189
-
190
- current_dir = pathlib.Path(__file__).parent.resolve()
191
- phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
- kv_cache_max_len = 1024
193
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
194
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
196
- tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
- kv = kv_utils.KVCache.from_model_config(model.config)
199
- output = model.forward(tokens, input_pos, kv)
200
- print("comparing with goldens..")
201
- assert torch.allclose(
202
- phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
- )
204
-
205
-
206
- if __name__ == "__main__":
207
- input_checkpoint_path = os.path.join(
208
- pathlib.Path.home(), "Downloads/llm_data/phi2"
209
- )
210
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,53 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.phi import phi2
21
+ from ai_edge_torch.generative.utilities import verifier
22
+ import kagglehub
23
+ import transformers
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+
31
+
32
+ def main(_):
33
+ checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
34
+ verifier.log_msg("Loading the original model from", checkpoint)
35
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
36
+
37
+ verifier.log_msg("Building the reauthored model from", checkpoint)
38
+ reauthored_model = phi2.build_model(checkpoint)
39
+
40
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
41
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
42
+
43
+ verifier.verify_reauthored_model(
44
+ original_model=original_model,
45
+ reauthored_model=reauthored_model,
46
+ tokenizer=tokenizer,
47
+ prompts=_PROMPTS.value,
48
+ atol=1e-03,
49
+ )
50
+
51
+
52
+ if __name__ == "__main__":
53
+ app.run(main)